1use std::path::{Path, PathBuf};
10use std::sync::{Arc, Mutex};
11
12use ::sqlite::Connection;
13use crossbeam::sync::{ShardedLock, ShardedLockReadGuard, ShardedLockWriteGuard};
14use either::{Either, Left, Right};
15use miette::{bail, miette, IntoDiagnostic, Result};
16use sqlite::{ConnectionThreadSafe, State, Statement};
17
18use crate::data::tuple::{check_key_for_validity, Tuple};
19use crate::data::value::ValidityTs;
20use crate::runtime::relation::{decode_tuple_from_kv, extend_tuple_from_v};
21use crate::storage::{Storage, StoreTx};
22use crate::utils::swap_option_result;
23
24#[derive(Clone)]
26pub struct SqliteStorage {
27 lock: Arc<ShardedLock<()>>,
28 name: PathBuf,
29 pool: Arc<Mutex<Vec<ConnectionThreadSafe>>>,
30}
31
32pub fn new_cozo_sqlite(path: impl AsRef<Path>) -> Result<crate::Db<SqliteStorage>> {
38 if path.as_ref().to_str() == Some("") {
39 bail!("empty path for sqlite storage")
40 }
41 let conn = Connection::open_thread_safe(&path).into_diagnostic()?;
42 let query = r#"
43 create table if not exists cozo
44 (
45 k BLOB primary key,
46 v BLOB
47 );
48 "#;
49 let mut statement = conn.prepare(query).unwrap();
50 while statement.next().into_diagnostic()? != State::Done {}
51
52 let ret = crate::Db::new(SqliteStorage {
53 lock: Default::default(),
54 name: PathBuf::from(path.as_ref()),
55 pool: Default::default(),
56 })?;
57
58 ret.initialize()?;
59 Ok(ret)
60}
61
62impl<'s> Storage<'s> for SqliteStorage {
63 type Tx = SqliteTx<'s>;
64
65 fn transact(&'s self, write: bool) -> Result<Self::Tx> {
66 let conn = {
67 match self.pool.lock().unwrap().pop() {
68 None => Connection::open_thread_safe(&self.name).into_diagnostic()?,
69 Some(conn) => conn,
70 }
71 };
72 let lock = if write {
73 Right(self.lock.write().unwrap())
74 } else {
75 Left(self.lock.read().unwrap())
76 };
77 if write {
78 let mut stmt = conn.prepare("begin;").into_diagnostic()?;
79 while stmt.next().into_diagnostic()? != State::Done {}
80 }
81 Ok(SqliteTx {
82 lock,
83 storage: self,
84 conn: Some(conn),
85 stmts: [
86 Mutex::new(None),
87 Mutex::new(None),
88 Mutex::new(None),
89 Mutex::new(None),
90 ],
91 committed: false,
92 })
93 }
94
95 fn batch_put<'a>(
96 &'a self,
97 data: Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>,
98 ) -> Result<()> {
99 let mut tx = self.transact(true)?;
100 for result in data {
101 let (key, val) = result?;
102 tx.put(&key, &val)?;
103 }
104 tx.commit()?;
105 Ok(())
106 }
107
108 fn range_compact(&'_ self, _lower: &[u8], _upper: &[u8]) -> Result<()> {
109 let mut pool = self.pool.lock().unwrap();
110 while pool.pop().is_some() {}
111 Ok(())
112 }
113
114 fn storage_kind(&self) -> &'static str {
115 "sqlite"
116 }
117}
118
119pub struct SqliteTx<'a> {
120 lock: Either<ShardedLockReadGuard<'a, ()>, ShardedLockWriteGuard<'a, ()>>,
121 storage: &'a SqliteStorage,
122 conn: Option<ConnectionThreadSafe>,
123 stmts: [Mutex<Option<Statement<'a>>>; N_CACHED_QUERIES],
124 committed: bool,
125}
126
127unsafe impl Sync for SqliteTx<'_> {}
128
129const N_QUERIES: usize = 7;
130const N_CACHED_QUERIES: usize = 4;
131const QUERIES: [&str; N_QUERIES] = [
132 "select v from cozo where k = ?;",
133 "insert into cozo(k, v) values (?, ?) on conflict(k) do update set v=excluded.v;",
134 "delete from cozo where k = ?;",
135 "select 1 from cozo where k = ?;",
136 "select k, v from cozo where k >= ? and k < ? order by k;",
137 "select k, v from cozo where k >= ? and k < ? order by k limit 1;",
138 "select count(*) from cozo where k >= ? and k < ?;",
139];
140
141const GET_QUERY: usize = 0;
142const PUT_QUERY: usize = 1;
143const DEL_QUERY: usize = 2;
144const EXISTS_QUERY: usize = 3;
145const RANGE_QUERY: usize = 4;
146const SKIP_RANGE_QUERY: usize = 5;
147const COUNT_RANGE_QUERY: usize = 6;
148
149impl Drop for SqliteTx<'_> {
150 fn drop(&mut self) {
151 if let Right(ShardedLockWriteGuard { .. }) = self.lock {
152 if !self.committed {
153 let query = r#"rollback;"#;
154 let _ = self.conn.as_ref().unwrap().execute(query);
155 }
156 }
157 let mut pool = self.storage.pool.lock().unwrap();
158 let conn = self.conn.take().unwrap();
159 pool.push(conn)
160 }
161}
162
163impl<'s> SqliteTx<'s> {
164 fn ensure_stmt(&self, idx: usize) {
165 let mut stmt = self.stmts[idx].lock().unwrap();
166 if stmt.is_none() {
167 let query = QUERIES[idx];
168 let prepared = self.conn.as_ref().unwrap().prepare(query).unwrap();
169
170 let prepared = unsafe { std::mem::transmute(prepared) };
174
175 *stmt = Some(prepared)
176 }
177 }
178}
179
180impl<'s> StoreTx<'s> for SqliteTx<'s> {
181 fn get(&self, key: &[u8], _for_update: bool) -> Result<Option<Vec<u8>>> {
182 self.ensure_stmt(GET_QUERY);
183 let mut statement = self.stmts[GET_QUERY].lock().unwrap();
184 let statement = statement.as_mut().unwrap();
185 statement.reset().unwrap();
186
187 statement.bind((1, key)).unwrap();
188 Ok(match statement.next().into_diagnostic()? {
189 State::Row => {
190 let res = statement.read::<Vec<u8>, _>(0).into_diagnostic()?;
191 Some(res)
192 }
193 State::Done => None,
194 })
195 }
196
197 fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
198 self.par_put(key, val)
199 }
200
201 fn supports_par_put(&self) -> bool {
202 true
203 }
204
205 fn par_put(&self, key: &[u8], val: &[u8]) -> Result<()> {
206 self.ensure_stmt(PUT_QUERY);
207 let mut statement = self.stmts[PUT_QUERY].lock().unwrap();
208 let statement = statement.as_mut().unwrap();
209 statement.reset().unwrap();
210
211 statement.bind((1, key)).unwrap();
212 statement.bind((2, val)).unwrap();
213 while statement.next().into_diagnostic()? != State::Done {}
214 Ok(())
215 }
216
217 fn del(&mut self, key: &[u8]) -> Result<()> {
218 self.par_del(key)
219 }
220
221 fn par_del(&self, key: &[u8]) -> Result<()> {
222 self.ensure_stmt(DEL_QUERY);
223 let mut statement = self.stmts[DEL_QUERY].lock().unwrap();
224 let statement = statement.as_mut().unwrap();
225 statement.reset().unwrap();
226
227 statement.bind((1, key)).unwrap();
228 while statement.next().into_diagnostic()? != State::Done {}
229
230 Ok(())
231 }
232
233 fn del_range_from_persisted(&mut self, lower: &[u8], upper: &[u8]) -> Result<()> {
234 let query = r#"
235 delete from cozo where k >= ? and k < ?;
236 "#;
237 let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
238
239 statement.bind((1, lower)).unwrap();
240 statement.bind((2, upper)).unwrap();
241 while statement.next().unwrap() != State::Done {}
242 Ok(())
243 }
244
245 fn exists(&self, key: &[u8], _for_update: bool) -> Result<bool> {
246 self.ensure_stmt(EXISTS_QUERY);
247 let mut statement = self.stmts[EXISTS_QUERY].lock().unwrap();
248 let statement = statement.as_mut().unwrap();
249 statement.reset().unwrap();
250
251 statement.bind((1, key)).unwrap();
252 Ok(match statement.next().into_diagnostic()? {
253 State::Row => true,
254 State::Done => false,
255 })
256 }
257
258 fn commit(&mut self) -> Result<()> {
259 if let Right(ShardedLockWriteGuard { .. }) = self.lock {
260 if !self.committed {
261 let query = r#"commit;"#;
262 let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
263 while statement.next().into_diagnostic()? != State::Done {}
264 self.committed = true;
265 } else {
266 bail!("multiple commits")
267 }
268 }
269 Ok(())
270 }
271
272 fn range_scan_tuple<'a>(
273 &'a self,
274 lower: &[u8],
275 upper: &[u8],
276 ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a>
277 where
278 's: 'a,
279 {
280 let query = QUERIES[RANGE_QUERY];
283 let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
284 statement.bind((1, lower)).unwrap();
285 statement.bind((2, upper)).unwrap();
286 Box::new(TupleIter(statement))
287 }
288
289 fn range_skip_scan_tuple<'a>(
290 &'a self,
291 lower: &[u8],
292 upper: &[u8],
293 valid_at: ValidityTs,
294 ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a> {
295 let query = QUERIES[SKIP_RANGE_QUERY];
296 let statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
297 Box::new(SkipIter {
298 stmt: statement,
299 valid_at,
300 next_bound: lower.to_vec(),
301 upper_bound: upper.to_vec(),
302 })
303 }
304
305 fn range_scan<'a>(
306 &'a self,
307 lower: &[u8],
308 upper: &[u8],
309 ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
310 where
311 's: 'a,
312 {
313 let query = QUERIES[RANGE_QUERY];
314 let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
315 statement.bind((1, lower)).unwrap();
316 statement.bind((2, upper)).unwrap();
317 Box::new(RawIter(statement))
318 }
319
320 fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
321 where
322 's: 'a,
323 {
324 let query = QUERIES[COUNT_RANGE_QUERY];
325 let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
326 statement.bind((1, lower)).unwrap();
327 statement.bind((2, upper)).unwrap();
328 match statement.next() {
329 Ok(State::Done) => bail!("range count query returned no rows"),
330 Ok(State::Row) => {
331 let k = statement.read::<i64, _>(0).unwrap();
332 Ok(k as usize)
333 }
334 Err(err) => bail!(err),
335 }
336 }
337
338 fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
339 where
340 's: 'a,
341 {
342 let statement = self
343 .conn
344 .as_ref()
345 .unwrap()
346 .prepare("select k, v from cozo order by k;")
347 .unwrap();
348 Box::new(RawIter(statement))
349 }
350}
351
352struct TupleIter<'l>(Statement<'l>);
353
354impl<'l> Iterator for TupleIter<'l> {
355 type Item = Result<Tuple>;
356
357 fn next(&mut self) -> Option<Self::Item> {
358 match self.0.next() {
359 Ok(State::Done) => None,
360 Ok(State::Row) => {
361 let k = self.0.read::<Vec<u8>, _>(0).unwrap();
362 let v = self.0.read::<Vec<u8>, _>(1).unwrap();
363 let tuple = decode_tuple_from_kv(&k, &v, None);
364 Some(Ok(tuple))
365 }
366 Err(err) => Some(Err(miette!(err))),
367 }
368 }
369}
370
371struct RawIter<'l>(Statement<'l>);
372
373impl<'l> Iterator for RawIter<'l> {
374 type Item = Result<(Vec<u8>, Vec<u8>)>;
375
376 fn next(&mut self) -> Option<Self::Item> {
377 match self.0.next() {
378 Ok(State::Done) => None,
379 Ok(State::Row) => {
380 let k = self.0.read::<Vec<u8>, _>(0).unwrap();
381 let v = self.0.read::<Vec<u8>, _>(1).unwrap();
382 Some(Ok((k, v)))
383 }
384 Err(err) => Some(Err(miette!(err))),
385 }
386 }
387}
388
389struct SkipIter<'l> {
390 stmt: Statement<'l>,
391 valid_at: ValidityTs,
392 next_bound: Vec<u8>,
393 upper_bound: Vec<u8>,
394}
395
396impl<'l> SkipIter<'l> {
397 fn next_inner(&mut self) -> Result<Option<Tuple>> {
398 loop {
399 self.stmt.reset().into_diagnostic()?;
400 self.stmt.bind((1, &self.next_bound as &[u8])).unwrap();
401 self.stmt.bind((2, &self.upper_bound as &[u8])).unwrap();
402
403 match self.stmt.next().into_diagnostic()? {
404 State::Done => return Ok(None),
405 State::Row => {
406 let k = self.stmt.read::<Vec<u8>, _>(0).unwrap();
407 let (ret, nxt_bound) = check_key_for_validity(&k, self.valid_at, None);
408 self.next_bound = nxt_bound;
409 if let Some(mut tup) = ret {
410 let v = self.stmt.read::<Vec<u8>, _>(1).unwrap();
411 extend_tuple_from_v(&mut tup, &v);
412 return Ok(Some(tup));
413 }
414 }
415 }
416 }
417 }
418}
419
420impl<'l> Iterator for SkipIter<'l> {
421 type Item = Result<Tuple>;
422
423 fn next(&mut self) -> Option<Self::Item> {
424 swap_option_result(self.next_inner())
425 }
426}