1use std::borrow::BorrowMut;
2
3use ents::Edge;
4use ents::{
5 DatabaseError, EdgeDraft, EdgeProvider, EdgeQuery, EdgeValue, Ent, EntWithEdges, Id, QueryEdge,
6 SortOrder, Transactional,
7};
8use r2d2_sqlite::rusqlite::{params, OptionalExtension, Transaction};
9
10pub struct Txn<'conn>(Transaction<'conn>);
11
12impl<'conn> Txn<'conn> {
13 pub fn new(tx: Transaction<'conn>) -> Self {
14 Self(tx)
15 }
16
17 fn update(
18 &self,
19 id: Id,
20 ent: Box<dyn Ent>,
21 expected_last_updated: Option<u64>,
22 ) -> Result<bool, DatabaseError> {
23 let entity_type = ent.typetag_name().to_string();
25 let data_json = serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
26 source: Box::new(e),
27 })?;
28
29 let rows_affected = self
31 .0
32 .execute(
33 r#"
34 UPDATE entities SET data = ?1, type = ?2
35 WHERE
36 id = ?3 AND
37 (
38 JSON_EXTRACT(data, '$.last_updated') = ?4 OR
39 ?4 IS NULL
40 )
41 "#,
42 params![data_json, entity_type, id as i64, expected_last_updated],
43 )
44 .map_err(|e| DatabaseError::Other {
45 source: Box::new(e),
46 })?;
47
48 dbg!(&rows_affected, &id);
49
50 Ok(rows_affected > 0)
51 }
52}
53
54impl<'conn> Txn<'conn> {
55 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
56 let entity_type = ent.typetag_name().to_string();
58
59 let data_json =
61 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| DatabaseError::Other {
62 source: Box::new(e),
63 })?;
64
65 self.0
66 .execute(
67 "INSERT INTO entities (type, data) VALUES (?1, ?2)",
68 params![entity_type, data_json],
69 )
70 .map_err(|e| DatabaseError::Other {
71 source: Box::new(e),
72 })?;
73
74 let inserted_id = self.0.last_insert_rowid() as Id;
75
76 Ok(inserted_id)
77 }
78}
79
80impl<'conn> Transactional for Txn<'conn> {
81 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
82 let mut stmt = self
83 .0
84 .prepare("SELECT id, data FROM entities WHERE id = ?1")
85 .map_err(|e| DatabaseError::Other {
86 source: Box::new(e),
87 })?;
88
89 stmt.query_row(params![id as i64], |row| {
90 let id: Id = row.get(0)?;
91 let data_json: &str = row.get_ref(1)?.as_str()?;
92 let mut ret =
93 serde_json::from_str::<Box<dyn Ent>>(data_json).expect("failed to parse JSON");
94 ret.set_id(id);
95 Ok(ret)
96 })
97 .optional()
98 .map_err(|e| DatabaseError::Other {
99 source: Box::new(e),
100 })
101 }
102
103 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
104 let source = edge.source;
105 let sort_key = edge.sort_key;
106 let dest = edge.dest;
107
108 self.0
109 .execute(
110 "INSERT INTO edges (source, type, dest) VALUES (?1, ?2, ?3)",
111 params![source as i64, sort_key, dest as i64],
112 )
113 .map_err(|e| DatabaseError::Other {
114 source: Box::new(e),
115 })?;
116
117 Ok(())
118 }
119
120 fn delete<E: Ent + EntWithEdges>(&self, id: Id) -> Result<(), DatabaseError> {
121 self.0
122 .prepare_cached(
123 r#"
124 DELETE FROM edges WHERE dest = ?1;
125 "#,
126 )
127 .map_err(|e| DatabaseError::Other {
128 source: Box::new(e),
129 })?
130 .execute(params![id])
131 .map_err(|e| DatabaseError::Other {
132 source: Box::new(e),
133 })?;
134
135 self.0
136 .prepare_cached(
137 r#"
138 DELETE FROM entities WHERE id = ?1;
139 "#,
140 )
141 .map_err(|e| DatabaseError::Other {
142 source: Box::new(e),
143 })?
144 .execute(params![id])
145 .map_err(|e| DatabaseError::Other {
146 source: Box::new(e),
147 })?;
148
149 Ok(())
150 }
151
152 fn update<T: Ent + EntWithEdges + 'static, F: FnOnce(&mut T), B: BorrowMut<T>>(
153 &self,
154 mut ent0: B,
155 mutator: F,
156 ) -> Result<bool, DatabaseError> {
157 let ent = ent0.borrow_mut();
158 let draft0 = T::EdgeProvider::draft(ent);
159 let expected_last_updated = ent.last_updated();
160
161 mutator(ent);
162 ent.mark_updated().map_err(|e| DatabaseError::Other {
163 source: Box::new(e),
164 })?;
165
166 let draft1 = T::EdgeProvider::draft(ent);
167
168 if draft0 == draft1 {
170 return self.update(
171 ent.id(),
172 dyn_clone::clone_box(ent),
173 Some(expected_last_updated),
174 );
175 }
176
177 let edge0 = draft0.check(self).map_err(|e| DatabaseError::Other {
178 source: Box::new(e),
179 })?;
180 let edge1 = draft1.check(self).map_err(|e| DatabaseError::Other {
181 source: Box::new(e),
182 })?;
183
184 let updated = self.update(
185 ent.id(),
186 dyn_clone::clone_box(ent),
187 Some(expected_last_updated),
188 )?;
189
190 if updated {
191 for edge in edge0 {
193 self.0
194 .execute(
195 "DELETE FROM edges WHERE source = ?1 AND type = ?2 AND dest = ?3",
196 params![edge.source as i64, edge.sort_key, edge.dest as i64],
197 )
198 .map_err(|e| DatabaseError::Other {
199 source: Box::new(e),
200 })?;
201 }
202
203 for edge in edge1 {
205 self.create_edge(edge)?;
206 }
207 }
208
209 Ok(updated)
210 }
211
212 fn create<E: Ent + EntWithEdges>(&self, mut ent: E) -> Result<Id, DatabaseError> {
213 let id = self.insert(&ent)?;
214 ent.set_id(id);
215 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
216 source: Box::new(e),
217 })?;
218 Ok(id)
219 }
220
221 fn commit(self) -> Result<(), DatabaseError> {
222 self.0.commit().map_err(|e| DatabaseError::Other {
223 source: Box::new(e),
224 })
225 }
226}
227
228impl<'conn> QueryEdge for Txn<'conn> {
229 fn find_edges(&self, source: Id, query: EdgeQuery) -> Result<Vec<Edge>, DatabaseError> {
230 let name_filter = if query.edge_names.is_empty() {
232 String::new()
233 } else {
234 let placeholders = query
235 .edge_names
236 .iter()
237 .map(|_| "?")
238 .collect::<Vec<_>>()
239 .join(", ");
240 format!(" AND type IN ({})", placeholders)
241 };
242
243 let cursor_filter = match (&query.cursor, query.order) {
245 (Some(_), SortOrder::Asc) => " AND (type, dest) > (?, ?)",
246 (Some(_), SortOrder::Desc) => " AND (type, dest) < (?, ?)",
247 (None, _) => "",
248 };
249
250 let order_clause = match query.order {
252 SortOrder::Asc => "ORDER BY type ASC, dest ASC",
253 SortOrder::Desc => "ORDER BY type DESC, dest DESC",
254 };
255
256 let sql = format!(
257 "SELECT source, type, dest FROM edges WHERE source = ?{}{} {} LIMIT 100",
258 name_filter, cursor_filter, order_clause
259 );
260
261 let mut params: Vec<Box<dyn r2d2_sqlite::rusqlite::ToSql>> = Vec::new();
263 params.push(Box::new(source));
264
265 for name in query.edge_names {
266 params.push(Box::new(name.to_vec()));
267 }
268
269 if let Some(cursor) = query.cursor {
270 params.push(Box::new(cursor.sort_key.to_vec()));
271 params.push(Box::new(cursor.destination));
272 }
273
274 let params_refs: Vec<&dyn r2d2_sqlite::rusqlite::ToSql> =
275 params.iter().map(|p| p.as_ref()).collect();
276
277 let mut stmt = self.0.prepare(&sql).map_err(|e| DatabaseError::Other {
278 source: Box::new(e),
279 })?;
280
281 let rows = stmt
282 .query_map(params_refs.as_slice(), |row| {
283 let source: i64 = row.get(0)?;
284 let sort_key: Vec<u8> = match row.get_ref(1)? {
285 r2d2_sqlite::rusqlite::types::ValueRef::Text(s) => s.to_vec(),
286 r2d2_sqlite::rusqlite::types::ValueRef::Blob(b) => b.to_vec(),
287 _ => {
288 return Err(r2d2_sqlite::rusqlite::Error::InvalidColumnType(
289 1,
290 "type".into(),
291 row.get_ref(1)?.data_type(),
292 ))
293 }
294 };
295 let destination: i64 = row.get(2)?;
296 Ok(Edge::new(source as Id, sort_key, destination as Id))
297 })
298 .map_err(|e| DatabaseError::Other {
299 source: Box::new(e),
300 })?;
301
302 rows.collect::<Result<Vec<_>, _>>()
303 .map_err(|e| DatabaseError::Other {
304 source: Box::new(e),
305 })
306 }
307}