Skip to main content

atomr_persistence_sql/
journal.rs

1//! `Journal` implementation backed by sqlx.
2//!
3//! Uses the `sqlx::Any` pool so the same code targets every supported
4//! dialect. Tag writes go to a companion `event_tags` table that powers
5//! `events_by_tag`.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_persistence::{Journal, JournalError, PersistentRepr};
11use sqlx::any::AnyPoolOptions;
12use sqlx::AnyPool;
13
14use crate::config::SqlConfig;
15use crate::schema::{ensure_schema, init_drivers};
16
17/// Saturating cast from `u64` to `i64` so `u64::MAX` sentinels turn into
18/// `i64::MAX` instead of wrapping negative.
19fn clamp_i64(v: u64) -> i64 {
20    if v > i64::MAX as u64 {
21        i64::MAX
22    } else {
23        v as i64
24    }
25}
26
27pub struct SqlJournal {
28    pool: AnyPool,
29    cfg: SqlConfig,
30}
31
32impl SqlJournal {
33    /// Connect, install drivers, and optionally run migrations.
34    pub async fn connect(cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
35        init_drivers();
36        let pool = AnyPoolOptions::new()
37            .max_connections(cfg.max_connections)
38            .connect(&cfg.url)
39            .await
40            .map_err(JournalError::backend)?;
41        ensure_schema(&pool, &cfg).await?;
42        Ok(Arc::new(Self { pool, cfg }))
43    }
44
45    /// Reuse an existing pool (for tests or app-wide sharing).
46    pub async fn from_pool(pool: AnyPool, cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
47        ensure_schema(&pool, &cfg).await?;
48        Ok(Arc::new(Self { pool, cfg }))
49    }
50
51    pub fn pool(&self) -> &AnyPool {
52        &self.pool
53    }
54
55    pub fn config(&self) -> &SqlConfig {
56        &self.cfg
57    }
58
59    async fn current_highest(&self, pid: &str) -> Result<u64, JournalError> {
60        let row: Option<(Option<i64>,)> =
61            sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
62                .bind(pid)
63                .fetch_optional(&self.pool)
64                .await
65                .map_err(JournalError::backend)?;
66        Ok(row.and_then(|(v,)| v).map(|v| v as u64).unwrap_or(0))
67    }
68}
69
70#[async_trait]
71impl Journal for SqlJournal {
72    async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
73        if messages.is_empty() {
74            return Ok(());
75        }
76        let mut tx = self.pool.begin().await.map_err(JournalError::backend)?;
77        let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
78            std::collections::BTreeMap::new();
79        for m in messages {
80            by_pid.entry(m.persistence_id.clone()).or_default().push(m);
81        }
82        for (pid, batch) in by_pid {
83            let row: Option<(Option<i64>,)> =
84                sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
85                    .bind(&pid)
86                    .fetch_optional(&mut *tx)
87                    .await
88                    .map_err(JournalError::backend)?;
89            let start = row.and_then(|(v,)| v).map(|v| v as u64 + 1).unwrap_or(1);
90            for (expected, msg) in (start..).zip(batch) {
91                if msg.sequence_nr != expected {
92                    return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
93                }
94                let created_at = chrono::Utc::now().timestamp_millis();
95                sqlx::query(
96                    "INSERT INTO event_journal (persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
97                )
98                .bind(&msg.persistence_id)
99                .bind(msg.sequence_nr as i64)
100                .bind(msg.payload.clone())
101                .bind(&msg.manifest)
102                .bind(&msg.writer_uuid)
103                .bind(0i32)
104                .bind(created_at)
105                .execute(&mut *tx)
106                .await
107                .map_err(JournalError::backend)?;
108                for tag in &msg.tags {
109                    sqlx::query("INSERT INTO event_tags (persistence_id, sequence_nr, tag) VALUES (?, ?, ?)")
110                        .bind(&msg.persistence_id)
111                        .bind(msg.sequence_nr as i64)
112                        .bind(tag)
113                        .execute(&mut *tx)
114                        .await
115                        .map_err(JournalError::backend)?;
116                }
117            }
118        }
119        tx.commit().await.map_err(JournalError::backend)?;
120        Ok(())
121    }
122
123    async fn delete_messages_to(
124        &self,
125        persistence_id: &str,
126        to_sequence_nr: u64,
127    ) -> Result<(), JournalError> {
128        sqlx::query("UPDATE event_journal SET deleted = 1 WHERE persistence_id = ? AND sequence_nr <= ?")
129            .bind(persistence_id)
130            .bind(to_sequence_nr as i64)
131            .execute(&self.pool)
132            .await
133            .map_err(JournalError::backend)?;
134        Ok(())
135    }
136
137    async fn replay_messages(
138        &self,
139        persistence_id: &str,
140        from: u64,
141        to: u64,
142        max: u64,
143    ) -> Result<Vec<PersistentRepr>, JournalError> {
144        let limit = clamp_i64(max);
145        let to_bound = clamp_i64(to);
146        let from_bound = clamp_i64(from);
147        let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
148            "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted FROM event_journal \
149             WHERE persistence_id = ? AND sequence_nr >= ? AND sequence_nr <= ? AND deleted = 0 \
150             ORDER BY sequence_nr ASC LIMIT ?",
151        )
152        .bind(persistence_id)
153        .bind(from_bound)
154        .bind(to_bound)
155        .bind(limit)
156        .fetch_all(&self.pool)
157        .await
158        .map_err(JournalError::backend)?;
159        let mut out = Vec::with_capacity(rows.len());
160        for (pid, seq, payload, manifest, writer, deleted) in rows {
161            let tags: Vec<(String,)> =
162                sqlx::query_as("SELECT tag FROM event_tags WHERE persistence_id = ? AND sequence_nr = ?")
163                    .bind(&pid)
164                    .bind(seq)
165                    .fetch_all(&self.pool)
166                    .await
167                    .map_err(JournalError::backend)?;
168            out.push(PersistentRepr {
169                persistence_id: pid,
170                sequence_nr: seq as u64,
171                payload,
172                manifest,
173                writer_uuid: writer,
174                deleted: deleted != 0,
175                tags: tags.into_iter().map(|(t,)| t).collect(),
176            });
177        }
178        Ok(out)
179    }
180
181    async fn highest_sequence_nr(
182        &self,
183        persistence_id: &str,
184        _from_sequence_nr: u64,
185    ) -> Result<u64, JournalError> {
186        self.current_highest(persistence_id).await
187    }
188
189    async fn events_by_tag(
190        &self,
191        tag: &str,
192        from_offset: u64,
193        max: u64,
194    ) -> Result<Vec<PersistentRepr>, JournalError> {
195        let limit = clamp_i64(max);
196        let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
197            "SELECT j.persistence_id, j.sequence_nr, j.payload, j.manifest, j.writer_uuid, j.deleted \
198             FROM event_journal j INNER JOIN event_tags t \
199             ON j.persistence_id = t.persistence_id AND j.sequence_nr = t.sequence_nr \
200             WHERE t.tag = ? AND j.sequence_nr >= ? AND j.deleted = 0 \
201             ORDER BY j.persistence_id, j.sequence_nr ASC LIMIT ?",
202        )
203        .bind(tag)
204        .bind(clamp_i64(from_offset))
205        .bind(limit)
206        .fetch_all(&self.pool)
207        .await
208        .map_err(JournalError::backend)?;
209        Ok(rows
210            .into_iter()
211            .map(|(pid, seq, payload, manifest, writer, deleted)| PersistentRepr {
212                persistence_id: pid,
213                sequence_nr: seq as u64,
214                payload,
215                manifest,
216                writer_uuid: writer,
217                deleted: deleted != 0,
218                tags: vec![tag.to_string()],
219            })
220            .collect())
221    }
222}