ents_sqlite/
lib.rs

1use std::borrow::BorrowMut;
2
3use ents::{
4    DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
5    Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6    TransactionProvider,
7};
8use ents_admin::AdminEnt;
9use r2d2::Pool;
10use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
11use r2d2_sqlite::SqliteConnectionManager;
12
13pub struct Txn<'conn>(Transaction<'conn>);
14
15impl<'conn> Txn<'conn> {
16    pub fn new(tx: Transaction<'conn>) -> Self {
17        Self(tx)
18    }
19
20    fn update(
21        &self,
22        id: Id,
23        ent: Box<dyn Ent>,
24        expected_last_updated: Option<u64>,
25    ) -> Result<bool, DatabaseError> {
26        // Serialize the entity to JSON
27        let entity_type = ent.typetag_name().to_string();
28        let data_json =
29            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
30                source: Box::new(e),
31            })?;
32
33        // Build the UPDATE query with optional CAS check
34        let rows_affected = self
35            .0
36            .execute(
37                r#"
38                UPDATE entities SET data = ?1, type = ?2
39                WHERE
40                    id = ?3 AND
41                    (
42                        JSON_EXTRACT(data, '$.last_updated') = ?4 OR
43                        ?4 IS NULL
44                    )
45                "#,
46                params![
47                    data_json,
48                    entity_type,
49                    id as i64,
50                    expected_last_updated.map(|v| v as i64)
51                ],
52            )
53            .map_err(|e| DatabaseError::Other {
54                source: Box::new(e),
55            })?;
56
57        Ok(rows_affected > 0)
58    }
59}
60
61impl<'conn> Txn<'conn> {
62    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
63        // Serialize the entity to JSON
64        let entity_type = ent.typetag_name().to_string();
65
66        // Had to cast to &dyn Ent to make sure `type` to be serialized as well.
67        let data_json =
68            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
69                DatabaseError::Other {
70                    source: Box::new(e),
71                }
72            })?;
73
74        self.0
75            .execute(
76                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
77                params![entity_type, data_json],
78            )
79            .map_err(|e| DatabaseError::Other {
80                source: Box::new(e),
81            })?;
82
83        let inserted_id = self.0.last_insert_rowid() as Id;
84
85        Ok(inserted_id)
86    }
87}
88
89impl<'conn> ReadEnt for Txn<'conn> {
90    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
91        let mut stmt = self
92            .0
93            .prepare("SELECT id, data FROM entities WHERE id = ?1")
94            .map_err(|e| DatabaseError::Other {
95                source: Box::new(e),
96            })?;
97
98        stmt.query_row(params![id as i64], |row| {
99            let id: Id = row.get::<_, i64>(0)? as Id;
100            let data_json: &str = row.get_ref(1)?.as_str()?;
101            let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
102                .expect("failed to parse JSON");
103            ret.set_id(id);
104            Ok(ret)
105        })
106        .optional()
107        .map_err(|e| DatabaseError::Other {
108            source: Box::new(e),
109        })
110    }
111}
112
113impl<'conn> Transactional for Txn<'conn> {
114    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
115        let source = edge.source;
116        let sort_key = edge.sort_key;
117        let dest = edge.dest;
118
119        self.0
120            .execute(
121                "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
122                params![source as i64, sort_key, dest as i64],
123            )
124            .map_err(|e| DatabaseError::Other {
125                source: Box::new(e),
126            })?;
127
128        Ok(())
129    }
130
131    fn delete(&self, id: Id) -> Result<(), DatabaseError> {
132        self.0
133            .prepare_cached(
134                r#"
135        DELETE FROM edges WHERE dest = ?1;
136        "#,
137            )
138            .map_err(|e| DatabaseError::Other {
139                source: Box::new(e),
140            })?
141            .execute(params![id as i64])
142            .map_err(|e| DatabaseError::Other {
143                source: Box::new(e),
144            })?;
145
146        self.0
147            .prepare_cached(
148                r#"
149        DELETE FROM entities WHERE id = ?1;
150        "#,
151            )
152            .map_err(|e| DatabaseError::Other {
153                source: Box::new(e),
154            })?
155            .execute(params![id as i64])
156            .map_err(|e| DatabaseError::Other {
157                source: Box::new(e),
158            })?;
159
160        Ok(())
161    }
162
163    fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
164        &self,
165        mut ent0: B,
166        mutator: F,
167    ) -> Result<bool, DatabaseError> {
168        let ent = ent0.borrow_mut();
169        let draft0 = T::EdgeProvider::draft(ent);
170        let expected_last_updated = ent.last_updated();
171
172        mutator(ent);
173        ent.mark_updated().map_err(|e| DatabaseError::Other {
174            source: Box::new(e),
175        })?;
176
177        let draft1 = T::EdgeProvider::draft(ent);
178
179        // Optimization: if drafts are equal, no edge changes needed
180        if draft0 == draft1 {
181            return self.update(
182                ent.id(),
183                dyn_clone::clone_box(ent as &dyn Ent),
184                Some(expected_last_updated),
185            );
186        }
187
188        let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
189            source: Box::new(e),
190        })?;
191        let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
192            source: Box::new(e),
193        })?;
194
195        let updated = self.update(
196            ent.id(),
197            dyn_clone::clone_box(ent as &dyn Ent),
198            Some(expected_last_updated),
199        )?;
200
201        if updated {
202            // Remove old edges if they existed
203            for edge in edge0 {
204                self.0
205                    .execute(
206                        "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
207                        params![edge.source as i64, edge.sort_key, edge.dest as i64],
208                    )
209                    .map_err(|e| DatabaseError::Other {
210                        source: Box::new(e),
211                    })?;
212            }
213
214            // Create new edges if they exist
215            for edge in edge1 {
216                self.create_edge(edge)?;
217            }
218        }
219
220        Ok(updated)
221    }
222
223    fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
224        let id = self.insert(&ent)?;
225        ent.set_id(id);
226        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
227            source: Box::new(e),
228        })?;
229        Ok(id)
230    }
231
232    fn commit(self) -> Result<(), DatabaseError> {
233        self.0.commit().map_err(|e| DatabaseError::Other {
234            source: Box::new(e),
235        })
236    }
237}
238
239impl<'conn> AdminEnt for Txn<'conn> {
240    fn create_dyn(&self, ent: Box<dyn Ent>) -> Result<Id, DatabaseError> {
241        // Serialize the entity to JSON
242        let entity_type = ent.typetag_name().to_string();
243        let data_json =
244            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
245                source: Box::new(e),
246            })?;
247
248        self.0
249            .execute(
250                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
251                params![entity_type, data_json],
252            )
253            .map_err(|e| DatabaseError::Other {
254                source: Box::new(e),
255            })?;
256
257        let id = self.0.last_insert_rowid() as Id;
258        Ok(id)
259    }
260
261    fn update_dyn(&self, ent: Box<dyn Ent>) -> Result<(), DatabaseError> {
262        let id = ent.id();
263        let updated = self.update(id, ent, None)?;
264        if !updated {
265            return Err(DatabaseError::Other {
266                source: "Entity not found or update failed".into(),
267            });
268        }
269        Ok(())
270    }
271
272    fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
273        let mut stmt = self
274            .0
275            .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
276            .map_err(|e| DatabaseError::Other {
277                source: Box::new(e),
278            })?;
279
280        let rows = stmt
281            .query_map(params![dest as i64], |row| {
282                let source: i64 = row.get(0)?;
283                let sort_key: Vec<u8> = match row.get_ref(1)? {
284                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
285                        s.to_vec()
286                    }
287                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
288                        b.to_vec()
289                    }
290                    _ => {
291                        return Err(
292                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
293                                1,
294                                "type".into(),
295                                row.get_ref(1)?.data_type(),
296                            ),
297                        )
298                    }
299                };
300                let dest: i64 = row.get(2)?;
301                Ok(Edge::new(source as Id, sort_key, dest as Id))
302            })
303            .map_err(|e| DatabaseError::Other {
304                source: Box::new(e),
305            })?;
306
307        rows.collect::<Result<Vec<_>, _>>()
308            .map_err(|e| DatabaseError::Other {
309                source: Box::new(e),
310            })
311    }
312
313    fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
314        self.0
315            .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
316            .map_err(|e| DatabaseError::Other {
317                source: Box::new(e),
318            })?;
319        Ok(())
320    }
321
322    fn list_entities(
323        &self,
324        entity_type: &str,
325        cursor: Option<Id>,
326        limit: usize,
327    ) -> Result<Vec<Box<dyn Ent>>, DatabaseError> {
328        let sql = "SELECT id, data FROM entities WHERE type = ?1 AND id > ?2 ORDER BY id ASC LIMIT ?3";
329        let cursor_val = cursor.unwrap_or(0);
330
331        let mut stmt =
332            self.0.prepare(sql).map_err(|e| DatabaseError::Other {
333                source: Box::new(e),
334            })?;
335
336        let rows = stmt
337            .query_map(
338                params![entity_type, cursor_val as i64, limit as i64],
339                |row| {
340                    let id: Id = row.get::<_, i64>(0)? as Id;
341                    let data_json: &str = row.get_ref(1)?.as_str()?;
342                    let mut ret =
343                        serde_json::from_str::<Box<dyn Ent>>(data_json)
344                            .expect("failed to parse JSON");
345                    ret.set_id(id);
346                    Ok(ret)
347                },
348            )
349            .map_err(|e| DatabaseError::Other {
350                source: Box::new(e),
351            })?;
352
353        rows.collect::<Result<Vec<_>, _>>()
354            .map_err(|e| DatabaseError::Other {
355                source: Box::new(e),
356            })
357    }
358}
359
360impl<'conn> QueryEdge for Txn<'conn> {
361    fn find_edges(
362        &self,
363        source: Id,
364        query: EdgeQuery,
365    ) -> Result<EdgeQueryResult, DatabaseError> {
366        // Build WHERE clause for edge names filter
367        let name_filter = if query.edge_names.is_empty() {
368            String::new()
369        } else {
370            let placeholders = query
371                .edge_names
372                .iter()
373                .map(|_| "?")
374                .collect::<Vec<_>>()
375                .join(", ");
376            format!(" AND type IN ({})", placeholders)
377        };
378
379        // Build cursor filter based on sort order
380        let cursor_filter = match (&query.cursor, query.order) {
381            (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
382            (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
383            (None, _) => "",
384        };
385
386        // Build ORDER BY clause
387        let order_clause = match query.order {
388            SortOrder::Asc => "ORDER BY type ASC, dest ASC",
389            SortOrder::Desc => "ORDER BY type DESC, dest DESC",
390        };
391
392        // Request one extra row to detect if there are more results
393        let sql = format!(
394            "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 101",
395            name_filter, cursor_filter, order_clause
396        );
397
398        // Build parameters
399        let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
400        params.push(Box::new(source as i64));
401
402        for name in query.edge_names {
403            params.push(Box::new(name.to_vec()));
404        }
405
406        if let Some(cursor) = query.cursor {
407            params.push(Box::new(cursor.sort_key.to_vec()));
408            params.push(Box::new(cursor.destination as i64));
409        }
410
411        let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
412            params.iter().map(|p| p.as_ref()).collect();
413
414        let mut stmt =
415            self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
416                source: Box::new(e),
417            })?;
418
419        let rows = stmt
420            .query_map(params_refs.as_slice(), |row| {
421                let source: i64 = row.get(0)?;
422                let sort_key: Vec<u8> = match row.get_ref(1)? {
423                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
424                        s.to_vec()
425                    }
426                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
427                        b.to_vec()
428                    }
429                    _ => {
430                        return Err(
431                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
432                                1,
433                                "type".into(),
434                                row.get_ref(1)?.data_type(),
435                            ),
436                        )
437                    }
438                };
439                let destination: i64 = row.get(2)?;
440                Ok(Edge::new(source as Id, sort_key, destination as Id))
441            })
442            .map_err(|e| DatabaseError::Other {
443                source: Box::new(e),
444            })?;
445
446        let mut edges: Vec<Edge> = rows
447            .collect::<Result<Vec<_>, _>>()
448            .map_err(|e| DatabaseError::Other {
449                source: Box::new(e),
450            })?;
451
452        let has_more = edges.len() > 100;
453        if has_more {
454            edges.truncate(100);
455        }
456
457        Ok(EdgeQueryResult { edges, has_more })
458    }
459}
460
461#[derive(Clone)]
462pub struct SqliteDb {
463    pool: Pool<SqliteConnectionManager>,
464}
465
466impl SqliteDb {
467    pub fn new(pool: Pool<SqliteConnectionManager>) -> Self {
468        Self { pool }
469    }
470}
471
472impl From<Pool<SqliteConnectionManager>> for SqliteDb {
473    fn from(pool: Pool<SqliteConnectionManager>) -> Self {
474        Self::new(pool)
475    }
476}
477
478impl TransactionProvider for SqliteDb {
479    type Tx<'a> = Txn<'a>;
480
481    fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
482    where
483        F: for<'b> FnOnce(Self::Tx<'b>) -> R,
484    {
485        let mut conn = self.pool.get().map_err(|e| DatabaseError::Other {
486            source: Box::new(e),
487        })?;
488
489        let ret = func(Txn::new(conn.transaction().map_err(|e| {
490            DatabaseError::Other {
491                source: Box::new(e),
492            }
493        })?));
494
495        Ok(ret)
496    }
497}