Skip to main content

atomr_persistence_aws/
journal.rs

1//! DynamoDB `Journal` implementation (single-table design).
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use atomr_persistence::{Journal, JournalError, PersistentRepr};
8use aws_sdk_dynamodb::primitives::Blob;
9use aws_sdk_dynamodb::types::{AttributeValue, ReturnValue};
10use aws_sdk_dynamodb::Client;
11
12use crate::config::DynamoConfig;
13use crate::keys::{event_sk, parse_sequence, EVENT_PREFIX};
14use crate::schema::ensure_table;
15
16pub struct DynamoJournal {
17    client: Client,
18    cfg: DynamoConfig,
19}
20
21impl DynamoJournal {
22    pub async fn connect(cfg: DynamoConfig) -> Result<Arc<Self>, JournalError> {
23        let client = build_client(&cfg).await;
24        ensure_table(&client, &cfg).await?;
25        Ok(Arc::new(Self { client, cfg }))
26    }
27
28    pub async fn from_client(client: Client, cfg: DynamoConfig) -> Result<Arc<Self>, JournalError> {
29        ensure_table(&client, &cfg).await?;
30        Ok(Arc::new(Self { client, cfg }))
31    }
32
33    pub fn client(&self) -> &Client {
34        &self.client
35    }
36
37    pub fn config(&self) -> &DynamoConfig {
38        &self.cfg
39    }
40
41    fn to_av(&self, repr: &PersistentRepr) -> HashMap<String, AttributeValue> {
42        let mut av = HashMap::new();
43        av.insert("pid".into(), AttributeValue::S(repr.persistence_id.clone()));
44        av.insert("sk".into(), AttributeValue::S(event_sk(repr.sequence_nr)));
45        av.insert("seq".into(), AttributeValue::N(repr.sequence_nr.to_string()));
46        av.insert("payload".into(), AttributeValue::B(Blob::new(repr.payload.clone())));
47        av.insert("manifest".into(), AttributeValue::S(repr.manifest.clone()));
48        av.insert("writer_uuid".into(), AttributeValue::S(repr.writer_uuid.clone()));
49        av.insert("deleted".into(), AttributeValue::Bool(repr.deleted));
50        if !repr.tags.is_empty() {
51            av.insert("tags".into(), AttributeValue::Ss(repr.tags.clone()));
52        }
53        av
54    }
55
56    async fn current_max(&self, pid: &str) -> Result<u64, JournalError> {
57        let out = self
58            .client
59            .query()
60            .table_name(&self.cfg.table_name)
61            .key_condition_expression("#p = :p AND begins_with(#s, :e)")
62            .expression_attribute_names("#p", "pid")
63            .expression_attribute_names("#s", "sk")
64            .expression_attribute_values(":p", AttributeValue::S(pid.into()))
65            .expression_attribute_values(":e", AttributeValue::S(EVENT_PREFIX.into()))
66            .scan_index_forward(false)
67            .limit(1)
68            .send()
69            .await
70            .map_err(|e| JournalError::backend(format!("{e:?}")))?;
71        let items = out.items();
72        if items.is_empty() {
73            return Ok(0);
74        }
75        let sk = items[0].get("sk").and_then(|v| v.as_s().ok()).cloned().unwrap_or_default();
76        Ok(parse_sequence(&sk).unwrap_or(0))
77    }
78}
79
80async fn build_client(cfg: &DynamoConfig) -> Client {
81    let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
82    if let Some(region) = &cfg.region {
83        loader = loader.region(aws_config::Region::new(region.clone()));
84    }
85    let sdk_cfg = loader.load().await;
86    let mut builder = aws_sdk_dynamodb::config::Builder::from(&sdk_cfg);
87    if let Some(endpoint) = &cfg.endpoint_url {
88        builder = builder.endpoint_url(endpoint);
89    }
90    Client::from_conf(builder.build())
91}
92
93#[async_trait]
94impl Journal for DynamoJournal {
95    async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
96        if messages.is_empty() {
97            return Ok(());
98        }
99        let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
100            std::collections::BTreeMap::new();
101        for m in messages {
102            by_pid.entry(m.persistence_id.clone()).or_default().push(m);
103        }
104        for (pid, batch) in by_pid {
105            let mut expected = self.current_max(&pid).await? + 1;
106            for msg in batch {
107                if msg.sequence_nr != expected {
108                    return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
109                }
110                expected += 1;
111                let item = self.to_av(&msg);
112                self.client
113                    .put_item()
114                    .table_name(&self.cfg.table_name)
115                    .set_item(Some(item))
116                    .condition_expression("attribute_not_exists(sk)")
117                    .send()
118                    .await
119                    .map_err(|e| JournalError::backend(format!("{e:?}")))?;
120            }
121        }
122        Ok(())
123    }
124
125    async fn delete_messages_to(
126        &self,
127        persistence_id: &str,
128        to_sequence_nr: u64,
129    ) -> Result<(), JournalError> {
130        for seq in 1..=to_sequence_nr {
131            let mut key = HashMap::new();
132            key.insert("pid".into(), AttributeValue::S(persistence_id.into()));
133            key.insert("sk".into(), AttributeValue::S(event_sk(seq)));
134            let _ = self
135                .client
136                .update_item()
137                .table_name(&self.cfg.table_name)
138                .set_key(Some(key))
139                .update_expression("SET #d = :t")
140                .expression_attribute_names("#d", "deleted")
141                .expression_attribute_values(":t", AttributeValue::Bool(true))
142                .return_values(ReturnValue::None)
143                .send()
144                .await;
145        }
146        Ok(())
147    }
148
149    async fn replay_messages(
150        &self,
151        persistence_id: &str,
152        from: u64,
153        to: u64,
154        max: u64,
155    ) -> Result<Vec<PersistentRepr>, JournalError> {
156        let limit = if max > i32::MAX as u64 { i32::MAX } else { max as i32 };
157        let from_sk = event_sk(from);
158        let to_sk = event_sk(to);
159        let out = self
160            .client
161            .query()
162            .table_name(&self.cfg.table_name)
163            .key_condition_expression("#p = :p AND #s BETWEEN :from AND :to")
164            .expression_attribute_names("#p", "pid")
165            .expression_attribute_names("#s", "sk")
166            .expression_attribute_values(":p", AttributeValue::S(persistence_id.into()))
167            .expression_attribute_values(":from", AttributeValue::S(from_sk))
168            .expression_attribute_values(":to", AttributeValue::S(to_sk))
169            .limit(limit)
170            .send()
171            .await
172            .map_err(|e| JournalError::backend(format!("{e:?}")))?;
173        let mut results = Vec::new();
174        for item in out.items() {
175            let deleted = item.get("deleted").and_then(|v| v.as_bool().ok()).copied().unwrap_or(false);
176            if deleted {
177                continue;
178            }
179            let seq =
180                item.get("seq").and_then(|v| v.as_n().ok()).and_then(|s| s.parse::<u64>().ok()).unwrap_or(0);
181            let payload = item
182                .get("payload")
183                .and_then(|v| v.as_b().ok())
184                .map(|b| b.as_ref().to_vec())
185                .unwrap_or_default();
186            let manifest = item.get("manifest").and_then(|v| v.as_s().ok()).cloned().unwrap_or_default();
187            let writer_uuid =
188                item.get("writer_uuid").and_then(|v| v.as_s().ok()).cloned().unwrap_or_default();
189            let tags = item.get("tags").and_then(|v| v.as_ss().ok()).cloned().unwrap_or_default();
190            results.push(PersistentRepr {
191                persistence_id: persistence_id.to_string(),
192                sequence_nr: seq,
193                payload,
194                manifest,
195                writer_uuid,
196                deleted,
197                tags,
198            });
199        }
200        Ok(results)
201    }
202
203    async fn highest_sequence_nr(&self, persistence_id: &str, _from: u64) -> Result<u64, JournalError> {
204        self.current_max(persistence_id).await
205    }
206}