1use std::fmt::{Debug, Display};
2
3use r2d2::Pool;
4use r2d2_sqlite::SqliteConnectionManager;
5use rusqlite::{params, params_from_iter, Statement, Transaction};
6use tracing::{debug, instrument, warn};
7use uuid::Uuid;
8
9use crate::backend::model::Event;
10
11#[derive(Clone)]
12pub struct SqliteBackend {
13 pool: Pool<SqliteConnectionManager>,
14}
15
16#[derive(Debug)]
17pub struct GetAggOpts {
18 pub agg_id: Uuid,
19 pub since_version: u32,
20}
21
22pub enum Error {
23 WithMsg(String),
24 InvalidUUID,
25 NotFound,
26 Sqlite(rusqlite::Error),
27 R2D2Sqlite(r2d2::Error),
28}
29
30impl Display for Error {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Error::InvalidUUID => f.write_fmt(format_args!("could not parse uuid")),
34 Error::Sqlite(err) => f.write_fmt(format_args!("sqlite: {}", err)),
35 Error::R2D2Sqlite(err) => f.write_fmt(format_args!("r2d2_sqlite: {}", err)),
36 Error::WithMsg(msg) => f.write_fmt(format_args!("plain error: {}", msg)),
37 Error::NotFound => f.write_fmt(format_args!("not found")),
38 }
39 }
40}
41
42impl Debug for Error {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Error::InvalidUUID => f.write_fmt(format_args!("could not parse uuid")),
46 Error::Sqlite(err) => f.write_fmt(format_args!("sqlite: {}", err)),
47 Error::R2D2Sqlite(err) => f.write_fmt(format_args!("r2d2_sqlite: {}", err)),
48 Error::WithMsg(msg) => f.write_fmt(format_args!("plain error: {}", msg)),
49 Error::NotFound => f.write_fmt(format_args!("not found")),
50 }
51 }
52}
53
54impl From<rusqlite::Error> for Error {
55 fn from(value: rusqlite::Error) -> Self {
56 Error::Sqlite(value)
57 }
58}
59
60impl From<r2d2::Error> for Error {
61 fn from(value: r2d2::Error) -> Self {
62 Error::R2D2Sqlite(value)
63 }
64}
65
66static CREATE_AGGREGATE_OVERVIEW_TABLE_STMT: &'static str = "CREATE TABLE aggregate_index(
67 aggregate_id TEXT PRIMARY KEY,
68 type_name TEXT,
69 version INTEGER
70 )";
71
72static CREATE_AGGREGATE_TABLE_STMT: &'static str = "CREATE TABLE eventstore(
73 aggregate_id TEXT,
74 data BLOB,
75 version INTEGER
76 )";
77
78static CREATE_SNAPSHOT_OVERVIEW_TABLE_STMT: &'static str = "CREATE TABLE snapshot_index(
79 aggregate_id TEXT PRIMARY KEY,
80 type_name TEXT,
81 version INTEGER
82 )";
83
84static CREATE_SNAPSHOT_TABLE_STMT: &'static str = "CREATE TABLE snapshot(
85 aggregate_id TEXT,
86 data BLOB,
87 version INTEGER
88 )";
89
90impl Debug for SqliteBackend {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("SqliteBackend")
93 .field("pool", &self.pool.state())
94 .finish()
95 }
96}
97
98impl SqliteBackend {
99 pub fn new(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
100 let pool = r2d2::Pool::new(manager).unwrap(); let backend = Self { pool };
103 backend.init_tables().unwrap();
104 backend.init_indices().unwrap();
105 return backend;
106 }
107
108 #[instrument]
109 fn init_tables(&self) -> Result<(), Error> {
110 let _span = tracing::debug_span!("creating tables").entered();
111 for qry in vec![
112 CREATE_AGGREGATE_TABLE_STMT,
113 CREATE_AGGREGATE_OVERVIEW_TABLE_STMT,
114 CREATE_SNAPSHOT_TABLE_STMT,
115 CREATE_SNAPSHOT_OVERVIEW_TABLE_STMT,
116 ] {
117 self.pool.get()?.execute(qry, params![])?;
118 }
119 return Ok(());
120 }
121
122 #[instrument]
123 fn init_indices(&self) -> Result<(), Error> {
124 self.pool.get()?.execute(
125 "CREATE INDEX IF NOT EXISTS eventstore_agg_id_idx ON eventstore (aggregate_id)",
126 params![],
127 )?;
128 self.pool.get()?.execute(
129 "CREATE INDEX IF NOT EXISTS snapshot_agg_id_idx ON snapshot (aggregate_id)",
130 params![],
131 )?;
132 self.pool.get()?.execute(
133 "CREATE UNIQUE INDEX IF NOT EXISTS snapshot_unique_idx ON snapshot (aggregate_id, version)",
134 params![],
135 )?;
136 return Ok(());
137 }
138
139 #[instrument]
140 pub fn get_agg_max_version(&self, tx: &Transaction, agg_id_str: &str) -> Result<u32, Error> {
141 let mut stmt = tx
142 .prepare("SELECT COALESCE(MAX(version), 0) as max_version FROM aggregate_index WHERE aggregate_id = ?")?;
143 let version = stmt.query_row(params![agg_id_str], |row| match row.get(0) {
144 Ok(val) => Ok(val),
145 Err(err) => {
146 warn!(sqlite_error = err.to_string());
147 Err(err)
148 }
149 })?;
150 debug!(current_event_version = version);
151 Ok(version)
152 }
153
154 #[instrument]
161 pub fn save_snapshot(&self, event: &Event) -> Result<(), Error> {
162 let mut conn = self.pool.get()?;
163 let tx = conn.transaction()?;
164 tx.execute(
165 "INSERT INTO snapshot(aggregate_id, version, data) VALUES(?,?,?)
166 ON CONFLICT(aggregate_id, version) DO UPDATE SET version = excluded.version, data = excluded.data",
167 params![&event.id.to_string(), event.version, event.data],
168 )?;
169 let res = tx.execute(
170 "INSERT INTO snapshot_index(version, aggregate_id, type_name) VALUES(?,?, 'todo_implement_type_name')
171 ON CONFLICT(aggregate_id) DO UPDATE SET version = ?",
172 params![event.version, &event.id.to_string(), event.version],
173 );
174 match res {
175 Ok(_) => match tx.commit() {
176 Ok(_) => Ok(()),
177 Err(err) => {
178 warn!(sqlite_error = err.to_string());
179 Err(Error::Sqlite(err))
180 }
181 },
182 Err(err) => {
183 warn!(sqlite_error = err.to_string());
184 Err(Error::Sqlite(err))
185 }
186 }
187 }
188
189 #[instrument]
190 pub fn append_event(&self, event: &Event) -> Result<(), Error> {
191 let mut conn = self.pool.get()?;
192 let tx = match conn.transaction() {
193 Ok(tx) => tx,
194 Err(err) => {
195 warn!(sqlite_error = err.to_string());
196 return Err(Error::Sqlite(err));
197 }
198 };
199 let version = match self.get_agg_max_version(&tx, &event.id.to_string()) {
200 Ok(version) => version,
201 Err(err) => {
202 return Err(err);
203 }
204 };
205 let expected_version = version + 1;
206 if event.version != expected_version {
207 warn!("version mismtach {} != {}", event.version, expected_version);
208 return Err(Error::WithMsg("version mismtach".to_string()));
209 }
210 let res = tx.execute(
211 "INSERT INTO eventstore(aggregate_id, version, data) VALUES(?,?,?)",
212 params![&event.id.to_string(), event.version, event.data],
213 );
214 if let Err(err) = res {
215 warn!(sqlite_error = err.to_string());
216 return Err(Error::Sqlite(err));
217 }
218 let res = tx.execute(
219 "INSERT INTO aggregate_index(version, aggregate_id, type_name) VALUES(?,?, 'todo_implement_type_name')
220 ON CONFLICT(aggregate_id) DO UPDATE SET version = ?",
221 params![event.version, &event.id.to_string(), event.version],
222 );
223 match res {
224 Ok(_) => match tx.commit() {
225 Ok(_) => Ok(()),
226 Err(err) => {
227 warn!(sqlite_error = err.to_string());
228 Err(Error::Sqlite(err))
229 }
230 },
231 Err(err) => {
232 warn!(sqlite_error = err.to_string());
233 Err(Error::Sqlite(err))
234 }
235 }
236 }
237
238 #[instrument]
239 fn result_from_stmt(stmt: &mut Statement, agg_id_str: &str) -> Result<Vec<Event>, Error> {
240 let params = vec![agg_id_str];
241 Self::result_from_stmt_with_params(stmt, ¶ms)
242 }
243
244 fn result_from_stmt_with_params(
245 stmt: &mut Statement,
246 params: &Vec<&str>,
247 ) -> Result<Vec<Event>, Error> {
248 let mut events: Vec<_> = Vec::new();
249 let query_res = stmt.query_and_then(params_from_iter(params), |r| {
250 let id = if let Ok(tmp) = r.get::<_, String>(0) {
251 match uuid::Uuid::parse_str(tmp.as_str()) {
252 Ok(id) => id,
253 Err(_) => return Err(Error::InvalidUUID),
254 }
255 } else {
256 return Err(Error::WithMsg("could not read uuid from row".to_string()));
257 };
258 Ok(Event {
259 id,
260 data: r.get(1)?,
261 version: r.get(2)?,
262 })
263 });
264 match query_res {
265 Ok(iter) => {
266 iter.filter_map(|e| match e {
267 Ok(val) => Some(val),
268 Err(_) => None,
269 })
270 .fold(&mut events, |acc, e| {
271 acc.push(e);
272 acc
273 });
274 Ok(events)
275 }
276 Err(err) => {
277 warn!(sqlite_error = err.to_string());
278 Err(Error::Sqlite(err))
279 }
280 }
281 }
282
283 #[instrument]
284 pub fn get_aggretate(&self, aggregate_id: Uuid) -> Result<Vec<Event>, Error> {
285 let agg_id_str: String = aggregate_id.to_string();
286 let conn = self.pool.get()?;
287 let mut stmt =
288 conn.prepare("SELECT * FROM eventstore WHERE aggregate_id = ? ORDER BY version ASC")?;
289 SqliteBackend::result_from_stmt(&mut stmt, &agg_id_str)
290 }
291
292 #[instrument]
293 pub fn get_snapshots(&self, aggregate_id: Uuid) -> Result<Vec<Event>, Error> {
294 let agg_id_str: String = aggregate_id.to_string();
295 let conn = self.pool.get()?;
296 let mut stmt =
297 conn.prepare("SELECT * FROM snapshot WHERE aggregate_id = ? ORDER BY version ASC")?;
298 SqliteBackend::result_from_stmt(&mut stmt, &agg_id_str)
299 }
300
301 #[instrument]
302 pub fn get_snapshot_by_version(
303 &self,
304 aggregate_id: Uuid,
305 version: u32,
306 ) -> Result<Event, Error> {
307 let agg_id_str: String = aggregate_id.to_string();
308 let conn = self.pool.get()?;
309 let mut stmt = conn.prepare(
310 "SELECT * FROM snapshot WHERE aggregate_id = ? AND version = ? ORDER BY version ASC",
311 )?;
312 SqliteBackend::result_from_stmt_with_params(
313 &mut stmt,
314 &vec![&agg_id_str, &version.to_string()],
315 )?
316 .pop()
317 .ok_or(Error::NotFound)
318 }
319
320 #[instrument]
321 pub fn get_aggretate_with_opts(
322 &self,
323 aggregate_id: Uuid,
324 opts: &GetAggOpts,
325 ) -> Result<Vec<Event>, Error> {
326 let agg_id_str: String = aggregate_id.to_string();
327 let conn = self.pool.get()?;
328 let mut stmt = conn.prepare(
329 "SELECT * FROM eventstore WHERE aggregate_id = ? AND version > ? ORDER BY version ASC",
330 )?;
331 SqliteBackend::result_from_stmt_with_params(
332 &mut stmt,
333 &vec![&agg_id_str, &opts.since_version.to_string()],
334 )
335 }
336}