use std::sync::Arc;
use std::time::{Duration, Instant};
use beamr::process::registry::ProcessHandle;
use crate::channel::ChannelMode;
use crate::envelope::Envelope;
use crate::error::LiminalError;
use crate::tracing::{ConversationSpan, FinishedSpan, TraceContext};
#[derive(Debug)]
pub struct Conversation {
span: ConversationSpan,
}
impl Conversation {
#[must_use]
pub fn start(conversation_id: impl Into<String>) -> Self {
Self {
span: ConversationSpan::root(conversation_id),
}
}
#[must_use]
pub fn spawn_child(&self, conversation_id: impl Into<String>) -> Self {
Self {
span: self.span.child(conversation_id),
}
}
#[must_use]
pub const fn message<Payload>(&self, payload: Payload) -> ConversationMessage<Payload> {
ConversationMessage::new(payload, self.span.message_context())
}
#[must_use]
pub fn name(&self) -> &str {
self.span.name()
}
#[must_use]
pub const fn trace_context(&self) -> TraceContext {
self.span.context()
}
#[must_use]
pub const fn parent_trace_context(&self) -> Option<TraceContext> {
self.span.parent()
}
#[must_use]
pub fn finish(self) -> FinishedSpan {
self.span.finish()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ConversationMessage<Payload> {
payload: Payload,
trace_context: TraceContext,
}
impl<Payload> ConversationMessage<Payload> {
const fn new(payload: Payload, trace_context: TraceContext) -> Self {
Self {
payload,
trace_context,
}
}
#[must_use]
pub const fn trace_context(&self) -> TraceContext {
self.trace_context
}
#[must_use]
pub const fn payload(&self) -> &Payload {
&self.payload
}
#[must_use]
pub fn into_payload(self) -> Payload {
self.payload
}
#[must_use]
pub fn map<NextPayload>(
self,
map_payload: impl FnOnce(Payload) -> NextPayload,
) -> ConversationMessage<NextPayload> {
ConversationMessage {
payload: map_payload(self.payload),
trace_context: self.trace_context,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ParticipantPid(u64);
impl ParticipantPid {
#[must_use]
pub const fn new(pid: u64) -> Self {
Self(pid)
}
#[must_use]
pub const fn get(self) -> u64 {
self.0
}
}
impl From<u64> for ParticipantPid {
fn from(pid: u64) -> Self {
Self::new(pid)
}
}
impl From<ProcessHandle> for ParticipantPid {
fn from(handle: ProcessHandle) -> Self {
Self::new(handle.pid())
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CrashPolicy {
Fail,
RouteToNext,
Compensate,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ConversationConfig {
pub participants: Vec<ParticipantPid>,
pub timeout: Option<Duration>,
pub mode: ChannelMode,
pub on_crash: CrashPolicy,
}
impl ConversationConfig {
#[must_use]
pub const fn new(
participants: Vec<ParticipantPid>,
timeout: Option<Duration>,
mode: ChannelMode,
on_crash: CrashPolicy,
) -> Self {
Self {
participants,
timeout,
mode,
on_crash,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ConversationPhase {
Created,
Active,
Completing,
Closed,
Failed,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ParticipantHealth {
Alive,
Dead,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ParticipantStatus {
pub participant: ParticipantPid,
pub health: ParticipantHealth,
pub exited_at: Option<Instant>,
}
impl ParticipantStatus {
#[must_use]
pub const fn alive(participant: ParticipantPid) -> Self {
Self {
participant,
health: ParticipantHealth::Alive,
exited_at: None,
}
}
pub const fn mark_dead_at(&mut self, at: Instant) {
self.health = ParticipantHealth::Dead;
self.exited_at = Some(at);
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ConversationContextEntry {
Sent(Envelope),
Received(Envelope),
ParticipantCrashed {
participant: ParticipantPid,
policy: CrashPolicy,
},
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ConversationState {
pub current_phase: ConversationPhase,
pub context: Vec<ConversationContextEntry>,
pub deadline: Option<Instant>,
pub participants: Vec<ParticipantStatus>,
pub mode: ChannelMode,
}
impl ConversationState {
#[must_use]
pub fn from_config(config: &ConversationConfig, now: Instant) -> Self {
let deadline = config.timeout.map(|timeout| now + timeout);
let participants = config
.participants
.iter()
.copied()
.map(ParticipantStatus::alive)
.collect();
Self {
current_phase: ConversationPhase::Created,
context: Vec::new(),
deadline,
participants,
mode: config.mode,
}
}
pub fn activate(&mut self) -> Result<(), LiminalError> {
match self.current_phase {
ConversationPhase::Created => {
self.current_phase = ConversationPhase::Active;
Ok(())
}
ConversationPhase::Active => Ok(()),
phase => Err(invalid_transition(phase, ConversationPhase::Active)),
}
}
pub fn begin_completing(&mut self) -> Result<(), LiminalError> {
match self.current_phase {
ConversationPhase::Active => {
self.current_phase = ConversationPhase::Completing;
Ok(())
}
ConversationPhase::Completing => Ok(()),
phase => Err(invalid_transition(phase, ConversationPhase::Completing)),
}
}
pub fn close(&mut self) -> Result<(), LiminalError> {
match self.current_phase {
ConversationPhase::Completing => {
self.current_phase = ConversationPhase::Closed;
Ok(())
}
ConversationPhase::Closed => Ok(()),
phase => Err(invalid_transition(phase, ConversationPhase::Closed)),
}
}
pub const fn fail(&mut self) {
self.current_phase = ConversationPhase::Failed;
}
pub fn record_sent(&mut self, envelope: Envelope) {
self.context.push(ConversationContextEntry::Sent(envelope));
}
pub fn record_received(&mut self, envelope: Envelope) {
self.context
.push(ConversationContextEntry::Received(envelope));
}
pub fn record_participant_crash(
&mut self,
participant: ParticipantPid,
policy: CrashPolicy,
exited_at: Instant,
) {
for status in &mut self.participants {
if status.participant == participant {
status.mark_dead_at(exited_at);
}
}
self.context
.push(ConversationContextEntry::ParticipantCrashed {
participant,
policy,
});
}
}
#[derive(Clone, Debug)]
pub struct ConversationHandle {
backend: Arc<dyn ConversationHandleBackend>,
}
impl ConversationHandle {
pub(crate) fn new(backend: Arc<dyn ConversationHandleBackend>) -> Self {
Self { backend }
}
pub fn send(&self, message: impl Into<Envelope>) -> Result<(), LiminalError> {
self.backend.send(message.into())
}
pub fn receive(&self) -> Result<Envelope, LiminalError> {
self.backend.receive()
}
pub fn close(&self) -> Result<(), LiminalError> {
self.backend.close()
}
pub fn query_state(&self) -> Result<ConversationState, LiminalError> {
self.backend.query_state()
}
pub fn actor_pid(&self) -> Result<ParticipantPid, LiminalError> {
self.backend.actor_pid()
}
}
pub(crate) trait ConversationHandleBackend: std::fmt::Debug + Send + Sync {
fn send(&self, message: Envelope) -> Result<(), LiminalError>;
fn receive(&self) -> Result<Envelope, LiminalError>;
fn close(&self) -> Result<(), LiminalError>;
fn query_state(&self) -> Result<ConversationState, LiminalError>;
fn actor_pid(&self) -> Result<ParticipantPid, LiminalError>;
}
fn invalid_transition(from: ConversationPhase, to: ConversationPhase) -> LiminalError {
LiminalError::ConversationFailed {
message: format!("invalid conversation phase transition from {from:?} to {to:?}"),
}
}
#[cfg(test)]
mod tests {
use super::{Conversation, ConversationHandle};
#[test]
fn starting_conversation_creates_named_span_with_fresh_trace_context() {
let first = Conversation::start("conversation-1");
let second = Conversation::start("conversation-2");
assert_eq!(first.name(), "conversation-1");
assert_eq!(first.parent_trace_context(), None);
assert_ne!(first.trace_context().trace_id(), 0);
assert_ne!(first.trace_context().span_id(), 0);
assert_ne!(
first.trace_context().trace_id(),
second.trace_context().trace_id()
);
}
#[test]
fn messages_inherit_conversation_trace_context_automatically() {
let conversation = Conversation::start("conversation");
let message = conversation.message("payload");
assert_eq!(message.payload(), &"payload");
assert_eq!(message.trace_context(), conversation.trace_context());
}
#[test]
fn child_conversation_references_parent_trace_context() {
let parent = Conversation::start("parent");
let child = parent.spawn_child("child");
assert_eq!(child.name(), "child");
assert_eq!(child.parent_trace_context(), Some(parent.trace_context()));
assert_eq!(
child.trace_context().trace_id(),
parent.trace_context().trace_id()
);
assert_ne!(
child.trace_context().span_id(),
parent.trace_context().span_id()
);
}
#[test]
fn message_mapping_preserves_trace_context() {
let conversation = Conversation::start("conversation");
let context = conversation.trace_context();
let mapped = conversation.message(1_u8).map(u16::from);
assert_eq!(mapped.payload(), &1_u16);
assert_eq!(mapped.trace_context(), context);
}
#[test]
fn conversation_handle_is_clone_send_sync() {
fn assert_clone_send_sync<T: Clone + Send + Sync>() {}
assert_clone_send_sync::<ConversationHandle>();
}
}