1use std::borrow::BorrowMut;
2
3use ents::Edge;
4use ents::{
5 DatabaseError, EdgeDraft, EdgeProvider, EdgeQuery, EdgeValue, Ent,
6 EntWithEdges, Id, QueryEdge, SortOrder, Transactional,
7};
8use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
9
10pub struct Txn<'conn>(Transaction<'conn>);
11
12impl<'conn> Txn<'conn> {
13 pub fn new(tx: Transaction<'conn>) -> Self {
14 Self(tx)
15 }
16
17 fn update(
18 &self,
19 id: Id,
20 ent: Box<dyn Ent>,
21 expected_last_updated: Option<u64>,
22 ) -> Result<bool, DatabaseError> {
23 let entity_type = ent.typetag_name().to_string();
25 let data_json =
26 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
27 source: Box::new(e),
28 })?;
29
30 let rows_affected = self
32 .0
33 .execute(
34 r#"
35 UPDATE entities SET data = ?1, type = ?2
36 WHERE
37 id = ?3 AND
38 (
39 JSON_EXTRACT(data, '$.last_updated') = ?4 OR
40 ?4 IS NULL
41 )
42 "#,
43 params![
44 data_json,
45 entity_type,
46 id as i64,
47 expected_last_updated.map(|v| v as i64)
48 ],
49 )
50 .map_err(|e| DatabaseError::Other {
51 source: Box::new(e),
52 })?;
53
54 Ok(rows_affected > 0)
55 }
56}
57
58impl<'conn> Txn<'conn> {
59 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
60 let entity_type = ent.typetag_name().to_string();
62
63 let data_json =
65 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
66 DatabaseError::Other {
67 source: Box::new(e),
68 }
69 })?;
70
71 self.0
72 .execute(
73 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
74 params![entity_type, data_json],
75 )
76 .map_err(|e| DatabaseError::Other {
77 source: Box::new(e),
78 })?;
79
80 let inserted_id = self.0.last_insert_rowid() as Id;
81
82 Ok(inserted_id)
83 }
84}
85
86impl<'conn> Transactional for Txn<'conn> {
87 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
88 let mut stmt = self
89 .0
90 .prepare("SELECT id, data FROM entities WHERE id = ?1")
91 .map_err(|e| DatabaseError::Other {
92 source: Box::new(e),
93 })?;
94
95 stmt.query_row(params![id as i64], |row| {
96 let id: Id = row.get::<_, i64>(0)? as Id;
97 let data_json: &str = row.get_ref(1)?.as_str()?;
98 let mut ret = serde_json::from_str::<Box<dyn Ent>>(data_json)
99 .expect("failed to parse JSON");
100 ret.set_id(id);
101 Ok(ret)
102 })
103 .optional()
104 .map_err(|e| DatabaseError::Other {
105 source: Box::new(e),
106 })
107 }
108
109 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
110 let source = edge.source;
111 let sort_key = edge.sort_key;
112 let dest = edge.dest;
113
114 self.0
115 .execute(
116 "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
117 params![source as i64, sort_key, dest as i64],
118 )
119 .map_err(|e| DatabaseError::Other {
120 source: Box::new(e),
121 })?;
122
123 Ok(())
124 }
125
126 fn delete<E: Ent + EntWithEdges>(
127 &self,
128 id: Id,
129 ) -> Result<(), DatabaseError> {
130 self.0
131 .prepare_cached(
132 r#"
133 DELETE FROM edges WHERE dest = ?1;
134 "#,
135 )
136 .map_err(|e| DatabaseError::Other {
137 source: Box::new(e),
138 })?
139 .execute(params![id as i64])
140 .map_err(|e| DatabaseError::Other {
141 source: Box::new(e),
142 })?;
143
144 self.0
145 .prepare_cached(
146 r#"
147 DELETE FROM entities WHERE id = ?1;
148 "#,
149 )
150 .map_err(|e| DatabaseError::Other {
151 source: Box::new(e),
152 })?
153 .execute(params![id as i64])
154 .map_err(|e| DatabaseError::Other {
155 source: Box::new(e),
156 })?;
157
158 Ok(())
159 }
160
161 fn update<T: EntWithEdges, F: FnOnce(&mut T), B: BorrowMut<T>>(
162 &self,
163 mut ent0: B,
164 mutator: F,
165 ) -> Result<bool, DatabaseError> {
166 let ent = ent0.borrow_mut();
167 let draft0 = T::EdgeProvider::draft(ent);
168 let expected_last_updated = ent.last_updated();
169
170 mutator(ent);
171 ent.mark_updated().map_err(|e| DatabaseError::Other {
172 source: Box::new(e),
173 })?;
174
175 let draft1 = T::EdgeProvider::draft(ent);
176
177 if draft0 == draft1 {
179 return self.update(
180 ent.id(),
181 dyn_clone::clone_box(ent),
182 Some(expected_last_updated),
183 );
184 }
185
186 let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
187 source: Box::new(e),
188 })?;
189 let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
190 source: Box::new(e),
191 })?;
192
193 let updated = self.update(
194 ent.id(),
195 dyn_clone::clone_box(ent),
196 Some(expected_last_updated),
197 )?;
198
199 if updated {
200 for edge in edge0 {
202 self.0
203 .execute(
204 "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
205 params![edge.source as i64, edge.sort_key, edge.dest as i64],
206 )
207 .map_err(|e| DatabaseError::Other {
208 source: Box::new(e),
209 })?;
210 }
211
212 for edge in edge1 {
214 self.create_edge(edge)?;
215 }
216 }
217
218 Ok(updated)
219 }
220
221 fn create<E: Ent + EntWithEdges>(
222 &self,
223 mut ent: E,
224 ) -> Result<Id, DatabaseError> {
225 let id = self.insert(&ent)?;
226 ent.set_id(id);
227 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
228 source: Box::new(e),
229 })?;
230 Ok(id)
231 }
232
233 fn commit(self) -> Result<(), DatabaseError> {
234 self.0.commit().map_err(|e| DatabaseError::Other {
235 source: Box::new(e),
236 })
237 }
238}
239
240impl<'conn> QueryEdge for Txn<'conn> {
241 fn find_edges(
242 &self,
243 source: Id,
244 query: EdgeQuery,
245 ) -> Result<Vec<Edge>, DatabaseError> {
246 let name_filter = if query.edge_names.is_empty() {
248 String::new()
249 } else {
250 let placeholders = query
251 .edge_names
252 .iter()
253 .map(|_| "?")
254 .collect::<Vec<_>>()
255 .join(", ");
256 format!(" AND type IN ({})", placeholders)
257 };
258
259 let cursor_filter = match (&query.cursor, query.order) {
261 (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
262 (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
263 (None, _) => "",
264 };
265
266 let order_clause = match query.order {
268 SortOrder::Asc => "ORDER BY type ASC, dest ASC",
269 SortOrder::Desc => "ORDER BY type DESC, dest DESC",
270 };
271
272 let sql = format!(
273 "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 100",
274 name_filter, cursor_filter, order_clause
275 );
276
277 let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
279 params.push(Box::new(source as i64));
280
281 for name in query.edge_names {
282 params.push(Box::new(name.to_vec()));
283 }
284
285 if let Some(cursor) = query.cursor {
286 params.push(Box::new(cursor.sort_key.to_vec()));
287 params.push(Box::new(cursor.destination as i64));
288 }
289
290 let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
291 params.iter().map(|p| p.as_ref()).collect();
292
293 let mut stmt =
294 self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
295 source: Box::new(e),
296 })?;
297
298 let rows = stmt
299 .query_map(params_refs.as_slice(), |row| {
300 let source: i64 = row.get(0)?;
301 let sort_key: Vec<u8> = match row.get_ref(1)? {
302 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => {
303 s.to_vec()
304 }
305 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => {
306 b.to_vec()
307 }
308 _ => {
309 return Err(
310 r2d2_sqlite::rusqlite::Error::InvalidColumnType(
311 1,
312 "type".into(),
313 row.get_ref(1)?.data_type(),
314 ),
315 )
316 }
317 };
318 let destination: i64 = row.get(2)?;
319 Ok(Edge::new(source as Id, sort_key, destination as Id))
320 })
321 .map_err(|e| DatabaseError::Other {
322 source: Box::new(e),
323 })?;
324
325 rows.collect::<Result<Vec<_>, _>>()
326 .map_err(|e| DatabaseError::Other {
327 source: Box::new(e),
328 })
329 }
330}