cozo_ce/storage/
sqlite.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8
9use std::path::{Path, PathBuf};
10use std::sync::{Arc, Mutex};
11
12use ::sqlite::Connection;
13use crossbeam::sync::{ShardedLock, ShardedLockReadGuard, ShardedLockWriteGuard};
14use either::{Either, Left, Right};
15use miette::{bail, miette, IntoDiagnostic, Result};
16use sqlite::{ConnectionThreadSafe, State, Statement};
17
18use crate::data::tuple::{check_key_for_validity, Tuple};
19use crate::data::value::ValidityTs;
20use crate::runtime::relation::{decode_tuple_from_kv, extend_tuple_from_v};
21use crate::storage::{Storage, StoreTx};
22use crate::utils::swap_option_result;
23
24/// The Sqlite storage engine
25#[derive(Clone)]
26pub struct SqliteStorage {
27    lock: Arc<ShardedLock<()>>,
28    name: PathBuf,
29    pool: Arc<Mutex<Vec<ConnectionThreadSafe>>>,
30}
31
32/// Create a sqlite backed database.
33/// Supports concurrent readers but only a single writer.
34///
35/// You must provide a disk-based path: `:memory:` is not OK.
36/// If you want a pure memory storage, use [`new_cozo_mem`](crate::new_cozo_mem).
37pub fn new_cozo_sqlite(path: impl AsRef<Path>) -> Result<crate::Db<SqliteStorage>> {
38    if path.as_ref().to_str() == Some("") {
39        bail!("empty path for sqlite storage")
40    }
41    let conn = Connection::open_thread_safe(&path).into_diagnostic()?;
42    let query = r#"
43        create table if not exists cozo
44        (
45            k BLOB primary key,
46            v BLOB
47        );
48    "#;
49    let mut statement = conn.prepare(query).unwrap();
50    while statement.next().into_diagnostic()? != State::Done {}
51
52    let ret = crate::Db::new(SqliteStorage {
53        lock: Default::default(),
54        name: PathBuf::from(path.as_ref()),
55        pool: Default::default(),
56    })?;
57
58    ret.initialize()?;
59    Ok(ret)
60}
61
62impl<'s> Storage<'s> for SqliteStorage {
63    type Tx = SqliteTx<'s>;
64
65    fn transact(&'s self, write: bool) -> Result<Self::Tx> {
66        let conn = {
67            match self.pool.lock().unwrap().pop() {
68                None => Connection::open_thread_safe(&self.name).into_diagnostic()?,
69                Some(conn) => conn,
70            }
71        };
72        let lock = if write {
73            Right(self.lock.write().unwrap())
74        } else {
75            Left(self.lock.read().unwrap())
76        };
77        if write {
78            let mut stmt = conn.prepare("begin;").into_diagnostic()?;
79            while stmt.next().into_diagnostic()? != State::Done {}
80        }
81        Ok(SqliteTx {
82            lock,
83            storage: self,
84            conn: Some(conn),
85            stmts: [
86                Mutex::new(None),
87                Mutex::new(None),
88                Mutex::new(None),
89                Mutex::new(None),
90            ],
91            committed: false,
92        })
93    }
94
95    fn batch_put<'a>(
96        &'a self,
97        data: Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>,
98    ) -> Result<()> {
99        let mut tx = self.transact(true)?;
100        for result in data {
101            let (key, val) = result?;
102            tx.put(&key, &val)?;
103        }
104        tx.commit()?;
105        Ok(())
106    }
107
108    fn range_compact(&'_ self, _lower: &[u8], _upper: &[u8]) -> Result<()> {
109        let mut pool = self.pool.lock().unwrap();
110        while pool.pop().is_some() {}
111        Ok(())
112    }
113
114    fn storage_kind(&self) -> &'static str {
115        "sqlite"
116    }
117}
118
119pub struct SqliteTx<'a> {
120    lock: Either<ShardedLockReadGuard<'a, ()>, ShardedLockWriteGuard<'a, ()>>,
121    storage: &'a SqliteStorage,
122    conn: Option<ConnectionThreadSafe>,
123    stmts: [Mutex<Option<Statement<'a>>>; N_CACHED_QUERIES],
124    committed: bool,
125}
126
127unsafe impl Sync for SqliteTx<'_> {}
128
129const N_QUERIES: usize = 7;
130const N_CACHED_QUERIES: usize = 4;
131const QUERIES: [&str; N_QUERIES] = [
132    "select v from cozo where k = ?;",
133    "insert into cozo(k, v) values (?, ?) on conflict(k) do update set v=excluded.v;",
134    "delete from cozo where k = ?;",
135    "select 1 from cozo where k = ?;",
136    "select k, v from cozo where k >= ? and k < ? order by k;",
137    "select k, v from cozo where k >= ? and k < ? order by k limit 1;",
138    "select count(*) from cozo where k >= ? and k < ?;",
139];
140
141const GET_QUERY: usize = 0;
142const PUT_QUERY: usize = 1;
143const DEL_QUERY: usize = 2;
144const EXISTS_QUERY: usize = 3;
145const RANGE_QUERY: usize = 4;
146const SKIP_RANGE_QUERY: usize = 5;
147const COUNT_RANGE_QUERY: usize = 6;
148
149impl Drop for SqliteTx<'_> {
150    fn drop(&mut self) {
151        if let Right(ShardedLockWriteGuard { .. }) = self.lock {
152            if !self.committed {
153                let query = r#"rollback;"#;
154                let _ = self.conn.as_ref().unwrap().execute(query);
155            }
156        }
157        let mut pool = self.storage.pool.lock().unwrap();
158        let conn = self.conn.take().unwrap();
159        pool.push(conn)
160    }
161}
162
163impl<'s> SqliteTx<'s> {
164    fn ensure_stmt(&self, idx: usize) {
165        let mut stmt = self.stmts[idx].lock().unwrap();
166        if stmt.is_none() {
167            let query = QUERIES[idx];
168            let prepared = self.conn.as_ref().unwrap().prepare(query).unwrap();
169
170            // Casting away the lifetime!
171            // This is OK because we are abiding by the contract of the underlying C pointer,
172            // as required by Sqlite's implementation
173            let prepared = unsafe { std::mem::transmute(prepared) };
174
175            *stmt = Some(prepared)
176        }
177    }
178}
179
180impl<'s> StoreTx<'s> for SqliteTx<'s> {
181    fn get(&self, key: &[u8], _for_update: bool) -> Result<Option<Vec<u8>>> {
182        self.ensure_stmt(GET_QUERY);
183        let mut statement = self.stmts[GET_QUERY].lock().unwrap();
184        let statement = statement.as_mut().unwrap();
185        statement.reset().unwrap();
186
187        statement.bind((1, key)).unwrap();
188        Ok(match statement.next().into_diagnostic()? {
189            State::Row => {
190                let res = statement.read::<Vec<u8>, _>(0).into_diagnostic()?;
191                Some(res)
192            }
193            State::Done => None,
194        })
195    }
196
197    fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
198        self.par_put(key, val)
199    }
200
201    fn supports_par_put(&self) -> bool {
202        true
203    }
204
205    fn par_put(&self, key: &[u8], val: &[u8]) -> Result<()> {
206        self.ensure_stmt(PUT_QUERY);
207        let mut statement = self.stmts[PUT_QUERY].lock().unwrap();
208        let statement = statement.as_mut().unwrap();
209        statement.reset().unwrap();
210
211        statement.bind((1, key)).unwrap();
212        statement.bind((2, val)).unwrap();
213        while statement.next().into_diagnostic()? != State::Done {}
214        Ok(())
215    }
216
217    fn del(&mut self, key: &[u8]) -> Result<()> {
218        self.par_del(key)
219    }
220
221    fn par_del(&self, key: &[u8]) -> Result<()> {
222        self.ensure_stmt(DEL_QUERY);
223        let mut statement = self.stmts[DEL_QUERY].lock().unwrap();
224        let statement = statement.as_mut().unwrap();
225        statement.reset().unwrap();
226
227        statement.bind((1, key)).unwrap();
228        while statement.next().into_diagnostic()? != State::Done {}
229
230        Ok(())
231    }
232
233    fn del_range_from_persisted(&mut self, lower: &[u8], upper: &[u8]) -> Result<()> {
234        let query = r#"
235                delete from cozo where k >= ? and k < ?;
236            "#;
237        let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
238
239        statement.bind((1, lower)).unwrap();
240        statement.bind((2, upper)).unwrap();
241        while statement.next().unwrap() != State::Done {}
242        Ok(())
243    }
244
245    fn exists(&self, key: &[u8], _for_update: bool) -> Result<bool> {
246        self.ensure_stmt(EXISTS_QUERY);
247        let mut statement = self.stmts[EXISTS_QUERY].lock().unwrap();
248        let statement = statement.as_mut().unwrap();
249        statement.reset().unwrap();
250
251        statement.bind((1, key)).unwrap();
252        Ok(match statement.next().into_diagnostic()? {
253            State::Row => true,
254            State::Done => false,
255        })
256    }
257
258    fn commit(&mut self) -> Result<()> {
259        if let Right(ShardedLockWriteGuard { .. }) = self.lock {
260            if !self.committed {
261                let query = r#"commit;"#;
262                let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
263                while statement.next().into_diagnostic()? != State::Done {}
264                self.committed = true;
265            } else {
266                bail!("multiple commits")
267            }
268        }
269        Ok(())
270    }
271
272    fn range_scan_tuple<'a>(
273        &'a self,
274        lower: &[u8],
275        upper: &[u8],
276    ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a>
277    where
278        's: 'a,
279    {
280        // Range scans cannot use cached prepared statements, as several of them
281        // can be used at the same time.
282        let query = QUERIES[RANGE_QUERY];
283        let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
284        statement.bind((1, lower)).unwrap();
285        statement.bind((2, upper)).unwrap();
286        Box::new(TupleIter(statement))
287    }
288
289    fn range_skip_scan_tuple<'a>(
290        &'a self,
291        lower: &[u8],
292        upper: &[u8],
293        valid_at: ValidityTs,
294    ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a> {
295        let query = QUERIES[SKIP_RANGE_QUERY];
296        let statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
297        Box::new(SkipIter {
298            stmt: statement,
299            valid_at,
300            next_bound: lower.to_vec(),
301            upper_bound: upper.to_vec(),
302        })
303    }
304
305    fn range_scan<'a>(
306        &'a self,
307        lower: &[u8],
308        upper: &[u8],
309    ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
310    where
311        's: 'a,
312    {
313        let query = QUERIES[RANGE_QUERY];
314        let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
315        statement.bind((1, lower)).unwrap();
316        statement.bind((2, upper)).unwrap();
317        Box::new(RawIter(statement))
318    }
319
320    fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
321    where
322        's: 'a,
323    {
324        let query = QUERIES[COUNT_RANGE_QUERY];
325        let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
326        statement.bind((1, lower)).unwrap();
327        statement.bind((2, upper)).unwrap();
328        match statement.next() {
329            Ok(State::Done) => bail!("range count query returned no rows"),
330            Ok(State::Row) => {
331                let k = statement.read::<i64, _>(0).unwrap();
332                Ok(k as usize)
333            }
334            Err(err) => bail!(err),
335        }
336    }
337
338    fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
339    where
340        's: 'a,
341    {
342        let statement = self
343            .conn
344            .as_ref()
345            .unwrap()
346            .prepare("select k, v from cozo order by k;")
347            .unwrap();
348        Box::new(RawIter(statement))
349    }
350}
351
352struct TupleIter<'l>(Statement<'l>);
353
354impl<'l> Iterator for TupleIter<'l> {
355    type Item = Result<Tuple>;
356
357    fn next(&mut self) -> Option<Self::Item> {
358        match self.0.next() {
359            Ok(State::Done) => None,
360            Ok(State::Row) => {
361                let k = self.0.read::<Vec<u8>, _>(0).unwrap();
362                let v = self.0.read::<Vec<u8>, _>(1).unwrap();
363                let tuple = decode_tuple_from_kv(&k, &v, None);
364                Some(Ok(tuple))
365            }
366            Err(err) => Some(Err(miette!(err))),
367        }
368    }
369}
370
371struct RawIter<'l>(Statement<'l>);
372
373impl<'l> Iterator for RawIter<'l> {
374    type Item = Result<(Vec<u8>, Vec<u8>)>;
375
376    fn next(&mut self) -> Option<Self::Item> {
377        match self.0.next() {
378            Ok(State::Done) => None,
379            Ok(State::Row) => {
380                let k = self.0.read::<Vec<u8>, _>(0).unwrap();
381                let v = self.0.read::<Vec<u8>, _>(1).unwrap();
382                Some(Ok((k, v)))
383            }
384            Err(err) => Some(Err(miette!(err))),
385        }
386    }
387}
388
389struct SkipIter<'l> {
390    stmt: Statement<'l>,
391    valid_at: ValidityTs,
392    next_bound: Vec<u8>,
393    upper_bound: Vec<u8>,
394}
395
396impl<'l> SkipIter<'l> {
397    fn next_inner(&mut self) -> Result<Option<Tuple>> {
398        loop {
399            self.stmt.reset().into_diagnostic()?;
400            self.stmt.bind((1, &self.next_bound as &[u8])).unwrap();
401            self.stmt.bind((2, &self.upper_bound as &[u8])).unwrap();
402
403            match self.stmt.next().into_diagnostic()? {
404                State::Done => return Ok(None),
405                State::Row => {
406                    let k = self.stmt.read::<Vec<u8>, _>(0).unwrap();
407                    let (ret, nxt_bound) = check_key_for_validity(&k, self.valid_at, None);
408                    self.next_bound = nxt_bound;
409                    if let Some(mut tup) = ret {
410                        let v = self.stmt.read::<Vec<u8>, _>(1).unwrap();
411                        extend_tuple_from_v(&mut tup, &v);
412                        return Ok(Some(tup));
413                    }
414                }
415            }
416        }
417    }
418}
419
420impl<'l> Iterator for SkipIter<'l> {
421    type Item = Result<Tuple>;
422
423    fn next(&mut self) -> Option<Self::Item> {
424        swap_option_result(self.next_inner())
425    }
426}