Skip to main content

atomr_patterns/saga/
runner.rs

1//! Saga implementation.
2//!
3//! v1 design: the saga listens to an event channel (typically wired
4//! from a [`crate::cqrs::CqrsPattern`]'s `tap_events`) and dispatches
5//! commands via a user-supplied `dispatcher` closure. State is kept in
6//! a `HashMap<CorrelationId, Saga::State>` inside a tokio task.
7
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use atomr_core::actor::ActorSystem;
15use tokio::sync::mpsc::UnboundedReceiver;
16
17use crate::saga::state_store::{InMemorySagaStateStore, SagaStateStore};
18use crate::topology::Topology;
19use crate::PatternError;
20
21/// What a saga decides to do in response to an event.
22pub enum SagaAction<C> {
23    /// Dispatch this command immediately.
24    Send(C),
25    /// Dispatch this command after a delay.
26    Schedule(C, Duration),
27    /// Dispatch a chain of compensating commands (rollback).
28    Compensate(Vec<C>),
29    /// The saga is done — clear its state.
30    Complete,
31}
32
33/// User-defined saga / process manager.
34#[async_trait]
35pub trait Saga: Send + 'static {
36    type Event: Send + Clone + 'static;
37    type Command: Send + 'static;
38    type State: Default + Send + 'static;
39    type Error: std::error::Error + Send + 'static;
40
41    /// Stable correlation key for `event`. `None` means the event is
42    /// not for this saga.
43    fn correlation_id(event: &Self::Event) -> Option<String>;
44
45    /// React to an event. Receives mutable access to the per-saga state
46    /// keyed by `correlation_id`.
47    async fn handle(
48        &mut self,
49        state: &mut Self::State,
50        event: Self::Event,
51    ) -> Result<Vec<SagaAction<Self::Command>>, Self::Error>;
52
53    /// Optional codec for state persistence. `None` keeps state
54    /// in-memory only (default — preserves v1 behavior). Implement to
55    /// participate in [`crate::saga::SagaStateStore`] persistence.
56    fn encode_state(_state: &Self::State) -> Option<Result<Vec<u8>, String>> {
57        None
58    }
59
60    /// Decode a persisted payload back into `State`. Required iff
61    /// [`Self::encode_state`] is implemented.
62    fn decode_state(_bytes: &[u8]) -> Result<Self::State, String> {
63        Err("decode_state not implemented".into())
64    }
65}
66
67/// Public, zero-sized handle for the saga pattern.
68pub struct SagaPattern<S>(PhantomData<S>);
69
70impl<S: Saga> SagaPattern<S> {
71    /// Build a saga around the given event source and command dispatcher.
72    /// `dispatcher` returns `true` on success — used to decide whether
73    /// to invoke compensation.
74    pub fn builder() -> SagaBuilder<S> {
75        SagaBuilder::default()
76    }
77}
78
79type SagaDispatcher<C> = Arc<dyn Fn(C) -> futures::future::BoxFuture<'static, bool> + Send + Sync>;
80
81/// Fluent builder.
82pub struct SagaBuilder<S: Saga> {
83    name: Option<String>,
84    saga: Option<S>,
85    events: Option<UnboundedReceiver<S::Event>>,
86    dispatcher: Option<SagaDispatcher<S::Command>>,
87    state_store: Option<Arc<dyn SagaStateStore>>,
88}
89
90impl<S: Saga> Default for SagaBuilder<S> {
91    fn default() -> Self {
92        Self { name: None, saga: None, events: None, dispatcher: None, state_store: None }
93    }
94}
95
96impl<S: Saga> SagaBuilder<S> {
97    /// Override the actor name used for tracing / topology display.
98    pub fn name(mut self, n: impl Into<String>) -> Self {
99        self.name = Some(n.into());
100        self
101    }
102
103    /// Provide the saga implementation.
104    pub fn saga(mut self, s: S) -> Self {
105        self.saga = Some(s);
106        self
107    }
108
109    /// Provide the event source. Typically wired from
110    /// `CqrsBuilder::tap_events`.
111    pub fn events(mut self, rx: UnboundedReceiver<S::Event>) -> Self {
112        self.events = Some(rx);
113        self
114    }
115
116    /// Provide the command dispatcher. The closure receives the
117    /// command and returns whether the dispatch succeeded — failures
118    /// cause [`SagaAction::Compensate`] chains to fire (when present).
119    pub fn dispatcher<F, Fut>(mut self, f: F) -> Self
120    where
121        F: Fn(S::Command) -> Fut + Send + Sync + 'static,
122        Fut: std::future::Future<Output = bool> + Send + 'static,
123    {
124        let f = Arc::new(f);
125        self.dispatcher = Some(Arc::new(move |cmd| {
126            let f = f.clone();
127            Box::pin(async move { f(cmd).await })
128        }));
129        self
130    }
131
132    /// Provide a [`SagaStateStore`]. When set together with
133    /// [`Saga::encode_state`] / [`Saga::decode_state`], the runner
134    /// reloads in-flight saga states on startup and persists state
135    /// after each event handle. Default: in-memory.
136    pub fn state_store<T: SagaStateStore>(mut self, store: Arc<T>) -> Self {
137        self.state_store = Some(store);
138        self
139    }
140
141    /// Finalize the builder.
142    pub fn build(self) -> Result<SagaTopology<S>, PatternError<S::Error>> {
143        let state_store: Arc<dyn SagaStateStore> =
144            self.state_store.unwrap_or_else(|| Arc::new(InMemorySagaStateStore::new()));
145        Ok(SagaTopology {
146            name: self.name.unwrap_or_else(|| "saga".into()),
147            saga: self.saga.ok_or(PatternError::NotConfigured("saga"))?,
148            events: self.events.ok_or(PatternError::NotConfigured("events"))?,
149            dispatcher: self.dispatcher.ok_or(PatternError::NotConfigured("dispatcher"))?,
150            state_store,
151        })
152    }
153}
154
155/// Materializable description of a saga.
156pub struct SagaTopology<S: Saga> {
157    name: String,
158    saga: S,
159    events: UnboundedReceiver<S::Event>,
160    dispatcher: SagaDispatcher<S::Command>,
161    state_store: Arc<dyn SagaStateStore>,
162}
163
164/// Handles handed back after [`Topology::materialize`].
165pub struct SagaHandles {
166    pub name: String,
167}
168
169#[async_trait]
170impl<S: Saga> Topology for SagaTopology<S> {
171    type Handles = SagaHandles;
172
173    async fn materialize(self, _system: &ActorSystem) -> Result<SagaHandles, PatternError<()>> {
174        let SagaTopology { name, mut saga, mut events, dispatcher, state_store } = self;
175        let task_name = name.clone();
176        tokio::spawn(async move {
177            let mut states: HashMap<String, S::State> = HashMap::new();
178            // Rehydrate any persisted in-flight saga states.
179            if S::encode_state(&S::State::default()).is_some() {
180                for corr in state_store.keys().await {
181                    if let Some(payload) = state_store.load(&corr).await {
182                        match S::decode_state(&payload) {
183                            Ok(state) => {
184                                states.insert(corr, state);
185                            }
186                            Err(e) => {
187                                tracing::warn!(
188                                    saga = %task_name,
189                                    error = %e,
190                                    "decode saga state failed; dropping"
191                                );
192                            }
193                        }
194                    }
195                }
196            }
197            while let Some(event) = events.recv().await {
198                let Some(corr) = S::correlation_id(&event) else {
199                    continue;
200                };
201                let state = states.entry(corr.clone()).or_default();
202                match saga.handle(state, event).await {
203                    Ok(actions) => {
204                        // Persist updated state before any dispatch so a
205                        // crash mid-dispatch doesn't lose the decision.
206                        if let Some(Ok(payload)) = S::encode_state(state) {
207                            state_store.save(&corr, payload).await;
208                        }
209                        let mut completed = false;
210                        for action in actions {
211                            match action {
212                                SagaAction::Send(c) => {
213                                    let _ = (dispatcher)(c).await;
214                                }
215                                SagaAction::Schedule(c, delay) => {
216                                    let dispatcher = dispatcher.clone();
217                                    tokio::spawn(async move {
218                                        tokio::time::sleep(delay).await;
219                                        let _ = (dispatcher)(c).await;
220                                    });
221                                }
222                                SagaAction::Compensate(cs) => {
223                                    for c in cs {
224                                        let _ = (dispatcher)(c).await;
225                                    }
226                                }
227                                SagaAction::Complete => {
228                                    completed = true;
229                                    break;
230                                }
231                            }
232                        }
233                        if completed {
234                            states.remove(&corr);
235                            state_store.delete(&corr).await;
236                        }
237                    }
238                    Err(e) => {
239                        tracing::warn!(saga = %task_name, error = %e, "saga handle failed");
240                    }
241                }
242            }
243        });
244        Ok(SagaHandles { name })
245    }
246}