keyvaluedb_sqlite/
lib.rs

1#![deny(clippy::all)]
2
3mod tools;
4
5pub use async_sqlite::rusqlite::OpenFlags;
6use async_sqlite::rusqlite::{params, OptionalExtension as _};
7use async_sqlite::*;
8use keyvaluedb::{
9    DBKeyRef, DBKeyValue, DBKeyValueRef, DBOp, DBTransaction, DBTransactionError, DBValue, IoStats,
10    IoStatsKind, KeyValueDB,
11};
12use parking_lot::Mutex;
13use std::sync::Arc;
14use std::{
15    future::Future,
16    io,
17    path::{Path, PathBuf},
18    pin::Pin,
19    str::FromStr,
20};
21use tools::*;
22
23///////////////////////////////////////////////////////////////////////////////
24
25#[derive(Copy, Clone, Debug, Eq, PartialEq)]
26pub enum VacuumMode {
27    None,
28    Incremental,
29    Full,
30}
31
32/// Database configuration
33#[derive(Clone)]
34pub struct DatabaseConfig {
35    /// Set number of columns.
36    /// The number of columns must not be zero.
37    pub columns: u32,
38    /// Set flags used to open the database
39    pub flags: OpenFlags,
40    /// Number of connections to open
41    pub num_conns: usize,
42    /// Vacuum mode
43    pub vacuum_mode: VacuumMode,
44}
45
46impl DatabaseConfig {
47    /// Create new `DatabaseConfig` with default parameters
48    pub fn new() -> Self {
49        Default::default()
50    }
51
52    /// Set the number of columns. `columns` must not be zero.
53    pub fn with_columns(self, columns: u32) -> Self {
54        assert!(columns > 0, "the number of columns must not be zero");
55        Self { columns, ..self }
56    }
57
58    /// Sets the flags to 'in-memory database'
59    pub fn with_in_memory(self) -> Self {
60        Self {
61            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
62                | OpenFlags::SQLITE_OPEN_CREATE
63                | OpenFlags::SQLITE_OPEN_NO_MUTEX
64                | OpenFlags::SQLITE_OPEN_MEMORY,
65            ..self
66        }
67    }
68
69    /// Replaces all the flags
70    pub fn with_flags(self, flags: OpenFlags) -> Self {
71        Self { flags, ..self }
72    }
73
74    /// Sets the number of connections for this database
75    pub fn with_num_conns(self, num_conns: usize) -> Self {
76        Self { num_conns, ..self }
77    }
78
79    /// Set the vacuum mode used by 'cleanup'
80    pub fn with_vacuum_mode(self, vacuum_mode: VacuumMode) -> Self {
81        Self {
82            vacuum_mode,
83            ..self
84        }
85    }
86}
87
88impl Default for DatabaseConfig {
89    fn default() -> DatabaseConfig {
90        DatabaseConfig {
91            columns: 1,
92            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
93                | OpenFlags::SQLITE_OPEN_CREATE
94                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
95            num_conns: 1,
96            vacuum_mode: VacuumMode::None,
97        }
98    }
99}
100
101///////////////////////////////////////////////////////////////////////////////
102
103/// An sqlite table with its statement strings
104pub struct DatabaseTable {
105    _table: String,
106    str_has_value: String,
107    str_has_value_like: String,
108    str_get_unique_value: String,
109    str_get_first_value_like: String,
110    str_set_unique_value: String,
111    str_remove_unique_value: String,
112    str_remove_and_return_unique_value: String,
113    str_remove_unique_value_like: String,
114    str_iter_with_prefix: String,
115    str_iter_no_prefix: String,
116    str_iter_keys_with_prefix: String,
117    str_iter_keys_no_prefix: String,
118}
119
120impl DatabaseTable {
121    pub fn new(table: String) -> Self {
122        let str_has_value = format!("SELECT 1 FROM {} WHERE [key] = ? LIMIT 1", table);
123        let str_has_value_like = format!(
124            "SELECT 1 FROM {} WHERE [key] LIKE ? ESCAPE '\\' LIMIT 1",
125            table
126        );
127        let str_get_unique_value = format!("SELECT value FROM {} WHERE [key] = ? LIMIT 1", table);
128        let str_get_first_value_like = format!(
129            "SELECT key, value FROM {} WHERE [key] LIKE ? ESCAPE '\\' LIMIT 1",
130            table
131        );
132        let str_set_unique_value = format!(
133            "INSERT OR REPLACE INTO {} ([key], value) VALUES(?, ?)",
134            table
135        );
136        let str_remove_unique_value = format!("DELETE FROM {} WHERE [key] = ?", table);
137        let str_remove_and_return_unique_value =
138            format!("DELETE FROM {} WHERE [key] = ? RETURNING value", table);
139        let str_remove_unique_value_like =
140            format!("DELETE FROM {} WHERE [key] LIKE ? ESCAPE '\\'", table);
141        let str_iter_with_prefix = format!(
142            "SELECT key, value FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
143            table
144        );
145        let str_iter_no_prefix = format!("SELECT key, value FROM {}", table);
146        let str_iter_keys_with_prefix =
147            format!("SELECT key FROM {} WHERE [key] LIKE ? ESCAPE '\\'", table);
148        let str_iter_keys_no_prefix = format!("SELECT key FROM {}", table);
149
150        Self {
151            _table: table,
152            str_has_value,
153            str_has_value_like,
154            str_get_unique_value,
155            str_get_first_value_like,
156            str_set_unique_value,
157            str_remove_unique_value,
158            str_remove_and_return_unique_value,
159            str_remove_unique_value_like,
160            str_iter_with_prefix,
161            str_iter_no_prefix,
162            str_iter_keys_with_prefix,
163            str_iter_keys_no_prefix,
164        }
165    }
166}
167
168///////////////////////////////////////////////////////////////////////////////
169
170/// An sqlite key-value database fulfilling the `KeyValueDB` trait
171pub struct DatabaseUnlockedInner {
172    path: PathBuf,
173    config: DatabaseConfig,
174    pool: Pool,
175    control_table: Arc<DatabaseTable>,
176    column_tables: Vec<Arc<DatabaseTable>>,
177}
178
179impl Drop for DatabaseUnlockedInner {
180    fn drop(&mut self) {
181        let _ = self.pool.close_blocking();
182    }
183}
184
185pub struct DatabaseInner {
186    overall_stats: IoStats,
187    current_stats: IoStats,
188}
189
190#[derive(Clone)]
191pub struct Database {
192    unlocked_inner: Arc<DatabaseUnlockedInner>,
193    inner: Arc<Mutex<DatabaseInner>>,
194}
195
196impl Database {
197    ////////////////////////////////////////////////////////////////
198    // Initialization
199
200    pub fn open<P: AsRef<Path>>(path: P, config: DatabaseConfig) -> io::Result<Self> {
201        assert_ne!(config.columns, 0, "number of columns must be >= 1");
202
203        let path = PathBuf::from(path.as_ref());
204        let flags = config.flags;
205
206        let mut column_tables = vec![];
207        for n in 0..config.columns {
208            column_tables.push(Arc::new(DatabaseTable::new(get_column_table_name(n))))
209        }
210        let control_table = Arc::new(DatabaseTable::new("control".to_string()));
211
212        let pool_builder = PoolBuilder::new()
213            .path(&path)
214            .flags(flags)
215            .num_conns(config.num_conns);
216
217        let pool = pool_builder.open_blocking().map_err(io::Error::other)?;
218
219        let out = Self {
220            unlocked_inner: Arc::new(DatabaseUnlockedInner {
221                path,
222                config,
223                pool,
224                control_table,
225                column_tables,
226            }),
227            inner: Arc::new(Mutex::new(DatabaseInner {
228                overall_stats: IoStats::empty(),
229                current_stats: IoStats::empty(),
230            })),
231        };
232
233        let vacuum_mode = out.config().vacuum_mode;
234
235        out.conn_blocking(move |conn| {
236            // Don't rely on STATEMENT_CACHE_DEFAULT_CAPACITY in rusqlite, set it explicitly
237            conn.set_prepared_statement_cache_capacity(256);
238
239            conn.pragma_update(None, "case_sensitive_like", "ON")?;
240            conn.pragma_update(None, "journal_mode", "WAL")?;
241            conn.pragma_update(None, "synchronous", "normal")?;
242            conn.pragma_update(None, "journal_size_limit", 6144000)?;
243            conn.pragma_update(None, "wal_checkpoint", "TRUNCATE")?;
244
245            match vacuum_mode {
246                VacuumMode::None | VacuumMode::Full => {
247                    let current: u32 =
248                        conn.pragma_query_value(None, "auto_vacuum", |x| x.get(0))?;
249                    if current != 0 {
250                        conn.execute("VACUUM", [])?;
251                        conn.pragma_update(None, "auto_vacuum", 0)?;
252                    }
253                }
254                VacuumMode::Incremental => {
255                    let current: u32 =
256                        conn.pragma_query_value(None, "auto_vacuum", |x| x.get(0))?;
257                    if current != 2 {
258                        conn.execute("VACUUM", [])?;
259                        conn.pragma_update(None, "auto_vacuum", "2")?;
260                    }
261                }
262            }
263
264            Ok(())
265        })
266        .map_err(io::Error::other)?;
267
268        out.open_resize_columns()?;
269
270        Ok(out)
271    }
272
273    pub fn path(&self) -> PathBuf {
274        self.unlocked_inner.path.clone()
275    }
276
277    pub fn config(&self) -> DatabaseConfig {
278        self.unlocked_inner.config.clone()
279    }
280
281    pub fn columns(&self) -> u32 {
282        self.unlocked_inner.config.columns
283    }
284
285    pub fn control_table(&self) -> Arc<DatabaseTable> {
286        self.unlocked_inner.control_table.clone()
287    }
288
289    pub fn column_table(&self, col: u32) -> Arc<DatabaseTable> {
290        self.unlocked_inner.column_tables[col as usize].clone()
291    }
292
293    pub fn conn_blocking<T, F>(&self, func: F) -> Result<T, Error>
294    where
295        F: FnOnce(&rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + 'static,
296        T: Send + 'static,
297    {
298        self.unlocked_inner.pool.conn_blocking(func)
299    }
300
301    pub async fn conn<T, F>(&self, func: F) -> Result<T, Error>
302    where
303        F: FnOnce(&rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + 'static,
304        T: Send + 'static,
305    {
306        self.unlocked_inner.pool.conn(func).await
307    }
308
309    pub async fn conn_mut<T, F>(&self, func: F) -> Result<T, Error>
310    where
311        F: FnOnce(&mut rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + 'static,
312        T: Send + 'static,
313    {
314        self.unlocked_inner.pool.conn_mut(func).await
315    }
316
317    ////////////////////////////////////////////////////////////////
318    // Low level operations
319
320    /// Remove the last column family in the database. The deletion is definitive.
321    pub fn remove_last_column(&self) -> Result<(), Error> {
322        let this = self.clone();
323        self.conn_blocking(move |conn| {
324            let columns = Self::get_unique_value(conn, this.control_table(), "columns", 0u32)?;
325            if columns == 0 {
326                return Err(rusqlite::Error::QueryReturnedNoRows);
327            }
328            Self::set_unique_value(conn, this.control_table(), "columns", columns - 1)?;
329
330            conn.execute(
331                &format!("DROP TABLE {}", get_column_table_name(columns - 1)),
332                [],
333            )?;
334            Ok(())
335        })
336    }
337
338    /// Add a new column family to the DB.
339    pub fn add_column(&self) -> Result<(), Error> {
340        let this = self.clone();
341
342        self.conn_blocking(move |conn| {
343            let columns = Self::get_unique_value(conn, this.control_table(), "columns", 0u32)?;
344            Self::set_unique_value(conn, this.control_table(), "columns", columns + 1)?;
345            Self::create_column_table(conn, columns)
346        })
347    }
348    /// Helper to create new transaction for this database.
349    pub fn transaction(&self) -> DBTransaction {
350        DBTransaction::new()
351    }
352
353    /// Vacuum database
354    pub async fn vacuum(&self) -> Result<(), Error> {
355        match self.config().vacuum_mode {
356            VacuumMode::None => {
357                self.conn(move |conn| {
358                    conn.pragma_update(None, "wal_checkpoint", "TRUNCATE")?;
359                    Ok(())
360                })
361                .await
362            }
363            VacuumMode::Incremental => {
364                self.conn(move |conn| {
365                    conn.execute("PRAGMA incremental_vacuum", [])?;
366                    conn.pragma_update(None, "wal_checkpoint", "TRUNCATE")?;
367                    Ok(())
368                })
369                .await
370            }
371            VacuumMode::Full => {
372                self.conn(move |conn| {
373                    conn.execute("VACUUM", [])?;
374                    conn.pragma_update(None, "wal_checkpoint", "TRUNCATE")?;
375                    Ok(())
376                })
377                .await
378            }
379        }
380    }
381
382    ////////////////////////////////////////////////////////////////
383    // Internal helpers
384
385    fn validate_column(&self, col: u32) -> rusqlite::Result<()> {
386        if col >= self.columns() {
387            return Err(rusqlite::Error::InvalidColumnIndex(col as usize));
388        }
389        Ok(())
390    }
391
392    fn create_column_table(conn: &rusqlite::Connection, column: u32) -> rusqlite::Result<()> {
393        conn.execute(&format!("CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value BLOB)", get_column_table_name(column)), []).map(drop)
394    }
395
396    fn get_unique_value<V>(
397        conn: &rusqlite::Connection,
398        table: Arc<DatabaseTable>,
399        key: &str,
400        default: V,
401    ) -> rusqlite::Result<V>
402    where
403        V: FromStr,
404    {
405        let mut stmt = conn.prepare_cached(&table.str_get_unique_value)?;
406
407        if let Ok(found) = stmt.query_row([key], |row| -> rusqlite::Result<String> { row.get(0) }) {
408            if let Ok(v) = V::from_str(&found) {
409                return Ok(v);
410            }
411        }
412        Ok(default)
413    }
414
415    fn set_unique_value<V>(
416        conn: &rusqlite::Connection,
417        table: Arc<DatabaseTable>,
418        key: &str,
419        value: V,
420    ) -> rusqlite::Result<()>
421    where
422        V: ToString,
423    {
424        let mut stmt = conn.prepare_cached(&table.str_set_unique_value)?;
425
426        let changed = stmt.execute([key, value.to_string().as_str()])?;
427
428        assert!(
429            changed <= 1,
430            "multiple changes to unique key should not occur"
431        );
432        if changed == 0 {
433            return Err(rusqlite::Error::QueryReturnedNoRows);
434        }
435
436        Ok(())
437    }
438
439    fn has_value(
440        conn: &rusqlite::Connection,
441        table: Arc<DatabaseTable>,
442        key: &str,
443    ) -> rusqlite::Result<bool> {
444        let mut stmt = conn.prepare_cached(&table.str_has_value)?;
445        stmt.exists([key])
446    }
447
448    fn has_value_like(
449        conn: &rusqlite::Connection,
450        table: Arc<DatabaseTable>,
451        key: &str,
452    ) -> rusqlite::Result<bool> {
453        let mut stmt = conn.prepare_cached(&table.str_has_value_like)?;
454        stmt.exists([key])
455    }
456
457    fn load_unique_value_blob(
458        conn: &rusqlite::Connection,
459        table: Arc<DatabaseTable>,
460        key: &str,
461    ) -> rusqlite::Result<Option<Vec<u8>>> {
462        let mut stmt = conn.prepare_cached(&table.str_get_unique_value)?;
463
464        stmt.query_row([key], |row| -> rusqlite::Result<Vec<u8>> { row.get(0) })
465            .optional()
466    }
467
468    fn load_first_value_blob_like(
469        conn: &rusqlite::Connection,
470        table: Arc<DatabaseTable>,
471        like: &str,
472    ) -> rusqlite::Result<Option<(String, Vec<u8>)>> {
473        let mut stmt = conn.prepare_cached(&table.str_get_first_value_like)?;
474
475        stmt.query_row([like], |row| -> rusqlite::Result<(String, Vec<u8>)> {
476            Ok((row.get(0)?, row.get(1)?))
477        })
478        .optional()
479    }
480
481    fn store_unique_value_blob(
482        conn: &rusqlite::Connection,
483        table: Arc<DatabaseTable>,
484        key: &str,
485        value: &[u8],
486    ) -> rusqlite::Result<()> {
487        let mut stmt = conn.prepare_cached(&table.str_set_unique_value)?;
488
489        let changed = stmt.execute(params![key, value])?;
490        assert!(
491            changed <= 1,
492            "multiple changes to unique key should not occur"
493        );
494        if changed == 0 {
495            return Err(rusqlite::Error::QueryReturnedNoRows);
496        }
497        Ok(())
498    }
499
500    fn remove_unique_value_blob(
501        conn: &rusqlite::Connection,
502        table: Arc<DatabaseTable>,
503        key: &str,
504    ) -> rusqlite::Result<()> {
505        let mut stmt = conn.prepare_cached(&table.str_remove_unique_value)?;
506
507        let changed = stmt.execute([key])?;
508        assert!(
509            changed <= 1,
510            "multiple deletions of unique key should not occur"
511        );
512        if changed == 0 {
513            return Err(rusqlite::Error::QueryReturnedNoRows);
514        }
515        Ok(())
516    }
517
518    fn remove_and_return_unique_value_blob(
519        conn: &rusqlite::Connection,
520        table: Arc<DatabaseTable>,
521        key: &str,
522    ) -> rusqlite::Result<Option<Vec<u8>>> {
523        let mut stmt = conn.prepare_cached(&table.str_remove_and_return_unique_value)?;
524
525        stmt.query_row([key], |row| -> rusqlite::Result<Vec<u8>> { row.get(0) })
526            .optional()
527    }
528
529    fn remove_unique_value_blob_like(
530        conn: &rusqlite::Connection,
531        table: Arc<DatabaseTable>,
532        like: &str,
533    ) -> rusqlite::Result<usize> {
534        let mut stmt = conn.prepare_cached(&table.str_remove_unique_value_like)?;
535
536        let changed = stmt.execute([like])?;
537        Ok(changed)
538    }
539
540    fn open_resize_columns(&self) -> io::Result<()> {
541        let columns = self.columns();
542        let this = self.clone();
543        self.conn_blocking(move |conn| {
544			// First see if we have a control table with the number of columns
545			conn.execute("CREATE TABLE IF NOT EXISTS control (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value TEXT)", [])?;
546
547            // Get column count
548            let on_disk_columns =
549                Self::get_unique_value(conn, this.control_table(), "columns", 0u32)?;
550
551            // If desired column count is less than or equal to current column count, then allow it, but restrict access to columns
552            if columns <= on_disk_columns {
553                return Ok(());
554            }
555
556            // Otherwise resize and add other columns
557            for cn in on_disk_columns..columns {
558                // Create the column table if we don't have it
559                Self::create_column_table(conn, cn)?;
560            }
561            Self::set_unique_value(
562                conn,
563                this.control_table(),
564                "columns",
565                columns,
566            )?;
567            Ok(())
568        }).map_err(io::Error::other)
569    }
570
571    fn stats_read(&self, count: usize, bytes: usize) {
572        let mut inner = self.inner.lock();
573        inner.current_stats.reads += count as u64;
574        inner.overall_stats.reads += count as u64;
575        inner.current_stats.bytes_read += bytes as u64;
576        inner.overall_stats.bytes_read += bytes as u64;
577    }
578    fn stats_write(&self, count: usize, bytes: usize) {
579        let mut inner = self.inner.lock();
580        inner.current_stats.writes += count as u64;
581        inner.overall_stats.writes += count as u64;
582        inner.current_stats.bytes_written += bytes as u64;
583        inner.overall_stats.bytes_written += bytes as u64;
584    }
585    fn stats_transaction(&self, count: usize) {
586        let mut inner = self.inner.lock();
587        inner.current_stats.transactions += count as u64;
588        inner.overall_stats.transactions += count as u64;
589    }
590}
591
592impl KeyValueDB for Database {
593    fn get(
594        &self,
595        col: u32,
596        key: &[u8],
597    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + '_>> {
598        let key_text = key_to_text(key);
599        let key_len = key.len();
600
601        Box::pin(async move {
602            let that = self.clone();
603            that.validate_column(col).map_err(io::Error::other)?;
604            let someval = self
605                .conn_blocking(move |conn| {
606                    Self::load_unique_value_blob(conn, that.column_table(col), &key_text)
607                })
608                .map_err(io::Error::other)?;
609            {
610                match &someval {
611                    Some(val) => self.stats_read(1, key_len + val.len()),
612                    None => self.stats_read(1, key_len),
613                }
614            }
615
616            Ok(someval)
617        })
618    }
619
620    /// Remove a value by key, returning the old value
621    fn delete(
622        &self,
623        col: u32,
624        key: &[u8],
625    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + '_>> {
626        let key_text = key_to_text(key);
627        let key_len = key.len();
628
629        Box::pin(async move {
630            let that = self.clone();
631            that.validate_column(col).map_err(io::Error::other)?;
632            self.conn_blocking(move |conn| {
633                let someval = Self::remove_and_return_unique_value_blob(
634                    conn,
635                    that.column_table(col),
636                    &key_text,
637                )?;
638
639                match &someval {
640                    Some(val) => {
641                        that.stats_read(1, key_len + val.len());
642                    }
643                    None => that.stats_read(1, key_len),
644                }
645
646                Ok(someval)
647            })
648            .map_err(io::Error::other)
649        })
650    }
651
652    fn write(
653        &self,
654        transaction: DBTransaction,
655    ) -> Pin<Box<dyn Future<Output = Result<(), DBTransactionError>> + Send + '_>> {
656        let transaction = Arc::new(transaction);
657        Box::pin(async move {
658            self.stats_transaction(1);
659
660            let that = self.clone();
661            let transaction_clone = transaction.clone();
662            self.conn_mut(move |conn| {
663                let mut sw = 0usize;
664                let mut sbw = 0usize;
665
666                let tx = conn.transaction()?;
667
668                for op in &transaction_clone.ops {
669                    match op {
670                        DBOp::Insert { col, key, value } => {
671                            that.validate_column(*col)?;
672                            Self::store_unique_value_blob(
673                                &tx,
674                                that.column_table(*col),
675                                &key_to_text(key),
676                                value,
677                            )?;
678                            sw += 1;
679                            sbw += key.len() + value.len();
680                        }
681                        DBOp::Delete { col, key } => {
682                            that.validate_column(*col)?;
683                            Self::remove_unique_value_blob(
684                                &tx,
685                                that.column_table(*col),
686                                &key_to_text(key),
687                            )?;
688                            sw += 1;
689                        }
690                        DBOp::DeletePrefix { col, prefix } => {
691                            that.validate_column(*col)?;
692                            Self::remove_unique_value_blob_like(
693                                &tx,
694                                that.column_table(*col),
695                                &(like_key_to_text(prefix) + "%"),
696                            )?;
697                            sw += 1;
698                        }
699                    }
700                }
701                tx.commit()?;
702
703                that.stats_write(sw, sbw);
704
705                Ok(())
706            })
707            .await
708            .map_err(io::Error::other)
709            .map_err(|error| {
710                let transaction = transaction.as_ref().clone();
711                DBTransactionError { error, transaction }
712            })
713        })
714    }
715
716    fn iter<
717        'a,
718        T: Send + 'static,
719        C: Send + 'static,
720        F: FnMut(&mut C, DBKeyValueRef) -> io::Result<Option<T>> + Send + Sync + 'static,
721    >(
722        &'a self,
723        col: u32,
724        prefix: Option<&'a [u8]>,
725        context: C,
726        mut f: F,
727    ) -> Pin<Box<dyn Future<Output = io::Result<(C, Option<T>)>> + Send + 'a>> {
728        let opt_prefix_query = prefix.map(|p| like_key_to_text(p) + "%");
729        Box::pin(async move {
730            if col >= self.columns() {
731                return Err(io::Error::from(io::ErrorKind::NotFound));
732            }
733
734            let that = self.clone();
735            let context = Arc::new(Mutex::new(Some(context)));
736            let context_ref = context.clone();
737
738            let res = self
739                .conn(move |conn| {
740                    let mut context = context_ref.lock();
741                    let context = context.as_mut().unwrap();
742
743                    let mut stmt;
744                    let mut rows;
745                    if let Some(prefix_query) = opt_prefix_query {
746                        stmt = match conn
747                            .prepare_cached(&that.column_table(col).str_iter_with_prefix)
748                        {
749                            Ok(v) => v,
750                            Err(e) => {
751                                return Ok(Err(io::Error::other(e)));
752                            }
753                        };
754                        rows = match stmt.query([prefix_query]) {
755                            Ok(v) => v,
756                            Err(e) => {
757                                return Ok(Err(io::Error::other(e)));
758                            }
759                        };
760                    } else {
761                        stmt = match conn.prepare_cached(&that.column_table(col).str_iter_no_prefix)
762                        {
763                            Ok(v) => v,
764                            Err(e) => {
765                                return Ok(Err(io::Error::other(e)));
766                            }
767                        };
768                        rows = match stmt.query([]) {
769                            Ok(v) => v,
770                            Err(e) => {
771                                return Ok(Err(io::Error::other(e)));
772                            }
773                        };
774                    }
775
776                    let mut sw = 0usize;
777                    let mut sbw = 0usize;
778
779                    let out = loop {
780                        match rows.next() {
781                            // Iterated value
782                            Ok(Some(row)) => {
783                                let kt: String = match row.get(0) {
784                                    Err(e) => {
785                                        break Err(io::Error::other(e));
786                                    }
787                                    Ok(v) => v,
788                                };
789                                let v: Vec<u8> = match row.get(1) {
790                                    Err(e) => {
791                                        break Err(io::Error::other(e));
792                                    }
793                                    Ok(v) => v,
794                                };
795                                let k: Vec<u8> = match text_to_key(&kt) {
796                                    Err(e) => {
797                                        break Err(io::Error::other(format!(
798                                            "SQLite row get column 0 text convert error: {:?}",
799                                            e
800                                        )));
801                                    }
802                                    Ok(v) => v,
803                                };
804
805                                sw += 1;
806                                sbw += k.len() + v.len();
807
808                                match f(context, (&k, &v)) {
809                                    Ok(None) => (),
810                                    Ok(Some(out)) => {
811                                        // Callback early termination
812                                        that.stats_read(sw, sbw);
813                                        break Ok(Some(out));
814                                    }
815                                    Err(e) => {
816                                        // Callback error termination
817                                        that.stats_read(sw, sbw);
818                                        break Err(e);
819                                    }
820                                }
821                            }
822                            // Natural iterator termination
823                            Ok(None) => {
824                                break Ok(None);
825                            }
826                            // Error iterator termination
827                            Err(_) => {
828                                break Ok(None);
829                            }
830                        }
831                    };
832
833                    that.stats_read(sw, sbw);
834
835                    Ok(out)
836                })
837                .await
838                .map_err(io::Error::other)?;
839
840            let context = context.lock().take().unwrap();
841
842            res.map(|x| (context, x))
843        })
844    }
845
846    fn iter_keys<
847        'a,
848        T: Send + 'static,
849        C: Send + 'static,
850        F: FnMut(&mut C, DBKeyRef) -> io::Result<Option<T>> + Send + Sync + 'static,
851    >(
852        &'a self,
853        col: u32,
854        prefix: Option<&'a [u8]>,
855        context: C,
856        mut f: F,
857    ) -> Pin<Box<dyn Future<Output = io::Result<(C, Option<T>)>> + Send + 'a>> {
858        let opt_prefix_query = prefix.map(|p| like_key_to_text(p) + "%");
859        Box::pin(async move {
860            if col >= self.columns() {
861                return Err(io::Error::from(io::ErrorKind::NotFound));
862            }
863
864            let that = self.clone();
865            let context = Arc::new(Mutex::new(Some(context)));
866            let context_ref = context.clone();
867
868            let res = self
869                .conn(move |conn| {
870                    let mut context = context_ref.lock();
871                    let context = context.as_mut().unwrap();
872
873                    let mut stmt;
874                    let mut rows;
875                    if let Some(prefix_query) = opt_prefix_query {
876                        stmt = match conn
877                            .prepare_cached(&that.column_table(col).str_iter_keys_with_prefix)
878                        {
879                            Ok(v) => v,
880                            Err(e) => {
881                                return Ok(Err(io::Error::other(e)));
882                            }
883                        };
884                        rows = match stmt.query([prefix_query]) {
885                            Ok(v) => v,
886                            Err(e) => {
887                                return Ok(Err(io::Error::other(e)));
888                            }
889                        };
890                    } else {
891                        stmt = match conn
892                            .prepare_cached(&that.column_table(col).str_iter_keys_no_prefix)
893                        {
894                            Ok(v) => v,
895                            Err(e) => {
896                                return Ok(Err(io::Error::other(e)));
897                            }
898                        };
899                        rows = match stmt.query([]) {
900                            Ok(v) => v,
901                            Err(e) => {
902                                return Ok(Err(io::Error::other(e)));
903                            }
904                        };
905                    }
906
907                    let mut sw = 0usize;
908                    let mut sbw = 0usize;
909
910                    let out = loop {
911                        match rows.next() {
912                            // Iterated value
913                            Ok(Some(row)) => {
914                                let kt: String = match row.get(0) {
915                                    Err(e) => {
916                                        break Err(io::Error::other(e));
917                                    }
918                                    Ok(v) => v,
919                                };
920                                let k: Vec<u8> = match text_to_key(&kt) {
921                                    Err(e) => {
922                                        break Err(io::Error::other(format!(
923                                            "SQLite row get column 0 text convert error: {:?}",
924                                            e
925                                        )));
926                                    }
927                                    Ok(v) => v,
928                                };
929
930                                sw += 1;
931                                sbw += k.len();
932
933                                match f(context, &k) {
934                                    Ok(None) => (),
935                                    Ok(Some(out)) => {
936                                        // Callback early termination
937                                        that.stats_read(sw, sbw);
938                                        break Ok(Some(out));
939                                    }
940                                    Err(e) => {
941                                        // Callback error termination
942                                        that.stats_read(sw, sbw);
943                                        break Err(e);
944                                    }
945                                }
946                            }
947                            // Natural iterator termination
948                            Ok(None) => {
949                                break Ok(None);
950                            }
951                            // Error iterator termination
952                            Err(_) => {
953                                break Ok(None);
954                            }
955                        }
956                    };
957
958                    that.stats_read(sw, sbw);
959
960                    Ok(out)
961                })
962                .await
963                .map_err(io::Error::other)?;
964
965            let context = context.lock().take().unwrap();
966
967            res.map(|x| (context, x))
968        })
969    }
970
971    fn io_stats(&self, kind: IoStatsKind) -> IoStats {
972        let mut inner = self.inner.lock();
973        match kind {
974            IoStatsKind::Overall => {
975                let mut stats = inner.overall_stats.clone();
976                stats.span = std::time::SystemTime::now()
977                    .duration_since(stats.started)
978                    .unwrap_or_default();
979                stats
980            }
981            IoStatsKind::SincePrevious => {
982                let mut stats = inner.current_stats.clone();
983                stats.span = std::time::SystemTime::now()
984                    .duration_since(stats.started)
985                    .unwrap_or_default();
986                inner.current_stats = IoStats::empty();
987                stats
988            }
989        }
990    }
991
992    fn num_columns(&self) -> io::Result<u32> {
993        let this = self.clone();
994        self.conn_blocking(move |conn| {
995            Self::get_unique_value(conn, this.control_table(), "columns", 0u32)
996        })
997        .map_err(io::Error::other)
998    }
999
1000    fn num_keys(&self, col: u32) -> Pin<Box<dyn Future<Output = io::Result<u64>> + Send + '_>> {
1001        let this = self.clone();
1002        Box::pin(async move {
1003            this.conn(move |conn| {
1004                conn.query_row(
1005                    &format!("SELECT Count(*) FROM {}", get_column_table_name(col)),
1006                    [],
1007                    |row| -> rusqlite::Result<u64> { row.get(0) },
1008                )
1009            })
1010            .await
1011            .map_err(|_| io::Error::from(io::ErrorKind::NotFound))
1012        })
1013    }
1014
1015    /// Check for the existence of a value by key.
1016    fn has_key<'a>(
1017        &'a self,
1018        col: u32,
1019        key: &'a [u8],
1020    ) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + 'a>> {
1021        let key_text = key_to_text(key);
1022        let key_len = key.len();
1023
1024        Box::pin(async move {
1025            let that = self.clone();
1026            that.validate_column(col).map_err(io::Error::other)?;
1027            let someval = self
1028                .conn_blocking(move |conn| Self::has_value(conn, that.column_table(col), &key_text))
1029                .map_err(io::Error::other)?;
1030
1031            self.stats_read(1, key_len);
1032
1033            Ok(someval)
1034        })
1035    }
1036
1037    /// Check for the existence of a value by prefix.
1038    fn has_prefix<'a>(
1039        &'a self,
1040        col: u32,
1041        prefix: &'a [u8],
1042    ) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + 'a>> {
1043        let prefix_len = prefix.len();
1044        let prefix_text = like_key_to_text(prefix) + "%";
1045
1046        Box::pin(async move {
1047            let that = self.clone();
1048            that.validate_column(col).map_err(io::Error::other)?;
1049            let someval = self
1050                .conn_blocking(move |conn| {
1051                    Self::has_value_like(conn, that.column_table(col), &prefix_text)
1052                })
1053                .map_err(io::Error::other)?;
1054
1055            self.stats_read(1, prefix_len);
1056
1057            Ok(someval)
1058        })
1059    }
1060
1061    /// Get the first value matching the given prefix.
1062    fn first_with_prefix<'a>(
1063        &'a self,
1064        col: u32,
1065        prefix: &'a [u8],
1066    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBKeyValue>>> + Send + 'a>> {
1067        let prefix_len = prefix.len();
1068        let like = like_key_to_text(prefix) + "%";
1069
1070        Box::pin(async move {
1071            let that = self.clone();
1072            that.validate_column(col).map_err(io::Error::other)?;
1073            let someval = self
1074                .conn_blocking(move |conn| {
1075                    Self::load_first_value_blob_like(conn, that.column_table(col), &like)
1076                })
1077                .map_err(io::Error::other)?;
1078
1079            self.stats_read(1, prefix_len);
1080
1081            match someval {
1082                Some((kt, val)) => match text_to_key(&kt) {
1083                    Err(e) => {
1084                        return Err(io::Error::other(format!(
1085                            "SQLite row get column 0 text convert error: {:?}",
1086                            e
1087                        )))
1088                    }
1089                    Ok(k) => Ok(Some((k, val))),
1090                },
1091                None => Ok(None),
1092            }
1093        })
1094    }
1095
1096    /// Vacuum database
1097    fn cleanup(&self) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + '_>> {
1098        Box::pin(async { self.vacuum().await.map_err(io::Error::other) })
1099    }
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104
1105    use super::*;
1106    use keyvaluedb_shared_tests as st;
1107    use tempfile::Builder as TempfileBuilder;
1108
1109    fn create(columns: u32) -> io::Result<Database> {
1110        let tempfile = TempfileBuilder::new()
1111            .prefix("")
1112            .tempfile()?
1113            .path()
1114            .to_path_buf();
1115        let config = DatabaseConfig::new().with_columns(columns);
1116        Database::open(tempfile, config)
1117    }
1118
1119    fn create_vacuum_mode(columns: u32, vacuum_mode: VacuumMode) -> io::Result<Database> {
1120        let tempfile = TempfileBuilder::new()
1121            .prefix("")
1122            .tempfile()?
1123            .path()
1124            .to_path_buf();
1125        let config = DatabaseConfig::new()
1126            .with_columns(columns)
1127            .with_vacuum_mode(vacuum_mode);
1128        Database::open(tempfile, config)
1129    }
1130
1131    #[tokio::test]
1132    async fn get_fails_with_non_existing_column() -> io::Result<()> {
1133        let db = create(1)?;
1134        st::test_get_fails_with_non_existing_column(db).await
1135    }
1136
1137    #[tokio::test]
1138    async fn num_keys() -> io::Result<()> {
1139        let db = create(1)?;
1140        st::test_num_keys(db).await
1141    }
1142
1143    #[tokio::test]
1144    async fn put_and_get() -> io::Result<()> {
1145        let db = create(1)?;
1146        st::test_put_and_get(db).await
1147    }
1148
1149    #[tokio::test]
1150    async fn delete_and_get() -> io::Result<()> {
1151        let db = create(1)?;
1152        st::test_delete_and_get(db).await
1153    }
1154
1155    #[tokio::test]
1156    async fn delete_and_get_single() -> io::Result<()> {
1157        let db = create(1)?;
1158        st::test_delete_and_get_single(db).await
1159    }
1160
1161    #[tokio::test]
1162    async fn delete_prefix() -> io::Result<()> {
1163        let db = create(st::DELETE_PREFIX_NUM_COLUMNS)?;
1164        st::test_delete_prefix(db).await
1165    }
1166
1167    #[tokio::test]
1168    async fn iter() -> io::Result<()> {
1169        let db = create(1)?;
1170        st::test_iter(db).await
1171    }
1172
1173    #[tokio::test]
1174    async fn iter_keys() -> io::Result<()> {
1175        let db = create(1)?;
1176        st::test_iter_keys(db).await
1177    }
1178
1179    #[tokio::test]
1180    async fn iter_with_prefix() -> io::Result<()> {
1181        let db = create(1)?;
1182        st::test_iter_with_prefix(db).await
1183    }
1184
1185    #[tokio::test]
1186    async fn complex() -> io::Result<()> {
1187        let db = create(1)?;
1188        st::test_complex(db).await
1189    }
1190
1191    #[tokio::test]
1192    async fn cleanup() -> io::Result<()> {
1193        let db = create(1)?;
1194        st::test_cleanup(db).await?;
1195
1196        let db = create_vacuum_mode(1, VacuumMode::None)?;
1197        st::test_cleanup(db).await?;
1198
1199        let db = create_vacuum_mode(1, VacuumMode::Incremental)?;
1200        st::test_cleanup(db).await?;
1201
1202        let db = create_vacuum_mode(1, VacuumMode::Full)?;
1203        st::test_cleanup(db).await?;
1204
1205        let tempfile = TempfileBuilder::new()
1206            .prefix("")
1207            .tempfile()?
1208            .path()
1209            .to_path_buf();
1210        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::None);
1211        let db = Database::open(tempfile.clone(), config)?;
1212        st::test_cleanup(db).await?;
1213
1214        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::Incremental);
1215        let db = Database::open(tempfile.clone(), config)?;
1216        st::test_cleanup(db).await?;
1217
1218        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::Full);
1219        let db = Database::open(tempfile.clone(), config)?;
1220        st::test_cleanup(db).await?;
1221
1222        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::None);
1223        let db = Database::open(tempfile, config)?;
1224        st::test_cleanup(db).await?;
1225
1226        Ok(())
1227    }
1228
1229    #[tokio::test]
1230    async fn stats() -> io::Result<()> {
1231        let db = create(st::IO_STATS_NUM_COLUMNS)?;
1232        st::test_io_stats(db).await
1233    }
1234
1235    #[tokio::test]
1236    #[should_panic]
1237    async fn db_config_with_zero_columns() {
1238        let _cfg = DatabaseConfig::new().with_columns(0);
1239    }
1240
1241    #[tokio::test]
1242    #[should_panic]
1243    async fn open_db_with_zero_columns() {
1244        let cfg = DatabaseConfig::new().with_columns(0);
1245        let _db = Database::open("", cfg);
1246    }
1247
1248    #[tokio::test]
1249    async fn add_columns() {
1250        let config_1 = DatabaseConfig::default();
1251        let config_5 = DatabaseConfig::new().with_columns(5);
1252
1253        let tempfile = TempfileBuilder::new()
1254            .prefix("")
1255            .tempfile()
1256            .unwrap()
1257            .path()
1258            .to_path_buf();
1259
1260        // open 1, add 4.
1261        {
1262            let db = Database::open(&tempfile, config_1).unwrap();
1263            assert_eq!(db.num_columns().unwrap(), 1);
1264
1265            for i in 2..=5 {
1266                db.add_column().unwrap();
1267                assert_eq!(db.num_columns().unwrap(), i);
1268            }
1269        }
1270
1271        // reopen as 5.
1272        {
1273            let db = Database::open(&tempfile, config_5).unwrap();
1274            assert_eq!(db.num_columns().unwrap(), 5);
1275        }
1276    }
1277
1278    #[tokio::test]
1279    async fn remove_columns() {
1280        let config_1 = DatabaseConfig::default();
1281        let config_5 = DatabaseConfig::new().with_columns(5);
1282
1283        let tempfile = TempfileBuilder::new()
1284            .prefix("drop_columns")
1285            .tempfile()
1286            .unwrap()
1287            .path()
1288            .to_path_buf();
1289
1290        // open 5, remove 4.
1291        {
1292            let db = Database::open(&tempfile, config_5).expect("open with 5 columns");
1293            assert_eq!(db.num_columns().unwrap(), 5);
1294
1295            for i in (1..5).rev() {
1296                db.remove_last_column().unwrap();
1297                assert_eq!(db.num_columns().unwrap(), i);
1298            }
1299        }
1300
1301        // reopen as 1.
1302        {
1303            let db = Database::open(&tempfile, config_1).unwrap();
1304            assert_eq!(db.num_columns().unwrap(), 1);
1305        }
1306    }
1307
1308    #[tokio::test]
1309    async fn test_num_keys() {
1310        let tempfile = TempfileBuilder::new()
1311            .prefix("")
1312            .tempfile()
1313            .unwrap()
1314            .path()
1315            .to_path_buf();
1316        let config = DatabaseConfig::new().with_columns(1);
1317        let db = Database::open(tempfile, config).unwrap();
1318
1319        assert_eq!(
1320            db.num_keys(0).await.unwrap(),
1321            0,
1322            "database is empty after creation"
1323        );
1324        let key1 = b"beef";
1325        let mut batch = db.transaction();
1326        batch.put(0, key1, key1);
1327        db.write(batch).await.unwrap();
1328        assert_eq!(
1329            db.num_keys(0).await.unwrap(),
1330            1,
1331            "adding a key increases the count"
1332        );
1333    }
1334}