eventstore/backend/
sqlite.rs

1use std::fmt::{Debug, Display};
2
3use r2d2::Pool;
4use r2d2_sqlite::SqliteConnectionManager;
5use rusqlite::{params, params_from_iter, Statement, Transaction};
6use tracing::{debug, instrument, warn};
7use uuid::Uuid;
8
9use crate::backend::model::Event;
10
11#[derive(Clone)]
12pub struct SqliteBackend {
13    pool: Pool<SqliteConnectionManager>,
14}
15
16#[derive(Debug)]
17pub struct GetAggOpts {
18    pub agg_id: Uuid,
19    pub since_version: u32,
20}
21
22pub enum Error {
23    WithMsg(String),
24    InvalidUUID,
25    NotFound,
26    Sqlite(rusqlite::Error),
27    R2D2Sqlite(r2d2::Error),
28}
29
30impl Display for Error {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Error::InvalidUUID => f.write_fmt(format_args!("could not parse uuid")),
34            Error::Sqlite(err) => f.write_fmt(format_args!("sqlite: {}", err)),
35            Error::R2D2Sqlite(err) => f.write_fmt(format_args!("r2d2_sqlite: {}", err)),
36            Error::WithMsg(msg) => f.write_fmt(format_args!("plain error: {}", msg)),
37            Error::NotFound => f.write_fmt(format_args!("not found")),
38        }
39    }
40}
41
42impl Debug for Error {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            Error::InvalidUUID => f.write_fmt(format_args!("could not parse uuid")),
46            Error::Sqlite(err) => f.write_fmt(format_args!("sqlite: {}", err)),
47            Error::R2D2Sqlite(err) => f.write_fmt(format_args!("r2d2_sqlite: {}", err)),
48            Error::WithMsg(msg) => f.write_fmt(format_args!("plain error: {}", msg)),
49            Error::NotFound => f.write_fmt(format_args!("not found")),
50        }
51    }
52}
53
54impl From<rusqlite::Error> for Error {
55    fn from(value: rusqlite::Error) -> Self {
56        Error::Sqlite(value)
57    }
58}
59
60impl From<r2d2::Error> for Error {
61    fn from(value: r2d2::Error) -> Self {
62        Error::R2D2Sqlite(value)
63    }
64}
65
66static CREATE_AGGREGATE_OVERVIEW_TABLE_STMT: &'static str = "CREATE TABLE aggregate_index(
67                aggregate_id TEXT PRIMARY KEY,
68                type_name TEXT,
69                version INTEGER
70            )";
71
72static CREATE_AGGREGATE_TABLE_STMT: &'static str = "CREATE TABLE eventstore(
73                aggregate_id TEXT,
74                data BLOB,
75                version INTEGER
76            )";
77
78static CREATE_SNAPSHOT_OVERVIEW_TABLE_STMT: &'static str = "CREATE TABLE snapshot_index(
79                aggregate_id TEXT PRIMARY KEY,
80                type_name TEXT,
81                version INTEGER
82            )";
83
84static CREATE_SNAPSHOT_TABLE_STMT: &'static str = "CREATE TABLE snapshot(
85                aggregate_id TEXT,
86                data BLOB,
87                version INTEGER
88            )";
89
90impl Debug for SqliteBackend {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("SqliteBackend")
93            .field("pool", &self.pool.state())
94            .finish()
95    }
96}
97
98impl SqliteBackend {
99    pub fn new(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
100        let pool = r2d2::Pool::new(manager).unwrap(); // TODO(juf): this should also be the
101                                                      // responsibility of the caller in the future to make this lib even thinner.
102        let backend = Self { pool };
103        backend.init_tables().unwrap();
104        backend.init_indices().unwrap();
105        return backend;
106    }
107
108    #[instrument]
109    fn init_tables(&self) -> Result<(), Error> {
110        let _span = tracing::debug_span!("creating tables").entered();
111        for qry in vec![
112            CREATE_AGGREGATE_TABLE_STMT,
113            CREATE_AGGREGATE_OVERVIEW_TABLE_STMT,
114            CREATE_SNAPSHOT_TABLE_STMT,
115            CREATE_SNAPSHOT_OVERVIEW_TABLE_STMT,
116        ] {
117            self.pool.get()?.execute(qry, params![])?;
118        }
119        return Ok(());
120    }
121
122    #[instrument]
123    fn init_indices(&self) -> Result<(), Error> {
124        self.pool.get()?.execute(
125            "CREATE INDEX IF NOT EXISTS eventstore_agg_id_idx ON eventstore (aggregate_id)",
126            params![],
127        )?;
128        self.pool.get()?.execute(
129            "CREATE INDEX IF NOT EXISTS snapshot_agg_id_idx ON snapshot (aggregate_id)",
130            params![],
131        )?;
132        self.pool.get()?.execute(
133            "CREATE UNIQUE INDEX IF NOT EXISTS snapshot_unique_idx ON snapshot (aggregate_id, version)",
134            params![],
135        )?;
136        return Ok(());
137    }
138
139    #[instrument]
140    pub fn get_agg_max_version(&self, tx: &Transaction, agg_id_str: &str) -> Result<u32, Error> {
141        let mut stmt = tx
142            .prepare("SELECT COALESCE(MAX(version), 0) as max_version FROM aggregate_index WHERE aggregate_id = ?")?;
143        let version = stmt.query_row(params![agg_id_str], |row| match row.get(0) {
144            Ok(val) => Ok(val),
145            Err(err) => {
146                warn!(sqlite_error = err.to_string());
147                Err(err)
148            }
149        })?;
150        debug!(current_event_version = version);
151        Ok(version)
152    }
153
154    /// Save an snapshot to the eventstore.
155    /// Will overwrite existing snapshots.
156    ///
157    /// # Errors
158    ///
159    /// This function will return an error if .
160    #[instrument]
161    pub fn save_snapshot(&self, event: &Event) -> Result<(), Error> {
162        let mut conn = self.pool.get()?;
163        let tx = conn.transaction()?;
164        tx.execute(
165            "INSERT INTO snapshot(aggregate_id, version, data) VALUES(?,?,?)
166                ON CONFLICT(aggregate_id, version) DO UPDATE SET version = excluded.version, data = excluded.data",
167            params![&event.id.to_string(), event.version, event.data],
168        )?;
169        let res = tx.execute(
170            "INSERT INTO snapshot_index(version, aggregate_id, type_name) VALUES(?,?, 'todo_implement_type_name')
171                ON CONFLICT(aggregate_id) DO UPDATE SET version = ?",
172            params![event.version, &event.id.to_string(), event.version],
173        );
174        match res {
175            Ok(_) => match tx.commit() {
176                Ok(_) => Ok(()),
177                Err(err) => {
178                    warn!(sqlite_error = err.to_string());
179                    Err(Error::Sqlite(err))
180                }
181            },
182            Err(err) => {
183                warn!(sqlite_error = err.to_string());
184                Err(Error::Sqlite(err))
185            }
186        }
187    }
188
189    #[instrument]
190    pub fn append_event(&self, event: &Event) -> Result<(), Error> {
191        let mut conn = self.pool.get()?;
192        let tx = match conn.transaction() {
193            Ok(tx) => tx,
194            Err(err) => {
195                warn!(sqlite_error = err.to_string());
196                return Err(Error::Sqlite(err));
197            }
198        };
199        let version = match self.get_agg_max_version(&tx, &event.id.to_string()) {
200            Ok(version) => version,
201            Err(err) => {
202                return Err(err);
203            }
204        };
205        let expected_version = version + 1;
206        if event.version != expected_version {
207            warn!("version mismtach {} != {}", event.version, expected_version);
208            return Err(Error::WithMsg("version mismtach".to_string()));
209        }
210        let res = tx.execute(
211            "INSERT INTO eventstore(aggregate_id, version, data) VALUES(?,?,?)",
212            params![&event.id.to_string(), event.version, event.data],
213        );
214        if let Err(err) = res {
215            warn!(sqlite_error = err.to_string());
216            return Err(Error::Sqlite(err));
217        }
218        let res = tx.execute(
219            "INSERT INTO aggregate_index(version, aggregate_id, type_name) VALUES(?,?, 'todo_implement_type_name')
220                ON CONFLICT(aggregate_id) DO UPDATE SET version = ?",
221            params![event.version, &event.id.to_string(), event.version],
222        );
223        match res {
224            Ok(_) => match tx.commit() {
225                Ok(_) => Ok(()),
226                Err(err) => {
227                    warn!(sqlite_error = err.to_string());
228                    Err(Error::Sqlite(err))
229                }
230            },
231            Err(err) => {
232                warn!(sqlite_error = err.to_string());
233                Err(Error::Sqlite(err))
234            }
235        }
236    }
237
238    #[instrument]
239    fn result_from_stmt(stmt: &mut Statement, agg_id_str: &str) -> Result<Vec<Event>, Error> {
240        let params = vec![agg_id_str];
241        Self::result_from_stmt_with_params(stmt, &params)
242    }
243
244    fn result_from_stmt_with_params(
245        stmt: &mut Statement,
246        params: &Vec<&str>,
247    ) -> Result<Vec<Event>, Error> {
248        let mut events: Vec<_> = Vec::new();
249        let query_res = stmt.query_and_then(params_from_iter(params), |r| {
250            let id = if let Ok(tmp) = r.get::<_, String>(0) {
251                match uuid::Uuid::parse_str(tmp.as_str()) {
252                    Ok(id) => id,
253                    Err(_) => return Err(Error::InvalidUUID),
254                }
255            } else {
256                return Err(Error::WithMsg("could not read uuid from row".to_string()));
257            };
258            Ok(Event {
259                id,
260                data: r.get(1)?,
261                version: r.get(2)?,
262            })
263        });
264        match query_res {
265            Ok(iter) => {
266                iter.filter_map(|e| match e {
267                    Ok(val) => Some(val),
268                    Err(_) => None,
269                })
270                .fold(&mut events, |acc, e| {
271                    acc.push(e);
272                    acc
273                });
274                Ok(events)
275            }
276            Err(err) => {
277                warn!(sqlite_error = err.to_string());
278                Err(Error::Sqlite(err))
279            }
280        }
281    }
282
283    #[instrument]
284    pub fn get_aggretate(&self, aggregate_id: Uuid) -> Result<Vec<Event>, Error> {
285        let agg_id_str: String = aggregate_id.to_string();
286        let conn = self.pool.get()?;
287        let mut stmt =
288            conn.prepare("SELECT * FROM eventstore WHERE aggregate_id = ? ORDER BY version ASC")?;
289        SqliteBackend::result_from_stmt(&mut stmt, &agg_id_str)
290    }
291
292    #[instrument]
293    pub fn get_snapshots(&self, aggregate_id: Uuid) -> Result<Vec<Event>, Error> {
294        let agg_id_str: String = aggregate_id.to_string();
295        let conn = self.pool.get()?;
296        let mut stmt =
297            conn.prepare("SELECT * FROM snapshot WHERE aggregate_id = ? ORDER BY version ASC")?;
298        SqliteBackend::result_from_stmt(&mut stmt, &agg_id_str)
299    }
300
301    #[instrument]
302    pub fn get_snapshot_by_version(
303        &self,
304        aggregate_id: Uuid,
305        version: u32,
306    ) -> Result<Event, Error> {
307        let agg_id_str: String = aggregate_id.to_string();
308        let conn = self.pool.get()?;
309        let mut stmt = conn.prepare(
310            "SELECT * FROM snapshot WHERE aggregate_id = ? AND version = ? ORDER BY version ASC",
311        )?;
312        SqliteBackend::result_from_stmt_with_params(
313            &mut stmt,
314            &vec![&agg_id_str, &version.to_string()],
315        )?
316        .pop()
317        .ok_or(Error::NotFound)
318    }
319
320    #[instrument]
321    pub fn get_aggretate_with_opts(
322        &self,
323        aggregate_id: Uuid,
324        opts: &GetAggOpts,
325    ) -> Result<Vec<Event>, Error> {
326        let agg_id_str: String = aggregate_id.to_string();
327        let conn = self.pool.get()?;
328        let mut stmt = conn.prepare(
329            "SELECT * FROM eventstore WHERE aggregate_id = ? AND version > ? ORDER BY version ASC",
330        )?;
331        SqliteBackend::result_from_stmt_with_params(
332            &mut stmt,
333            &vec![&agg_id_str, &opts.since_version.to_string()],
334        )
335    }
336}