atomr_patterns/saga/
runner.rs1use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use atomr_core::actor::ActorSystem;
15use tokio::sync::mpsc::UnboundedReceiver;
16
17use crate::saga::state_store::{InMemorySagaStateStore, SagaStateStore};
18use crate::topology::Topology;
19use crate::PatternError;
20
21pub enum SagaAction<C> {
23 Send(C),
25 Schedule(C, Duration),
27 Compensate(Vec<C>),
29 Complete,
31}
32
33#[async_trait]
35pub trait Saga: Send + 'static {
36 type Event: Send + Clone + 'static;
37 type Command: Send + 'static;
38 type State: Default + Send + 'static;
39 type Error: std::error::Error + Send + 'static;
40
41 fn correlation_id(event: &Self::Event) -> Option<String>;
44
45 async fn handle(
48 &mut self,
49 state: &mut Self::State,
50 event: Self::Event,
51 ) -> Result<Vec<SagaAction<Self::Command>>, Self::Error>;
52
53 fn encode_state(_state: &Self::State) -> Option<Result<Vec<u8>, String>> {
57 None
58 }
59
60 fn decode_state(_bytes: &[u8]) -> Result<Self::State, String> {
63 Err("decode_state not implemented".into())
64 }
65}
66
67pub struct SagaPattern<S>(PhantomData<S>);
69
70impl<S: Saga> SagaPattern<S> {
71 pub fn builder() -> SagaBuilder<S> {
75 SagaBuilder::default()
76 }
77}
78
79type SagaDispatcher<C> = Arc<dyn Fn(C) -> futures::future::BoxFuture<'static, bool> + Send + Sync>;
80
81pub struct SagaBuilder<S: Saga> {
83 name: Option<String>,
84 saga: Option<S>,
85 events: Option<UnboundedReceiver<S::Event>>,
86 dispatcher: Option<SagaDispatcher<S::Command>>,
87 state_store: Option<Arc<dyn SagaStateStore>>,
88}
89
90impl<S: Saga> Default for SagaBuilder<S> {
91 fn default() -> Self {
92 Self { name: None, saga: None, events: None, dispatcher: None, state_store: None }
93 }
94}
95
96impl<S: Saga> SagaBuilder<S> {
97 pub fn name(mut self, n: impl Into<String>) -> Self {
99 self.name = Some(n.into());
100 self
101 }
102
103 pub fn saga(mut self, s: S) -> Self {
105 self.saga = Some(s);
106 self
107 }
108
109 pub fn events(mut self, rx: UnboundedReceiver<S::Event>) -> Self {
112 self.events = Some(rx);
113 self
114 }
115
116 pub fn dispatcher<F, Fut>(mut self, f: F) -> Self
120 where
121 F: Fn(S::Command) -> Fut + Send + Sync + 'static,
122 Fut: std::future::Future<Output = bool> + Send + 'static,
123 {
124 let f = Arc::new(f);
125 self.dispatcher = Some(Arc::new(move |cmd| {
126 let f = f.clone();
127 Box::pin(async move { f(cmd).await })
128 }));
129 self
130 }
131
132 pub fn state_store<T: SagaStateStore>(mut self, store: Arc<T>) -> Self {
137 self.state_store = Some(store);
138 self
139 }
140
141 pub fn build(self) -> Result<SagaTopology<S>, PatternError<S::Error>> {
143 let state_store: Arc<dyn SagaStateStore> =
144 self.state_store.unwrap_or_else(|| Arc::new(InMemorySagaStateStore::new()));
145 Ok(SagaTopology {
146 name: self.name.unwrap_or_else(|| "saga".into()),
147 saga: self.saga.ok_or(PatternError::NotConfigured("saga"))?,
148 events: self.events.ok_or(PatternError::NotConfigured("events"))?,
149 dispatcher: self.dispatcher.ok_or(PatternError::NotConfigured("dispatcher"))?,
150 state_store,
151 })
152 }
153}
154
155pub struct SagaTopology<S: Saga> {
157 name: String,
158 saga: S,
159 events: UnboundedReceiver<S::Event>,
160 dispatcher: SagaDispatcher<S::Command>,
161 state_store: Arc<dyn SagaStateStore>,
162}
163
164pub struct SagaHandles {
166 pub name: String,
167}
168
169#[async_trait]
170impl<S: Saga> Topology for SagaTopology<S> {
171 type Handles = SagaHandles;
172
173 async fn materialize(self, _system: &ActorSystem) -> Result<SagaHandles, PatternError<()>> {
174 let SagaTopology { name, mut saga, mut events, dispatcher, state_store } = self;
175 let task_name = name.clone();
176 tokio::spawn(async move {
177 let mut states: HashMap<String, S::State> = HashMap::new();
178 if S::encode_state(&S::State::default()).is_some() {
180 for corr in state_store.keys().await {
181 if let Some(payload) = state_store.load(&corr).await {
182 match S::decode_state(&payload) {
183 Ok(state) => {
184 states.insert(corr, state);
185 }
186 Err(e) => {
187 tracing::warn!(
188 saga = %task_name,
189 error = %e,
190 "decode saga state failed; dropping"
191 );
192 }
193 }
194 }
195 }
196 }
197 while let Some(event) = events.recv().await {
198 let Some(corr) = S::correlation_id(&event) else {
199 continue;
200 };
201 let state = states.entry(corr.clone()).or_default();
202 match saga.handle(state, event).await {
203 Ok(actions) => {
204 if let Some(Ok(payload)) = S::encode_state(state) {
207 state_store.save(&corr, payload).await;
208 }
209 let mut completed = false;
210 for action in actions {
211 match action {
212 SagaAction::Send(c) => {
213 let _ = (dispatcher)(c).await;
214 }
215 SagaAction::Schedule(c, delay) => {
216 let dispatcher = dispatcher.clone();
217 tokio::spawn(async move {
218 tokio::time::sleep(delay).await;
219 let _ = (dispatcher)(c).await;
220 });
221 }
222 SagaAction::Compensate(cs) => {
223 for c in cs {
224 let _ = (dispatcher)(c).await;
225 }
226 }
227 SagaAction::Complete => {
228 completed = true;
229 break;
230 }
231 }
232 }
233 if completed {
234 states.remove(&corr);
235 state_store.delete(&corr).await;
236 }
237 }
238 Err(e) => {
239 tracing::warn!(saga = %task_name, error = %e, "saga handle failed");
240 }
241 }
242 }
243 });
244 Ok(SagaHandles { name })
245 }
246}