ents_sqlite/
lib.rs

1use std::borrow::BorrowMut;
2
3use ents::Edge;
4use ents::{
5    DatabaseError, EdgeDraft, EdgeProvider, EdgeQuery, EdgeValue, Ent,
6    EntWithEdges, Id, QueryEdge, SortOrder, Transactional,
7};
8use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
9
10pub struct Txn<'conn>(Transaction<'conn>);
11
12impl<'conn> Txn<'conn> {
13    pub fn new(tx: Transaction<'conn>) -> Self {
14        Self(tx)
15    }
16
17    fn update(
18        &self,
19        id: Id,
20        ent: Box<dyn Ent>,
21        expected_last_updated: Option<u64>,
22    ) -> Result<bool, DatabaseError> {
23        // Serialize the entity to JSON
24        let entity_type = ent.typetag_name().to_string();
25        let data_json =
26            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
27                source: Box::new(e),
28            })?;
29
30        // Build the UPDATE query with optional CAS check
31        let rows_affected = self
32            .0
33            .execute(
34                r#"
35                UPDATE entities SET data = ?1, type = ?2
36                WHERE
37                    id = ?3 AND
38                    (
39                        JSON_EXTRACT(data, '$.last_updated') = ?4 OR
40                        ?4 IS NULL
41                    )
42                "#,
43                params![
44                    data_json,
45                    entity_type,
46                    id as i64,
47                    expected_last_updated.map(|v| v as i64)
48                ],
49            )
50            .map_err(|e| DatabaseError::Other {
51                source: Box::new(e),
52            })?;
53
54        Ok(rows_affected > 0)
55    }
56}
57
58impl<'conn> Txn<'conn> {
59    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
60        // Serialize the entity to JSON
61        let entity_type = ent.typetag_name().to_string();
62
63        // Had to cast to &dyn Ent to make sure `type` to be serialized as well.
64        let data_json =
65            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
66                DatabaseError::Other {
67                    source: Box::new(e),
68                }
69            })?;
70
71        self.0
72            .execute(
73                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
74                params![entity_type, data_json],
75            )
76            .map_err(|e| DatabaseError::Other {
77                source: Box::new(e),
78            })?;
79
80        let inserted_id = self.0.last_insert_rowid() as Id;
81
82        Ok(inserted_id)
83    }
84}
85
86impl<'conn> Transactional for Txn<'conn> {
87    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
88        let mut stmt = self
89            .0
90            .prepare("SELECT id, data FROM entities WHERE id = ?1")
91            .map_err(|e| DatabaseError::Other {
92                source: Box::new(e),
93            })?;
94
95        stmt.query_row(params![id as i64], |row| {
96            let id: Id = row.get::<_, i64>(0)? as Id;
97            let data_json: &str = row.get_ref(1)?.as_str()?;
98            let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
99                .expect("failed to parse JSON");
100            ret.set_id(id);
101            Ok(ret)
102        })
103        .optional()
104        .map_err(|e| DatabaseError::Other {
105            source: Box::new(e),
106        })
107    }
108
109    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
110        let source = edge.source;
111        let sort_key = edge.sort_key;
112        let dest = edge.dest;
113
114        self.0
115            .execute(
116                "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
117                params![source as i64, sort_key, dest as i64],
118            )
119            .map_err(|e| DatabaseError::Other {
120                source: Box::new(e),
121            })?;
122
123        Ok(())
124    }
125
126    fn delete<E: Ent + EntWithEdges>(
127        &self,
128        id: Id,
129    ) -> Result<(), DatabaseError> {
130        self.0
131            .prepare_cached(
132                r#"
133        DELETE FROM edges WHERE dest = ?1;
134        "#,
135            )
136            .map_err(|e| DatabaseError::Other {
137                source: Box::new(e),
138            })?
139            .execute(params![id as i64])
140            .map_err(|e| DatabaseError::Other {
141                source: Box::new(e),
142            })?;
143
144        self.0
145            .prepare_cached(
146                r#"
147        DELETE FROM entities WHERE id = ?1;
148        "#,
149            )
150            .map_err(|e| DatabaseError::Other {
151                source: Box::new(e),
152            })?
153            .execute(params![id as i64])
154            .map_err(|e| DatabaseError::Other {
155                source: Box::new(e),
156            })?;
157
158        Ok(())
159    }
160
161    fn update<T: EntWithEdges, F: FnOnce(&mut T), B: BorrowMut<T>>(
162        &self,
163        mut ent0: B,
164        mutator: F,
165    ) -> Result<bool, DatabaseError> {
166        let ent = ent0.borrow_mut();
167        let draft0 = T::EdgeProvider::draft(ent);
168        let expected_last_updated = ent.last_updated();
169
170        mutator(ent);
171        ent.mark_updated().map_err(|e| DatabaseError::Other {
172            source: Box::new(e),
173        })?;
174
175        let draft1 = T::EdgeProvider::draft(ent);
176
177        // Optimization: if drafts are equal, no edge changes needed
178        if draft0 == draft1 {
179            return self.update(
180                ent.id(),
181                dyn_clone::clone_box(ent),
182                Some(expected_last_updated),
183            );
184        }
185
186        let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
187            source: Box::new(e),
188        })?;
189        let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
190            source: Box::new(e),
191        })?;
192
193        let updated = self.update(
194            ent.id(),
195            dyn_clone::clone_box(ent),
196            Some(expected_last_updated),
197        )?;
198
199        if updated {
200            // Remove old edges if they existed
201            for edge in edge0 {
202                self.0
203                    .execute(
204                        "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
205                        params![edge.source as i64, edge.sort_key, edge.dest as i64],
206                    )
207                    .map_err(|e| DatabaseError::Other {
208                        source: Box::new(e),
209                    })?;
210            }
211
212            // Create new edges if they exist
213            for edge in edge1 {
214                self.create_edge(edge)?;
215            }
216        }
217
218        Ok(updated)
219    }
220
221    fn create<E: Ent + EntWithEdges>(
222        &self,
223        mut ent: E,
224    ) -> Result<Id, DatabaseError> {
225        let id = self.insert(&ent)?;
226        ent.set_id(id);
227        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
228            source: Box::new(e),
229        })?;
230        Ok(id)
231    }
232
233    fn commit(self) -> Result<(), DatabaseError> {
234        self.0.commit().map_err(|e| DatabaseError::Other {
235            source: Box::new(e),
236        })
237    }
238}
239
240impl<'conn> QueryEdge for Txn<'conn> {
241    fn find_edges(
242        &self,
243        source: Id,
244        query: EdgeQuery,
245    ) -> Result<Vec<Edge>, DatabaseError> {
246        // Build WHERE clause for edge names filter
247        let name_filter = if query.edge_names.is_empty() {
248            String::new()
249        } else {
250            let placeholders = query
251                .edge_names
252                .iter()
253                .map(|_| "?")
254                .collect::<Vec<_>>()
255                .join(", ");
256            format!(" AND type IN ({})", placeholders)
257        };
258
259        // Build cursor filter based on sort order
260        let cursor_filter = match (&query.cursor, query.order) {
261            (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
262            (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
263            (None, _) => "",
264        };
265
266        // Build ORDER BY clause
267        let order_clause = match query.order {
268            SortOrder::Asc => "ORDER BY type ASC, dest ASC",
269            SortOrder::Desc => "ORDER BY type DESC, dest DESC",
270        };
271
272        let sql = format!(
273            "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 100",
274            name_filter, cursor_filter, order_clause
275        );
276
277        // Build parameters
278        let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
279        params.push(Box::new(source as i64));
280
281        for name in query.edge_names {
282            params.push(Box::new(name.to_vec()));
283        }
284
285        if let Some(cursor) = query.cursor {
286            params.push(Box::new(cursor.sort_key.to_vec()));
287            params.push(Box::new(cursor.destination as i64));
288        }
289
290        let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
291            params.iter().map(|p| p.as_ref()).collect();
292
293        let mut stmt =
294            self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
295                source: Box::new(e),
296            })?;
297
298        let rows = stmt
299            .query_map(params_refs.as_slice(), |row| {
300                let source: i64 = row.get(0)?;
301                let sort_key: Vec<u8> = match row.get_ref(1)? {
302                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
303                        s.to_vec()
304                    }
305                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
306                        b.to_vec()
307                    }
308                    _ => {
309                        return Err(
310                            r2d2_sqlite::rusqlite::Error::InvalidColumnType(
311                                1,
312                                "type".into(),
313                                row.get_ref(1)?.data_type(),
314                            ),
315                        )
316                    }
317                };
318                let destination: i64 = row.get(2)?;
319                Ok(Edge::new(source as Id, sort_key, destination as Id))
320            })
321            .map_err(|e| DatabaseError::Other {
322                source: Box::new(e),
323            })?;
324
325        rows.collect::<Result<Vec<_>, _>>()
326            .map_err(|e| DatabaseError::Other {
327                source: Box::new(e),
328            })
329    }
330}