1use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_persistence::{Journal, JournalError, PersistentRepr};
11use sqlx::any::AnyPoolOptions;
12use sqlx::AnyPool;
13
14use crate::config::SqlConfig;
15use crate::schema::{ensure_schema, init_drivers};
16
17fn clamp_i64(v: u64) -> i64 {
20 if v > i64::MAX as u64 {
21 i64::MAX
22 } else {
23 v as i64
24 }
25}
26
27pub struct SqlJournal {
28 pool: AnyPool,
29 cfg: SqlConfig,
30}
31
32impl SqlJournal {
33 pub async fn connect(cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
35 init_drivers();
36 let pool = AnyPoolOptions::new()
37 .max_connections(cfg.max_connections)
38 .connect(&cfg.url)
39 .await
40 .map_err(JournalError::backend)?;
41 ensure_schema(&pool, &cfg).await?;
42 Ok(Arc::new(Self { pool, cfg }))
43 }
44
45 pub async fn from_pool(pool: AnyPool, cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
47 ensure_schema(&pool, &cfg).await?;
48 Ok(Arc::new(Self { pool, cfg }))
49 }
50
51 pub fn pool(&self) -> &AnyPool {
52 &self.pool
53 }
54
55 pub fn config(&self) -> &SqlConfig {
56 &self.cfg
57 }
58
59 async fn current_highest(&self, pid: &str) -> Result<u64, JournalError> {
60 let row: Option<(Option<i64>,)> =
61 sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
62 .bind(pid)
63 .fetch_optional(&self.pool)
64 .await
65 .map_err(JournalError::backend)?;
66 Ok(row.and_then(|(v,)| v).map(|v| v as u64).unwrap_or(0))
67 }
68}
69
70#[async_trait]
71impl Journal for SqlJournal {
72 async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
73 if messages.is_empty() {
74 return Ok(());
75 }
76 let mut tx = self.pool.begin().await.map_err(JournalError::backend)?;
77 let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
78 std::collections::BTreeMap::new();
79 for m in messages {
80 by_pid.entry(m.persistence_id.clone()).or_default().push(m);
81 }
82 for (pid, batch) in by_pid {
83 let row: Option<(Option<i64>,)> =
84 sqlx::query_as("SELECT MAX(sequence_nr) FROM event_journal WHERE persistence_id = ?")
85 .bind(&pid)
86 .fetch_optional(&mut *tx)
87 .await
88 .map_err(JournalError::backend)?;
89 let start = row.and_then(|(v,)| v).map(|v| v as u64 + 1).unwrap_or(1);
90 for (expected, msg) in (start..).zip(batch) {
91 if msg.sequence_nr != expected {
92 return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
93 }
94 let created_at = chrono::Utc::now().timestamp_millis();
95 sqlx::query(
96 "INSERT INTO event_journal (persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
97 )
98 .bind(&msg.persistence_id)
99 .bind(msg.sequence_nr as i64)
100 .bind(msg.payload.clone())
101 .bind(&msg.manifest)
102 .bind(&msg.writer_uuid)
103 .bind(0i32)
104 .bind(created_at)
105 .execute(&mut *tx)
106 .await
107 .map_err(JournalError::backend)?;
108 for tag in &msg.tags {
109 sqlx::query("INSERT INTO event_tags (persistence_id, sequence_nr, tag) VALUES (?, ?, ?)")
110 .bind(&msg.persistence_id)
111 .bind(msg.sequence_nr as i64)
112 .bind(tag)
113 .execute(&mut *tx)
114 .await
115 .map_err(JournalError::backend)?;
116 }
117 }
118 }
119 tx.commit().await.map_err(JournalError::backend)?;
120 Ok(())
121 }
122
123 async fn delete_messages_to(
124 &self,
125 persistence_id: &str,
126 to_sequence_nr: u64,
127 ) -> Result<(), JournalError> {
128 sqlx::query("UPDATE event_journal SET deleted = 1 WHERE persistence_id = ? AND sequence_nr <= ?")
129 .bind(persistence_id)
130 .bind(to_sequence_nr as i64)
131 .execute(&self.pool)
132 .await
133 .map_err(JournalError::backend)?;
134 Ok(())
135 }
136
137 async fn replay_messages(
138 &self,
139 persistence_id: &str,
140 from: u64,
141 to: u64,
142 max: u64,
143 ) -> Result<Vec<PersistentRepr>, JournalError> {
144 let limit = clamp_i64(max);
145 let to_bound = clamp_i64(to);
146 let from_bound = clamp_i64(from);
147 let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
148 "SELECT persistence_id, sequence_nr, payload, manifest, writer_uuid, deleted FROM event_journal \
149 WHERE persistence_id = ? AND sequence_nr >= ? AND sequence_nr <= ? AND deleted = 0 \
150 ORDER BY sequence_nr ASC LIMIT ?",
151 )
152 .bind(persistence_id)
153 .bind(from_bound)
154 .bind(to_bound)
155 .bind(limit)
156 .fetch_all(&self.pool)
157 .await
158 .map_err(JournalError::backend)?;
159 let mut out = Vec::with_capacity(rows.len());
160 for (pid, seq, payload, manifest, writer, deleted) in rows {
161 let tags: Vec<(String,)> =
162 sqlx::query_as("SELECT tag FROM event_tags WHERE persistence_id = ? AND sequence_nr = ?")
163 .bind(&pid)
164 .bind(seq)
165 .fetch_all(&self.pool)
166 .await
167 .map_err(JournalError::backend)?;
168 out.push(PersistentRepr {
169 persistence_id: pid,
170 sequence_nr: seq as u64,
171 payload,
172 manifest,
173 writer_uuid: writer,
174 deleted: deleted != 0,
175 tags: tags.into_iter().map(|(t,)| t).collect(),
176 });
177 }
178 Ok(out)
179 }
180
181 async fn highest_sequence_nr(
182 &self,
183 persistence_id: &str,
184 _from_sequence_nr: u64,
185 ) -> Result<u64, JournalError> {
186 self.current_highest(persistence_id).await
187 }
188
189 async fn events_by_tag(
190 &self,
191 tag: &str,
192 from_offset: u64,
193 max: u64,
194 ) -> Result<Vec<PersistentRepr>, JournalError> {
195 let limit = clamp_i64(max);
196 let rows: Vec<(String, i64, Vec<u8>, String, String, i32)> = sqlx::query_as(
197 "SELECT j.persistence_id, j.sequence_nr, j.payload, j.manifest, j.writer_uuid, j.deleted \
198 FROM event_journal j INNER JOIN event_tags t \
199 ON j.persistence_id = t.persistence_id AND j.sequence_nr = t.sequence_nr \
200 WHERE t.tag = ? AND j.sequence_nr >= ? AND j.deleted = 0 \
201 ORDER BY j.persistence_id, j.sequence_nr ASC LIMIT ?",
202 )
203 .bind(tag)
204 .bind(clamp_i64(from_offset))
205 .bind(limit)
206 .fetch_all(&self.pool)
207 .await
208 .map_err(JournalError::backend)?;
209 Ok(rows
210 .into_iter()
211 .map(|(pid, seq, payload, manifest, writer, deleted)| PersistentRepr {
212 persistence_id: pid,
213 sequence_nr: seq as u64,
214 payload,
215 manifest,
216 writer_uuid: writer,
217 deleted: deleted != 0,
218 tags: vec![tag.to_string()],
219 })
220 .collect())
221 }
222}