Skip to main content

atomr_persistence_sql/
journal.rs

1//! `Journal` implementation backed by sqlx.
2//!
3//! Uses the `sqlx::Any` pool so the same code targets every supported
4//! dialect. Tag writes go to a companion `event_tags` table that powers
5//! `events_by_tag`.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_persistence::{Journal, JournalError, PersistentRepr};
11use sqlx::any::AnyPoolOptions;
12use sqlx::AnyPool;
13
14use crate::config::SqlConfig;
15use crate::schema::{ensure_schema, init_drivers};
16use crate::worm::{compute_row_hash, WormConfig};
17
18/// Saturating cast from `u64` to `i64` so `u64::MAX` sentinels turn into
19/// `i64::MAX` instead of wrapping negative.
20fn clamp_i64(v: u64) -> i64 {
21    if v > i64::MAX as u64 {
22        i64::MAX
23    } else {
24        v as i64
25    }
26}
27
28/// FR-8: extract `valid_time` (nanos) from a `valid_time:<nanos>` tag.
29/// Returns `None` when absent so the column stays NULL.
30fn parse_valid_time(tags: &[String]) -> Option<i64> {
31    for t in tags {
32        if let Some(rest) = t.strip_prefix("valid_time:") {
33            if let Ok(n) = rest.parse::<i64>() {
34                return Some(n);
35            }
36        }
37    }
38    None
39}
40
41pub struct SqlJournal {
42    pool: AnyPool,
43    cfg: SqlConfig,
44    worm: WormConfig,
45}
46
47impl SqlJournal {
48    /// Connect, install drivers, and optionally run migrations.
49    pub async fn connect(cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
50        init_drivers();
51        let pool = AnyPoolOptions::new()
52            .max_connections(cfg.max_connections)
53            .connect(&cfg.url)
54            .await
55            .map_err(JournalError::backend)?;
56        ensure_schema(&pool, &cfg).await?;
57        Ok(Arc::new(Self { pool, cfg, worm: WormConfig::default() }))
58    }
59
60    /// Reuse an existing pool (for tests or app-wide sharing).
61    pub async fn from_pool(pool: AnyPool, cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
62        ensure_schema(&pool, &cfg).await?;
63        Ok(Arc::new(Self { pool, cfg, worm: WormConfig::default() }))
64    }
65
66    /// Turn on WORM protections (FR-9).
67    ///
68    /// When `deny_update_delete` is set, this installs the dialect's
69    /// append-only DDL immediately. When `hash_chain` is set, subsequent
70    /// writes maintain the per-pid tamper-evident hash chain. Consumes the
71    /// (typically freshly-built) journal and returns a reconfigured `Arc`.
72    pub async fn with_worm(self: Arc<Self>, worm: WormConfig) -> Result<Arc<Self>, JournalError> {
73        if worm.deny_update_delete {
74            crate::schema::install_worm_triggers(&self.pool, &self.cfg).await?;
75        }
76        Ok(Arc::new(Self { pool: self.pool.clone(), cfg: self.cfg.clone(), worm }))
77    }
78
79    pub fn pool(&self) -> &AnyPool {
80        &self.pool
81    }
82
83    pub fn config(&self) -> &SqlConfig {
84        &self.cfg
85    }
86
87    pub fn worm_config(&self) -> WormConfig {
88        self.worm
89    }
90
91    async fn current_highest(&self, pid: &str) -> Result<u64, JournalError> {
92        let row: Option<(Option<i64>,)> =
93            sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
94                .bind(pid)
95                .fetch_optional(&self.pool)
96                .await
97                .map_err(JournalError::backend)?;
98        Ok(row.and_then(|(v,)| v).map(|v| v as u64).unwrap_or(0))
99    }
100
101    /// FR-8 — system-time as-of: rows recorded at or before
102    /// `system_time_nanos`. `system_time` falls back to `created_at` for rows
103    /// written before the column existed. Later-recorded restatements (whose
104    /// `system_time` is greater) are excluded → no lookahead.
105    pub async fn replay_as_of(
106        &self,
107        pid: &str,
108        system_time_nanos: i64,
109    ) -> Result<Vec<PersistentRepr>, JournalError> {
110        let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
111            "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted \
112             FROM event_journal \
113             WHERE persistence_id = ? AND deleted = 0 \
114               AND COALESCE(system_time, created_at) <= ? \
115             ORDER BY sequence_nr ASC",
116        )
117        .bind(pid)
118        .bind(system_time_nanos)
119        .fetch_all(&self.pool)
120        .await
121        .map_err(JournalError::backend)?;
122        self.hydrate(rows).await
123    }
124
125    /// FR-8 — bitemporal slice: rows whose `valid_time` is at or before
126    /// `valid_time_nanos`, restricted to what was known to the system at
127    /// `system_time_nanos`. Rows without a `valid_time` are treated as valid
128    /// from their `system_time` (always-valid) so they remain visible.
129    pub async fn replay_valid_as_of(
130        &self,
131        pid: &str,
132        valid_time_nanos: i64,
133        system_time_nanos: i64,
134    ) -> Result<Vec<PersistentRepr>, JournalError> {
135        let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
136            "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted \
137             FROM event_journal \
138             WHERE persistence_id = ? AND deleted = 0 \
139               AND COALESCE(system_time, created_at) <= ? \
140               AND COALESCE(valid_time, COALESCE(system_time, created_at)) <= ? \
141             ORDER BY sequence_nr ASC",
142        )
143        .bind(pid)
144        .bind(system_time_nanos)
145        .bind(valid_time_nanos)
146        .fetch_all(&self.pool)
147        .await
148        .map_err(JournalError::backend)?;
149        self.hydrate(rows).await
150    }
151
152    /// Attach tags to bare journal rows, reproducing `replay_messages` shape.
153    async fn hydrate(
154        &self,
155        rows: Vec<(String, i64, Vec<u8>, String, String, i32)>,
156    ) -> Result<Vec<PersistentRepr>, JournalError> {
157        let mut out = Vec::with_capacity(rows.len());
158        for (pid, seq, payload, manifest, writer, deleted) in rows {
159            let tags: Vec<(String,)> =
160                sqlx::query_as("SELECT tag FROM event_tags WHERE persistence_id = ? AND sequence_nr = ?")
161                    .bind(&pid)
162                    .bind(seq)
163                    .fetch_all(&self.pool)
164                    .await
165                    .map_err(JournalError::backend)?;
166            out.push(PersistentRepr {
167                persistence_id: pid,
168                sequence_nr: seq as u64,
169                payload,
170                manifest,
171                writer_uuid: writer,
172                deleted: deleted != 0,
173                tags: tags.into_iter().map(|(t,)| t).collect(),
174            });
175        }
176        Ok(out)
177    }
178}
179
180#[async_trait]
181impl Journal for SqlJournal {
182    async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
183        if messages.is_empty() {
184            return Ok(());
185        }
186        let mut tx = self.pool.begin().await.map_err(JournalError::backend)?;
187        let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
188            std::collections::BTreeMap::new();
189        for m in messages {
190            by_pid.entry(m.persistence_id.clone()).or_default().push(m);
191        }
192        for (pid, batch) in by_pid {
193            let row: Option<(Option<i64>,)> =
194                sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
195                    .bind(&pid)
196                    .fetch_optional(&mut *tx)
197                    .await
198                    .map_err(JournalError::backend)?;
199            let start = row.and_then(|(v,)| v).map(|v| v as u64 + 1).unwrap_or(1);
200
201            // Seed the running hash from the latest existing row so the chain
202            // survives across separate write batches.
203            let mut prev_hash: Vec<u8> = if self.worm.hash_chain {
204                let last: Option<(Option<Vec<u8>>,)> = sqlx::query_as(
205                    "SELECT row_hash FROM event_journal WHERE persistence_id = ? \
206                     ORDER BY sequence_nr DESC LIMIT 1",
207                )
208                .bind(&pid)
209                .fetch_optional(&mut *tx)
210                .await
211                .map_err(JournalError::backend)?;
212                last.and_then(|(h,)| h).unwrap_or_default()
213            } else {
214                Vec::new()
215            };
216
217            for (expected, msg) in (start..).zip(batch) {
218                if msg.sequence_nr != expected {
219                    return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
220                }
221                let created_at = chrono::Utc::now().timestamp_millis();
222                // FR-8: system_time is backend-assigned (defaults to created_at);
223                // valid_time is parsed from a `valid_time:<nanos>` tag if present.
224                let system_time = created_at;
225                let valid_time = parse_valid_time(&msg.tags);
226
227                // FR-9: compute the chain hash for this row when enabled.
228                let (row_hash_opt, prev_for_insert): (Option<Vec<u8>>, Option<Vec<u8>>) =
229                    if self.worm.hash_chain {
230                        let prev_for_insert =
231                            if prev_hash.is_empty() { None } else { Some(prev_hash.clone()) };
232                        let rh = compute_row_hash(
233                            &prev_hash,
234                            &msg.persistence_id,
235                            msg.sequence_nr,
236                            &msg.payload,
237                            created_at,
238                        );
239                        prev_hash = rh.clone();
240                        (Some(rh), prev_for_insert)
241                    } else {
242                        (None, None)
243                    };
244
245                sqlx::query(
246                    "INSERT INTO event_journal (persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted, created_at, prev_hash, row_hash, system_time, valid_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
247                )
248                .bind(&msg.persistence_id)
249                .bind(msg.sequence_nr as i64)
250                .bind(msg.payload.clone())
251                .bind(&msg.manifest)
252                .bind(&msg.writer_uuid)
253                .bind(0i32)
254                .bind(created_at)
255                .bind(prev_for_insert)
256                .bind(row_hash_opt)
257                .bind(system_time)
258                .bind(valid_time)
259                .execute(&mut *tx)
260                .await
261                .map_err(JournalError::backend)?;
262                for tag in &msg.tags {
263                    sqlx::query("INSERT INTO event_tags (persistence_id, sequence_nr, tag) VALUES (?, ?, ?)")
264                        .bind(&msg.persistence_id)
265                        .bind(msg.sequence_nr as i64)
266                        .bind(tag)
267                        .execute(&mut *tx)
268                        .await
269                        .map_err(JournalError::backend)?;
270                }
271            }
272        }
273        tx.commit().await.map_err(JournalError::backend)?;
274        Ok(())
275    }
276
277    async fn delete_messages_to(
278        &self,
279        persistence_id: &str,
280        to_sequence_nr: u64,
281    ) -> Result<(), JournalError> {
282        sqlx::query("UPDATE event_journal SET deleted = 1 WHERE persistence_id = ? AND sequence_nr <= ?")
283            .bind(persistence_id)
284            .bind(to_sequence_nr as i64)
285            .execute(&self.pool)
286            .await
287            .map_err(JournalError::backend)?;
288        Ok(())
289    }
290
291    async fn replay_messages(
292        &self,
293        persistence_id: &str,
294        from: u64,
295        to: u64,
296        max: u64,
297    ) -> Result<Vec<PersistentRepr>, JournalError> {
298        let limit = clamp_i64(max);
299        let to_bound = clamp_i64(to);
300        let from_bound = clamp_i64(from);
301        let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
302            "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted FROM event_journal \
303             WHERE persistence_id = ? AND sequence_nr >= ? AND sequence_nr <= ? AND deleted = 0 \
304             ORDER BY sequence_nr ASC LIMIT ?",
305        )
306        .bind(persistence_id)
307        .bind(from_bound)
308        .bind(to_bound)
309        .bind(limit)
310        .fetch_all(&self.pool)
311        .await
312        .map_err(JournalError::backend)?;
313        let mut out = Vec::with_capacity(rows.len());
314        for (pid, seq, payload, manifest, writer, deleted) in rows {
315            let tags: Vec<(String,)> =
316                sqlx::query_as("SELECT tag FROM event_tags WHERE persistence_id = ? AND sequence_nr = ?")
317                    .bind(&pid)
318                    .bind(seq)
319                    .fetch_all(&self.pool)
320                    .await
321                    .map_err(JournalError::backend)?;
322            out.push(PersistentRepr {
323                persistence_id: pid,
324                sequence_nr: seq as u64,
325                payload,
326                manifest,
327                writer_uuid: writer,
328                deleted: deleted != 0,
329                tags: tags.into_iter().map(|(t,)| t).collect(),
330            });
331        }
332        Ok(out)
333    }
334
335    async fn highest_sequence_nr(
336        &self,
337        persistence_id: &str,
338        _from_sequence_nr: u64,
339    ) -> Result<u64, JournalError> {
340        self.current_highest(persistence_id).await
341    }
342
343    async fn events_by_tag(
344        &self,
345        tag: &str,
346        from_offset: u64,
347        max: u64,
348    ) -> Result<Vec<PersistentRepr>, JournalError> {
349        let limit = clamp_i64(max);
350        let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
351            "SELECT j.persistence_id, j.sequence_nr, j.payload, j.manifest, j.writer_uuid, j.deleted \
352             FROM event_journal j INNER JOIN event_tags t \
353             ON j.persistence_id = t.persistence_id AND j.sequence_nr = t.sequence_nr \
354             WHERE t.tag = ? AND j.sequence_nr >= ? AND j.deleted = 0 \
355             ORDER BY j.persistence_id, j.sequence_nr ASC LIMIT ?",
356        )
357        .bind(tag)
358        .bind(clamp_i64(from_offset))
359        .bind(limit)
360        .fetch_all(&self.pool)
361        .await
362        .map_err(JournalError::backend)?;
363        Ok(rows
364            .into_iter()
365            .map(|(pid, seq, payload, manifest, writer, deleted)| PersistentRepr {
366                persistence_id: pid,
367                sequence_nr: seq as u64,
368                payload,
369                manifest,
370                writer_uuid: writer,
371                deleted: deleted != 0,
372                tags: vec![tag.to_string()],
373            })
374            .collect())
375    }
376}