ents_sqlite/
lib.rs

1use std::borrow::BorrowMut;
2
3use ents::Edge;
4use ents::{
5    DatabaseError, EdgeDraft, EdgeProvider, EdgeQuery, EdgeValue, Ent, EntWithEdges, Id, QueryEdge,
6    SortOrder, Transactional,
7};
8use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
9
10pub struct Txn<'conn>(Transaction<'conn>);
11
12impl<'conn> Txn<'conn> {
13    pub fn new(tx: Transaction<'conn>) -> Self {
14        Self(tx)
15    }
16
17    fn update(
18        &self,
19        id: Id,
20        ent: Box<dyn Ent>,
21        expected_last_updated: Option<u64>,
22    ) -> Result<bool, DatabaseError> {
23        // Serialize the entity to JSON
24        let entity_type = ent.typetag_name().to_string();
25        let data_json = serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
26            source: Box::new(e),
27        })?;
28
29        // Build the UPDATE query with optional CAS check
30        let rows_affected = self
31            .0
32            .execute(
33                r#"
34                UPDATE entities SET data = ?1, type = ?2
35                WHERE
36                    id = ?3 AND
37                    (
38                        JSON_EXTRACT(data, '$.last_updated') = ?4 OR
39                        ?4 IS NULL
40                    )
41                "#,
42                params![data_json, entity_type, id as i64, expected_last_updated],
43            )
44            .map_err(|e| DatabaseError::Other {
45                source: Box::new(e),
46            })?;
47
48        dbg!(&rows_affected, &id);
49
50        Ok(rows_affected > 0)
51    }
52}
53
54impl<'conn> Txn<'conn> {
55    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
56        // Serialize the entity to JSON
57        let entity_type = ent.typetag_name().to_string();
58
59        // Had to cast to &dyn Ent to make sure `type` to be serialized as well.
60        let data_json =
61            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| DatabaseError::Other {
62                source: Box::new(e),
63            })?;
64
65        self.0
66            .execute(
67                "INSERT INTO entities (type, data) VALUES (?1, ?2)",
68                params![entity_type, data_json],
69            )
70            .map_err(|e| DatabaseError::Other {
71                source: Box::new(e),
72            })?;
73
74        let inserted_id = self.0.last_insert_rowid() as Id;
75
76        Ok(inserted_id)
77    }
78}
79
80impl<'conn> Transactional for Txn<'conn> {
81    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
82        let mut stmt = self
83            .0
84            .prepare("SELECT id, data FROM entities WHERE id = ?1")
85            .map_err(|e| DatabaseError::Other {
86                source: Box::new(e),
87            })?;
88
89        stmt.query_row(params![id as i64], |row| {
90            let id: Id = row.get(0)?;
91            let data_json: &str = row.get_ref(1)?.as_str()?;
92            let mut ret =
93                serde_json::from_str::<Box<dyn Ent>>(data_json).expect("failed to parse JSON");
94            ret.set_id(id);
95            Ok(ret)
96        })
97        .optional()
98        .map_err(|e| DatabaseError::Other {
99            source: Box::new(e),
100        })
101    }
102
103    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
104        let source = edge.source;
105        let sort_key = edge.sort_key;
106        let dest = edge.dest;
107
108        self.0
109            .execute(
110                "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
111                params![source as i64, sort_key, dest as i64],
112            )
113            .map_err(|e| DatabaseError::Other {
114                source: Box::new(e),
115            })?;
116
117        Ok(())
118    }
119
120    fn delete<E: Ent + EntWithEdges>(&self, id: Id) -> Result<(), DatabaseError> {
121        self.0
122            .prepare_cached(
123                r#"
124        DELETE FROM edges WHERE dest = ?1;
125        "#,
126            )
127            .map_err(|e| DatabaseError::Other {
128                source: Box::new(e),
129            })?
130            .execute(params![id])
131            .map_err(|e| DatabaseError::Other {
132                source: Box::new(e),
133            })?;
134
135        self.0
136            .prepare_cached(
137                r#"
138        DELETE FROM entities WHERE id = ?1;
139        "#,
140            )
141            .map_err(|e| DatabaseError::Other {
142                source: Box::new(e),
143            })?
144            .execute(params![id])
145            .map_err(|e| DatabaseError::Other {
146                source: Box::new(e),
147            })?;
148
149        Ok(())
150    }
151
152    fn update<T: EntWithEdges, F: FnOnce(&mut T), B: BorrowMut<T>>(
153        &self,
154        mut ent0: B,
155        mutator: F,
156    ) -> Result<bool, DatabaseError> {
157        let ent = ent0.borrow_mut();
158        let draft0 = T::EdgeProvider::draft(ent);
159        let expected_last_updated = ent.last_updated();
160
161        mutator(ent);
162        ent.mark_updated().map_err(|e| DatabaseError::Other {
163            source: Box::new(e),
164        })?;
165
166        let draft1 = T::EdgeProvider::draft(ent);
167
168        // Optimization: if drafts are equal, no edge changes needed
169        if draft0 == draft1 {
170            return self.update(
171                ent.id(),
172                dyn_clone::clone_box(ent),
173                Some(expected_last_updated),
174            );
175        }
176
177        let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
178            source: Box::new(e),
179        })?;
180        let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
181            source: Box::new(e),
182        })?;
183
184        let updated = self.update(
185            ent.id(),
186            dyn_clone::clone_box(ent),
187            Some(expected_last_updated),
188        )?;
189
190        if updated {
191            // Remove old edges if they existed
192            for edge in edge0 {
193                self.0
194                    .execute(
195                        "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
196                        params![edge.source as i64, edge.sort_key, edge.dest as i64],
197                    )
198                    .map_err(|e| DatabaseError::Other {
199                        source: Box::new(e),
200                    })?;
201            }
202
203            // Create new edges if they exist
204            for edge in edge1 {
205                self.create_edge(edge)?;
206            }
207        }
208
209        Ok(updated)
210    }
211
212    fn create<E: Ent + EntWithEdges>(&self, mut ent: E) -> Result<Id, DatabaseError> {
213        let id = self.insert(&ent)?;
214        ent.set_id(id);
215        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
216            source: Box::new(e),
217        })?;
218        Ok(id)
219    }
220
221    fn commit(self) -> Result<(), DatabaseError> {
222        self.0.commit().map_err(|e| DatabaseError::Other {
223            source: Box::new(e),
224        })
225    }
226}
227
228impl<'conn> QueryEdge for Txn<'conn> {
229    fn find_edges(&self, source: Id, query: EdgeQuery) -> Result<Vec<Edge>, DatabaseError> {
230        // Build WHERE clause for edge names filter
231        let name_filter = if query.edge_names.is_empty() {
232            String::new()
233        } else {
234            let placeholders = query
235                .edge_names
236                .iter()
237                .map(|_| "?")
238                .collect::<Vec<_>>()
239                .join(", ");
240            format!(" AND type IN ({})", placeholders)
241        };
242
243        // Build cursor filter based on sort order
244        let cursor_filter = match (&query.cursor, query.order) {
245            (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
246            (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
247            (None, _) => "",
248        };
249
250        // Build ORDER BY clause
251        let order_clause = match query.order {
252            SortOrder::Asc => "ORDER BY type ASC, dest ASC",
253            SortOrder::Desc => "ORDER BY type DESC, dest DESC",
254        };
255
256        let sql = format!(
257            "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 100",
258            name_filter, cursor_filter, order_clause
259        );
260
261        // Build parameters
262        let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
263        params.push(Box::new(source));
264
265        for name in query.edge_names {
266            params.push(Box::new(name.to_vec()));
267        }
268
269        if let Some(cursor) = query.cursor {
270            params.push(Box::new(cursor.sort_key.to_vec()));
271            params.push(Box::new(cursor.destination));
272        }
273
274        let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
275            params.iter().map(|p| p.as_ref()).collect();
276
277        let mut stmt = self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
278            source: Box::new(e),
279        })?;
280
281        let rows = stmt
282            .query_map(params_refs.as_slice(), |row| {
283                let source: i64 = row.get(0)?;
284                let sort_key: Vec<u8> = match row.get_ref(1)? {
285                    r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => s.to_vec(),
286                    r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => b.to_vec(),
287                    _ => {
288                        return Err(r2d2_sqlite::rusqlite::Error::InvalidColumnType(
289                            1,
290                            "type".into(),
291                            row.get_ref(1)?.data_type(),
292                        ))
293                    }
294                };
295                let destination: i64 = row.get(2)?;
296                Ok(Edge::new(source as Id, sort_key, destination as Id))
297            })
298            .map_err(|e| DatabaseError::Other {
299                source: Box::new(e),
300            })?;
301
302        rows.collect::<Result<Vec<_>, _>>()
303            .map_err(|e| DatabaseError::Other {
304                source: Box::new(e),
305            })
306    }
307}