1use std::borrow::BorrowMut;
2
3use ents::{
4 DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeValue, Ent, Id,
5 IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
6};
7use ents_admin::AdminEdgeByDest;
8use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
9
10pub struct Txn<'conn>(Transaction<'conn>);
11
12impl<'conn> Txn<'conn> {
13 pub fn new(tx: Transaction<'conn>) -> Self {
14 Self(tx)
15 }
16
17 fn update(
18 &self,
19 id: Id,
20 ent: Box<dyn Ent>,
21 expected_last_updated: Option<u64>,
22 ) -> Result<bool, DatabaseError> {
23 let entity_type = ent.typetag_name().to_string();
25 let data_json =
26 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
27 source: Box::new(e),
28 })?;
29
30 let rows_affected = self
32 .0
33 .execute(
34 r#"
35 UPDATE entities SET data = ?1, type = ?2
36 WHERE
37 id = ?3 AND
38 (
39 JSON_EXTRACT(data, '$.last_updated') = ?4 OR
40 ?4 IS NULL
41 )
42 "#,
43 params![
44 data_json,
45 entity_type,
46 id as i64,
47 expected_last_updated.map(|v| v as i64)
48 ],
49 )
50 .map_err(|e| DatabaseError::Other {
51 source: Box::new(e),
52 })?;
53
54 Ok(rows_affected > 0)
55 }
56}
57
58impl<'conn> Txn<'conn> {
59 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
60 let entity_type = ent.typetag_name().to_string();
62
63 let data_json =
65 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
66 DatabaseError::Other {
67 source: Box::new(e),
68 }
69 })?;
70
71 self.0
72 .execute(
73 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
74 params![entity_type, data_json],
75 )
76 .map_err(|e| DatabaseError::Other {
77 source: Box::new(e),
78 })?;
79
80 let inserted_id = self.0.last_insert_rowid() as Id;
81
82 Ok(inserted_id)
83 }
84}
85
86impl<'conn> ReadEnt for Txn<'conn> {
87 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
88 let mut stmt = self
89 .0
90 .prepare("SELECT id, data FROM entities WHERE id = ?1")
91 .map_err(|e| DatabaseError::Other {
92 source: Box::new(e),
93 })?;
94
95 stmt.query_row(params![id as i64], |row| {
96 let id: Id = row.get::<_, i64>(0)? as Id;
97 let data_json: &str = row.get_ref(1)?.as_str()?;
98 let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
99 .expect("failed to parse JSON");
100 ret.set_id(id);
101 Ok(ret)
102 })
103 .optional()
104 .map_err(|e| DatabaseError::Other {
105 source: Box::new(e),
106 })
107 }
108}
109
110impl<'conn> Transactional for Txn<'conn> {
111 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
112 let source = edge.source;
113 let sort_key = edge.sort_key;
114 let dest = edge.dest;
115
116 self.0
117 .execute(
118 "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
119 params![source as i64, sort_key, dest as i64],
120 )
121 .map_err(|e| DatabaseError::Other {
122 source: Box::new(e),
123 })?;
124
125 Ok(())
126 }
127
128 fn delete<E: Ent>(&self, id: Id) -> Result<(), DatabaseError> {
129 self.0
130 .prepare_cached(
131 r#"
132 DELETE FROM edges WHERE dest = ?1;
133 "#,
134 )
135 .map_err(|e| DatabaseError::Other {
136 source: Box::new(e),
137 })?
138 .execute(params![id as i64])
139 .map_err(|e| DatabaseError::Other {
140 source: Box::new(e),
141 })?;
142
143 self.0
144 .prepare_cached(
145 r#"
146 DELETE FROM entities WHERE id = ?1;
147 "#,
148 )
149 .map_err(|e| DatabaseError::Other {
150 source: Box::new(e),
151 })?
152 .execute(params![id as i64])
153 .map_err(|e| DatabaseError::Other {
154 source: Box::new(e),
155 })?;
156
157 Ok(())
158 }
159
160 fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
161 &self,
162 mut ent0: B,
163 mutator: F,
164 ) -> Result<bool, DatabaseError> {
165 let ent = ent0.borrow_mut();
166 let draft0 = T::EdgeProvider::draft(ent);
167 let expected_last_updated = ent.last_updated();
168
169 mutator(ent);
170 ent.mark_updated().map_err(|e| DatabaseError::Other {
171 source: Box::new(e),
172 })?;
173
174 let draft1 = T::EdgeProvider::draft(ent);
175
176 if draft0 == draft1 {
178 return self.update(
179 ent.id(),
180 dyn_clone::clone_box(ent),
181 Some(expected_last_updated),
182 );
183 }
184
185 let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
186 source: Box::new(e),
187 })?;
188 let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
189 source: Box::new(e),
190 })?;
191
192 let updated = self.update(
193 ent.id(),
194 dyn_clone::clone_box(ent),
195 Some(expected_last_updated),
196 )?;
197
198 if updated {
199 for edge in edge0 {
201 self.0
202 .execute(
203 "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
204 params![edge.source as i64, edge.sort_key, edge.dest as i64],
205 )
206 .map_err(|e| DatabaseError::Other {
207 source: Box::new(e),
208 })?;
209 }
210
211 for edge in edge1 {
213 self.create_edge(edge)?;
214 }
215 }
216
217 Ok(updated)
218 }
219
220 fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
221 let id = self.insert(&ent)?;
222 ent.set_id(id);
223 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
224 source: Box::new(e),
225 })?;
226 Ok(id)
227 }
228
229 fn commit(self) -> Result<(), DatabaseError> {
230 self.0.commit().map_err(|e| DatabaseError::Other {
231 source: Box::new(e),
232 })
233 }
234}
235
236impl<'conn> AdminEdgeByDest for Txn<'conn> {
237 fn find_edges_by_dest(&self, dest: Id) -> Result<Vec<Edge>, DatabaseError> {
238 let mut stmt = self
239 .0
240 .prepare("SELECT source, type, dest FROM edges WHERE dest = ?1 ORDER BY source, type")
241 .map_err(|e| DatabaseError::Other {
242 source: Box::new(e),
243 })?;
244
245 let rows = stmt
246 .query_map(params![dest as i64], |row| {
247 let source: i64 = row.get(0)?;
248 let sort_key: Vec<u8> = match row.get_ref(1)? {
249 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
250 s.to_vec()
251 }
252 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
253 b.to_vec()
254 }
255 _ => {
256 return Err(
257 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
258 1,
259 "type".into(),
260 row.get_ref(1)?.data_type(),
261 ),
262 )
263 }
264 };
265 let dest: i64 = row.get(2)?;
266 Ok(Edge::new(source as Id, sort_key, dest as Id))
267 })
268 .map_err(|e| DatabaseError::Other {
269 source: Box::new(e),
270 })?;
271
272 rows.collect::<Result<Vec<_>, _>>()
273 .map_err(|e| DatabaseError::Other {
274 source: Box::new(e),
275 })
276 }
277
278 fn remove_edges_by_dest(&self, dest: Id) -> Result<(), DatabaseError> {
279 self.0
280 .execute("DELETE FROM edges WHERE dest = ?1", params![dest as i64])
281 .map_err(|e| DatabaseError::Other {
282 source: Box::new(e),
283 })?;
284 Ok(())
285 }
286}
287
288impl<'conn> QueryEdge for Txn<'conn> {
289 fn find_edges(
290 &self,
291 source: Id,
292 query: EdgeQuery,
293 ) -> Result<Vec<Edge>, DatabaseError> {
294 let name_filter = if query.edge_names.is_empty() {
296 String::new()
297 } else {
298 let placeholders = query
299 .edge_names
300 .iter()
301 .map(|_| "?")
302 .collect::<Vec<_>>()
303 .join(", ");
304 format!(" AND type IN ({})", placeholders)
305 };
306
307 let cursor_filter = match (&query.cursor, query.order) {
309 (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
310 (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
311 (None, _) => "",
312 };
313
314 let order_clause = match query.order {
316 SortOrder::Asc => "ORDER BY type ASC, dest ASC",
317 SortOrder::Desc => "ORDER BY type DESC, dest DESC",
318 };
319
320 let sql = format!(
321 "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 100",
322 name_filter, cursor_filter, order_clause
323 );
324
325 let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
327 params.push(Box::new(source as i64));
328
329 for name in query.edge_names {
330 params.push(Box::new(name.to_vec()));
331 }
332
333 if let Some(cursor) = query.cursor {
334 params.push(Box::new(cursor.sort_key.to_vec()));
335 params.push(Box::new(cursor.destination as i64));
336 }
337
338 let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
339 params.iter().map(|p| p.as_ref()).collect();
340
341 let mut stmt =
342 self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
343 source: Box::new(e),
344 })?;
345
346 let rows = stmt
347 .query_map(params_refs.as_slice(), |row| {
348 let source: i64 = row.get(0)?;
349 let sort_key: Vec<u8> = match row.get_ref(1)? {
350 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
351 s.to_vec()
352 }
353 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
354 b.to_vec()
355 }
356 _ => {
357 return Err(
358 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
359 1,
360 "type".into(),
361 row.get_ref(1)?.data_type(),
362 ),
363 )
364 }
365 };
366 let destination: i64 = row.get(2)?;
367 Ok(Edge::new(source as Id, sort_key, destination as Id))
368 })
369 .map_err(|e| DatabaseError::Other {
370 source: Box::new(e),
371 })?;
372
373 rows.collect::<Result<Vec<_>, _>>()
374 .map_err(|e| DatabaseError::Other {
375 source: Box::new(e),
376 })
377 }
378}