keyvaluedb_sqlite/
lib.rs

1#![deny(clippy::all)]
2
3mod tools;
4
5pub use async_sqlite::rusqlite::OpenFlags;
6use async_sqlite::rusqlite::{params, OptionalExtension as _};
7use async_sqlite::*;
8use keyvaluedb::{
9    DBKeyRef, DBKeyValue, DBKeyValueRef, DBOp, DBTransaction, DBTransactionError, DBValue, IoStats,
10    IoStatsKind, KeyValueDB,
11};
12use parking_lot::Mutex;
13use std::sync::Arc;
14use std::{
15    future::Future,
16    io,
17    path::{Path, PathBuf},
18    pin::Pin,
19    str::FromStr,
20};
21use tools::*;
22
23///////////////////////////////////////////////////////////////////////////////
24
25#[derive(Copy, Clone, Debug, Eq, PartialEq)]
26pub enum VacuumMode {
27    None,
28    Incremental,
29    Full,
30}
31
32/// Database configuration
33#[derive(Clone)]
34pub struct DatabaseConfig {
35    /// Set number of columns.
36    /// The number of columns must not be zero.
37    pub columns: u32,
38    /// Set flags used to open the database
39    pub flags: OpenFlags,
40    /// Number of connections to open
41    pub num_conns: usize,
42    /// Vacuum mode
43    pub vacuum_mode: VacuumMode,
44}
45
46impl DatabaseConfig {
47    /// Create new `DatabaseConfig` with default parameters
48    pub fn new() -> Self {
49        Default::default()
50    }
51
52    /// Set the number of columns. `columns` must not be zero.
53    pub fn with_columns(self, columns: u32) -> Self {
54        assert!(columns > 0, "the number of columns must not be zero");
55        Self { columns, ..self }
56    }
57
58    /// Sets the flags to 'in-memory database'
59    pub fn with_in_memory(self) -> Self {
60        Self {
61            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
62                | OpenFlags::SQLITE_OPEN_CREATE
63                | OpenFlags::SQLITE_OPEN_NO_MUTEX
64                | OpenFlags::SQLITE_OPEN_MEMORY,
65            ..self
66        }
67    }
68
69    /// Replaces all the flags
70    pub fn with_flags(self, flags: OpenFlags) -> Self {
71        Self { flags, ..self }
72    }
73
74    /// Sets the number of connections for this database
75    pub fn with_num_conns(self, num_conns: usize) -> Self {
76        Self { num_conns, ..self }
77    }
78
79    /// Set the vacuum mode used by 'cleanup'
80    pub fn with_vacuum_mode(self, vacuum_mode: VacuumMode) -> Self {
81        Self {
82            vacuum_mode,
83            ..self
84        }
85    }
86}
87
88impl Default for DatabaseConfig {
89    fn default() -> DatabaseConfig {
90        DatabaseConfig {
91            columns: 1,
92            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
93                | OpenFlags::SQLITE_OPEN_CREATE
94                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
95            num_conns: 1,
96            vacuum_mode: VacuumMode::None,
97        }
98    }
99}
100
101///////////////////////////////////////////////////////////////////////////////
102
103/// An sqlite table with its statement strings
104pub struct DatabaseTable {
105    _table: String,
106    str_has_value: String,
107    str_has_value_like: String,
108    str_get_unique_value: String,
109    str_get_first_value_like: String,
110    str_set_unique_value: String,
111    str_remove_unique_value: String,
112    str_remove_and_return_unique_value: String,
113    str_remove_unique_value_like: String,
114    str_iter_with_prefix: String,
115    str_iter_no_prefix: String,
116    str_iter_keys_with_prefix: String,
117    str_iter_keys_no_prefix: String,
118}
119
120impl DatabaseTable {
121    pub fn new(table: String) -> Self {
122        let str_has_value = format!("SELECT 1 FROM {} WHERE [key] = ? LIMIT 1", table);
123        let str_has_value_like = format!(
124            "SELECT 1 FROM {} WHERE [key] LIKE ? ESCAPE '\\' LIMIT 1",
125            table
126        );
127        let str_get_unique_value = format!("SELECT value FROM {} WHERE [key] = ? LIMIT 1", table);
128        let str_get_first_value_like = format!(
129            "SELECT key, value FROM {} WHERE [key] LIKE ? ESCAPE '\\' LIMIT 1",
130            table
131        );
132        let str_set_unique_value = format!(
133            "INSERT OR REPLACE INTO {} ([key], value) VALUES(?, ?)",
134            table
135        );
136        let str_remove_unique_value = format!("DELETE FROM {} WHERE [key] = ?", table);
137        let str_remove_and_return_unique_value =
138            format!("DELETE FROM {} WHERE [key] = ? RETURNING value", table);
139        let str_remove_unique_value_like =
140            format!("DELETE FROM {} WHERE [key] LIKE ? ESCAPE '\\'", table);
141        let str_iter_with_prefix = format!(
142            "SELECT key, value FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
143            table
144        );
145        let str_iter_no_prefix = format!("SELECT key, value FROM {}", table);
146        let str_iter_keys_with_prefix =
147            format!("SELECT key FROM {} WHERE [key] LIKE ? ESCAPE '\\'", table);
148        let str_iter_keys_no_prefix = format!("SELECT key FROM {}", table);
149
150        Self {
151            _table: table,
152            str_has_value,
153            str_has_value_like,
154            str_get_unique_value,
155            str_get_first_value_like,
156            str_set_unique_value,
157            str_remove_unique_value,
158            str_remove_and_return_unique_value,
159            str_remove_unique_value_like,
160            str_iter_with_prefix,
161            str_iter_no_prefix,
162            str_iter_keys_with_prefix,
163            str_iter_keys_no_prefix,
164        }
165    }
166}
167
168///////////////////////////////////////////////////////////////////////////////
169
170/// An sqlite key-value database fulfilling the `KeyValueDB` trait
171pub struct DatabaseUnlockedInner {
172    path: PathBuf,
173    config: DatabaseConfig,
174    pool: Pool,
175    control_table: Arc<DatabaseTable>,
176    column_tables: Vec<Arc<DatabaseTable>>,
177}
178
179pub struct DatabaseInner {
180    overall_stats: IoStats,
181    current_stats: IoStats,
182}
183
184#[derive(Clone)]
185pub struct Database {
186    unlocked_inner: Arc<DatabaseUnlockedInner>,
187    inner: Arc<Mutex<DatabaseInner>>,
188}
189
190impl Database {
191    ////////////////////////////////////////////////////////////////
192    // Initialization
193
194    pub fn open<P: AsRef<Path>>(path: P, config: DatabaseConfig) -> io::Result<Self> {
195        assert_ne!(config.columns, 0, "number of columns must be >= 1");
196
197        let path = PathBuf::from(path.as_ref());
198        let flags = config.flags;
199
200        let mut column_tables = vec![];
201        for n in 0..config.columns {
202            column_tables.push(Arc::new(DatabaseTable::new(get_column_table_name(n))))
203        }
204        let control_table = Arc::new(DatabaseTable::new("control".to_string()));
205
206        let pool_builder = PoolBuilder::new()
207            .path(&path)
208            .flags(flags)
209            .num_conns(config.num_conns);
210
211        let pool = pool_builder.open_blocking().map_err(io::Error::other)?;
212
213        let out = Self {
214            unlocked_inner: Arc::new(DatabaseUnlockedInner {
215                path,
216                config,
217                pool,
218                control_table,
219                column_tables,
220            }),
221            inner: Arc::new(Mutex::new(DatabaseInner {
222                overall_stats: IoStats::empty(),
223                current_stats: IoStats::empty(),
224            })),
225        };
226
227        let vacuum_mode = out.config().vacuum_mode;
228
229        out.conn_blocking(move |conn| {
230            // Don't rely on STATEMENT_CACHE_DEFAULT_CAPACITY in rusqlite, set it explicitly
231            conn.set_prepared_statement_cache_capacity(256);
232
233            conn.pragma_update(None, "case_sensitive_like", "ON")?;
234            conn.pragma_update(None, "journal_mode", "WAL")?;
235            conn.pragma_update(None, "synchronous", "normal")?;
236            conn.pragma_update(None, "journal_size_limit", 6144000)?;
237
238            match vacuum_mode {
239                VacuumMode::None | VacuumMode::Full => {
240                    let current: u32 =
241                        conn.pragma_query_value(None, "auto_vacuum", |x| x.get(0))?;
242                    if current != 0 {
243                        conn.execute("VACUUM", [])?;
244                        conn.pragma_update(None, "auto_vacuum", 0)?;
245                    }
246                }
247                VacuumMode::Incremental => {
248                    let current: u32 =
249                        conn.pragma_query_value(None, "auto_vacuum", |x| x.get(0))?;
250                    if current != 2 {
251                        conn.execute("VACUUM", [])?;
252                        conn.pragma_update(None, "auto_vacuum", "2")?;
253                    }
254                }
255            }
256
257            Ok(())
258        })
259        .map_err(io::Error::other)?;
260
261        out.open_resize_columns()?;
262
263        Ok(out)
264    }
265
266    pub fn path(&self) -> PathBuf {
267        self.unlocked_inner.path.clone()
268    }
269
270    pub fn config(&self) -> DatabaseConfig {
271        self.unlocked_inner.config.clone()
272    }
273
274    pub fn columns(&self) -> u32 {
275        self.unlocked_inner.config.columns
276    }
277
278    pub fn control_table(&self) -> Arc<DatabaseTable> {
279        self.unlocked_inner.control_table.clone()
280    }
281
282    pub fn column_table(&self, col: u32) -> Arc<DatabaseTable> {
283        self.unlocked_inner.column_tables[col as usize].clone()
284    }
285
286    pub fn conn_blocking<T, F>(&self, func: F) -> Result<T, Error>
287    where
288        F: FnOnce(&rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + 'static,
289        T: Send + 'static,
290    {
291        self.unlocked_inner.pool.conn_blocking(func)
292    }
293
294    pub async fn conn<T, F>(&self, func: F) -> Result<T, Error>
295    where
296        F: FnOnce(&rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + 'static,
297        T: Send + 'static,
298    {
299        self.unlocked_inner.pool.conn(func).await
300    }
301
302    pub async fn conn_mut<T, F>(&self, func: F) -> Result<T, Error>
303    where
304        F: FnOnce(&mut rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + 'static,
305        T: Send + 'static,
306    {
307        self.unlocked_inner.pool.conn_mut(func).await
308    }
309
310    ////////////////////////////////////////////////////////////////
311    // Low level operations
312
313    /// Remove the last column family in the database. The deletion is definitive.
314    pub fn remove_last_column(&self) -> Result<(), Error> {
315        let this = self.clone();
316        self.conn_blocking(move |conn| {
317            let columns = Self::get_unique_value(conn, this.control_table(), "columns", 0u32)?;
318            if columns == 0 {
319                return Err(rusqlite::Error::QueryReturnedNoRows);
320            }
321            Self::set_unique_value(conn, this.control_table(), "columns", columns - 1)?;
322
323            conn.execute(
324                &format!("DROP TABLE {}", get_column_table_name(columns - 1)),
325                [],
326            )?;
327            Ok(())
328        })
329    }
330
331    /// Add a new column family to the DB.
332    pub fn add_column(&self) -> Result<(), Error> {
333        let this = self.clone();
334
335        self.conn_blocking(move |conn| {
336            let columns = Self::get_unique_value(conn, this.control_table(), "columns", 0u32)?;
337            Self::set_unique_value(conn, this.control_table(), "columns", columns + 1)?;
338            Self::create_column_table(conn, columns)
339        })
340    }
341    /// Helper to create new transaction for this database.
342    pub fn transaction(&self) -> DBTransaction {
343        DBTransaction::new()
344    }
345
346    /// Vacuum database
347    pub async fn vacuum(&self) -> Result<(), Error> {
348        match self.config().vacuum_mode {
349            VacuumMode::None => {
350                return Ok(());
351            }
352            VacuumMode::Incremental => {
353                self.conn(move |conn| {
354                    conn.execute("PRAGMA incremental_vacuum", [])?;
355                    Ok(())
356                })
357                .await
358            }
359            VacuumMode::Full => {
360                self.conn(move |conn| {
361                    conn.execute("VACUUM", [])?;
362                    Ok(())
363                })
364                .await
365            }
366        }
367    }
368
369    ////////////////////////////////////////////////////////////////
370    // Internal helpers
371
372    fn validate_column(&self, col: u32) -> rusqlite::Result<()> {
373        if col >= self.columns() {
374            return Err(rusqlite::Error::InvalidColumnIndex(col as usize));
375        }
376        Ok(())
377    }
378
379    fn create_column_table(conn: &rusqlite::Connection, column: u32) -> rusqlite::Result<()> {
380        conn.execute(&format!("CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value BLOB)", get_column_table_name(column)), []).map(drop)
381    }
382
383    fn get_unique_value<V>(
384        conn: &rusqlite::Connection,
385        table: Arc<DatabaseTable>,
386        key: &str,
387        default: V,
388    ) -> rusqlite::Result<V>
389    where
390        V: FromStr,
391    {
392        let mut stmt = conn.prepare_cached(&table.str_get_unique_value)?;
393
394        if let Ok(found) = stmt.query_row([key], |row| -> rusqlite::Result<String> { row.get(0) }) {
395            if let Ok(v) = V::from_str(&found) {
396                return Ok(v);
397            }
398        }
399        Ok(default)
400    }
401
402    fn set_unique_value<V>(
403        conn: &rusqlite::Connection,
404        table: Arc<DatabaseTable>,
405        key: &str,
406        value: V,
407    ) -> rusqlite::Result<()>
408    where
409        V: ToString,
410    {
411        let mut stmt = conn.prepare_cached(&table.str_set_unique_value)?;
412
413        let changed = stmt.execute([key, value.to_string().as_str()])?;
414
415        assert!(
416            changed <= 1,
417            "multiple changes to unique key should not occur"
418        );
419        if changed == 0 {
420            return Err(rusqlite::Error::QueryReturnedNoRows);
421        }
422
423        Ok(())
424    }
425
426    fn has_value(
427        conn: &rusqlite::Connection,
428        table: Arc<DatabaseTable>,
429        key: &str,
430    ) -> rusqlite::Result<bool> {
431        let mut stmt = conn.prepare_cached(&table.str_has_value)?;
432        stmt.exists([key])
433    }
434
435    fn has_value_like(
436        conn: &rusqlite::Connection,
437        table: Arc<DatabaseTable>,
438        key: &str,
439    ) -> rusqlite::Result<bool> {
440        let mut stmt = conn.prepare_cached(&table.str_has_value_like)?;
441        stmt.exists([key])
442    }
443
444    fn load_unique_value_blob(
445        conn: &rusqlite::Connection,
446        table: Arc<DatabaseTable>,
447        key: &str,
448    ) -> rusqlite::Result<Option<Vec<u8>>> {
449        let mut stmt = conn.prepare_cached(&table.str_get_unique_value)?;
450
451        stmt.query_row([key], |row| -> rusqlite::Result<Vec<u8>> { row.get(0) })
452            .optional()
453    }
454
455    fn load_first_value_blob_like(
456        conn: &rusqlite::Connection,
457        table: Arc<DatabaseTable>,
458        like: &str,
459    ) -> rusqlite::Result<Option<(String, Vec<u8>)>> {
460        let mut stmt = conn.prepare_cached(&table.str_get_first_value_like)?;
461
462        stmt.query_row([like], |row| -> rusqlite::Result<(String, Vec<u8>)> {
463            Ok((row.get(0)?, row.get(1)?))
464        })
465        .optional()
466    }
467
468    fn store_unique_value_blob(
469        conn: &rusqlite::Connection,
470        table: Arc<DatabaseTable>,
471        key: &str,
472        value: &[u8],
473    ) -> rusqlite::Result<()> {
474        let mut stmt = conn.prepare_cached(&table.str_set_unique_value)?;
475
476        let changed = stmt.execute(params![key, value])?;
477        assert!(
478            changed <= 1,
479            "multiple changes to unique key should not occur"
480        );
481        if changed == 0 {
482            return Err(rusqlite::Error::QueryReturnedNoRows);
483        }
484        Ok(())
485    }
486
487    fn remove_unique_value_blob(
488        conn: &rusqlite::Connection,
489        table: Arc<DatabaseTable>,
490        key: &str,
491    ) -> rusqlite::Result<()> {
492        let mut stmt = conn.prepare_cached(&table.str_remove_unique_value)?;
493
494        let changed = stmt.execute([key])?;
495        assert!(
496            changed <= 1,
497            "multiple deletions of unique key should not occur"
498        );
499        if changed == 0 {
500            return Err(rusqlite::Error::QueryReturnedNoRows);
501        }
502        Ok(())
503    }
504
505    fn remove_and_return_unique_value_blob(
506        conn: &rusqlite::Connection,
507        table: Arc<DatabaseTable>,
508        key: &str,
509    ) -> rusqlite::Result<Option<Vec<u8>>> {
510        let mut stmt = conn.prepare_cached(&table.str_remove_and_return_unique_value)?;
511
512        stmt.query_row([key], |row| -> rusqlite::Result<Vec<u8>> { row.get(0) })
513            .optional()
514    }
515
516    fn remove_unique_value_blob_like(
517        conn: &rusqlite::Connection,
518        table: Arc<DatabaseTable>,
519        like: &str,
520    ) -> rusqlite::Result<usize> {
521        let mut stmt = conn.prepare_cached(&table.str_remove_unique_value_like)?;
522
523        let changed = stmt.execute([like])?;
524        Ok(changed)
525    }
526
527    fn open_resize_columns(&self) -> io::Result<()> {
528        let columns = self.columns();
529        let this = self.clone();
530        self.conn_blocking(move |conn| {
531			// First see if we have a control table with the number of columns
532			conn.execute("CREATE TABLE IF NOT EXISTS control (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value TEXT)", [])?;
533
534            // Get column count
535            let on_disk_columns =
536                Self::get_unique_value(conn, this.control_table(), "columns", 0u32)?;
537
538            // If desired column count is less than or equal to current column count, then allow it, but restrict access to columns
539            if columns <= on_disk_columns {
540                return Ok(());
541            }
542
543            // Otherwise resize and add other columns
544            for cn in on_disk_columns..columns {
545                // Create the column table if we don't have it
546                Self::create_column_table(conn, cn)?;
547            }
548            Self::set_unique_value(
549                conn,
550                this.control_table(),
551                "columns",
552                columns,
553            )?;
554            Ok(())
555        }).map_err(io::Error::other)
556    }
557
558    fn stats_read(&self, count: usize, bytes: usize) {
559        let mut inner = self.inner.lock();
560        inner.current_stats.reads += count as u64;
561        inner.overall_stats.reads += count as u64;
562        inner.current_stats.bytes_read += bytes as u64;
563        inner.overall_stats.bytes_read += bytes as u64;
564    }
565    fn stats_write(&self, count: usize, bytes: usize) {
566        let mut inner = self.inner.lock();
567        inner.current_stats.writes += count as u64;
568        inner.overall_stats.writes += count as u64;
569        inner.current_stats.bytes_written += bytes as u64;
570        inner.overall_stats.bytes_written += bytes as u64;
571    }
572    fn stats_transaction(&self, count: usize) {
573        let mut inner = self.inner.lock();
574        inner.current_stats.transactions += count as u64;
575        inner.overall_stats.transactions += count as u64;
576    }
577}
578
579impl KeyValueDB for Database {
580    fn get(
581        &self,
582        col: u32,
583        key: &[u8],
584    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + '_>> {
585        let key_text = key_to_text(key);
586        let key_len = key.len();
587
588        Box::pin(async move {
589            let that = self.clone();
590            that.validate_column(col).map_err(io::Error::other)?;
591            let someval = self
592                .conn_blocking(move |conn| {
593                    Self::load_unique_value_blob(conn, that.column_table(col), &key_text)
594                })
595                .map_err(io::Error::other)?;
596            {
597                match &someval {
598                    Some(val) => self.stats_read(1, key_len + val.len()),
599                    None => self.stats_read(1, key_len),
600                }
601            }
602
603            Ok(someval)
604        })
605    }
606
607    /// Remove a value by key, returning the old value
608    fn delete(
609        &self,
610        col: u32,
611        key: &[u8],
612    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + '_>> {
613        let key_text = key_to_text(key);
614        let key_len = key.len();
615
616        Box::pin(async move {
617            let that = self.clone();
618            that.validate_column(col).map_err(io::Error::other)?;
619            self.conn_blocking(move |conn| {
620                let someval = Self::remove_and_return_unique_value_blob(
621                    conn,
622                    that.column_table(col),
623                    &key_text,
624                )?;
625
626                match &someval {
627                    Some(val) => {
628                        that.stats_read(1, key_len + val.len());
629                    }
630                    None => that.stats_read(1, key_len),
631                }
632
633                Ok(someval)
634            })
635            .map_err(io::Error::other)
636        })
637    }
638
639    fn write(
640        &self,
641        transaction: DBTransaction,
642    ) -> Pin<Box<dyn Future<Output = Result<(), DBTransactionError>> + Send + '_>> {
643        let transaction = Arc::new(transaction);
644        Box::pin(async move {
645            self.stats_transaction(1);
646
647            let that = self.clone();
648            let transaction_clone = transaction.clone();
649            self.conn_mut(move |conn| {
650                let mut sw = 0usize;
651                let mut sbw = 0usize;
652
653                let tx = conn.transaction()?;
654
655                for op in &transaction_clone.ops {
656                    match op {
657                        DBOp::Insert { col, key, value } => {
658                            that.validate_column(*col)?;
659                            Self::store_unique_value_blob(
660                                &tx,
661                                that.column_table(*col),
662                                &key_to_text(key),
663                                value,
664                            )?;
665                            sw += 1;
666                            sbw += key.len() + value.len();
667                        }
668                        DBOp::Delete { col, key } => {
669                            that.validate_column(*col)?;
670                            Self::remove_unique_value_blob(
671                                &tx,
672                                that.column_table(*col),
673                                &key_to_text(key),
674                            )?;
675                            sw += 1;
676                        }
677                        DBOp::DeletePrefix { col, prefix } => {
678                            that.validate_column(*col)?;
679                            Self::remove_unique_value_blob_like(
680                                &tx,
681                                that.column_table(*col),
682                                &(like_key_to_text(prefix) + "%"),
683                            )?;
684                            sw += 1;
685                        }
686                    }
687                }
688                tx.commit()?;
689
690                that.stats_write(sw, sbw);
691
692                Ok(())
693            })
694            .await
695            .map_err(io::Error::other)
696            .map_err(|error| {
697                let transaction = transaction.as_ref().clone();
698                DBTransactionError { error, transaction }
699            })
700        })
701    }
702
703    fn iter<
704        'a,
705        T: Send + 'static,
706        C: Send + 'static,
707        F: FnMut(&mut C, DBKeyValueRef) -> io::Result<Option<T>> + Send + Sync + 'static,
708    >(
709        &'a self,
710        col: u32,
711        prefix: Option<&'a [u8]>,
712        context: C,
713        mut f: F,
714    ) -> Pin<Box<dyn Future<Output = io::Result<(C, Option<T>)>> + Send + 'a>> {
715        let opt_prefix_query = prefix.map(|p| like_key_to_text(p) + "%");
716        Box::pin(async move {
717            if col >= self.columns() {
718                return Err(io::Error::from(io::ErrorKind::NotFound));
719            }
720
721            let that = self.clone();
722            let context = Arc::new(Mutex::new(Some(context)));
723            let context_ref = context.clone();
724
725            let res = self
726                .conn(move |conn| {
727                    let mut context = context_ref.lock();
728                    let context = context.as_mut().unwrap();
729
730                    let mut stmt;
731                    let mut rows;
732                    if let Some(prefix_query) = opt_prefix_query {
733                        stmt = match conn
734                            .prepare_cached(&that.column_table(col).str_iter_with_prefix)
735                        {
736                            Ok(v) => v,
737                            Err(e) => {
738                                return Ok(Err(io::Error::other(e)));
739                            }
740                        };
741                        rows = match stmt.query([prefix_query]) {
742                            Ok(v) => v,
743                            Err(e) => {
744                                return Ok(Err(io::Error::other(e)));
745                            }
746                        };
747                    } else {
748                        stmt = match conn.prepare_cached(&that.column_table(col).str_iter_no_prefix)
749                        {
750                            Ok(v) => v,
751                            Err(e) => {
752                                return Ok(Err(io::Error::other(e)));
753                            }
754                        };
755                        rows = match stmt.query([]) {
756                            Ok(v) => v,
757                            Err(e) => {
758                                return Ok(Err(io::Error::other(e)));
759                            }
760                        };
761                    }
762
763                    let mut sw = 0usize;
764                    let mut sbw = 0usize;
765
766                    let out = loop {
767                        match rows.next() {
768                            // Iterated value
769                            Ok(Some(row)) => {
770                                let kt: String = match row.get(0) {
771                                    Err(e) => {
772                                        break Err(io::Error::other(e));
773                                    }
774                                    Ok(v) => v,
775                                };
776                                let v: Vec<u8> = match row.get(1) {
777                                    Err(e) => {
778                                        break Err(io::Error::other(e));
779                                    }
780                                    Ok(v) => v,
781                                };
782                                let k: Vec<u8> = match text_to_key(&kt) {
783                                    Err(e) => {
784                                        break Err(io::Error::other(format!(
785                                            "SQLite row get column 0 text convert error: {:?}",
786                                            e
787                                        )));
788                                    }
789                                    Ok(v) => v,
790                                };
791
792                                sw += 1;
793                                sbw += k.len() + v.len();
794
795                                match f(context, (&k, &v)) {
796                                    Ok(None) => (),
797                                    Ok(Some(out)) => {
798                                        // Callback early termination
799                                        that.stats_read(sw, sbw);
800                                        break Ok(Some(out));
801                                    }
802                                    Err(e) => {
803                                        // Callback error termination
804                                        that.stats_read(sw, sbw);
805                                        break Err(e);
806                                    }
807                                }
808                            }
809                            // Natural iterator termination
810                            Ok(None) => {
811                                break Ok(None);
812                            }
813                            // Error iterator termination
814                            Err(_) => {
815                                break Ok(None);
816                            }
817                        }
818                    };
819
820                    that.stats_read(sw, sbw);
821
822                    Ok(out)
823                })
824                .await
825                .map_err(io::Error::other)?;
826
827            let context = context.lock().take().unwrap();
828
829            res.map(|x| (context, x))
830        })
831    }
832
833    fn iter_keys<
834        'a,
835        T: Send + 'static,
836        C: Send + 'static,
837        F: FnMut(&mut C, DBKeyRef) -> io::Result<Option<T>> + Send + Sync + 'static,
838    >(
839        &'a self,
840        col: u32,
841        prefix: Option<&'a [u8]>,
842        context: C,
843        mut f: F,
844    ) -> Pin<Box<dyn Future<Output = io::Result<(C, Option<T>)>> + Send + 'a>> {
845        let opt_prefix_query = prefix.map(|p| like_key_to_text(p) + "%");
846        Box::pin(async move {
847            if col >= self.columns() {
848                return Err(io::Error::from(io::ErrorKind::NotFound));
849            }
850
851            let that = self.clone();
852            let context = Arc::new(Mutex::new(Some(context)));
853            let context_ref = context.clone();
854
855            let res = self
856                .conn(move |conn| {
857                    let mut context = context_ref.lock();
858                    let context = context.as_mut().unwrap();
859
860                    let mut stmt;
861                    let mut rows;
862                    if let Some(prefix_query) = opt_prefix_query {
863                        stmt = match conn
864                            .prepare_cached(&that.column_table(col).str_iter_keys_with_prefix)
865                        {
866                            Ok(v) => v,
867                            Err(e) => {
868                                return Ok(Err(io::Error::other(e)));
869                            }
870                        };
871                        rows = match stmt.query([prefix_query]) {
872                            Ok(v) => v,
873                            Err(e) => {
874                                return Ok(Err(io::Error::other(e)));
875                            }
876                        };
877                    } else {
878                        stmt = match conn
879                            .prepare_cached(&that.column_table(col).str_iter_keys_no_prefix)
880                        {
881                            Ok(v) => v,
882                            Err(e) => {
883                                return Ok(Err(io::Error::other(e)));
884                            }
885                        };
886                        rows = match stmt.query([]) {
887                            Ok(v) => v,
888                            Err(e) => {
889                                return Ok(Err(io::Error::other(e)));
890                            }
891                        };
892                    }
893
894                    let mut sw = 0usize;
895                    let mut sbw = 0usize;
896
897                    let out = loop {
898                        match rows.next() {
899                            // Iterated value
900                            Ok(Some(row)) => {
901                                let kt: String = match row.get(0) {
902                                    Err(e) => {
903                                        break Err(io::Error::other(e));
904                                    }
905                                    Ok(v) => v,
906                                };
907                                let k: Vec<u8> = match text_to_key(&kt) {
908                                    Err(e) => {
909                                        break Err(io::Error::other(format!(
910                                            "SQLite row get column 0 text convert error: {:?}",
911                                            e
912                                        )));
913                                    }
914                                    Ok(v) => v,
915                                };
916
917                                sw += 1;
918                                sbw += k.len();
919
920                                match f(context, &k) {
921                                    Ok(None) => (),
922                                    Ok(Some(out)) => {
923                                        // Callback early termination
924                                        that.stats_read(sw, sbw);
925                                        break Ok(Some(out));
926                                    }
927                                    Err(e) => {
928                                        // Callback error termination
929                                        that.stats_read(sw, sbw);
930                                        break Err(e);
931                                    }
932                                }
933                            }
934                            // Natural iterator termination
935                            Ok(None) => {
936                                break Ok(None);
937                            }
938                            // Error iterator termination
939                            Err(_) => {
940                                break Ok(None);
941                            }
942                        }
943                    };
944
945                    that.stats_read(sw, sbw);
946
947                    Ok(out)
948                })
949                .await
950                .map_err(io::Error::other)?;
951
952            let context = context.lock().take().unwrap();
953
954            res.map(|x| (context, x))
955        })
956    }
957
958    fn io_stats(&self, kind: IoStatsKind) -> IoStats {
959        let mut inner = self.inner.lock();
960        match kind {
961            IoStatsKind::Overall => {
962                let mut stats = inner.overall_stats.clone();
963                stats.span = std::time::SystemTime::now()
964                    .duration_since(stats.started)
965                    .unwrap_or_default();
966                stats
967            }
968            IoStatsKind::SincePrevious => {
969                let mut stats = inner.current_stats.clone();
970                stats.span = std::time::SystemTime::now()
971                    .duration_since(stats.started)
972                    .unwrap_or_default();
973                inner.current_stats = IoStats::empty();
974                stats
975            }
976        }
977    }
978
979    fn num_columns(&self) -> io::Result<u32> {
980        let this = self.clone();
981        self.conn_blocking(move |conn| {
982            Self::get_unique_value(conn, this.control_table(), "columns", 0u32)
983        })
984        .map_err(io::Error::other)
985    }
986
987    fn num_keys(&self, col: u32) -> Pin<Box<dyn Future<Output = io::Result<u64>> + Send + '_>> {
988        let this = self.clone();
989        Box::pin(async move {
990            this.conn(move |conn| {
991                conn.query_row(
992                    &format!("SELECT Count(*) FROM {}", get_column_table_name(col)),
993                    [],
994                    |row| -> rusqlite::Result<u64> { row.get(0) },
995                )
996            })
997            .await
998            .map_err(|_| io::Error::from(io::ErrorKind::NotFound))
999        })
1000    }
1001
1002    /// Check for the existence of a value by key.
1003    fn has_key<'a>(
1004        &'a self,
1005        col: u32,
1006        key: &'a [u8],
1007    ) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + 'a>> {
1008        let key_text = key_to_text(key);
1009        let key_len = key.len();
1010
1011        Box::pin(async move {
1012            let that = self.clone();
1013            that.validate_column(col).map_err(io::Error::other)?;
1014            let someval = self
1015                .conn_blocking(move |conn| Self::has_value(conn, that.column_table(col), &key_text))
1016                .map_err(io::Error::other)?;
1017
1018            self.stats_read(1, key_len);
1019
1020            Ok(someval)
1021        })
1022    }
1023
1024    /// Check for the existence of a value by prefix.
1025    fn has_prefix<'a>(
1026        &'a self,
1027        col: u32,
1028        prefix: &'a [u8],
1029    ) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + 'a>> {
1030        let prefix_len = prefix.len();
1031        let prefix_text = like_key_to_text(prefix) + "%";
1032
1033        Box::pin(async move {
1034            let that = self.clone();
1035            that.validate_column(col).map_err(io::Error::other)?;
1036            let someval = self
1037                .conn_blocking(move |conn| {
1038                    Self::has_value_like(conn, that.column_table(col), &prefix_text)
1039                })
1040                .map_err(io::Error::other)?;
1041
1042            self.stats_read(1, prefix_len);
1043
1044            Ok(someval)
1045        })
1046    }
1047
1048    /// Get the first value matching the given prefix.
1049    fn first_with_prefix<'a>(
1050        &'a self,
1051        col: u32,
1052        prefix: &'a [u8],
1053    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBKeyValue>>> + Send + 'a>> {
1054        let prefix_len = prefix.len();
1055        let like = like_key_to_text(prefix) + "%";
1056
1057        Box::pin(async move {
1058            let that = self.clone();
1059            that.validate_column(col).map_err(io::Error::other)?;
1060            let someval = self
1061                .conn_blocking(move |conn| {
1062                    Self::load_first_value_blob_like(conn, that.column_table(col), &like)
1063                })
1064                .map_err(io::Error::other)?;
1065
1066            self.stats_read(1, prefix_len);
1067
1068            match someval {
1069                Some((kt, val)) => match text_to_key(&kt) {
1070                    Err(e) => {
1071                        return Err(io::Error::other(format!(
1072                            "SQLite row get column 0 text convert error: {:?}",
1073                            e
1074                        )))
1075                    }
1076                    Ok(k) => Ok(Some((k, val))),
1077                },
1078                None => Ok(None),
1079            }
1080        })
1081    }
1082
1083    /// Vacuum database
1084    fn cleanup(&self) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + '_>> {
1085        Box::pin(async { self.vacuum().await.map_err(io::Error::other) })
1086    }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091
1092    use super::*;
1093    use keyvaluedb_shared_tests as st;
1094    use tempfile::Builder as TempfileBuilder;
1095
1096    fn create(columns: u32) -> io::Result<Database> {
1097        let tempfile = TempfileBuilder::new()
1098            .prefix("")
1099            .tempfile()?
1100            .path()
1101            .to_path_buf();
1102        let config = DatabaseConfig::new().with_columns(columns);
1103        Database::open(tempfile, config)
1104    }
1105
1106    fn create_vacuum_mode(columns: u32, vacuum_mode: VacuumMode) -> io::Result<Database> {
1107        let tempfile = TempfileBuilder::new()
1108            .prefix("")
1109            .tempfile()?
1110            .path()
1111            .to_path_buf();
1112        let config = DatabaseConfig::new()
1113            .with_columns(columns)
1114            .with_vacuum_mode(vacuum_mode);
1115        Database::open(tempfile, config)
1116    }
1117
1118    #[tokio::test]
1119    async fn get_fails_with_non_existing_column() -> io::Result<()> {
1120        let db = create(1)?;
1121        st::test_get_fails_with_non_existing_column(db).await
1122    }
1123
1124    #[tokio::test]
1125    async fn num_keys() -> io::Result<()> {
1126        let db = create(1)?;
1127        st::test_num_keys(db).await
1128    }
1129
1130    #[tokio::test]
1131    async fn put_and_get() -> io::Result<()> {
1132        let db = create(1)?;
1133        st::test_put_and_get(db).await
1134    }
1135
1136    #[tokio::test]
1137    async fn delete_and_get() -> io::Result<()> {
1138        let db = create(1)?;
1139        st::test_delete_and_get(db).await
1140    }
1141
1142    #[tokio::test]
1143    async fn delete_and_get_single() -> io::Result<()> {
1144        let db = create(1)?;
1145        st::test_delete_and_get_single(db).await
1146    }
1147
1148    #[tokio::test]
1149    async fn delete_prefix() -> io::Result<()> {
1150        let db = create(st::DELETE_PREFIX_NUM_COLUMNS)?;
1151        st::test_delete_prefix(db).await
1152    }
1153
1154    #[tokio::test]
1155    async fn iter() -> io::Result<()> {
1156        let db = create(1)?;
1157        st::test_iter(db).await
1158    }
1159
1160    #[tokio::test]
1161    async fn iter_keys() -> io::Result<()> {
1162        let db = create(1)?;
1163        st::test_iter_keys(db).await
1164    }
1165
1166    #[tokio::test]
1167    async fn iter_with_prefix() -> io::Result<()> {
1168        let db = create(1)?;
1169        st::test_iter_with_prefix(db).await
1170    }
1171
1172    #[tokio::test]
1173    async fn complex() -> io::Result<()> {
1174        let db = create(1)?;
1175        st::test_complex(db).await
1176    }
1177
1178    #[tokio::test]
1179    async fn cleanup() -> io::Result<()> {
1180        let db = create(1)?;
1181        st::test_cleanup(db).await?;
1182
1183        let db = create_vacuum_mode(1, VacuumMode::None)?;
1184        st::test_cleanup(db).await?;
1185
1186        let db = create_vacuum_mode(1, VacuumMode::Incremental)?;
1187        st::test_cleanup(db).await?;
1188
1189        let db = create_vacuum_mode(1, VacuumMode::Full)?;
1190        st::test_cleanup(db).await?;
1191
1192        let tempfile = TempfileBuilder::new()
1193            .prefix("")
1194            .tempfile()?
1195            .path()
1196            .to_path_buf();
1197        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::None);
1198        let db = Database::open(tempfile.clone(), config)?;
1199        st::test_cleanup(db).await?;
1200
1201        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::Incremental);
1202        let db = Database::open(tempfile.clone(), config)?;
1203        st::test_cleanup(db).await?;
1204
1205        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::Full);
1206        let db = Database::open(tempfile.clone(), config)?;
1207        st::test_cleanup(db).await?;
1208
1209        let config = DatabaseConfig::new().with_vacuum_mode(VacuumMode::None);
1210        let db = Database::open(tempfile, config)?;
1211        st::test_cleanup(db).await?;
1212
1213        Ok(())
1214    }
1215
1216    #[tokio::test]
1217    async fn stats() -> io::Result<()> {
1218        let db = create(st::IO_STATS_NUM_COLUMNS)?;
1219        st::test_io_stats(db).await
1220    }
1221
1222    #[tokio::test]
1223    #[should_panic]
1224    async fn db_config_with_zero_columns() {
1225        let _cfg = DatabaseConfig::new().with_columns(0);
1226    }
1227
1228    #[tokio::test]
1229    #[should_panic]
1230    async fn open_db_with_zero_columns() {
1231        let cfg = DatabaseConfig::new().with_columns(0);
1232        let _db = Database::open("", cfg);
1233    }
1234
1235    #[tokio::test]
1236    async fn add_columns() {
1237        let config_1 = DatabaseConfig::default();
1238        let config_5 = DatabaseConfig::new().with_columns(5);
1239
1240        let tempfile = TempfileBuilder::new()
1241            .prefix("")
1242            .tempfile()
1243            .unwrap()
1244            .path()
1245            .to_path_buf();
1246
1247        // open 1, add 4.
1248        {
1249            let db = Database::open(&tempfile, config_1).unwrap();
1250            assert_eq!(db.num_columns().unwrap(), 1);
1251
1252            for i in 2..=5 {
1253                db.add_column().unwrap();
1254                assert_eq!(db.num_columns().unwrap(), i);
1255            }
1256        }
1257
1258        // reopen as 5.
1259        {
1260            let db = Database::open(&tempfile, config_5).unwrap();
1261            assert_eq!(db.num_columns().unwrap(), 5);
1262        }
1263    }
1264
1265    #[tokio::test]
1266    async fn remove_columns() {
1267        let config_1 = DatabaseConfig::default();
1268        let config_5 = DatabaseConfig::new().with_columns(5);
1269
1270        let tempfile = TempfileBuilder::new()
1271            .prefix("drop_columns")
1272            .tempfile()
1273            .unwrap()
1274            .path()
1275            .to_path_buf();
1276
1277        // open 5, remove 4.
1278        {
1279            let db = Database::open(&tempfile, config_5).expect("open with 5 columns");
1280            assert_eq!(db.num_columns().unwrap(), 5);
1281
1282            for i in (1..5).rev() {
1283                db.remove_last_column().unwrap();
1284                assert_eq!(db.num_columns().unwrap(), i);
1285            }
1286        }
1287
1288        // reopen as 1.
1289        {
1290            let db = Database::open(&tempfile, config_1).unwrap();
1291            assert_eq!(db.num_columns().unwrap(), 1);
1292        }
1293    }
1294
1295    #[tokio::test]
1296    async fn test_num_keys() {
1297        let tempfile = TempfileBuilder::new()
1298            .prefix("")
1299            .tempfile()
1300            .unwrap()
1301            .path()
1302            .to_path_buf();
1303        let config = DatabaseConfig::new().with_columns(1);
1304        let db = Database::open(tempfile, config).unwrap();
1305
1306        assert_eq!(
1307            db.num_keys(0).await.unwrap(),
1308            0,
1309            "database is empty after creation"
1310        );
1311        let key1 = b"beef";
1312        let mut batch = db.transaction();
1313        batch.put(0, key1, key1);
1314        db.write(batch).await.unwrap();
1315        assert_eq!(
1316            db.num_keys(0).await.unwrap(),
1317            1,
1318            "adding a key increases the count"
1319        );
1320    }
1321}