1use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_persistence::{Journal, JournalError, PersistentRepr};
11use sqlx::any::AnyPoolOptions;
12use sqlx::AnyPool;
13
14use crate::config::SqlConfig;
15use crate::schema::{ensure_schema, init_drivers};
16use crate::worm::{compute_row_hash, WormConfig};
17
18fn clamp_i64(v: u64) -> i64 {
21 if v > i64::MAX as u64 {
22 i64::MAX
23 } else {
24 v as i64
25 }
26}
27
28fn parse_valid_time(tags: &[String]) -> Option<i64> {
31 for t in tags {
32 if let Some(rest) = t.strip_prefix("valid_time:") {
33 if let Ok(n) = rest.parse::<i64>() {
34 return Some(n);
35 }
36 }
37 }
38 None
39}
40
41pub struct SqlJournal {
42 pool: AnyPool,
43 cfg: SqlConfig,
44 worm: WormConfig,
45}
46
47impl SqlJournal {
48 pub async fn connect(cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
50 init_drivers();
51 let pool = AnyPoolOptions::new()
52 .max_connections(cfg.max_connections)
53 .connect(&cfg.url)
54 .await
55 .map_err(JournalError::backend)?;
56 ensure_schema(&pool, &cfg).await?;
57 Ok(Arc::new(Self { pool, cfg, worm: WormConfig::default() }))
58 }
59
60 pub async fn from_pool(pool: AnyPool, cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
62 ensure_schema(&pool, &cfg).await?;
63 Ok(Arc::new(Self { pool, cfg, worm: WormConfig::default() }))
64 }
65
66 pub async fn with_worm(self: Arc<Self>, worm: WormConfig) -> Result<Arc<Self>, JournalError> {
73 if worm.deny_update_delete {
74 crate::schema::install_worm_triggers(&self.pool, &self.cfg).await?;
75 }
76 Ok(Arc::new(Self { pool: self.pool.clone(), cfg: self.cfg.clone(), worm }))
77 }
78
79 pub fn pool(&self) -> &AnyPool {
80 &self.pool
81 }
82
83 pub fn config(&self) -> &SqlConfig {
84 &self.cfg
85 }
86
87 pub fn worm_config(&self) -> WormConfig {
88 self.worm
89 }
90
91 async fn current_highest(&self, pid: &str) -> Result<u64, JournalError> {
92 let row: Option<(Option<i64>,)> =
93 sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
94 .bind(pid)
95 .fetch_optional(&self.pool)
96 .await
97 .map_err(JournalError::backend)?;
98 Ok(row.and_then(|(v,)| v).map(|v| v as u64).unwrap_or(0))
99 }
100
101 pub async fn replay_as_of(
106 &self,
107 pid: &str,
108 system_time_nanos: i64,
109 ) -> Result<Vec<PersistentRepr>, JournalError> {
110 let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
111 "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted \
112 FROM event_journal \
113 WHERE persistence_id = ? AND deleted = 0 \
114 AND COALESCE(system_time, created_at) <= ? \
115 ORDER BY sequence_nr ASC",
116 )
117 .bind(pid)
118 .bind(system_time_nanos)
119 .fetch_all(&self.pool)
120 .await
121 .map_err(JournalError::backend)?;
122 self.hydrate(rows).await
123 }
124
125 pub async fn replay_valid_as_of(
130 &self,
131 pid: &str,
132 valid_time_nanos: i64,
133 system_time_nanos: i64,
134 ) -> Result<Vec<PersistentRepr>, JournalError> {
135 let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
136 "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted \
137 FROM event_journal \
138 WHERE persistence_id = ? AND deleted = 0 \
139 AND COALESCE(system_time, created_at) <= ? \
140 AND COALESCE(valid_time, COALESCE(system_time, created_at)) <= ? \
141 ORDER BY sequence_nr ASC",
142 )
143 .bind(pid)
144 .bind(system_time_nanos)
145 .bind(valid_time_nanos)
146 .fetch_all(&self.pool)
147 .await
148 .map_err(JournalError::backend)?;
149 self.hydrate(rows).await
150 }
151
152 async fn hydrate(
154 &self,
155 rows: Vec<(String, i64, Vec<u8>, String, String, i32)>,
156 ) -> Result<Vec<PersistentRepr>, JournalError> {
157 let mut out = Vec::with_capacity(rows.len());
158 for (pid, seq, payload, manifest, writer, deleted) in rows {
159 let tags: Vec<(String,)> =
160 sqlx::query_as("SELECT tag FROM event_tags WHERE persistence_id = ? AND sequence_nr = ?")
161 .bind(&pid)
162 .bind(seq)
163 .fetch_all(&self.pool)
164 .await
165 .map_err(JournalError::backend)?;
166 out.push(PersistentRepr {
167 persistence_id: pid,
168 sequence_nr: seq as u64,
169 payload,
170 manifest,
171 writer_uuid: writer,
172 deleted: deleted != 0,
173 tags: tags.into_iter().map(|(t,)| t).collect(),
174 });
175 }
176 Ok(out)
177 }
178}
179
180#[async_trait]
181impl Journal for SqlJournal {
182 async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
183 if messages.is_empty() {
184 return Ok(());
185 }
186 let mut tx = self.pool.begin().await.map_err(JournalError::backend)?;
187 let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
188 std::collections::BTreeMap::new();
189 for m in messages {
190 by_pid.entry(m.persistence_id.clone()).or_default().push(m);
191 }
192 for (pid, batch) in by_pid {
193 let row: Option<(Option<i64>,)> =
194 sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
195 .bind(&pid)
196 .fetch_optional(&mut *tx)
197 .await
198 .map_err(JournalError::backend)?;
199 let start = row.and_then(|(v,)| v).map(|v| v as u64 + 1).unwrap_or(1);
200
201 let mut prev_hash: Vec<u8> = if self.worm.hash_chain {
204 let last: Option<(Option<Vec<u8>>,)> = sqlx::query_as(
205 "SELECT row_hash FROM event_journal WHERE persistence_id = ? \
206 ORDER BY sequence_nr DESC LIMIT 1",
207 )
208 .bind(&pid)
209 .fetch_optional(&mut *tx)
210 .await
211 .map_err(JournalError::backend)?;
212 last.and_then(|(h,)| h).unwrap_or_default()
213 } else {
214 Vec::new()
215 };
216
217 for (expected, msg) in (start..).zip(batch) {
218 if msg.sequence_nr != expected {
219 return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
220 }
221 let created_at = chrono::Utc::now().timestamp_millis();
222 let system_time = created_at;
225 let valid_time = parse_valid_time(&msg.tags);
226
227 let (row_hash_opt, prev_for_insert): (Option<Vec<u8>>, Option<Vec<u8>>) =
229 if self.worm.hash_chain {
230 let prev_for_insert =
231 if prev_hash.is_empty() { None } else { Some(prev_hash.clone()) };
232 let rh = compute_row_hash(
233 &prev_hash,
234 &msg.persistence_id,
235 msg.sequence_nr,
236 &msg.payload,
237 created_at,
238 );
239 prev_hash = rh.clone();
240 (Some(rh), prev_for_insert)
241 } else {
242 (None, None)
243 };
244
245 sqlx::query(
246 "INSERT INTO event_journal (persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted, created_at, prev_hash, row_hash, system_time, valid_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
247 )
248 .bind(&msg.persistence_id)
249 .bind(msg.sequence_nr as i64)
250 .bind(msg.payload.clone())
251 .bind(&msg.manifest)
252 .bind(&msg.writer_uuid)
253 .bind(0i32)
254 .bind(created_at)
255 .bind(prev_for_insert)
256 .bind(row_hash_opt)
257 .bind(system_time)
258 .bind(valid_time)
259 .execute(&mut *tx)
260 .await
261 .map_err(JournalError::backend)?;
262 for tag in &msg.tags {
263 sqlx::query("INSERT INTO event_tags (persistence_id, sequence_nr, tag) VALUES (?, ?, ?)")
264 .bind(&msg.persistence_id)
265 .bind(msg.sequence_nr as i64)
266 .bind(tag)
267 .execute(&mut *tx)
268 .await
269 .map_err(JournalError::backend)?;
270 }
271 }
272 }
273 tx.commit().await.map_err(JournalError::backend)?;
274 Ok(())
275 }
276
277 async fn delete_messages_to(
278 &self,
279 persistence_id: &str,
280 to_sequence_nr: u64,
281 ) -> Result<(), JournalError> {
282 sqlx::query("UPDATE event_journal SET deleted = 1 WHERE persistence_id = ? AND sequence_nr <= ?")
283 .bind(persistence_id)
284 .bind(to_sequence_nr as i64)
285 .execute(&self.pool)
286 .await
287 .map_err(JournalError::backend)?;
288 Ok(())
289 }
290
291 async fn replay_messages(
292 &self,
293 persistence_id: &str,
294 from: u64,
295 to: u64,
296 max: u64,
297 ) -> Result<Vec<PersistentRepr>, JournalError> {
298 let limit = clamp_i64(max);
299 let to_bound = clamp_i64(to);
300 let from_bound = clamp_i64(from);
301 let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
302 "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted FROM event_journal \
303 WHERE persistence_id = ? AND sequence_nr >= ? AND sequence_nr <= ? AND deleted = 0 \
304 ORDER BY sequence_nr ASC LIMIT ?",
305 )
306 .bind(persistence_id)
307 .bind(from_bound)
308 .bind(to_bound)
309 .bind(limit)
310 .fetch_all(&self.pool)
311 .await
312 .map_err(JournalError::backend)?;
313 let mut out = Vec::with_capacity(rows.len());
314 for (pid, seq, payload, manifest, writer, deleted) in rows {
315 let tags: Vec<(String,)> =
316 sqlx::query_as("SELECT tag FROM event_tags WHERE persistence_id = ? AND sequence_nr = ?")
317 .bind(&pid)
318 .bind(seq)
319 .fetch_all(&self.pool)
320 .await
321 .map_err(JournalError::backend)?;
322 out.push(PersistentRepr {
323 persistence_id: pid,
324 sequence_nr: seq as u64,
325 payload,
326 manifest,
327 writer_uuid: writer,
328 deleted: deleted != 0,
329 tags: tags.into_iter().map(|(t,)| t).collect(),
330 });
331 }
332 Ok(out)
333 }
334
335 async fn highest_sequence_nr(
336 &self,
337 persistence_id: &str,
338 _from_sequence_nr: u64,
339 ) -> Result<u64, JournalError> {
340 self.current_highest(persistence_id).await
341 }
342
343 async fn events_by_tag(
344 &self,
345 tag: &str,
346 from_offset: u64,
347 max: u64,
348 ) -> Result<Vec<PersistentRepr>, JournalError> {
349 let limit = clamp_i64(max);
350 let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
351 "SELECT j.persistence_id, j.sequence_nr, j.payload, j.manifest, j.writer_uuid, j.deleted \
352 FROM event_journal j INNER JOIN event_tags t \
353 ON j.persistence_id = t.persistence_id AND j.sequence_nr = t.sequence_nr \
354 WHERE t.tag = ? AND j.sequence_nr >= ? AND j.deleted = 0 \
355 ORDER BY j.persistence_id, j.sequence_nr ASC LIMIT ?",
356 )
357 .bind(tag)
358 .bind(clamp_i64(from_offset))
359 .bind(limit)
360 .fetch_all(&self.pool)
361 .await
362 .map_err(JournalError::backend)?;
363 Ok(rows
364 .into_iter()
365 .map(|(pid, seq, payload, manifest, writer, deleted)| PersistentRepr {
366 persistence_id: pid,
367 sequence_nr: seq as u64,
368 payload,
369 manifest,
370 writer_uuid: writer,
371 deleted: deleted != 0,
372 tags: vec![tag.to_string()],
373 })
374 .collect())
375 }
376}