1use std::borrow::BorrowMut;
2
3use ents::{
4 DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
5 Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6 TransactionProvider,
7};
8use ents_admin::AdminEnt;
9use r2d2::Pool;
10use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
11use r2d2_sqlite::SqliteConnectionManager;
12
13pub struct Txn<'conn>(Transaction<'conn>);
14
15impl<'conn> Txn<'conn> {
16 pub fn new(tx: Transaction<'conn>) -> Self {
17 Self(tx)
18 }
19
20 fn update(
21 &self,
22 id: Id,
23 ent: Box<dyn Ent>,
24 expected_last_updated: Option<u64>,
25 ) -> Result<bool, DatabaseError> {
26 let entity_type = ent.typetag_name().to_string();
28 let data_json =
29 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
30 source: Box::new(e),
31 })?;
32
33 let rows_affected = self
35 .0
36 .execute(
37 r#"
38 UPDATE entities SET data = ?1, type = ?2
39 WHERE
40 id = ?3 AND
41 (
42 JSON_EXTRACT(data, '$.last_updated') = ?4 OR
43 ?4 IS NULL
44 )
45 "#,
46 params![
47 data_json,
48 entity_type,
49 id as i64,
50 expected_last_updated.map(|v| v as i64)
51 ],
52 )
53 .map_err(|e| DatabaseError::Other {
54 source: Box::new(e),
55 })?;
56
57 Ok(rows_affected > 0)
58 }
59}
60
61impl<'conn> Txn<'conn> {
62 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
63 let entity_type = ent.typetag_name().to_string();
65
66 let data_json =
68 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
69 DatabaseError::Other {
70 source: Box::new(e),
71 }
72 })?;
73
74 self.0
75 .execute(
76 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
77 params![entity_type, data_json],
78 )
79 .map_err(|e| DatabaseError::Other {
80 source: Box::new(e),
81 })?;
82
83 let inserted_id = self.0.last_insert_rowid() as Id;
84
85 Ok(inserted_id)
86 }
87}
88
89impl<'conn> ReadEnt for Txn<'conn> {
90 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
91 let mut stmt = self
92 .0
93 .prepare("SELECT id, data FROM entities WHERE id = ?1")
94 .map_err(|e| DatabaseError::Other {
95 source: Box::new(e),
96 })?;
97
98 stmt.query_row(params![id as i64], |row| {
99 let id: Id = row.get::<_, i64>(0)? as Id;
100 let data_json: &str = row.get_ref(1)?.as_str()?;
101 let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
102 .expect("failed to parse JSON");
103 ret.set_id(id);
104 Ok(ret)
105 })
106 .optional()
107 .map_err(|e| DatabaseError::Other {
108 source: Box::new(e),
109 })
110 }
111}
112
113impl<'conn> Transactional for Txn<'conn> {
114 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
115 let source = edge.source;
116 let sort_key = edge.sort_key;
117 let dest = edge.dest;
118
119 self.0
120 .execute(
121 "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
122 params![source as i64, sort_key, dest as i64],
123 )
124 .map_err(|e| DatabaseError::Other {
125 source: Box::new(e),
126 })?;
127
128 Ok(())
129 }
130
131 fn delete(&self, id: Id) -> Result<(), DatabaseError> {
132 self.0
133 .prepare_cached(
134 r#"
135 DELETE FROM edges WHERE dest = ?1;
136 "#,
137 )
138 .map_err(|e| DatabaseError::Other {
139 source: Box::new(e),
140 })?
141 .execute(params![id as i64])
142 .map_err(|e| DatabaseError::Other {
143 source: Box::new(e),
144 })?;
145
146 self.0
147 .prepare_cached(
148 r#"
149 DELETE FROM entities WHERE id = ?1;
150 "#,
151 )
152 .map_err(|e| DatabaseError::Other {
153 source: Box::new(e),
154 })?
155 .execute(params![id as i64])
156 .map_err(|e| DatabaseError::Other {
157 source: Box::new(e),
158 })?;
159
160 Ok(())
161 }
162
163 fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
164 &self,
165 mut ent0: B,
166 mutator: F,
167 ) -> Result<bool, DatabaseError> {
168 let ent = ent0.borrow_mut();
169 let draft0 = T::EdgeProvider::draft(ent);
170 let expected_last_updated = ent.last_updated();
171
172 mutator(ent);
173 ent.mark_updated().map_err(|e| DatabaseError::Other {
174 source: Box::new(e),
175 })?;
176
177 let draft1 = T::EdgeProvider::draft(ent);
178
179 if draft0 == draft1 {
181 return self.update(
182 ent.id(),
183 dyn_clone::clone_box(ent as &dyn Ent),
184 Some(expected_last_updated),
185 );
186 }
187
188 let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
189 source: Box::new(e),
190 })?;
191 let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
192 source: Box::new(e),
193 })?;
194
195 let updated = self.update(
196 ent.id(),
197 dyn_clone::clone_box(ent as &dyn Ent),
198 Some(expected_last_updated),
199 )?;
200
201 if updated {
202 for edge in edge0 {
204 self.0
205 .execute(
206 "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
207 params![edge.source as i64, edge.sort_key, edge.dest as i64],
208 )
209 .map_err(|e| DatabaseError::Other {
210 source: Box::new(e),
211 })?;
212 }
213
214 for edge in edge1 {
216 self.create_edge(edge)?;
217 }
218 }
219
220 Ok(updated)
221 }
222
223 fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
224 let id = self.insert(&ent)?;
225 ent.set_id(id);
226 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
227 source: Box::new(e),
228 })?;
229 Ok(id)
230 }
231
232 fn commit(self) -> Result<(), DatabaseError> {
233 self.0.commit().map_err(|e| DatabaseError::Other {
234 source: Box::new(e),
235 })
236 }
237}
238
239impl<'conn> AdminEnt for Txn<'conn> {
240 fn create_dyn(&self, ent: Box<dyn Ent>) -> Result<Id, DatabaseError> {
241 let entity_type = ent.typetag_name().to_string();
243 let data_json =
244 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
245 source: Box::new(e),
246 })?;
247
248 self.0
249 .execute(
250 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
251 params![entity_type, data_json],
252 )
253 .map_err(|e| DatabaseError::Other {
254 source: Box::new(e),
255 })?;
256
257 let id = self.0.last_insert_rowid() as Id;
258 Ok(id)
259 }
260
261 fn update_dyn(&self, ent: Box<dyn Ent>) -> Result<(), DatabaseError> {
262 let id = ent.id();
263 let updated = self.update(id, ent, None)?;
264 if !updated {
265 return Err(DatabaseError::Other {
266 source: "Entity not found or update failed".into(),
267 });
268 }
269 Ok(())
270 }
271
272 fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
273 let mut stmt = self
274 .0
275 .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
276 .map_err(|e| DatabaseError::Other {
277 source: Box::new(e),
278 })?;
279
280 let rows = stmt
281 .query_map(params![dest as i64], |row| {
282 let source: i64 = row.get(0)?;
283 let sort_key: Vec<u8> = match row.get_ref(1)? {
284 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
285 s.to_vec()
286 }
287 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
288 b.to_vec()
289 }
290 _ => {
291 return Err(
292 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
293 1,
294 "type".into(),
295 row.get_ref(1)?.data_type(),
296 ),
297 )
298 }
299 };
300 let dest: i64 = row.get(2)?;
301 Ok(Edge::new(source as Id, sort_key, dest as Id))
302 })
303 .map_err(|e| DatabaseError::Other {
304 source: Box::new(e),
305 })?;
306
307 rows.collect::<Result<Vec<_>, _>>()
308 .map_err(|e| DatabaseError::Other {
309 source: Box::new(e),
310 })
311 }
312
313 fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
314 self.0
315 .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
316 .map_err(|e| DatabaseError::Other {
317 source: Box::new(e),
318 })?;
319 Ok(())
320 }
321
322 fn list_entities(
323 &self,
324 entity_type: &str,
325 cursor: Option<Id>,
326 limit: usize,
327 ) -> Result<Vec<Box<dyn Ent>>, DatabaseError> {
328 let sql = "SELECT id, data FROM entities WHERE type = ?1 AND id > ?2 ORDER BY id ASC LIMIT ?3";
329 let cursor_val = cursor.unwrap_or(0);
330
331 let mut stmt =
332 self.0.prepare(sql).map_err(|e| DatabaseError::Other {
333 source: Box::new(e),
334 })?;
335
336 let rows = stmt
337 .query_map(
338 params![entity_type, cursor_val as i64, limit as i64],
339 |row| {
340 let id: Id = row.get::<_, i64>(0)? as Id;
341 let data_json: &str = row.get_ref(1)?.as_str()?;
342 let mut ret =
343 serde_json::from_str::<Box<dyn Ent>>(data_json)
344 .expect("failed to parse JSON");
345 ret.set_id(id);
346 Ok(ret)
347 },
348 )
349 .map_err(|e| DatabaseError::Other {
350 source: Box::new(e),
351 })?;
352
353 rows.collect::<Result<Vec<_>, _>>()
354 .map_err(|e| DatabaseError::Other {
355 source: Box::new(e),
356 })
357 }
358}
359
360impl<'conn> QueryEdge for Txn<'conn> {
361 fn find_edges(
362 &self,
363 source: Id,
364 query: EdgeQuery,
365 ) -> Result<EdgeQueryResult, DatabaseError> {
366 let name_filter = if query.edge_names.is_empty() {
368 String::new()
369 } else {
370 let placeholders = query
371 .edge_names
372 .iter()
373 .map(|_| "?")
374 .collect::<Vec<_>>()
375 .join(", ");
376 format!(" AND type IN ({})", placeholders)
377 };
378
379 let cursor_filter = match (&query.cursor, query.order) {
381 (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
382 (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
383 (None, _) => "",
384 };
385
386 let order_clause = match query.order {
388 SortOrder::Asc => "ORDER BY type ASC, dest ASC",
389 SortOrder::Desc => "ORDER BY type DESC, dest DESC",
390 };
391
392 let sql = format!(
394 "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 101",
395 name_filter, cursor_filter, order_clause
396 );
397
398 let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
400 params.push(Box::new(source as i64));
401
402 for name in query.edge_names {
403 params.push(Box::new(name.to_vec()));
404 }
405
406 if let Some(cursor) = query.cursor {
407 params.push(Box::new(cursor.sort_key.to_vec()));
408 params.push(Box::new(cursor.destination as i64));
409 }
410
411 let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
412 params.iter().map(|p| p.as_ref()).collect();
413
414 let mut stmt =
415 self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
416 source: Box::new(e),
417 })?;
418
419 let rows = stmt
420 .query_map(params_refs.as_slice(), |row| {
421 let source: i64 = row.get(0)?;
422 let sort_key: Vec<u8> = match row.get_ref(1)? {
423 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
424 s.to_vec()
425 }
426 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
427 b.to_vec()
428 }
429 _ => {
430 return Err(
431 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
432 1,
433 "type".into(),
434 row.get_ref(1)?.data_type(),
435 ),
436 )
437 }
438 };
439 let destination: i64 = row.get(2)?;
440 Ok(Edge::new(source as Id, sort_key, destination as Id))
441 })
442 .map_err(|e| DatabaseError::Other {
443 source: Box::new(e),
444 })?;
445
446 let mut edges: Vec<Edge> = rows
447 .collect::<Result<Vec<_>, _>>()
448 .map_err(|e| DatabaseError::Other {
449 source: Box::new(e),
450 })?;
451
452 let has_more = edges.len() > 100;
453 if has_more {
454 edges.truncate(100);
455 }
456
457 Ok(EdgeQueryResult { edges, has_more })
458 }
459}
460
461#[derive(Clone)]
462pub struct SqliteDb {
463 pool: Pool<SqliteConnectionManager>,
464}
465
466impl SqliteDb {
467 pub fn new(pool: Pool<SqliteConnectionManager>) -> Self {
468 Self { pool }
469 }
470}
471
472impl From<Pool<SqliteConnectionManager>> for SqliteDb {
473 fn from(pool: Pool<SqliteConnectionManager>) -> Self {
474 Self::new(pool)
475 }
476}
477
478impl TransactionProvider for SqliteDb {
479 type Tx<'a> = Txn<'a>;
480
481 fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
482 where
483 F: for<'b> FnOnce(Self::Tx<'b>) -> R,
484 {
485 let mut conn = self.pool.get().map_err(|e| DatabaseError::Other {
486 source: Box::new(e),
487 })?;
488
489 let ret = func(Txn::new(conn.transaction().map_err(|e| {
490 DatabaseError::Other {
491 source: Box::new(e),
492 }
493 })?));
494
495 Ok(ret)
496 }
497}