ents_sqlite/
lib.rs

1use std::borrow::BorrowMut;
2
3use ents::{
4    DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
5    Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6    TransactionProvider,
7};
8use ents_admin::AdminEnt;
9use r2d2::Pool;
10use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
11use r2d2_sqlite::SqliteConnectionManager;
12
13pub struct Txn<'conn>(Transaction<'conn>);
14
15impl<'conn> Txn<'conn> {
16    pub fn new(tx: Transaction<'conn>) -> Self {
17        Self(tx)
18    }
19
20    fn update(
21        &self,
22        id: Id,
23        ent: Box<dyn Ent>,
24        expected_last_updated: Option<u64>,
25    ) -> Result<bool, DatabaseError> {
26        // Serialize the entity to JSON
27        let entity_type = ent.typetag_name().to_string();
28        let data_json =
29            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
30                source: Box::new(e),
31            })?;
32
33        // Build the UPDATE query with optional CAS check
34        let rows_affected = self
35            .0
36            .execute(
37                r#"
38                UPDATE entities SET data = ?1, type = ?2
39                WHERE
40                    id = ?3 AND
41                    (
42                        JSON_EXTRACT(data, '$.last_updated') = ?4 OR
43                        ?4 IS NULL
44                    )
45                "#,
46                params![
47                    data_json,
48                    entity_type,
49                    id as i64,
50                    expected_last_updated.map(|v| v as i64)
51                ],
52            )
53            .map_err(|e| DatabaseError::Other {
54                source: Box::new(e),
55            })?;
56
57        Ok(rows_affected > 0)
58    }
59}
60
61impl<'conn> Txn<'conn> {
62    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
63        // Serialize the entity to JSON
64        let entity_type = ent.typetag_name().to_string();
65
66        // Had to cast to &dyn Ent to make sure `type` to be serialized as well.
67        let data_json =
68            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
69                DatabaseError::Other {
70                    source: Box::new(e),
71                }
72            })?;
73
74        self.0
75            .execute(
76                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
77                params![entity_type, data_json],
78            )
79            .map_err(|e| DatabaseError::Other {
80                source: Box::new(e),
81            })?;
82
83        let inserted_id = self.0.last_insert_rowid() as Id;
84
85        Ok(inserted_id)
86    }
87}
88
89impl<'conn> ReadEnt for Txn<'conn> {
90    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
91        let mut stmt = self
92            .0
93            .prepare("SELECT id, data FROM entities WHERE id = ?1")
94            .map_err(|e| DatabaseError::Other {
95                source: Box::new(e),
96            })?;
97
98        stmt.query_row(params![id as i64], |row| {
99            let id: Id = row.get::<_, i64>(0)? as Id;
100            let data_json: &str = row.get_ref(1)?.as_str()?;
101            let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
102                .expect("failed to parse JSON");
103            ret.set_id(id);
104            Ok(ret)
105        })
106        .optional()
107        .map_err(|e| DatabaseError::Other {
108            source: Box::new(e),
109        })
110    }
111}
112
113impl<'conn> Transactional for Txn<'conn> {
114    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
115        let source = edge.source;
116        let sort_key = edge.sort_key;
117        let dest = edge.dest;
118
119        self.0
120            .execute(
121                "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
122                params![source as i64, sort_key, dest as i64],
123            )
124            .map_err(|e| DatabaseError::Other {
125                source: Box::new(e),
126            })?;
127
128        Ok(())
129    }
130
131    fn delete(&self, id: Id) -> Result<(), DatabaseError> {
132        self.0
133            .prepare_cached(
134                r#"
135        DELETE FROM edges WHERE dest = ?1;
136        "#,
137            )
138            .map_err(|e| DatabaseError::Other {
139                source: Box::new(e),
140            })?
141            .execute(params![id as i64])
142            .map_err(|e| DatabaseError::Other {
143                source: Box::new(e),
144            })?;
145
146        self.0
147            .prepare_cached(
148                r#"
149        DELETE FROM entities WHERE id = ?1;
150        "#,
151            )
152            .map_err(|e| DatabaseError::Other {
153                source: Box::new(e),
154            })?
155            .execute(params![id as i64])
156            .map_err(|e| DatabaseError::Other {
157                source: Box::new(e),
158            })?;
159
160        Ok(())
161    }
162
163    fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
164        &self,
165        mut ent0: B,
166        mutator: F,
167    ) -> Result<bool, DatabaseError> {
168        let ent = ent0.borrow_mut();
169        let draft0 = T::EdgeProvider::draft(ent);
170        let expected_last_updated = ent.last_updated();
171
172        mutator(ent);
173        ent.mark_updated().map_err(|e| DatabaseError::Other {
174            source: Box::new(e),
175        })?;
176
177        let draft1 = T::EdgeProvider::draft(ent);
178
179        // Optimization: if drafts are equal, no edge changes needed
180        if draft0 == draft1 {
181            return self.update(
182                ent.id(),
183                dyn_clone::clone_box(ent as &dyn Ent),
184                Some(expected_last_updated),
185            );
186        }
187
188        let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
189            source: Box::new(e),
190        })?;
191        let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
192            source: Box::new(e),
193        })?;
194
195        let updated = self.update(
196            ent.id(),
197            dyn_clone::clone_box(ent as &dyn Ent),
198            Some(expected_last_updated),
199        )?;
200
201        if updated {
202            // Remove old edges if they existed
203            for edge in edge0 {
204                self.0
205                    .execute(
206                        "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
207                        params![edge.source as i64, edge.sort_key, edge.dest as i64],
208                    )
209                    .map_err(|e| DatabaseError::Other {
210                        source: Box::new(e),
211                    })?;
212            }
213
214            // Create new edges if they exist
215            for edge in edge1 {
216                self.create_edge(edge)?;
217            }
218        }
219
220        Ok(updated)
221    }
222
223    fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
224        let id = self.insert(&ent)?;
225        ent.set_id(id);
226        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
227            source: Box::new(e),
228        })?;
229        Ok(id)
230    }
231
232    fn commit(self) -> Result<(), DatabaseError> {
233        self.0.commit().map_err(|e| DatabaseError::Other {
234            source: Box::new(e),
235        })
236    }
237}
238
239impl<'conn> AdminEnt for Txn<'conn> {
240    fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
241        let mut stmt = self
242            .0
243            .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
244            .map_err(|e| DatabaseError::Other {
245                source: Box::new(e),
246            })?;
247
248        let rows = stmt
249            .query_map(params![dest as i64], |row| {
250                let source: i64 = row.get(0)?;
251                let sort_key: Vec<u8> = match row.get_ref(1)? {
252                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
253                        s.to_vec()
254                    }
255                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
256                        b.to_vec()
257                    }
258                    _ => {
259                        return Err(
260                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
261                                1,
262                                "type".into(),
263                                row.get_ref(1)?.data_type(),
264                            ),
265                        )
266                    }
267                };
268                let dest: i64 = row.get(2)?;
269                Ok(Edge::new(source as Id, sort_key, dest as Id))
270            })
271            .map_err(|e| DatabaseError::Other {
272                source: Box::new(e),
273            })?;
274
275        rows.collect::<Result<Vec<_>, _>>()
276            .map_err(|e| DatabaseError::Other {
277                source: Box::new(e),
278            })
279    }
280
281    fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
282        self.0
283            .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
284            .map_err(|e| DatabaseError::Other {
285                source: Box::new(e),
286            })?;
287        Ok(())
288    }
289
290    fn list_entities(
291        &self,
292        entity_type: &str,
293        cursor: Option<Id>,
294        limit: usize,
295    ) -> Result<Vec<Box<dyn Ent>>, DatabaseError> {
296        let sql = "SELECT id, data FROM entities WHERE type = ?1 AND id > ?2 ORDER BY id ASC LIMIT ?3";
297        let cursor_val = cursor.unwrap_or(0);
298
299        let mut stmt =
300            self.0.prepare(sql).map_err(|e| DatabaseError::Other {
301                source: Box::new(e),
302            })?;
303
304        let rows = stmt
305            .query_map(
306                params![entity_type, cursor_val as i64, limit as i64],
307                |row| {
308                    let id: Id = row.get::<_, i64>(0)? as Id;
309                    let data_json: &str = row.get_ref(1)?.as_str()?;
310                    let mut ret =
311                        serde_json::from_str::<Box<dyn Ent>>(data_json)
312                            .expect("failed to parse JSON");
313                    ret.set_id(id);
314                    Ok(ret)
315                },
316            )
317            .map_err(|e| DatabaseError::Other {
318                source: Box::new(e),
319            })?;
320
321        rows.collect::<Result<Vec<_>, _>>()
322            .map_err(|e| DatabaseError::Other {
323                source: Box::new(e),
324            })
325    }
326}
327
328impl<'conn> QueryEdge for Txn<'conn> {
329    fn find_edges(
330        &self,
331        source: Id,
332        query: EdgeQuery,
333    ) -> Result<EdgeQueryResult, DatabaseError> {
334        // Build WHERE clause for edge names filter
335        let name_filter = if query.edge_names.is_empty() {
336            String::new()
337        } else {
338            let placeholders = query
339                .edge_names
340                .iter()
341                .map(|_| "?")
342                .collect::<Vec<_>>()
343                .join(", ");
344            format!(" AND type IN ({})", placeholders)
345        };
346
347        // Build cursor filter based on sort order
348        let cursor_filter = match (&query.cursor, query.order) {
349            (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
350            (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
351            (None, _) => "",
352        };
353
354        // Build ORDER BY clause
355        let order_clause = match query.order {
356            SortOrder::Asc => "ORDER BY type ASC, dest ASC",
357            SortOrder::Desc => "ORDER BY type DESC, dest DESC",
358        };
359
360        // Request one extra row to detect if there are more results
361        let sql = format!(
362            "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 101",
363            name_filter, cursor_filter, order_clause
364        );
365
366        // Build parameters
367        let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
368        params.push(Box::new(source as i64));
369
370        for name in query.edge_names {
371            params.push(Box::new(name.to_vec()));
372        }
373
374        if let Some(cursor) = query.cursor {
375            params.push(Box::new(cursor.sort_key.to_vec()));
376            params.push(Box::new(cursor.destination as i64));
377        }
378
379        let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
380            params.iter().map(|p| p.as_ref()).collect();
381
382        let mut stmt =
383            self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
384                source: Box::new(e),
385            })?;
386
387        let rows = stmt
388            .query_map(params_refs.as_slice(), |row| {
389                let source: i64 = row.get(0)?;
390                let sort_key: Vec<u8> = match row.get_ref(1)? {
391                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
392                        s.to_vec()
393                    }
394                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
395                        b.to_vec()
396                    }
397                    _ => {
398                        return Err(
399                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
400                                1,
401                                "type".into(),
402                                row.get_ref(1)?.data_type(),
403                            ),
404                        )
405                    }
406                };
407                let destination: i64 = row.get(2)?;
408                Ok(Edge::new(source as Id, sort_key, destination as Id))
409            })
410            .map_err(|e| DatabaseError::Other {
411                source: Box::new(e),
412            })?;
413
414        let mut edges: Vec<Edge> = rows
415            .collect::<Result<Vec<_>, _>>()
416            .map_err(|e| DatabaseError::Other {
417                source: Box::new(e),
418            })?;
419
420        let has_more = edges.len() > 100;
421        if has_more {
422            edges.truncate(100);
423        }
424
425        Ok(EdgeQueryResult { edges, has_more })
426    }
427}
428
429#[derive(Clone)]
430pub struct SqliteDb {
431    pool: Pool<SqliteConnectionManager>,
432}
433
434impl SqliteDb {
435    pub fn new(pool: Pool<SqliteConnectionManager>) -> Self {
436        Self { pool }
437    }
438}
439
440impl From<Pool<SqliteConnectionManager>> for SqliteDb {
441    fn from(pool: Pool<SqliteConnectionManager>) -> Self {
442        Self::new(pool)
443    }
444}
445
446impl TransactionProvider for SqliteDb {
447    type Tx<'a> = Txn<'a>;
448
449    fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
450    where
451        F: for<'b> FnOnce(Self::Tx<'b>) -> R,
452    {
453        let mut conn = self.pool.get().map_err(|e| DatabaseError::Other {
454            source: Box::new(e),
455        })?;
456
457        let ret = func(Txn::new(conn.transaction().map_err(|e| {
458            DatabaseError::Other {
459                source: Box::new(e),
460            }
461        })?));
462
463        Ok(ret)
464    }
465}