atomr_patterns/saga/
state_store.rs1use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_persistence::{Journal, PersistentRepr};
15use parking_lot::RwLock;
16
17#[async_trait]
21pub trait SagaStateStore: Send + Sync + 'static {
22 async fn load(&self, correlation_id: &str) -> Option<Vec<u8>>;
25
26 async fn save(&self, correlation_id: &str, payload: Vec<u8>);
28
29 async fn delete(&self, correlation_id: &str);
31
32 async fn keys(&self) -> Vec<String>;
35}
36
37pub struct InMemorySagaStateStore {
40 inner: Arc<RwLock<HashMap<String, Vec<u8>>>>,
41}
42
43impl Default for InMemorySagaStateStore {
44 fn default() -> Self {
45 Self { inner: Arc::new(RwLock::new(HashMap::new())) }
46 }
47}
48
49impl InMemorySagaStateStore {
50 pub fn new() -> Self {
51 Self::default()
52 }
53}
54
55#[async_trait]
56impl SagaStateStore for InMemorySagaStateStore {
57 async fn load(&self, correlation_id: &str) -> Option<Vec<u8>> {
58 self.inner.read().get(correlation_id).cloned()
59 }
60 async fn save(&self, correlation_id: &str, payload: Vec<u8>) {
61 self.inner.write().insert(correlation_id.into(), payload);
62 }
63 async fn delete(&self, correlation_id: &str) {
64 self.inner.write().remove(correlation_id);
65 }
66 async fn keys(&self) -> Vec<String> {
67 self.inner.read().keys().cloned().collect()
68 }
69}
70
71pub struct JournalSagaStateStore<J: Journal> {
79 journal: Arc<J>,
80 saga_name: String,
81 writer_uuid: String,
82 _marker: PhantomData<J>,
83}
84
85impl<J: Journal> JournalSagaStateStore<J> {
86 pub fn new(journal: Arc<J>, saga_name: impl Into<String>) -> Self {
87 Self {
88 journal,
89 saga_name: saga_name.into(),
90 writer_uuid: format!("saga-{}", rand_id()),
91 _marker: PhantomData,
92 }
93 }
94
95 fn pid(&self, correlation_id: &str) -> String {
96 format!("saga::{}::{}", self.saga_name, correlation_id)
97 }
98
99 fn pid_prefix(&self) -> String {
100 format!("saga::{}::", self.saga_name)
101 }
102}
103
104#[async_trait]
105impl<J: Journal> SagaStateStore for JournalSagaStateStore<J> {
106 async fn load(&self, correlation_id: &str) -> Option<Vec<u8>> {
107 let pid = self.pid(correlation_id);
108 let highest = self.journal.highest_sequence_nr(&pid, 0).await.ok()?;
109 if highest == 0 {
110 return None;
111 }
112 let reprs = self.journal.replay_messages(&pid, highest, highest, 1).await.ok()?;
113 reprs.into_iter().last().filter(|r| !r.deleted).map(|r| r.payload)
114 }
115
116 async fn save(&self, correlation_id: &str, payload: Vec<u8>) {
117 let pid = self.pid(correlation_id);
118 let next_seq = self.journal.highest_sequence_nr(&pid, 0).await.unwrap_or(0) + 1;
119 let _ = self
120 .journal
121 .write_messages(vec![PersistentRepr {
122 persistence_id: pid,
123 sequence_nr: next_seq,
124 payload,
125 manifest: "saga-state".into(),
126 writer_uuid: self.writer_uuid.clone(),
127 deleted: false,
128 tags: vec![format!("saga::{}", self.saga_name)],
129 }])
130 .await;
131 }
132
133 async fn delete(&self, correlation_id: &str) {
134 let pid = self.pid(correlation_id);
135 if let Ok(highest) = self.journal.highest_sequence_nr(&pid, 0).await {
136 if highest > 0 {
137 let _ = self.journal.delete_messages_to(&pid, highest).await;
138 }
139 }
140 }
141
142 async fn keys(&self) -> Vec<String> {
143 let prefix = self.pid_prefix();
144 self.journal
145 .all_persistence_ids()
146 .await
147 .unwrap_or_default()
148 .into_iter()
149 .filter_map(|pid| pid.strip_prefix(&prefix).map(|s| s.to_string()))
150 .collect()
151 }
152}
153
154fn rand_id() -> String {
155 use std::time::{SystemTime, UNIX_EPOCH};
156 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_nanos()).unwrap_or(0);
157 format!("{nanos:x}")
158}