keyvaluedb_sqlite/
lib.rs

1#![deny(clippy::all)]
2
3mod tools;
4
5use keyvaluedb::{
6    DBKeyRef, DBKeyValueRef, DBOp, DBTransaction, DBTransactionError, DBValue, IoStats,
7    IoStatsKind, KeyValueDB,
8};
9use parking_lot::Mutex;
10pub use rusqlite::OpenFlags;
11use rusqlite::*;
12use std::sync::Arc;
13use std::{
14    error,
15    future::Future,
16    io,
17    path::{Path, PathBuf},
18    pin::Pin,
19    str::FromStr,
20};
21use tools::*;
22
23fn other_io_err<E>(e: E) -> io::Error
24where
25    E: Into<Box<dyn error::Error + Send + Sync>>,
26{
27    io::Error::new(io::ErrorKind::Other, e)
28}
29
30#[allow(dead_code)]
31fn panic_err<E>(e: E) -> io::Error
32where
33    E: Into<Box<dyn error::Error + Send + Sync>>,
34{
35    panic!("should not have failed: {:?}", e.into());
36}
37
38///////////////////////////////////////////////////////////////////////////////
39
40/// Database configuration
41#[derive(Clone)]
42pub struct DatabaseConfig {
43    /// Set number of columns.
44    /// The number of columns must not be zero.
45    pub columns: u32,
46    /// Set flags used to open the database
47    pub flags: OpenFlags,
48}
49
50impl DatabaseConfig {
51    /// Create new `DatabaseConfig` with default parameters and specified set of columns.
52    /// The number of `columns` must not be zero.
53    pub fn with_columns(columns: u32) -> Self {
54        assert!(columns > 0, "the number of columns must not be zero");
55        Self {
56            columns,
57            ..Default::default()
58        }
59    }
60    /// Create new in-memory database `DatabaseConfig` with default parameters and specified set of columns.
61    /// The number of `columns` must not be zero.
62    pub fn with_columns_in_memory(columns: u32) -> Self {
63        assert!(columns > 0, "the number of columns must not be zero");
64        Self {
65            columns,
66            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
67                | OpenFlags::SQLITE_OPEN_CREATE
68                | OpenFlags::SQLITE_OPEN_NO_MUTEX
69                | OpenFlags::SQLITE_OPEN_MEMORY,
70            // ..Default::default()
71        }
72    }
73    /// Create new `DatabaseConfig` with default parameters and specified set of columns.
74    /// Replaces the flags as well.
75    /// The number of `columns` must not be zero.
76    pub fn with_columns_and_flags(columns: u32, flags: OpenFlags) -> Self {
77        assert!(columns > 0, "the number of columns must not be zero");
78        Self {
79            columns,
80            flags,
81            // ..Default::default()
82        }
83    }
84}
85
86impl Default for DatabaseConfig {
87    fn default() -> DatabaseConfig {
88        DatabaseConfig {
89            columns: 1,
90            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
91                | OpenFlags::SQLITE_OPEN_CREATE
92                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
93        }
94    }
95}
96
97///////////////////////////////////////////////////////////////////////////////
98
99/// An sqlite key-value database fulfilling the `KeyValueDB` trait
100pub struct DatabaseInner {
101    path: PathBuf,
102    config: DatabaseConfig,
103    connection: rusqlite::Connection,
104    overall_stats: IoStats,
105    current_stats: IoStats,
106}
107
108#[derive(Clone)]
109pub struct Database {
110    inner: Arc<Mutex<DatabaseInner>>,
111}
112
113impl Database {
114    ////////////////////////////////////////////////////////////////
115    // Initialization
116
117    pub fn open<P: AsRef<Path>>(path: P, config: DatabaseConfig) -> io::Result<Self> {
118        assert_ne!(config.columns, 0, "number of columns must be >= 1");
119
120        let flags = config.flags;
121
122        let out = Self {
123            inner: Arc::new(Mutex::new(DatabaseInner {
124                path: PathBuf::from(path.as_ref()),
125                config,
126                connection: rusqlite::Connection::open_with_flags(path, flags)
127                    .map_err(other_io_err)?,
128                overall_stats: IoStats::empty(),
129                current_stats: IoStats::empty(),
130            })),
131        };
132
133        {
134            let lock = out.inner.lock();
135
136            // Don't rely on STATEMENT_CACHE_DEFAULT_CAPACITY in rusqlite, set it explicitly
137            lock.connection.set_prepared_statement_cache_capacity(16);
138
139            lock.connection
140                .execute("PRAGMA case_sensitive_like=ON;", [])
141                .map_err(other_io_err)?;
142        }
143
144        out.open_resize_columns()?;
145
146        Ok(out)
147    }
148    pub fn open_with_columns<P: AsRef<Path>>(path: P, columns: u32) -> io::Result<Self> {
149        assert_ne!(columns, 0, "number of columns must be >= 1");
150        Self::open(path, DatabaseConfig::with_columns(columns))
151    }
152    pub fn open_in_memory_with_columns<P: AsRef<Path>>(path: P, columns: u32) -> io::Result<Self> {
153        assert_ne!(columns, 0, "number of columns must be >= 1");
154        Self::open(path, DatabaseConfig::with_columns_in_memory(columns))
155    }
156
157    pub fn path(&self) -> PathBuf {
158        self.inner.lock().path.clone()
159    }
160
161    pub fn config(&self) -> DatabaseConfig {
162        self.inner.lock().config.clone()
163    }
164
165    ////////////////////////////////////////////////////////////////
166    // Low level operations
167
168    /// Remove the last column family in the database. The deletion is definitive.
169    pub fn remove_last_column(&self) -> io::Result<()> {
170        self.with_inner(|inner| {
171            let columns = Self::get_unique_value(&inner.connection, "control", "columns", 0u32)?;
172            if columns == 0 {
173                return Err(other_io_err("no columns exist"));
174            }
175            Self::set_unique_value(&inner.connection, "control", "columns", columns - 1)?;
176
177            inner
178                .connection
179                .execute(
180                    &format!("DROP TABLE {}", get_column_table_name(columns - 1)),
181                    [],
182                )
183                .map_err(other_io_err)?;
184            Ok(())
185        })
186    }
187
188    /// Add a new column family to the DB.
189    pub fn add_column(&self) -> io::Result<()> {
190        self.with_inner(|inner| {
191            let columns = Self::get_unique_value(&inner.connection, "control", "columns", 0u32)?;
192            Self::set_unique_value(&inner.connection, "control", "columns", columns + 1)?;
193            Self::create_column_table(&inner.connection, columns)
194        })
195    }
196    /// Helper to create new transaction for this database.
197    pub fn transaction(&self) -> DBTransaction {
198        DBTransaction::new()
199    }
200
201    ////////////////////////////////////////////////////////////////
202    // Internal helpers
203
204    fn with_inner<R, F: FnOnce(&DatabaseInner) -> R>(&self, f: F) -> R {
205        let inner = self.inner.lock();
206        f(&inner)
207    }
208
209    fn with_inner_mut<R, F: FnOnce(&mut DatabaseInner) -> R>(&self, f: F) -> R {
210        let mut inner = self.inner.lock();
211        f(&mut inner)
212    }
213
214    fn validate_column(col: u32, max: u32) -> io::Result<()> {
215        if col >= max {
216            return Err(io::Error::from(io::ErrorKind::NotFound));
217        }
218        Ok(())
219    }
220
221    fn create_column_table(conn: &rusqlite::Connection, column: u32) -> io::Result<()> {
222        conn.execute(&format!("CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value BLOB)", get_column_table_name(column)), []).map_err(other_io_err).map(drop)
223    }
224
225    fn get_unique_value<T, K, V>(
226        conn: &rusqlite::Connection,
227        table: T,
228        key: K,
229        default: V,
230    ) -> io::Result<V>
231    where
232        T: AsRef<str>,
233        K: AsRef<str>,
234        V: FromStr,
235    {
236        let sql = format!("SELECT value FROM {} WHERE [key] = ?", table.as_ref());
237        let mut stmt = conn.prepare_cached(&sql).map_err(other_io_err)?;
238
239        if let Ok(found) = stmt.query_row([key.as_ref()], |row| -> Result<String> { row.get(0) }) {
240            if let Ok(v) = V::from_str(&found) {
241                return Ok(v);
242            }
243        }
244        Ok(default)
245    }
246    /*
247    fn get_or_set_default_unique_value<T, K, V>(
248        conn: &rusqlite::Connection,
249        table: T,
250        key: K,
251        default: V,
252    ) -> io::Result<V>
253    where
254        T: AsRef<str>,
255        K: AsRef<str>,
256        V: FromStr + ToString,
257    {
258        if let Ok(found) = conn.query_row(
259            &format!("SELECT value FROM {} WHERE [key] = ?", table.as_ref()),
260            [key.as_ref()],
261            |row| -> Result<String> { row.get(0) },
262        ) {
263            if let Ok(v) = V::from_str(&found) {
264                return Ok(v);
265            }
266        }
267        let out = default.to_string();
268        Self::set_unique_value(conn, table, key, out)?;
269        return Ok(default);
270    }
271    */
272    fn set_unique_value<T, K, V>(
273        conn: &rusqlite::Connection,
274        table: T,
275        key: K,
276        value: V,
277    ) -> io::Result<()>
278    where
279        T: AsRef<str>,
280        K: AsRef<str>,
281        V: ToString,
282    {
283        let sql = format!(
284            "INSERT OR REPLACE INTO {} ([key], value) VALUES(?, ?)",
285            table.as_ref()
286        );
287        let mut stmt = conn.prepare_cached(&sql).map_err(other_io_err)?;
288
289        let changed = stmt
290            .execute([key.as_ref(), value.to_string().as_str()])
291            .map_err(other_io_err)?;
292        assert!(
293            changed <= 1,
294            "multiple changes to unique key should not occur"
295        );
296        if changed == 0 {
297            return Err(other_io_err("failed to set unique value"));
298        }
299        Ok(())
300    }
301    fn load_unique_value_blob<T, K>(
302        conn: &rusqlite::Connection,
303        table: T,
304        key: K,
305    ) -> io::Result<Option<Vec<u8>>>
306    where
307        T: AsRef<str>,
308        K: AsRef<str>,
309    {
310        let sql = format!("SELECT value FROM {} WHERE [key] = ?", table.as_ref());
311        let mut stmt = conn.prepare_cached(&sql).map_err(other_io_err)?;
312
313        stmt.query_row([key.as_ref()], |row| -> Result<Vec<u8>> { row.get(0) })
314            .optional()
315            .map_err(other_io_err)
316    }
317    // fn load_unique_value_blob_like<T, K>(
318    // 	conn: &rusqlite::Connection,
319    // 	table: T,
320    // 	like: K,
321    // ) -> io::Result<Option<Vec<u8>>>
322    // where
323    // 	T: AsRef<str>,
324    // 	K: AsRef<str>,
325    // {
326    // 	conn.query_row(
327    // 		&format!(
328    // 			"SELECT value FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
329    // 			table.as_ref()
330    // 		),
331    // 		[like.as_ref()],
332    // 		|row| -> Result<Vec<u8>> { row.get(0) },
333    // 	)
334    // 	.optional()
335    // 	.map_err(other_io_err)
336    // }
337    fn store_unique_value_blob<T, K>(
338        conn: &rusqlite::Connection,
339        table: T,
340        key: K,
341        value: &[u8],
342    ) -> io::Result<()>
343    where
344        T: AsRef<str>,
345        K: AsRef<str>,
346    {
347        let sql = format!(
348            "INSERT OR REPLACE INTO {} ([key], value) VALUES(?, ?)",
349            table.as_ref()
350        );
351        let mut stmt = conn.prepare_cached(&sql).map_err(other_io_err)?;
352
353        let changed = stmt
354            .execute(params![key.as_ref(), value])
355            .map_err(other_io_err)?;
356        assert!(
357            changed <= 1,
358            "multiple changes to unique key should not occur"
359        );
360        if changed == 0 {
361            return Err(other_io_err("failed to set unique value"));
362        }
363        Ok(())
364    }
365    fn remove_unique_value_blob<T, K>(
366        conn: &rusqlite::Connection,
367        table: T,
368        key: K,
369    ) -> io::Result<()>
370    where
371        T: AsRef<str>,
372        K: AsRef<str>,
373    {
374        let sql = format!("DELETE FROM {} WHERE [key] = ?", table.as_ref());
375        let mut stmt = conn.prepare_cached(&sql).map_err(other_io_err)?;
376
377        let changed = stmt.execute([key.as_ref()]).map_err(other_io_err)?;
378        assert!(
379            changed <= 1,
380            "multiple deletions of unique key should not occur"
381        );
382        if changed == 0 {
383            return Err(other_io_err("failed to remove unique value"));
384        }
385        Ok(())
386    }
387    fn remove_unique_value_blob_like<T, K>(
388        conn: &rusqlite::Connection,
389        table: T,
390        key: K,
391    ) -> io::Result<usize>
392    where
393        T: AsRef<str>,
394        K: AsRef<str>,
395    {
396        let sql = format!(
397            "DELETE FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
398            table.as_ref()
399        );
400        let mut stmt = conn.prepare_cached(&sql).map_err(other_io_err)?;
401
402        let changed = stmt.execute([key.as_ref()]).map_err(other_io_err)?;
403        Ok(changed)
404    }
405    fn open_resize_columns(&self) -> io::Result<()> {
406        self.with_inner(|inner| {
407			// First see if we have a control table with the number of columns
408			inner.connection.execute("CREATE TABLE IF NOT EXISTS control (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value TEXT)", []).map_err(other_io_err)?;
409			// Get column count
410			let on_disk_columns =
411				Self::get_unique_value(&inner.connection, "control", "columns", 0u32)?;
412
413			// If desired column count is less than or equal to current column count, then allow it, but restrict access to columns
414			if inner.config.columns <= on_disk_columns {
415				return Ok(());
416			}
417
418			// Otherwise resize and add other columns
419			for cn in on_disk_columns..inner.config.columns {
420				// Create the column table if we don't have it
421				Self::create_column_table(&inner.connection, cn)?;
422			}
423			Self::set_unique_value(
424				&inner.connection,
425				"control",
426				"columns",
427				inner.config.columns,
428			)?;
429
430			Ok(())
431		})
432    }
433
434    fn stats_read(inner: &mut DatabaseInner, count: usize, bytes: usize) {
435        inner.current_stats.reads += count as u64;
436        inner.overall_stats.reads += count as u64;
437        inner.current_stats.bytes_read += bytes as u64;
438        inner.overall_stats.bytes_read += bytes as u64;
439    }
440    fn stats_write(inner: &mut DatabaseInner, count: usize, bytes: usize) {
441        inner.current_stats.writes += count as u64;
442        inner.overall_stats.writes += count as u64;
443        inner.current_stats.bytes_written += bytes as u64;
444        inner.overall_stats.bytes_written += bytes as u64;
445    }
446    fn stats_transaction(inner: &mut DatabaseInner, count: usize) {
447        inner.current_stats.transactions += count as u64;
448        inner.overall_stats.transactions += count as u64;
449    }
450}
451
452impl KeyValueDB for Database {
453    fn get<'a>(
454        &self,
455        col: u32,
456        key: &'a [u8],
457    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + 'a>> {
458        let this = self.clone();
459        Box::pin(async move {
460            this.with_inner_mut(|inner| {
461                Self::validate_column(col, inner.config.columns)?;
462                match Self::load_unique_value_blob(
463                    &inner.connection,
464                    get_column_table_name(col),
465                    key_to_text(key),
466                ) {
467                    Ok(someval) => {
468                        match &someval {
469                            Some(val) => Self::stats_read(inner, 1, key.len() + val.len()),
470                            None => Self::stats_read(inner, 1, key.len()),
471                        };
472                        Ok(someval)
473                    }
474                    Err(e) => Err(e),
475                }
476            })
477        })
478    }
479
480    /// Remove a value by key, returning the old value
481    fn delete<'a>(
482        &self,
483        col: u32,
484        key: &'a [u8],
485    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + 'a>> {
486        let this = self.clone();
487        Box::pin(async move {
488            this.with_inner_mut(|inner| {
489                Self::validate_column(col, inner.config.columns)?;
490                let someval = Self::load_unique_value_blob(
491                    &inner.connection,
492                    get_column_table_name(col),
493                    key_to_text(key),
494                )?;
495
496                match &someval {
497                    Some(val) => {
498                        Self::stats_read(inner, 1, key.len() + val.len());
499
500                        Self::remove_unique_value_blob(
501                            &inner.connection,
502                            get_column_table_name(col),
503                            key_to_text(key),
504                        )?;
505
506                        Self::stats_write(inner, 1, 0);
507                    }
508                    None => Self::stats_read(inner, 1, key.len()),
509                };
510                Ok(someval)
511            })
512        })
513    }
514
515    fn write(
516        &self,
517        transaction: DBTransaction,
518    ) -> Pin<Box<dyn Future<Output = Result<(), DBTransactionError>> + Send + 'static>> {
519        let this = self.clone();
520        Box::pin(async move {
521            this.with_inner_mut(|inner| {
522                Self::stats_transaction(inner, 1);
523
524                let mut sw = 0usize;
525                let mut sbw = 0usize;
526                let out = {
527                    let tx = inner.connection.transaction().map_err(other_io_err)?;
528                    for op in &transaction.ops {
529                        match op {
530                            DBOp::Insert { col, key, value } => {
531                                Self::validate_column(*col, inner.config.columns)?;
532                                Self::store_unique_value_blob(
533                                    &tx,
534                                    get_column_table_name(*col),
535                                    key_to_text(key),
536                                    value,
537                                )?;
538                                sw += 1;
539                                sbw += key.len() + value.len();
540                            }
541                            DBOp::Delete { col, key } => {
542                                Self::validate_column(*col, inner.config.columns)?;
543                                Self::remove_unique_value_blob(
544                                    &tx,
545                                    get_column_table_name(*col),
546                                    key_to_text(key),
547                                )?;
548                                sw += 1;
549                            }
550                            DBOp::DeletePrefix { col, prefix } => {
551                                Self::validate_column(*col, inner.config.columns)?;
552                                Self::remove_unique_value_blob_like(
553                                    &tx,
554                                    get_column_table_name(*col),
555                                    like_key_to_text(prefix) + "%",
556                                )?;
557                                sw += 1;
558                            }
559                        }
560                    }
561                    tx.commit().map_err(other_io_err)
562                };
563                Self::stats_write(inner, sw, sbw);
564                out
565            })
566            .map_err(|error| DBTransactionError { error, transaction })
567        })
568    }
569
570    fn iter<'a, T: 'a, F: FnMut(DBKeyValueRef) -> io::Result<Option<T>> + Send + Sync + 'a>(
571        &self,
572        col: u32,
573        prefix: Option<&'a [u8]>,
574        mut f: F,
575    ) -> Pin<Box<dyn Future<Output = io::Result<Option<T>>> + Send + 'a>> {
576        let this = self.clone();
577        Box::pin(async move {
578            this.with_inner_mut(|inner| {
579                if col >= inner.config.columns {
580                    return Err(io::Error::from(io::ErrorKind::NotFound));
581                }
582
583                let mut sw = 0usize;
584                let mut sbw = 0usize;
585                let mut out = Ok(None);
586                {
587                    let mut stmt;
588                    let mut rows;
589                    if let Some(p) = prefix {
590                        let sql = format!(
591                            "SELECT key, value FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
592                            get_column_table_name(col)
593                        );
594                        stmt = match inner.connection.prepare_cached(&sql) {
595                            Err(e) => {
596                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
597                            }
598                            Ok(v) => v,
599                        };
600                        rows = match stmt.query([like_key_to_text(p) + "%"]) {
601                            Err(e) => {
602                                return Err(other_io_err(format!("SQLite rows error: {:?}", e)));
603                            }
604                            Ok(v) => v,
605                        };
606                    } else {
607                        let sql = format!("SELECT key, value FROM {}", get_column_table_name(col));
608                        stmt = match inner.connection.prepare_cached(&sql) {
609                            Err(e) => {
610                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
611                            }
612                            Ok(v) => v,
613                        };
614                        rows = match stmt.query([]) {
615                            Err(e) => {
616                                return Err(other_io_err(format!("SQLite query error: {:?}", e)));
617                            }
618                            Ok(v) => v,
619                        };
620                    };
621
622                    loop {
623                        match rows.next() {
624                            // Iterated value
625                            Ok(Some(row)) => {
626                                let kt: String = match row.get(0) {
627                                    Err(e) => {
628                                        out = Err(other_io_err(format!(
629                                            "SQLite row get column 0 error: {:?}",
630                                            e
631                                        )));
632                                        break;
633                                    }
634                                    Ok(v) => v,
635                                };
636                                let v: Vec<u8> = match row.get(1) {
637                                    Err(e) => {
638                                        out = Err(other_io_err(format!(
639                                            "SQLite row get column 1 error: {:?}",
640                                            e
641                                        )));
642                                        break;
643                                    }
644                                    Ok(v) => v,
645                                };
646                                let k: Vec<u8> = match text_to_key(&kt) {
647                                    Err(e) => {
648                                        out = Err(other_io_err(format!(
649                                            "SQLite row get column 0 text convert error: {:?}",
650                                            e
651                                        )));
652                                        break;
653                                    }
654                                    Ok(v) => v,
655                                };
656
657                                sw += 1;
658                                sbw += k.len() + v.len();
659
660                                match f((&k, &v)) {
661                                    Ok(None) => (),
662                                    Ok(Some(v)) => {
663                                        // Callback early termination
664                                        out = Ok(Some(v));
665                                        break;
666                                    }
667                                    Err(e) => {
668                                        // Callback error termination
669                                        out = Err(e);
670                                        break;
671                                    }
672                                }
673                            }
674                            // Natural iterator termination
675                            Ok(None) => {
676                                break;
677                            }
678                            // Error iterator termination
679                            Err(e) => {
680                                out = Err(other_io_err(format!("SQLite rows error: {:?}", e)));
681                                break;
682                            }
683                        }
684                    }
685                }
686                Self::stats_read(inner, sw, sbw);
687                out
688            })
689        })
690    }
691
692    fn iter_keys<'a, T: 'a, F: FnMut(DBKeyRef) -> io::Result<Option<T>> + Send + Sync + 'a>(
693        &self,
694        col: u32,
695        prefix: Option<&'a [u8]>,
696        mut f: F,
697    ) -> Pin<Box<dyn Future<Output = io::Result<Option<T>>> + Send + 'a>> {
698        let this = self.clone();
699        Box::pin(async move {
700            this.with_inner_mut(|inner| {
701                if col >= inner.config.columns {
702                    return Err(io::Error::from(io::ErrorKind::NotFound));
703                }
704
705                let mut sw = 0usize;
706                let mut sbw = 0usize;
707                let mut out = Ok(None);
708                {
709                    let mut stmt;
710                    let mut rows;
711                    if let Some(p) = prefix {
712                        let sql = format!(
713                            "SELECT key FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
714                            get_column_table_name(col)
715                        );
716                        stmt = match inner.connection.prepare_cached(&sql) {
717                            Err(e) => {
718                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
719                            }
720                            Ok(v) => v,
721                        };
722                        rows = match stmt.query([like_key_to_text(p) + "%"]) {
723                            Err(e) => {
724                                return Err(other_io_err(format!("SQLite rows error: {:?}", e)));
725                            }
726                            Ok(v) => v,
727                        };
728                    } else {
729                        let sql = format!("SELECT key FROM {}", get_column_table_name(col));
730                        stmt = match inner.connection.prepare_cached(&sql) {
731                            Err(e) => {
732                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
733                            }
734                            Ok(v) => v,
735                        };
736                        rows = match stmt.query([]) {
737                            Err(e) => {
738                                return Err(other_io_err(format!("SQLite query error: {:?}", e)));
739                            }
740                            Ok(v) => v,
741                        };
742                    };
743
744                    loop {
745                        match rows.next() {
746                            // Iterated value
747                            Ok(Some(row)) => {
748                                let kt: String = match row.get(0) {
749                                    Err(e) => {
750                                        out = Err(other_io_err(format!(
751                                            "SQLite row get column 0 error: {:?}",
752                                            e
753                                        )));
754                                        break;
755                                    }
756                                    Ok(v) => v,
757                                };
758                                let k: Vec<u8> = match text_to_key(&kt) {
759                                    Err(e) => {
760                                        out = Err(other_io_err(format!(
761                                            "SQLite row get column 0 text convert error: {:?}",
762                                            e
763                                        )));
764                                        break;
765                                    }
766                                    Ok(v) => v,
767                                };
768
769                                sw += 1;
770                                sbw += k.len();
771
772                                match f(&k) {
773                                    Ok(None) => (),
774                                    Ok(Some(v)) => {
775                                        // Callback early termination
776                                        out = Ok(Some(v));
777                                        break;
778                                    }
779                                    Err(e) => {
780                                        // Callback error termination
781                                        out = Err(e);
782                                        break;
783                                    }
784                                }
785                            }
786                            // Natural iterator termination
787                            Ok(None) => {
788                                break;
789                            }
790                            // Error iterator termination
791                            Err(e) => {
792                                out = Err(other_io_err(format!("SQLite rows error: {:?}", e)));
793                                break;
794                            }
795                        }
796                    }
797                }
798                Self::stats_read(inner, sw, sbw);
799                out
800            })
801        })
802    }
803
804    fn io_stats(&self, kind: IoStatsKind) -> IoStats {
805        self.with_inner_mut(|inner| match kind {
806            IoStatsKind::Overall => {
807                let mut stats = inner.overall_stats.clone();
808                stats.span = std::time::SystemTime::now()
809                    .duration_since(stats.started)
810                    .unwrap_or_default();
811                stats
812            }
813            IoStatsKind::SincePrevious => {
814                let mut stats = inner.current_stats.clone();
815                stats.span = std::time::SystemTime::now()
816                    .duration_since(stats.started)
817                    .unwrap_or_default();
818                inner.current_stats = IoStats::empty();
819                stats
820            }
821        })
822    }
823
824    fn num_columns(&self) -> io::Result<u32> {
825        self.with_inner(|inner| {
826            Self::get_unique_value(&inner.connection, "control", "columns", 0u32)
827        })
828    }
829
830    fn num_keys(&self, col: u32) -> Pin<Box<dyn Future<Output = io::Result<u64>> + Send>> {
831        let this = self.clone();
832        Box::pin(async move {
833            this.with_inner(|inner| {
834                if let Ok(found) = inner.connection.query_row(
835                    &format!("SELECT Count(*) FROM {}", get_column_table_name(col)),
836                    [],
837                    |row| -> Result<u64> { row.get(0) },
838                ) {
839                    return Ok(found);
840                }
841                Err(io::Error::from(io::ErrorKind::NotFound))
842            })
843        })
844    }
845
846    fn restore(&self, _new_db: &str) -> io::Result<()> {
847        Err(other_io_err("Attempted to restore sqlite database"))
848    }
849}
850
851#[cfg(test)]
852mod tests {
853
854    use super::*;
855    use keyvaluedb_shared_tests as st;
856    use tempfile::Builder as TempfileBuilder;
857
858    fn create(columns: u32) -> io::Result<Database> {
859        let tempfile = TempfileBuilder::new()
860            .prefix("")
861            .tempfile()?
862            .path()
863            .to_path_buf();
864        let config = DatabaseConfig::with_columns(columns);
865        Database::open(tempfile, config)
866    }
867
868    #[tokio::test]
869    async fn get_fails_with_non_existing_column() -> io::Result<()> {
870        let db = create(1)?;
871        st::test_get_fails_with_non_existing_column(db).await
872    }
873
874    #[tokio::test]
875    async fn num_keys() -> io::Result<()> {
876        let db = create(1)?;
877        st::test_num_keys(db).await
878    }
879
880    #[tokio::test]
881    async fn put_and_get() -> io::Result<()> {
882        let db = create(1)?;
883        st::test_put_and_get(db).await
884    }
885
886    #[tokio::test]
887    async fn delete_and_get() -> io::Result<()> {
888        let db = create(1)?;
889        st::test_delete_and_get(db).await
890    }
891
892    #[tokio::test]
893    async fn delete_and_get_single() -> io::Result<()> {
894        let db = create(1)?;
895        st::test_delete_and_get_single(db).await
896    }
897
898    #[tokio::test]
899    async fn delete_prefix() -> io::Result<()> {
900        let db = create(st::DELETE_PREFIX_NUM_COLUMNS)?;
901        st::test_delete_prefix(db).await
902    }
903
904    #[tokio::test]
905    async fn iter() -> io::Result<()> {
906        let db = create(1)?;
907        st::test_iter(db).await
908    }
909
910    #[tokio::test]
911    async fn iter_keys() -> io::Result<()> {
912        let db = create(1)?;
913        st::test_iter_keys(db).await
914    }
915
916    #[tokio::test]
917    async fn iter_with_prefix() -> io::Result<()> {
918        let db = create(1)?;
919        st::test_iter_with_prefix(db).await
920    }
921
922    #[tokio::test]
923    async fn complex() -> io::Result<()> {
924        let db = create(1)?;
925        st::test_complex(db).await
926    }
927
928    #[tokio::test]
929    async fn stats() -> io::Result<()> {
930        let db = create(st::IO_STATS_NUM_COLUMNS)?;
931        st::test_io_stats(db).await
932    }
933
934    #[tokio::test]
935    #[should_panic]
936    async fn db_config_with_zero_columns() {
937        let _cfg = DatabaseConfig::with_columns(0);
938    }
939
940    #[tokio::test]
941    #[should_panic]
942    async fn open_db_with_zero_columns() {
943        let cfg = DatabaseConfig {
944            columns: 0,
945            ..Default::default()
946        };
947        let _db = Database::open("", cfg);
948    }
949
950    #[tokio::test]
951    async fn add_columns() {
952        let config_1 = DatabaseConfig::default();
953        let config_5 = DatabaseConfig::with_columns(5);
954
955        let tempfile = TempfileBuilder::new()
956            .prefix("")
957            .tempfile()
958            .unwrap()
959            .path()
960            .to_path_buf();
961
962        // open 1, add 4.
963        {
964            let db = Database::open(&tempfile, config_1).unwrap();
965            assert_eq!(db.num_columns().unwrap(), 1);
966
967            for i in 2..=5 {
968                db.add_column().unwrap();
969                assert_eq!(db.num_columns().unwrap(), i);
970            }
971        }
972
973        // reopen as 5.
974        {
975            let db = Database::open(&tempfile, config_5).unwrap();
976            assert_eq!(db.num_columns().unwrap(), 5);
977        }
978    }
979
980    #[tokio::test]
981    async fn remove_columns() {
982        let config_1 = DatabaseConfig::default();
983        let config_5 = DatabaseConfig::with_columns(5);
984
985        let tempfile = TempfileBuilder::new()
986            .prefix("drop_columns")
987            .tempfile()
988            .unwrap()
989            .path()
990            .to_path_buf();
991
992        // open 5, remove 4.
993        {
994            let db = Database::open(&tempfile, config_5).expect("open with 5 columns");
995            assert_eq!(db.num_columns().unwrap(), 5);
996
997            for i in (1..5).rev() {
998                db.remove_last_column().unwrap();
999                assert_eq!(db.num_columns().unwrap(), i);
1000            }
1001        }
1002
1003        // reopen as 1.
1004        {
1005            let db = Database::open(&tempfile, config_1).unwrap();
1006            assert_eq!(db.num_columns().unwrap(), 1);
1007        }
1008    }
1009
1010    #[tokio::test]
1011    async fn test_num_keys() {
1012        let tempfile = TempfileBuilder::new()
1013            .prefix("")
1014            .tempfile()
1015            .unwrap()
1016            .path()
1017            .to_path_buf();
1018        let config = DatabaseConfig::with_columns(1);
1019        let db = Database::open(tempfile, config).unwrap();
1020
1021        assert_eq!(
1022            db.num_keys(0).await.unwrap(),
1023            0,
1024            "database is empty after creation"
1025        );
1026        let key1 = b"beef";
1027        let mut batch = db.transaction();
1028        batch.put(0, key1, key1);
1029        db.write(batch).await.unwrap();
1030        assert_eq!(
1031            db.num_keys(0).await.unwrap(),
1032            1,
1033            "adding a key increases the count"
1034        );
1035    }
1036}