1use std::borrow::BorrowMut;
2
3use ents::{
4 DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
5 Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6 TransactionProvider,
7};
8use ents_admin::AdminEnt;
9use r2d2::Pool;
10use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
11use r2d2_sqlite::SqliteConnectionManager;
12
13pub struct Txn<'conn>(Transaction<'conn>);
14
15impl<'conn> Txn<'conn> {
16 pub fn new(tx: Transaction<'conn>) -> Self {
17 Self(tx)
18 }
19
20 fn update(
21 &self,
22 id: Id,
23 ent: Box<dyn Ent>,
24 expected_last_updated: Option<u64>,
25 ) -> Result<bool, DatabaseError> {
26 let entity_type = ent.typetag_name().to_string();
28 let data_json =
29 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
30 source: Box::new(e),
31 })?;
32
33 let rows_affected = self
35 .0
36 .execute(
37 r#"
38 UPDATE entities SET data = ?1, type = ?2
39 WHERE
40 id = ?3 AND
41 (
42 JSON_EXTRACT(data, '$.last_updated') = ?4 OR
43 ?4 IS NULL
44 )
45 "#,
46 params![
47 data_json,
48 entity_type,
49 id as i64,
50 expected_last_updated.map(|v| v as i64)
51 ],
52 )
53 .map_err(|e| DatabaseError::Other {
54 source: Box::new(e),
55 })?;
56
57 Ok(rows_affected > 0)
58 }
59}
60
61impl<'conn> Txn<'conn> {
62 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
63 let entity_type = ent.typetag_name().to_string();
65
66 let data_json =
68 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
69 DatabaseError::Other {
70 source: Box::new(e),
71 }
72 })?;
73
74 self.0
75 .execute(
76 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
77 params![entity_type, data_json],
78 )
79 .map_err(|e| DatabaseError::Other {
80 source: Box::new(e),
81 })?;
82
83 let inserted_id = self.0.last_insert_rowid() as Id;
84
85 Ok(inserted_id)
86 }
87}
88
89impl<'conn> ReadEnt for Txn<'conn> {
90 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
91 let mut stmt = self
92 .0
93 .prepare("SELECT id, data FROM entities WHERE id = ?1")
94 .map_err(|e| DatabaseError::Other {
95 source: Box::new(e),
96 })?;
97
98 stmt.query_row(params![id as i64], |row| {
99 let id: Id = row.get::<_, i64>(0)? as Id;
100 let data_json: &str = row.get_ref(1)?.as_str()?;
101 let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
102 .expect("failed to parse JSON");
103 ret.set_id(id);
104 Ok(ret)
105 })
106 .optional()
107 .map_err(|e| DatabaseError::Other {
108 source: Box::new(e),
109 })
110 }
111}
112
113impl<'conn> Transactional for Txn<'conn> {
114 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
115 let source = edge.source;
116 let sort_key = edge.sort_key;
117 let dest = edge.dest;
118
119 self.0
120 .execute(
121 "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
122 params![source as i64, sort_key, dest as i64],
123 )
124 .map_err(|e| DatabaseError::Other {
125 source: Box::new(e),
126 })?;
127
128 Ok(())
129 }
130
131 fn delete(&self, id: Id) -> Result<(), DatabaseError> {
132 self.0
133 .prepare_cached(
134 r#"
135 DELETE FROM edges WHERE dest = ?1;
136 "#,
137 )
138 .map_err(|e| DatabaseError::Other {
139 source: Box::new(e),
140 })?
141 .execute(params![id as i64])
142 .map_err(|e| DatabaseError::Other {
143 source: Box::new(e),
144 })?;
145
146 self.0
147 .prepare_cached(
148 r#"
149 DELETE FROM entities WHERE id = ?1;
150 "#,
151 )
152 .map_err(|e| DatabaseError::Other {
153 source: Box::new(e),
154 })?
155 .execute(params![id as i64])
156 .map_err(|e| DatabaseError::Other {
157 source: Box::new(e),
158 })?;
159
160 Ok(())
161 }
162
163 fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
164 &self,
165 mut ent0: B,
166 mutator: F,
167 ) -> Result<bool, DatabaseError> {
168 let ent = ent0.borrow_mut();
169 let draft0 = T::EdgeProvider::draft(ent);
170 let ent_id = ent.id();
171 let expected_last_updated = ent.last_updated();
172
173 mutator(ent);
174 ent.mark_updated().map_err(|e| DatabaseError::Other {
175 source: Box::new(e),
176 })?;
177
178 let draft1 = T::EdgeProvider::draft(ent);
179
180 if draft0 == draft1 {
182 return self.update(
183 ent.id(),
184 dyn_clone::clone_box(ent as &dyn Ent),
185 Some(expected_last_updated),
186 );
187 }
188
189 let edge0 = draft0
190 .check(self)
191 .map(|edges| {
192 edges
193 .into_iter()
194 .map(|edge| edge.with_dest(ent_id))
195 .collect::<Vec<_>>()
196 })
197 .map_err(|e| DatabaseError::Other {
198 source: Box::new(e),
199 })?;
200 let edge1 = draft1
201 .check(self)
202 .map(|edges| {
203 edges
204 .into_iter()
205 .map(|edge| edge.with_dest(ent_id))
206 .collect::<Vec<_>>()
207 })
208 .map_err(|e| DatabaseError::Other {
209 source: Box::new(e),
210 })?;
211
212 let updated = self.update(
213 ent.id(),
214 dyn_clone::clone_box(ent as &dyn Ent),
215 Some(expected_last_updated),
216 )?;
217
218 if updated {
219 for edge in edge0 {
221 self.0
222 .execute(
223 "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
224 params![edge.source as i64, edge.sort_key, edge.dest as i64],
225 )
226 .map_err(|e| DatabaseError::Other {
227 source: Box::new(e),
228 })?;
229 }
230
231 for edge in edge1 {
233 self.create_edge(edge)?;
234 }
235 }
236
237 Ok(updated)
238 }
239
240 fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
241 let id = self.insert(&ent)?;
242 ent.set_id(id);
243 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
244 source: Box::new(e),
245 })?;
246 Ok(id)
247 }
248
249 fn commit(self) -> Result<(), DatabaseError> {
250 self.0.commit().map_err(|e| DatabaseError::Other {
251 source: Box::new(e),
252 })
253 }
254}
255
256impl<'conn> AdminEnt for Txn<'conn> {
257 fn create_dyn(&self, ent: Box<dyn Ent>) -> Result<Id, DatabaseError> {
258 let entity_type = ent.typetag_name().to_string();
260 let data_json =
261 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
262 source: Box::new(e),
263 })?;
264
265 self.0
266 .execute(
267 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
268 params![entity_type, data_json],
269 )
270 .map_err(|e| DatabaseError::Other {
271 source: Box::new(e),
272 })?;
273
274 let id = self.0.last_insert_rowid() as Id;
275 Ok(id)
276 }
277
278 fn update_dyn(&self, ent: Box<dyn Ent>) -> Result<(), DatabaseError> {
279 let id = ent.id();
280 let updated = self.update(id, ent, None)?;
281 if !updated {
282 return Err(DatabaseError::Other {
283 source: "Entity not found or update failed".into(),
284 });
285 }
286 Ok(())
287 }
288
289 fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
290 let mut stmt = self
291 .0
292 .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
293 .map_err(|e| DatabaseError::Other {
294 source: Box::new(e),
295 })?;
296
297 let rows = stmt
298 .query_map(params![dest as i64], |row| {
299 let source: i64 = row.get(0)?;
300 let sort_key: Vec<u8> = match row.get_ref(1)? {
301 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
302 s.to_vec()
303 }
304 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
305 b.to_vec()
306 }
307 _ => {
308 return Err(
309 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
310 1,
311 "type".into(),
312 row.get_ref(1)?.data_type(),
313 ),
314 )
315 }
316 };
317 let dest: i64 = row.get(2)?;
318 Ok(Edge::new(source as Id, sort_key, dest as Id))
319 })
320 .map_err(|e| DatabaseError::Other {
321 source: Box::new(e),
322 })?;
323
324 rows.collect::<Result<Vec<_>, _>>()
325 .map_err(|e| DatabaseError::Other {
326 source: Box::new(e),
327 })
328 }
329
330 fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
331 self.0
332 .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
333 .map_err(|e| DatabaseError::Other {
334 source: Box::new(e),
335 })?;
336 Ok(())
337 }
338
339 fn list_entities(
340 &self,
341 entity_type: &str,
342 cursor: Option<Id>,
343 limit: usize,
344 ) -> Result<Vec<Box<dyn Ent>>, DatabaseError> {
345 let sql = "SELECT id, data FROM entities WHERE type = ?1 AND id > ?2 ORDER BY id ASC LIMIT ?3";
346 let cursor_val = cursor.unwrap_or(0);
347
348 let mut stmt =
349 self.0.prepare(sql).map_err(|e| DatabaseError::Other {
350 source: Box::new(e),
351 })?;
352
353 let rows = stmt
354 .query_map(
355 params![entity_type, cursor_val as i64, limit as i64],
356 |row| {
357 let id: Id = row.get::<_, i64>(0)? as Id;
358 let data_json: &str = row.get_ref(1)?.as_str()?;
359 let mut ret =
360 serde_json::from_str::<Box<dyn Ent>>(data_json)
361 .expect("failed to parse JSON");
362 ret.set_id(id);
363 Ok(ret)
364 },
365 )
366 .map_err(|e| DatabaseError::Other {
367 source: Box::new(e),
368 })?;
369
370 rows.collect::<Result<Vec<_>, _>>()
371 .map_err(|e| DatabaseError::Other {
372 source: Box::new(e),
373 })
374 }
375}
376
377impl<'conn> QueryEdge for Txn<'conn> {
378 fn find_edges(
379 &self,
380 source: Id,
381 query: EdgeQuery,
382 ) -> Result<EdgeQueryResult, DatabaseError> {
383 let name_filter = if query.edge_names.is_empty() {
385 String::new()
386 } else {
387 let placeholders = query
388 .edge_names
389 .iter()
390 .map(|_| "?")
391 .collect::<Vec<_>>()
392 .join(", ");
393 format!(" AND type IN ({})", placeholders)
394 };
395
396 let cursor_filter = match (&query.cursor, query.order) {
398 (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
399 (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
400 (None, _) => "",
401 };
402
403 let order_clause = match query.order {
405 SortOrder::Asc => "ORDER BY type ASC, dest ASC",
406 SortOrder::Desc => "ORDER BY type DESC, dest DESC",
407 };
408
409 let sql = format!(
411 "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 101",
412 name_filter, cursor_filter, order_clause
413 );
414
415 let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
417 params.push(Box::new(source as i64));
418
419 for name in query.edge_names {
420 params.push(Box::new(name.to_vec()));
421 }
422
423 if let Some(cursor) = query.cursor {
424 params.push(Box::new(cursor.sort_key.to_vec()));
425 params.push(Box::new(cursor.destination as i64));
426 }
427
428 let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
429 params.iter().map(|p| p.as_ref()).collect();
430
431 let mut stmt =
432 self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
433 source: Box::new(e),
434 })?;
435
436 let rows = stmt
437 .query_map(params_refs.as_slice(), |row| {
438 let source: i64 = row.get(0)?;
439 let sort_key: Vec<u8> = match row.get_ref(1)? {
440 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
441 s.to_vec()
442 }
443 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
444 b.to_vec()
445 }
446 _ => {
447 return Err(
448 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
449 1,
450 "type".into(),
451 row.get_ref(1)?.data_type(),
452 ),
453 )
454 }
455 };
456 let destination: i64 = row.get(2)?;
457 Ok(Edge::new(source as Id, sort_key, destination as Id))
458 })
459 .map_err(|e| DatabaseError::Other {
460 source: Box::new(e),
461 })?;
462
463 let mut edges: Vec<Edge> = rows
464 .collect::<Result<Vec<_>, _>>()
465 .map_err(|e| DatabaseError::Other {
466 source: Box::new(e),
467 })?;
468
469 let has_more = edges.len() > 100;
470 if has_more {
471 edges.truncate(100);
472 }
473
474 Ok(EdgeQueryResult { edges, has_more })
475 }
476}
477
478#[derive(Clone)]
479pub struct SqliteDb {
480 pool: Pool<SqliteConnectionManager>,
481}
482
483impl SqliteDb {
484 pub fn new(pool: Pool<SqliteConnectionManager>) -> Self {
485 Self { pool }
486 }
487}
488
489impl From<Pool<SqliteConnectionManager>> for SqliteDb {
490 fn from(pool: Pool<SqliteConnectionManager>) -> Self {
491 Self::new(pool)
492 }
493}
494
495impl TransactionProvider for SqliteDb {
496 type Tx<'a> = Txn<'a>;
497
498 fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
499 where
500 F: for<'b> FnOnce(Self::Tx<'b>) -> R,
501 {
502 let mut conn = self.pool.get().map_err(|e| DatabaseError::Other {
503 source: Box::new(e),
504 })?;
505
506 let ret = func(Txn::new(conn.transaction().map_err(|e| {
507 DatabaseError::Other {
508 source: Box::new(e),
509 }
510 })?));
511
512 Ok(ret)
513 }
514}