atomr_persistence_aws/
journal.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use atomr_persistence::{Journal, JournalError, PersistentRepr};
8use aws_sdk_dynamodb::primitives::Blob;
9use aws_sdk_dynamodb::types::{AttributeValue, ReturnValue};
10use aws_sdk_dynamodb::Client;
11
12use crate::config::DynamoConfig;
13use crate::keys::{event_sk, parse_sequence, EVENT_PREFIX};
14use crate::schema::ensure_table;
15
16pub struct DynamoJournal {
17 client: Client,
18 cfg: DynamoConfig,
19}
20
21impl DynamoJournal {
22 pub async fn connect(cfg: DynamoConfig) -> Result<Arc<Self>, JournalError> {
23 let client = build_client(&cfg).await;
24 ensure_table(&client, &cfg).await?;
25 Ok(Arc::new(Self { client, cfg }))
26 }
27
28 pub async fn from_client(client: Client, cfg: DynamoConfig) -> Result<Arc<Self>, JournalError> {
29 ensure_table(&client, &cfg).await?;
30 Ok(Arc::new(Self { client, cfg }))
31 }
32
33 pub fn client(&self) -> &Client {
34 &self.client
35 }
36
37 pub fn config(&self) -> &DynamoConfig {
38 &self.cfg
39 }
40
41 fn to_av(&self, repr: &PersistentRepr) -> HashMap<String, AttributeValue> {
42 let mut av = HashMap::new();
43 av.insert("pid".into(), AttributeValue::S(repr.persistence_id.clone()));
44 av.insert("sk".into(), AttributeValue::S(event_sk(repr.sequence_nr)));
45 av.insert("seq".into(), AttributeValue::N(repr.sequence_nr.to_string()));
46 av.insert("payload".into(), AttributeValue::B(Blob::new(repr.payload.clone())));
47 av.insert("manifest".into(), AttributeValue::S(repr.manifest.clone()));
48 av.insert("writer_uuid".into(), AttributeValue::S(repr.writer_uuid.clone()));
49 av.insert("deleted".into(), AttributeValue::Bool(repr.deleted));
50 if !repr.tags.is_empty() {
51 av.insert("tags".into(), AttributeValue::Ss(repr.tags.clone()));
52 }
53 av
54 }
55
56 async fn current_max(&self, pid: &str) -> Result<u64, JournalError> {
57 let out = self
58 .client
59 .query()
60 .table_name(&self.cfg.table_name)
61 .key_condition_expression("#p = :p AND begins_with(#s, :e)")
62 .expression_attribute_names("#p", "pid")
63 .expression_attribute_names("#s", "sk")
64 .expression_attribute_values(":p", AttributeValue::S(pid.into()))
65 .expression_attribute_values(":e", AttributeValue::S(EVENT_PREFIX.into()))
66 .scan_index_forward(false)
67 .limit(1)
68 .send()
69 .await
70 .map_err(|e| JournalError::backend(format!("{e:?}")))?;
71 let items = out.items();
72 if items.is_empty() {
73 return Ok(0);
74 }
75 let sk = items[0].get("sk").and_then(|v| v.as_s().ok()).cloned().unwrap_or_default();
76 Ok(parse_sequence(&sk).unwrap_or(0))
77 }
78}
79
80async fn build_client(cfg: &DynamoConfig) -> Client {
81 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
82 if let Some(region) = &cfg.region {
83 loader = loader.region(aws_config::Region::new(region.clone()));
84 }
85 let sdk_cfg = loader.load().await;
86 let mut builder = aws_sdk_dynamodb::config::Builder::from(&sdk_cfg);
87 if let Some(endpoint) = &cfg.endpoint_url {
88 builder = builder.endpoint_url(endpoint);
89 }
90 Client::from_conf(builder.build())
91}
92
93#[async_trait]
94impl Journal for DynamoJournal {
95 async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
96 if messages.is_empty() {
97 return Ok(());
98 }
99 let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
100 std::collections::BTreeMap::new();
101 for m in messages {
102 by_pid.entry(m.persistence_id.clone()).or_default().push(m);
103 }
104 for (pid, batch) in by_pid {
105 let mut expected = self.current_max(&pid).await? + 1;
106 for msg in batch {
107 if msg.sequence_nr != expected {
108 return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
109 }
110 expected += 1;
111 let item = self.to_av(&msg);
112 self.client
113 .put_item()
114 .table_name(&self.cfg.table_name)
115 .set_item(Some(item))
116 .condition_expression("attribute_not_exists(sk)")
117 .send()
118 .await
119 .map_err(|e| JournalError::backend(format!("{e:?}")))?;
120 }
121 }
122 Ok(())
123 }
124
125 async fn delete_messages_to(
126 &self,
127 persistence_id: &str,
128 to_sequence_nr: u64,
129 ) -> Result<(), JournalError> {
130 for seq in 1..=to_sequence_nr {
131 let mut key = HashMap::new();
132 key.insert("pid".into(), AttributeValue::S(persistence_id.into()));
133 key.insert("sk".into(), AttributeValue::S(event_sk(seq)));
134 let _ = self
135 .client
136 .update_item()
137 .table_name(&self.cfg.table_name)
138 .set_key(Some(key))
139 .update_expression("SET #d = :t")
140 .expression_attribute_names("#d", "deleted")
141 .expression_attribute_values(":t", AttributeValue::Bool(true))
142 .return_values(ReturnValue::None)
143 .send()
144 .await;
145 }
146 Ok(())
147 }
148
149 async fn replay_messages(
150 &self,
151 persistence_id: &str,
152 from: u64,
153 to: u64,
154 max: u64,
155 ) -> Result<Vec<PersistentRepr>, JournalError> {
156 let limit = if max > i32::MAX as u64 { i32::MAX } else { max as i32 };
157 let from_sk = event_sk(from);
158 let to_sk = event_sk(to);
159 let out = self
160 .client
161 .query()
162 .table_name(&self.cfg.table_name)
163 .key_condition_expression("#p = :p AND #s BETWEEN :from AND :to")
164 .expression_attribute_names("#p", "pid")
165 .expression_attribute_names("#s", "sk")
166 .expression_attribute_values(":p", AttributeValue::S(persistence_id.into()))
167 .expression_attribute_values(":from", AttributeValue::S(from_sk))
168 .expression_attribute_values(":to", AttributeValue::S(to_sk))
169 .limit(limit)
170 .send()
171 .await
172 .map_err(|e| JournalError::backend(format!("{e:?}")))?;
173 let mut results = Vec::new();
174 for item in out.items() {
175 let deleted = item.get("deleted").and_then(|v| v.as_bool().ok()).copied().unwrap_or(false);
176 if deleted {
177 continue;
178 }
179 let seq =
180 item.get("seq").and_then(|v| v.as_n().ok()).and_then(|s| s.parse::<u64>().ok()).unwrap_or(0);
181 let payload = item
182 .get("payload")
183 .and_then(|v| v.as_b().ok())
184 .map(|b| b.as_ref().to_vec())
185 .unwrap_or_default();
186 let manifest = item.get("manifest").and_then(|v| v.as_s().ok()).cloned().unwrap_or_default();
187 let writer_uuid =
188 item.get("writer_uuid").and_then(|v| v.as_s().ok()).cloned().unwrap_or_default();
189 let tags = item.get("tags").and_then(|v| v.as_ss().ok()).cloned().unwrap_or_default();
190 results.push(PersistentRepr {
191 persistence_id: persistence_id.to_string(),
192 sequence_nr: seq,
193 payload,
194 manifest,
195 writer_uuid,
196 deleted,
197 tags,
198 });
199 }
200 Ok(results)
201 }
202
203 async fn highest_sequence_nr(&self, persistence_id: &str, _from: u64) -> Result<u64, JournalError> {
204 self.current_max(persistence_id).await
205 }
206}