Skip to main content

ents_sqlite/
lib.rs

1use std::borrow::BorrowMut;
2
3use ents::{
4    DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
5    Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder,
6    TransactionProvider, Transactional,
7};
8use ents_admin::AdminEnt;
9use r2d2::Pool;
10use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
11use r2d2_sqlite::SqliteConnectionManager;
12
13pub struct Txn<'conn>(Transaction<'conn>);
14
15impl<'conn> Txn<'conn> {
16    pub fn new(tx: Transaction<'conn>) -> Self {
17        Self(tx)
18    }
19
20    fn update(
21        &self,
22        id: Id,
23        ent: Box<dyn Ent>,
24        expected_last_updated: Option<u64>,
25    ) -> Result<bool, DatabaseError> {
26        // Serialize the entity to JSON
27        let entity_type = ent.typetag_name().to_string();
28        let data_json =
29            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
30                source: Box::new(e),
31            })?;
32
33        // Build the UPDATE query with optional CAS check
34        let rows_affected = self
35            .0
36            .execute(
37                r#"
38                UPDATE entities SET data = ?1, type = ?2
39                WHERE
40                    id = ?3 AND
41                    (
42                        JSON_EXTRACT(data, '$.last_updated') = ?4 OR
43                        ?4 IS NULL
44                    )
45                "#,
46                params![
47                    data_json,
48                    entity_type,
49                    id as i64,
50                    expected_last_updated.map(|v| v as i64)
51                ],
52            )
53            .map_err(|e| DatabaseError::Other {
54                source: Box::new(e),
55            })?;
56
57        Ok(rows_affected > 0)
58    }
59}
60
61impl<'conn> Txn<'conn> {
62    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
63        // Serialize the entity to JSON
64        let entity_type = ent.typetag_name().to_string();
65
66        // Had to cast to &dyn Ent to make sure `type` to be serialized as well.
67        let data_json =
68            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
69                DatabaseError::Other {
70                    source: Box::new(e),
71                }
72            })?;
73
74        self.0
75            .execute(
76                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
77                params![entity_type, data_json],
78            )
79            .map_err(|e| DatabaseError::Other {
80                source: Box::new(e),
81            })?;
82
83        let inserted_id = self.0.last_insert_rowid() as Id;
84
85        Ok(inserted_id)
86    }
87}
88
89impl<'conn> ReadEnt for Txn<'conn> {
90    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
91        let mut stmt = self
92            .0
93            .prepare("SELECT id, data FROM entities WHERE id = ?1")
94            .map_err(|e| DatabaseError::Other {
95                source: Box::new(e),
96            })?;
97
98        stmt.query_row(params![id as i64], |row| {
99            let id: Id = row.get::<_, i64>(0)? as Id;
100            let data_json: &str = row.get_ref(1)?.as_str()?;
101            let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
102                .expect("failed to parse JSON");
103            ret.set_id(id);
104            Ok(ret)
105        })
106        .optional()
107        .map_err(|e| DatabaseError::Other {
108            source: Box::new(e),
109        })
110    }
111}
112
113impl<'conn> Transactional for Txn<'conn> {
114    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
115        let source = edge.source;
116        let sort_key = edge.sort_key;
117        let dest = edge.dest;
118
119        self.0
120            .execute(
121                "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
122                params![source as i64, sort_key, dest as i64],
123            )
124            .map_err(|e| DatabaseError::Other {
125                source: Box::new(e),
126            })?;
127
128        Ok(())
129    }
130
131    fn delete(&self, id: Id) -> Result<(), DatabaseError> {
132        self.0
133            .prepare_cached(
134                r#"
135        DELETE FROM edges WHERE dest = ?1;
136        "#,
137            )
138            .map_err(|e| DatabaseError::Other {
139                source: Box::new(e),
140            })?
141            .execute(params![id as i64])
142            .map_err(|e| DatabaseError::Other {
143                source: Box::new(e),
144            })?;
145
146        self.0
147            .prepare_cached(
148                r#"
149        DELETE FROM entities WHERE id = ?1;
150        "#,
151            )
152            .map_err(|e| DatabaseError::Other {
153                source: Box::new(e),
154            })?
155            .execute(params![id as i64])
156            .map_err(|e| DatabaseError::Other {
157                source: Box::new(e),
158            })?;
159
160        Ok(())
161    }
162
163    fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
164        &self,
165        mut ent0: B,
166        mutator: F,
167    ) -> Result<bool, DatabaseError> {
168        let ent = ent0.borrow_mut();
169        let draft0 = T::EdgeProvider::draft(ent);
170        let ent_id = ent.id();
171        let expected_last_updated = ent.last_updated();
172
173        mutator(ent);
174        ent.mark_updated().map_err(|e| DatabaseError::Other {
175            source: Box::new(e),
176        })?;
177
178        let draft1 = T::EdgeProvider::draft(ent);
179
180        // Optimization: if drafts are equal, no edge changes needed
181        if draft0 == draft1 {
182            return self.update(
183                ent.id(),
184                dyn_clone::clone_box(ent as &dyn Ent),
185                Some(expected_last_updated),
186            );
187        }
188
189        let edge0 = draft0
190            .check(self)
191            .map(|edges| {
192                edges
193                    .into_iter()
194                    .map(|edge| edge.with_dest(ent_id))
195                    .collect::<Vec<_>>()
196            })
197            .map_err(|e| DatabaseError::Other {
198                source: Box::new(e),
199            })?;
200        let edge1 = draft1
201            .check(self)
202            .map(|edges| {
203                edges
204                    .into_iter()
205                    .map(|edge| edge.with_dest(ent_id))
206                    .collect::<Vec<_>>()
207            })
208            .map_err(|e| DatabaseError::Other {
209                source: Box::new(e),
210            })?;
211
212        let updated = self.update(
213            ent.id(),
214            dyn_clone::clone_box(ent as &dyn Ent),
215            Some(expected_last_updated),
216        )?;
217
218        if updated {
219            // Remove old edges if they existed
220            for edge in edge0 {
221                self.0
222                    .execute(
223                        "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
224                        params![edge.source as i64, edge.sort_key, edge.dest as i64],
225                    )
226                    .map_err(|e| DatabaseError::Other {
227                        source: Box::new(e),
228                    })?;
229            }
230
231            // Create new edges if they exist
232            for edge in edge1 {
233                self.create_edge(edge)?;
234            }
235        }
236
237        Ok(updated)
238    }
239
240    fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
241        let id = self.insert(&ent)?;
242        ent.set_id(id);
243        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
244            source: Box::new(e),
245        })?;
246        Ok(id)
247    }
248
249    fn commit(self) -> Result<(), DatabaseError> {
250        self.0.commit().map_err(|e| DatabaseError::Other {
251            source: Box::new(e),
252        })
253    }
254}
255
256impl<'conn> AdminEnt for Txn<'conn> {
257    fn create_dyn(&self, ent: Box<dyn Ent>) -> Result<Id, DatabaseError> {
258        // Serialize the entity to JSON
259        let entity_type = ent.typetag_name().to_string();
260        let data_json =
261            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
262                source: Box::new(e),
263            })?;
264
265        self.0
266            .execute(
267                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
268                params![entity_type, data_json],
269            )
270            .map_err(|e| DatabaseError::Other {
271                source: Box::new(e),
272            })?;
273
274        let id = self.0.last_insert_rowid() as Id;
275        Ok(id)
276    }
277
278    fn update_dyn(&self, ent: Box<dyn Ent>) -> Result<(), DatabaseError> {
279        let id = ent.id();
280        let updated = self.update(id, ent, None)?;
281        if !updated {
282            return Err(DatabaseError::Other {
283                source: "Entity not found or update failed".into(),
284            });
285        }
286        Ok(())
287    }
288
289    fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
290        let mut stmt = self
291            .0
292            .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
293            .map_err(|e| DatabaseError::Other {
294                source: Box::new(e),
295            })?;
296
297        let rows = stmt
298            .query_map(params![dest as i64], |row| {
299                let source: i64 = row.get(0)?;
300                let sort_key: Vec<u8> = match row.get_ref(1)? {
301                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
302                        s.to_vec()
303                    }
304                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
305                        b.to_vec()
306                    }
307                    _ => {
308                        return Err(
309                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
310                                1,
311                                "type".into(),
312                                row.get_ref(1)?.data_type(),
313                            ),
314                        )
315                    }
316                };
317                let dest: i64 = row.get(2)?;
318                Ok(Edge::new(source as Id, sort_key, dest as Id))
319            })
320            .map_err(|e| DatabaseError::Other {
321                source: Box::new(e),
322            })?;
323
324        rows.collect::<Result<Vec<_>, _>>()
325            .map_err(|e| DatabaseError::Other {
326                source: Box::new(e),
327            })
328    }
329
330    fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
331        self.0
332            .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
333            .map_err(|e| DatabaseError::Other {
334                source: Box::new(e),
335            })?;
336        Ok(())
337    }
338
339    fn list_entities(
340        &self,
341        entity_type: &str,
342        cursor: Option<Id>,
343        limit: usize,
344    ) -> Result<Vec<Box<dyn Ent>>, DatabaseError> {
345        let sql = "SELECT id, data FROM entities WHERE type = ?1 AND id > ?2 ORDER BY id ASC LIMIT ?3";
346        let cursor_val = cursor.unwrap_or(0);
347
348        let mut stmt =
349            self.0.prepare(sql).map_err(|e| DatabaseError::Other {
350                source: Box::new(e),
351            })?;
352
353        let rows = stmt
354            .query_map(
355                params![entity_type, cursor_val as i64, limit as i64],
356                |row| {
357                    let id: Id = row.get::<_, i64>(0)? as Id;
358                    let data_json: &str = row.get_ref(1)?.as_str()?;
359                    let mut ret =
360                        serde_json::from_str::<Box<dyn Ent>>(data_json)
361                            .expect("failed to parse JSON");
362                    ret.set_id(id);
363                    Ok(ret)
364                },
365            )
366            .map_err(|e| DatabaseError::Other {
367                source: Box::new(e),
368            })?;
369
370        rows.collect::<Result<Vec<_>, _>>()
371            .map_err(|e| DatabaseError::Other {
372                source: Box::new(e),
373            })
374    }
375}
376
377impl<'conn> QueryEdge for Txn<'conn> {
378    fn find_edges(
379        &self,
380        source: Id,
381        query: EdgeQuery,
382    ) -> Result<EdgeQueryResult, DatabaseError> {
383        // Build WHERE clause for edge names filter
384        let name_filter = if query.edge_names.is_empty() {
385            String::new()
386        } else {
387            let placeholders = query
388                .edge_names
389                .iter()
390                .map(|_| "?")
391                .collect::<Vec<_>>()
392                .join(", ");
393            format!(" AND type IN ({})", placeholders)
394        };
395
396        // Build cursor filter based on sort order
397        let cursor_filter = match (&query.cursor, query.order) {
398            (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
399            (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
400            (None, _) => "",
401        };
402
403        // Build ORDER BY clause
404        let order_clause = match query.order {
405            SortOrder::Asc => "ORDER BY type ASC, dest ASC",
406            SortOrder::Desc => "ORDER BY type DESC, dest DESC",
407        };
408
409        // Request one extra row to detect if there are more results
410        let sql = format!(
411            "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 101",
412            name_filter, cursor_filter, order_clause
413        );
414
415        // Build parameters
416        let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
417        params.push(Box::new(source as i64));
418
419        for name in query.edge_names {
420            params.push(Box::new(name.to_vec()));
421        }
422
423        if let Some(cursor) = query.cursor {
424            params.push(Box::new(cursor.sort_key.to_vec()));
425            params.push(Box::new(cursor.destination as i64));
426        }
427
428        let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
429            params.iter().map(|p| p.as_ref()).collect();
430
431        let mut stmt =
432            self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
433                source: Box::new(e),
434            })?;
435
436        let rows = stmt
437            .query_map(params_refs.as_slice(), |row| {
438                let source: i64 = row.get(0)?;
439                let sort_key: Vec<u8> = match row.get_ref(1)? {
440                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
441                        s.to_vec()
442                    }
443                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
444                        b.to_vec()
445                    }
446                    _ => {
447                        return Err(
448                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
449                                1,
450                                "type".into(),
451                                row.get_ref(1)?.data_type(),
452                            ),
453                        )
454                    }
455                };
456                let destination: i64 = row.get(2)?;
457                Ok(Edge::new(source as Id, sort_key, destination as Id))
458            })
459            .map_err(|e| DatabaseError::Other {
460                source: Box::new(e),
461            })?;
462
463        let mut edges: Vec<Edge> = rows
464            .collect::<Result<Vec<_>, _>>()
465            .map_err(|e| DatabaseError::Other {
466                source: Box::new(e),
467            })?;
468
469        let has_more = edges.len() > 100;
470        if has_more {
471            edges.truncate(100);
472        }
473
474        Ok(EdgeQueryResult { edges, has_more })
475    }
476}
477
478#[derive(Clone)]
479pub struct SqliteDb {
480    pool: Pool<SqliteConnectionManager>,
481}
482
483impl SqliteDb {
484    pub fn new(pool: Pool<SqliteConnectionManager>) -> Self {
485        Self { pool }
486    }
487}
488
489impl From<Pool<SqliteConnectionManager>> for SqliteDb {
490    fn from(pool: Pool<SqliteConnectionManager>) -> Self {
491        Self::new(pool)
492    }
493}
494
495impl TransactionProvider for SqliteDb {
496    type Tx<'a> = Txn<'a>;
497
498    fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
499    where
500        F: for<'b> FnOnce(Self::Tx<'b>) -> R,
501    {
502        let mut conn = self.pool.get().map_err(|e| DatabaseError::Other {
503            source: Box::new(e),
504        })?;
505
506        let ret = func(Txn::new(conn.transaction().map_err(|e| {
507            DatabaseError::Other {
508                source: Box::new(e),
509            }
510        })?));
511
512        Ok(ret)
513    }
514}