1use std::borrow::BorrowMut;
2
3use ents::{
4 DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
5 Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6 TransactionProvider,
7};
8use ents_admin::AdminEnt;
9use r2d2::Pool;
10use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
11use r2d2_sqlite::SqliteConnectionManager;
12
13pub struct Txn<'conn>(Transaction<'conn>);
14
15impl<'conn> Txn<'conn> {
16 pub fn new(tx: Transaction<'conn>) -> Self {
17 Self(tx)
18 }
19
20 fn update(
21 &self,
22 id: Id,
23 ent: Box<dyn Ent>,
24 expected_last_updated: Option<u64>,
25 ) -> Result<bool, DatabaseError> {
26 let entity_type = ent.typetag_name().to_string();
28 let data_json =
29 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
30 source: Box::new(e),
31 })?;
32
33 let rows_affected = self
35 .0
36 .execute(
37 r#"
38 UPDATE entities SET data = ?1, type = ?2
39 WHERE
40 id = ?3 AND
41 (
42 JSON_EXTRACT(data, '$.last_updated') = ?4 OR
43 ?4 IS NULL
44 )
45 "#,
46 params![
47 data_json,
48 entity_type,
49 id as i64,
50 expected_last_updated.map(|v| v as i64)
51 ],
52 )
53 .map_err(|e| DatabaseError::Other {
54 source: Box::new(e),
55 })?;
56
57 Ok(rows_affected > 0)
58 }
59}
60
61impl<'conn> Txn<'conn> {
62 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
63 let entity_type = ent.typetag_name().to_string();
65
66 let data_json =
68 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
69 DatabaseError::Other {
70 source: Box::new(e),
71 }
72 })?;
73
74 self.0
75 .execute(
76 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
77 params![entity_type, data_json],
78 )
79 .map_err(|e| DatabaseError::Other {
80 source: Box::new(e),
81 })?;
82
83 let inserted_id = self.0.last_insert_rowid() as Id;
84
85 Ok(inserted_id)
86 }
87}
88
89impl<'conn> ReadEnt for Txn<'conn> {
90 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
91 let mut stmt = self
92 .0
93 .prepare("SELECT id, data FROM entities WHERE id = ?1")
94 .map_err(|e| DatabaseError::Other {
95 source: Box::new(e),
96 })?;
97
98 stmt.query_row(params![id as i64], |row| {
99 let id: Id = row.get::<_, i64>(0)? as Id;
100 let data_json: &str = row.get_ref(1)?.as_str()?;
101 let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
102 .expect("failed to parse JSON");
103 ret.set_id(id);
104 Ok(ret)
105 })
106 .optional()
107 .map_err(|e| DatabaseError::Other {
108 source: Box::new(e),
109 })
110 }
111}
112
113impl<'conn> Transactional for Txn<'conn> {
114 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
115 let source = edge.source;
116 let sort_key = edge.sort_key;
117 let dest = edge.dest;
118
119 self.0
120 .execute(
121 "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
122 params![source as i64, sort_key, dest as i64],
123 )
124 .map_err(|e| DatabaseError::Other {
125 source: Box::new(e),
126 })?;
127
128 Ok(())
129 }
130
131 fn delete(&self, id: Id) -> Result<(), DatabaseError> {
132 self.0
133 .prepare_cached(
134 r#"
135 DELETE FROM edges WHERE dest = ?1;
136 "#,
137 )
138 .map_err(|e| DatabaseError::Other {
139 source: Box::new(e),
140 })?
141 .execute(params![id as i64])
142 .map_err(|e| DatabaseError::Other {
143 source: Box::new(e),
144 })?;
145
146 self.0
147 .prepare_cached(
148 r#"
149 DELETE FROM entities WHERE id = ?1;
150 "#,
151 )
152 .map_err(|e| DatabaseError::Other {
153 source: Box::new(e),
154 })?
155 .execute(params![id as i64])
156 .map_err(|e| DatabaseError::Other {
157 source: Box::new(e),
158 })?;
159
160 Ok(())
161 }
162
163 fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
164 &self,
165 mut ent0: B,
166 mutator: F,
167 ) -> Result<bool, DatabaseError> {
168 let ent = ent0.borrow_mut();
169 let draft0 = T::EdgeProvider::draft(ent);
170 let expected_last_updated = ent.last_updated();
171
172 mutator(ent);
173 ent.mark_updated().map_err(|e| DatabaseError::Other {
174 source: Box::new(e),
175 })?;
176
177 let draft1 = T::EdgeProvider::draft(ent);
178
179 if draft0 == draft1 {
181 return self.update(
182 ent.id(),
183 dyn_clone::clone_box(ent as &dyn Ent),
184 Some(expected_last_updated),
185 );
186 }
187
188 let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
189 source: Box::new(e),
190 })?;
191 let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
192 source: Box::new(e),
193 })?;
194
195 let updated = self.update(
196 ent.id(),
197 dyn_clone::clone_box(ent as &dyn Ent),
198 Some(expected_last_updated),
199 )?;
200
201 if updated {
202 for edge in edge0 {
204 self.0
205 .execute(
206 "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
207 params![edge.source as i64, edge.sort_key, edge.dest as i64],
208 )
209 .map_err(|e| DatabaseError::Other {
210 source: Box::new(e),
211 })?;
212 }
213
214 for edge in edge1 {
216 self.create_edge(edge)?;
217 }
218 }
219
220 Ok(updated)
221 }
222
223 fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
224 let id = self.insert(&ent)?;
225 ent.set_id(id);
226 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
227 source: Box::new(e),
228 })?;
229 Ok(id)
230 }
231
232 fn commit(self) -> Result<(), DatabaseError> {
233 self.0.commit().map_err(|e| DatabaseError::Other {
234 source: Box::new(e),
235 })
236 }
237}
238
239impl<'conn> AdminEnt for Txn<'conn> {
240 fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
241 let mut stmt = self
242 .0
243 .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
244 .map_err(|e| DatabaseError::Other {
245 source: Box::new(e),
246 })?;
247
248 let rows = stmt
249 .query_map(params![dest as i64], |row| {
250 let source: i64 = row.get(0)?;
251 let sort_key: Vec<u8> = match row.get_ref(1)? {
252 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
253 s.to_vec()
254 }
255 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
256 b.to_vec()
257 }
258 _ => {
259 return Err(
260 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
261 1,
262 "type".into(),
263 row.get_ref(1)?.data_type(),
264 ),
265 )
266 }
267 };
268 let dest: i64 = row.get(2)?;
269 Ok(Edge::new(source as Id, sort_key, dest as Id))
270 })
271 .map_err(|e| DatabaseError::Other {
272 source: Box::new(e),
273 })?;
274
275 rows.collect::<Result<Vec<_>, _>>()
276 .map_err(|e| DatabaseError::Other {
277 source: Box::new(e),
278 })
279 }
280
281 fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
282 self.0
283 .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
284 .map_err(|e| DatabaseError::Other {
285 source: Box::new(e),
286 })?;
287 Ok(())
288 }
289
290 fn list_entities(
291 &self,
292 entity_type: &str,
293 cursor: Option<Id>,
294 limit: usize,
295 ) -> Result<Vec<Box<dyn Ent>>, DatabaseError> {
296 let sql = "SELECT id, data FROM entities WHERE type = ?1 AND id > ?2 ORDER BY id ASC LIMIT ?3";
297 let cursor_val = cursor.unwrap_or(0);
298
299 let mut stmt =
300 self.0.prepare(sql).map_err(|e| DatabaseError::Other {
301 source: Box::new(e),
302 })?;
303
304 let rows = stmt
305 .query_map(
306 params![entity_type, cursor_val as i64, limit as i64],
307 |row| {
308 let id: Id = row.get::<_, i64>(0)? as Id;
309 let data_json: &str = row.get_ref(1)?.as_str()?;
310 let mut ret =
311 serde_json::from_str::<Box<dyn Ent>>(data_json)
312 .expect("failed to parse JSON");
313 ret.set_id(id);
314 Ok(ret)
315 },
316 )
317 .map_err(|e| DatabaseError::Other {
318 source: Box::new(e),
319 })?;
320
321 rows.collect::<Result<Vec<_>, _>>()
322 .map_err(|e| DatabaseError::Other {
323 source: Box::new(e),
324 })
325 }
326}
327
328impl<'conn> QueryEdge for Txn<'conn> {
329 fn find_edges(
330 &self,
331 source: Id,
332 query: EdgeQuery,
333 ) -> Result<EdgeQueryResult, DatabaseError> {
334 let name_filter = if query.edge_names.is_empty() {
336 String::new()
337 } else {
338 let placeholders = query
339 .edge_names
340 .iter()
341 .map(|_| "?")
342 .collect::<Vec<_>>()
343 .join(", ");
344 format!(" AND type IN ({})", placeholders)
345 };
346
347 let cursor_filter = match (&query.cursor, query.order) {
349 (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
350 (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
351 (None, _) => "",
352 };
353
354 let order_clause = match query.order {
356 SortOrder::Asc => "ORDER BY type ASC, dest ASC",
357 SortOrder::Desc => "ORDER BY type DESC, dest DESC",
358 };
359
360 let sql = format!(
362 "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 101",
363 name_filter, cursor_filter, order_clause
364 );
365
366 let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
368 params.push(Box::new(source as i64));
369
370 for name in query.edge_names {
371 params.push(Box::new(name.to_vec()));
372 }
373
374 if let Some(cursor) = query.cursor {
375 params.push(Box::new(cursor.sort_key.to_vec()));
376 params.push(Box::new(cursor.destination as i64));
377 }
378
379 let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
380 params.iter().map(|p| p.as_ref()).collect();
381
382 let mut stmt =
383 self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
384 source: Box::new(e),
385 })?;
386
387 let rows = stmt
388 .query_map(params_refs.as_slice(), |row| {
389 let source: i64 = row.get(0)?;
390 let sort_key: Vec<u8> = match row.get_ref(1)? {
391 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
392 s.to_vec()
393 }
394 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
395 b.to_vec()
396 }
397 _ => {
398 return Err(
399 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
400 1,
401 "type".into(),
402 row.get_ref(1)?.data_type(),
403 ),
404 )
405 }
406 };
407 let destination: i64 = row.get(2)?;
408 Ok(Edge::new(source as Id, sort_key, destination as Id))
409 })
410 .map_err(|e| DatabaseError::Other {
411 source: Box::new(e),
412 })?;
413
414 let mut edges: Vec<Edge> = rows
415 .collect::<Result<Vec<_>, _>>()
416 .map_err(|e| DatabaseError::Other {
417 source: Box::new(e),
418 })?;
419
420 let has_more = edges.len() > 100;
421 if has_more {
422 edges.truncate(100);
423 }
424
425 Ok(EdgeQueryResult { edges, has_more })
426 }
427}
428
429#[derive(Clone)]
430pub struct SqliteDb {
431 pool: Pool<SqliteConnectionManager>,
432}
433
434impl SqliteDb {
435 pub fn new(pool: Pool<SqliteConnectionManager>) -> Self {
436 Self { pool }
437 }
438}
439
440impl From<Pool<SqliteConnectionManager>> for SqliteDb {
441 fn from(pool: Pool<SqliteConnectionManager>) -> Self {
442 Self::new(pool)
443 }
444}
445
446impl TransactionProvider for SqliteDb {
447 type Tx<'a> = Txn<'a>;
448
449 fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
450 where
451 F: for<'b> FnOnce(Self::Tx<'b>) -> R,
452 {
453 let mut conn = self.pool.get().map_err(|e| DatabaseError::Other {
454 source: Box::new(e),
455 })?;
456
457 let ret = func(Txn::new(conn.transaction().map_err(|e| {
458 DatabaseError::Other {
459 source: Box::new(e),
460 }
461 })?));
462
463 Ok(ret)
464 }
465}