Skip to main content

atomr_patterns/saga/
state_store.rs

1//! [`SagaStateStore`] — pluggable per-correlation state storage.
2//!
3//! The saga runner keeps state keyed by correlation id. The default
4//! [`InMemorySagaStateStore`] is fine for tests and single-process
5//! workloads but loses state on restart. Implement this trait against
6//! a durable backend (or build on top of [`atomr_persistence::Journal`]
7//! via [`JournalSagaStateStore`]) for production sagas.
8
9use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_persistence::{Journal, PersistentRepr};
15use parking_lot::RwLock;
16
17/// Per-correlation state storage. Saga state is opaque (`Vec<u8>`) at
18/// this layer; the saga supplies the codec via
19/// [`crate::saga::Saga::encode_state`] / [`crate::saga::Saga::decode_state`].
20#[async_trait]
21pub trait SagaStateStore: Send + Sync + 'static {
22    /// Load the persisted state for `correlation_id`. `None` means no
23    /// state exists yet (treat as fresh / `Default`).
24    async fn load(&self, correlation_id: &str) -> Option<Vec<u8>>;
25
26    /// Persist `payload` as the latest state for `correlation_id`.
27    async fn save(&self, correlation_id: &str, payload: Vec<u8>);
28
29    /// Drop the state for `correlation_id` (called on `SagaAction::Complete`).
30    async fn delete(&self, correlation_id: &str);
31
32    /// Every correlation id with persisted state — used at startup to
33    /// rehydrate in-flight sagas.
34    async fn keys(&self) -> Vec<String>;
35}
36
37/// Reference in-memory implementation. Survives runner restarts within
38/// the same process; loses everything on process restart.
39pub struct InMemorySagaStateStore {
40    inner: Arc<RwLock<HashMap<String, Vec<u8>>>>,
41}
42
43impl Default for InMemorySagaStateStore {
44    fn default() -> Self {
45        Self { inner: Arc::new(RwLock::new(HashMap::new())) }
46    }
47}
48
49impl InMemorySagaStateStore {
50    pub fn new() -> Self {
51        Self::default()
52    }
53}
54
55#[async_trait]
56impl SagaStateStore for InMemorySagaStateStore {
57    async fn load(&self, correlation_id: &str) -> Option<Vec<u8>> {
58        self.inner.read().get(correlation_id).cloned()
59    }
60    async fn save(&self, correlation_id: &str, payload: Vec<u8>) {
61        self.inner.write().insert(correlation_id.into(), payload);
62    }
63    async fn delete(&self, correlation_id: &str) {
64        self.inner.write().remove(correlation_id);
65    }
66    async fn keys(&self) -> Vec<String> {
67        self.inner.read().keys().cloned().collect()
68    }
69}
70
71/// Journal-backed saga state store.
72///
73/// Each `(saga_name, correlation_id)` pair is treated as a single
74/// persistence id (`saga::<saga_name>::<correlation_id>`). Saves
75/// append a new event; loads replay the stream and return the most
76/// recent state payload. `keys()` is best-effort — it consults
77/// [`Journal::all_persistence_ids`].
78pub struct JournalSagaStateStore<J: Journal> {
79    journal: Arc<J>,
80    saga_name: String,
81    writer_uuid: String,
82    _marker: PhantomData<J>,
83}
84
85impl<J: Journal> JournalSagaStateStore<J> {
86    pub fn new(journal: Arc<J>, saga_name: impl Into<String>) -> Self {
87        Self {
88            journal,
89            saga_name: saga_name.into(),
90            writer_uuid: format!("saga-{}", rand_id()),
91            _marker: PhantomData,
92        }
93    }
94
95    fn pid(&self, correlation_id: &str) -> String {
96        format!("saga::{}::{}", self.saga_name, correlation_id)
97    }
98
99    fn pid_prefix(&self) -> String {
100        format!("saga::{}::", self.saga_name)
101    }
102}
103
104#[async_trait]
105impl<J: Journal> SagaStateStore for JournalSagaStateStore<J> {
106    async fn load(&self, correlation_id: &str) -> Option<Vec<u8>> {
107        let pid = self.pid(correlation_id);
108        let highest = self.journal.highest_sequence_nr(&pid, 0).await.ok()?;
109        if highest == 0 {
110            return None;
111        }
112        let reprs = self.journal.replay_messages(&pid, highest, highest, 1).await.ok()?;
113        reprs.into_iter().last().filter(|r| !r.deleted).map(|r| r.payload)
114    }
115
116    async fn save(&self, correlation_id: &str, payload: Vec<u8>) {
117        let pid = self.pid(correlation_id);
118        let next_seq = self.journal.highest_sequence_nr(&pid, 0).await.unwrap_or(0) + 1;
119        let _ = self
120            .journal
121            .write_messages(vec![PersistentRepr {
122                persistence_id: pid,
123                sequence_nr: next_seq,
124                payload,
125                manifest: "saga-state".into(),
126                writer_uuid: self.writer_uuid.clone(),
127                deleted: false,
128                tags: vec![format!("saga::{}", self.saga_name)],
129            }])
130            .await;
131    }
132
133    async fn delete(&self, correlation_id: &str) {
134        let pid = self.pid(correlation_id);
135        if let Ok(highest) = self.journal.highest_sequence_nr(&pid, 0).await {
136            if highest > 0 {
137                let _ = self.journal.delete_messages_to(&pid, highest).await;
138            }
139        }
140    }
141
142    async fn keys(&self) -> Vec<String> {
143        let prefix = self.pid_prefix();
144        self.journal
145            .all_persistence_ids()
146            .await
147            .unwrap_or_default()
148            .into_iter()
149            .filter_map(|pid| pid.strip_prefix(&prefix).map(|s| s.to_string()))
150            .collect()
151    }
152}
153
154fn rand_id() -> String {
155    use std::time::{SystemTime, UNIX_EPOCH};
156    let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_nanos()).unwrap_or(0);
157    format!("{nanos:x}")
158}