ents_sqlite/
lib.rs

1use std::borrow::BorrowMut;
2
3use ents::{
4    DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeValue, Ent, Id,
5    IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6};
7use ents_admin::AdminEdgeByDest;
8use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
9
10pub struct Txn<'conn>(Transaction<'conn>);
11
12impl<'conn> Txn<'conn> {
13    pub fn new(tx: Transaction<'conn>) -> Self {
14        Self(tx)
15    }
16
17    fn update(
18        &self,
19        id: Id,
20        ent: Box<dyn Ent>,
21        expected_last_updated: Option<u64>,
22    ) -> Result<bool, DatabaseError> {
23        // Serialize the entity to JSON
24        let entity_type = ent.typetag_name().to_string();
25        let data_json =
26            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
27                source: Box::new(e),
28            })?;
29
30        // Build the UPDATE query with optional CAS check
31        let rows_affected = self
32            .0
33            .execute(
34                r#"
35                UPDATE entities SET data = ?1, type = ?2
36                WHERE
37                    id = ?3 AND
38                    (
39                        JSON_EXTRACT(data, '$.last_updated') = ?4 OR
40                        ?4 IS NULL
41                    )
42                "#,
43                params![
44                    data_json,
45                    entity_type,
46                    id as i64,
47                    expected_last_updated.map(|v| v as i64)
48                ],
49            )
50            .map_err(|e| DatabaseError::Other {
51                source: Box::new(e),
52            })?;
53
54        Ok(rows_affected > 0)
55    }
56}
57
58impl<'conn> Txn<'conn> {
59    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
60        // Serialize the entity to JSON
61        let entity_type = ent.typetag_name().to_string();
62
63        // Had to cast to &dyn Ent to make sure `type` to be serialized as well.
64        let data_json =
65            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
66                DatabaseError::Other {
67                    source: Box::new(e),
68                }
69            })?;
70
71        self.0
72            .execute(
73                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
74                params![entity_type, data_json],
75            )
76            .map_err(|e| DatabaseError::Other {
77                source: Box::new(e),
78            })?;
79
80        let inserted_id = self.0.last_insert_rowid() as Id;
81
82        Ok(inserted_id)
83    }
84}
85
86impl<'conn> ReadEnt for Txn<'conn> {
87    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
88        let mut stmt = self
89            .0
90            .prepare("SELECT id, data FROM entities WHERE id = ?1")
91            .map_err(|e| DatabaseError::Other {
92                source: Box::new(e),
93            })?;
94
95        stmt.query_row(params![id as i64], |row| {
96            let id: Id = row.get::<_, i64>(0)? as Id;
97            let data_json: &str = row.get_ref(1)?.as_str()?;
98            let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
99                .expect("failed to parse JSON");
100            ret.set_id(id);
101            Ok(ret)
102        })
103        .optional()
104        .map_err(|e| DatabaseError::Other {
105            source: Box::new(e),
106        })
107    }
108}
109
110impl<'conn> Transactional for Txn<'conn> {
111    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
112        let source = edge.source;
113        let sort_key = edge.sort_key;
114        let dest = edge.dest;
115
116        self.0
117            .execute(
118                "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
119                params![source as i64, sort_key, dest as i64],
120            )
121            .map_err(|e| DatabaseError::Other {
122                source: Box::new(e),
123            })?;
124
125        Ok(())
126    }
127
128    fn delete<E: Ent>(&self, id: Id) -> Result<(), DatabaseError> {
129        self.0
130            .prepare_cached(
131                r#"
132        DELETE FROM edges WHERE dest = ?1;
133        "#,
134            )
135            .map_err(|e| DatabaseError::Other {
136                source: Box::new(e),
137            })?
138            .execute(params![id as i64])
139            .map_err(|e| DatabaseError::Other {
140                source: Box::new(e),
141            })?;
142
143        self.0
144            .prepare_cached(
145                r#"
146        DELETE FROM entities WHERE id = ?1;
147        "#,
148            )
149            .map_err(|e| DatabaseError::Other {
150                source: Box::new(e),
151            })?
152            .execute(params![id as i64])
153            .map_err(|e| DatabaseError::Other {
154                source: Box::new(e),
155            })?;
156
157        Ok(())
158    }
159
160    fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
161        &self,
162        mut ent0: B,
163        mutator: F,
164    ) -> Result<bool, DatabaseError> {
165        let ent = ent0.borrow_mut();
166        let draft0 = T::EdgeProvider::draft(ent);
167        let expected_last_updated = ent.last_updated();
168
169        mutator(ent);
170        ent.mark_updated().map_err(|e| DatabaseError::Other {
171            source: Box::new(e),
172        })?;
173
174        let draft1 = T::EdgeProvider::draft(ent);
175
176        // Optimization: if drafts are equal, no edge changes needed
177        if draft0 == draft1 {
178            return self.update(
179                ent.id(),
180                dyn_clone::clone_box(ent),
181                Some(expected_last_updated),
182            );
183        }
184
185        let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
186            source: Box::new(e),
187        })?;
188        let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
189            source: Box::new(e),
190        })?;
191
192        let updated = self.update(
193            ent.id(),
194            dyn_clone::clone_box(ent),
195            Some(expected_last_updated),
196        )?;
197
198        if updated {
199            // Remove old edges if they existed
200            for edge in edge0 {
201                self.0
202                    .execute(
203                        "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
204                        params![edge.source as i64, edge.sort_key, edge.dest as i64],
205                    )
206                    .map_err(|e| DatabaseError::Other {
207                        source: Box::new(e),
208                    })?;
209            }
210
211            // Create new edges if they exist
212            for edge in edge1 {
213                self.create_edge(edge)?;
214            }
215        }
216
217        Ok(updated)
218    }
219
220    fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
221        let id = self.insert(&ent)?;
222        ent.set_id(id);
223        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
224            source: Box::new(e),
225        })?;
226        Ok(id)
227    }
228
229    fn commit(self) -> Result<(), DatabaseError> {
230        self.0.commit().map_err(|e| DatabaseError::Other {
231            source: Box::new(e),
232        })
233    }
234}
235
236impl<'conn> AdminEdgeByDest for Txn<'conn> {
237    fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
238        let mut stmt = self
239            .0
240            .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
241            .map_err(|e| DatabaseError::Other {
242                source: Box::new(e),
243            })?;
244
245        let rows = stmt
246            .query_map(params![dest as i64], |row| {
247                let source: i64 = row.get(0)?;
248                let sort_key: Vec<u8> = match row.get_ref(1)? {
249                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
250                        s.to_vec()
251                    }
252                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
253                        b.to_vec()
254                    }
255                    _ => {
256                        return Err(
257                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
258                                1,
259                                "type".into(),
260                                row.get_ref(1)?.data_type(),
261                            ),
262                        )
263                    }
264                };
265                let dest: i64 = row.get(2)?;
266                Ok(Edge::new(source as Id, sort_key, dest as Id))
267            })
268            .map_err(|e| DatabaseError::Other {
269                source: Box::new(e),
270            })?;
271
272        rows.collect::<Result<Vec<_>, _>>()
273            .map_err(|e| DatabaseError::Other {
274                source: Box::new(e),
275            })
276    }
277
278    fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
279        self.0
280            .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
281            .map_err(|e| DatabaseError::Other {
282                source: Box::new(e),
283            })?;
284        Ok(())
285    }
286}
287
288impl<'conn> QueryEdge for Txn<'conn> {
289    fn find_edges(
290        &self,
291        source: Id,
292        query: EdgeQuery,
293    ) -> Result<Vec<Edge>, DatabaseError> {
294        // Build WHERE clause for edge names filter
295        let name_filter = if query.edge_names.is_empty() {
296            String::new()
297        } else {
298            let placeholders = query
299                .edge_names
300                .iter()
301                .map(|_| "?")
302                .collect::<Vec<_>>()
303                .join(", ");
304            format!(" AND type IN ({})", placeholders)
305        };
306
307        // Build cursor filter based on sort order
308        let cursor_filter = match (&query.cursor, query.order) {
309            (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
310            (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
311            (None, _) => "",
312        };
313
314        // Build ORDER BY clause
315        let order_clause = match query.order {
316            SortOrder::Asc => "ORDER BY type ASC, dest ASC",
317            SortOrder::Desc => "ORDER BY type DESC, dest DESC",
318        };
319
320        let sql = format!(
321            "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 100",
322            name_filter, cursor_filter, order_clause
323        );
324
325        // Build parameters
326        let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
327        params.push(Box::new(source as i64));
328
329        for name in query.edge_names {
330            params.push(Box::new(name.to_vec()));
331        }
332
333        if let Some(cursor) = query.cursor {
334            params.push(Box::new(cursor.sort_key.to_vec()));
335            params.push(Box::new(cursor.destination as i64));
336        }
337
338        let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
339            params.iter().map(|p| p.as_ref()).collect();
340
341        let mut stmt =
342            self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
343                source: Box::new(e),
344            })?;
345
346        let rows = stmt
347            .query_map(params_refs.as_slice(), |row| {
348                let source: i64 = row.get(0)?;
349                let sort_key: Vec<u8> = match row.get_ref(1)? {
350                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
351                        s.to_vec()
352                    }
353                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
354                        b.to_vec()
355                    }
356                    _ => {
357                        return Err(
358                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
359                                1,
360                                "type".into(),
361                                row.get_ref(1)?.data_type(),
362                            ),
363                        )
364                    }
365                };
366                let destination: i64 = row.get(2)?;
367                Ok(Edge::new(source as Id, sort_key, destination as Id))
368            })
369            .map_err(|e| DatabaseError::Other {
370                source: Box::new(e),
371            })?;
372
373        rows.collect::<Result<Vec<_>, _>>()
374            .map_err(|e| DatabaseError::Other {
375                source: Box::new(e),
376            })
377    }
378}