keyvaluedb_sqlite/
lib.rs

1#![deny(clippy::all)]
2
3mod tools;
4
5use keyvaluedb::{
6    DBKeyRef, DBKeyValueRef, DBOp, DBTransaction, DBTransactionError, DBValue, IoStats,
7    IoStatsKind, KeyValueDB,
8};
9use parking_lot::Mutex;
10pub use rusqlite::OpenFlags;
11use rusqlite::*;
12use std::sync::Arc;
13use std::{
14    error,
15    future::Future,
16    io,
17    path::{Path, PathBuf},
18    pin::Pin,
19    str::FromStr,
20};
21use tools::*;
22
23fn other_io_err<E>(e: E) -> io::Error
24where
25    E: Into<Box<dyn error::Error + Send + Sync>>,
26{
27    io::Error::new(io::ErrorKind::Other, e)
28}
29
30#[allow(dead_code)]
31fn panic_err<E>(e: E) -> io::Error
32where
33    E: Into<Box<dyn error::Error + Send + Sync>>,
34{
35    panic!("should not have failed: {:?}", e.into());
36}
37
38///////////////////////////////////////////////////////////////////////////////
39
40/// Database configuration
41#[derive(Clone)]
42pub struct DatabaseConfig {
43    /// Set number of columns.
44    /// The number of columns must not be zero.
45    pub columns: u32,
46    /// Set flags used to open the database
47    pub flags: OpenFlags,
48}
49
50impl DatabaseConfig {
51    /// Create new `DatabaseConfig` with default parameters and specified set of columns.
52    /// The number of `columns` must not be zero.
53    pub fn with_columns(columns: u32) -> Self {
54        assert!(columns > 0, "the number of columns must not be zero");
55        Self {
56            columns,
57            ..Default::default()
58        }
59    }
60    /// Create new in-memory database `DatabaseConfig` with default parameters and specified set of columns.
61    /// The number of `columns` must not be zero.
62    pub fn with_columns_in_memory(columns: u32) -> Self {
63        assert!(columns > 0, "the number of columns must not be zero");
64        Self {
65            columns,
66            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
67                | OpenFlags::SQLITE_OPEN_CREATE
68                | OpenFlags::SQLITE_OPEN_NO_MUTEX
69                | OpenFlags::SQLITE_OPEN_MEMORY,
70            // ..Default::default()
71        }
72    }
73    /// Create new `DatabaseConfig` with default parameters and specified set of columns.
74    /// Replaces the flags as well.
75    /// The number of `columns` must not be zero.
76    pub fn with_columns_and_flags(columns: u32, flags: OpenFlags) -> Self {
77        assert!(columns > 0, "the number of columns must not be zero");
78        Self {
79            columns,
80            flags,
81            // ..Default::default()
82        }
83    }
84}
85
86impl Default for DatabaseConfig {
87    fn default() -> DatabaseConfig {
88        DatabaseConfig {
89            columns: 1,
90            flags: OpenFlags::SQLITE_OPEN_READ_WRITE
91                | OpenFlags::SQLITE_OPEN_CREATE
92                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
93        }
94    }
95}
96
97///////////////////////////////////////////////////////////////////////////////
98
99/// An sqlite key-value database fulfilling the `KeyValueDB` trait
100
101pub struct DatabaseInner {
102    path: PathBuf,
103    config: DatabaseConfig,
104    connection: rusqlite::Connection,
105    overall_stats: IoStats,
106    current_stats: IoStats,
107}
108
109#[derive(Clone)]
110pub struct Database {
111    inner: Arc<Mutex<DatabaseInner>>,
112}
113
114impl Database {
115    ////////////////////////////////////////////////////////////////
116    /// Initialization
117    pub fn open<P: AsRef<Path>>(path: P, config: DatabaseConfig) -> io::Result<Self> {
118        assert_ne!(config.columns, 0, "number of columns must be >= 1");
119
120        let flags = config.flags;
121
122        let out = Self {
123            inner: Arc::new(Mutex::new(DatabaseInner {
124                path: PathBuf::from(path.as_ref()),
125                config,
126                connection: rusqlite::Connection::open_with_flags(path, flags)
127                    .map_err(other_io_err)?,
128                overall_stats: IoStats::empty(),
129                current_stats: IoStats::empty(),
130            })),
131        };
132
133        out.inner
134            .lock()
135            .connection
136            .execute("PRAGMA case_sensitive_like=ON;", [])
137            .map_err(other_io_err)?;
138
139        out.open_resize_columns()?;
140
141        Ok(out)
142    }
143    pub fn open_with_columns<P: AsRef<Path>>(path: P, columns: u32) -> io::Result<Self> {
144        assert_ne!(columns, 0, "number of columns must be >= 1");
145        Self::open(path, DatabaseConfig::with_columns(columns))
146    }
147    pub fn open_in_memory_with_columns<P: AsRef<Path>>(path: P, columns: u32) -> io::Result<Self> {
148        assert_ne!(columns, 0, "number of columns must be >= 1");
149        Self::open(path, DatabaseConfig::with_columns_in_memory(columns))
150    }
151
152    pub fn path(&self) -> PathBuf {
153        self.inner.lock().path.clone()
154    }
155
156    pub fn config(&self) -> DatabaseConfig {
157        self.inner.lock().config.clone()
158    }
159
160    ////////////////////////////////////////////////////////////////
161    /// Low level operations
162
163    /// Remove the last column family in the database. The deletion is definitive.
164    pub fn remove_last_column(&self) -> io::Result<()> {
165        self.with_inner(|inner| {
166            let columns = Self::get_unique_value(&inner.connection, "control", "columns", 0u32)?;
167            if columns == 0 {
168                return Err(other_io_err("no columns exist"));
169            }
170            Self::set_unique_value(&inner.connection, "control", "columns", columns - 1)?;
171
172            inner
173                .connection
174                .execute(
175                    &format!("DROP TABLE {}", get_column_table_name(columns - 1)),
176                    [],
177                )
178                .map_err(other_io_err)?;
179            Ok(())
180        })
181    }
182
183    /// Add a new column family to the DB.
184    pub fn add_column(&self) -> io::Result<()> {
185        self.with_inner(|inner| {
186            let columns = Self::get_unique_value(&inner.connection, "control", "columns", 0u32)?;
187            Self::set_unique_value(&inner.connection, "control", "columns", columns + 1)?;
188            Self::create_column_table(&inner.connection, columns)
189        })
190    }
191    /// Helper to create new transaction for this database.
192    pub fn transaction(&self) -> DBTransaction {
193        DBTransaction::new()
194    }
195
196    ////////////////////////////////////////////////////////////////
197    /// Internal helpers
198
199    fn with_inner<R, F: FnOnce(&DatabaseInner) -> R>(&self, f: F) -> R {
200        let inner = self.inner.lock();
201        f(&inner)
202    }
203
204    fn with_inner_mut<R, F: FnOnce(&mut DatabaseInner) -> R>(&self, f: F) -> R {
205        let mut inner = self.inner.lock();
206        f(&mut inner)
207    }
208
209    fn validate_column(col: u32, max: u32) -> io::Result<()> {
210        if col >= max {
211            return Err(io::Error::from(io::ErrorKind::NotFound));
212        }
213        Ok(())
214    }
215
216    fn create_column_table(conn: &rusqlite::Connection, column: u32) -> io::Result<()> {
217        conn.execute(&format!("CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value BLOB)", get_column_table_name(column)), []).map_err(other_io_err).map(drop)
218    }
219
220    fn get_unique_value<T, K, V>(
221        conn: &rusqlite::Connection,
222        table: T,
223        key: K,
224        default: V,
225    ) -> io::Result<V>
226    where
227        T: AsRef<str>,
228        K: AsRef<str>,
229        V: FromStr,
230    {
231        if let Ok(found) = conn.query_row(
232            &format!("SELECT value FROM {} WHERE [key] = ?", table.as_ref()),
233            [key.as_ref()],
234            |row| -> Result<String> { row.get(0) },
235        ) {
236            if let Ok(v) = V::from_str(&found) {
237                return Ok(v);
238            }
239        }
240        Ok(default)
241    }
242    /*
243    fn get_or_set_default_unique_value<T, K, V>(
244        conn: &rusqlite::Connection,
245        table: T,
246        key: K,
247        default: V,
248    ) -> io::Result<V>
249    where
250        T: AsRef<str>,
251        K: AsRef<str>,
252        V: FromStr + ToString,
253    {
254        if let Ok(found) = conn.query_row(
255            &format!("SELECT value FROM {} WHERE [key] = ?", table.as_ref()),
256            [key.as_ref()],
257            |row| -> Result<String> { row.get(0) },
258        ) {
259            if let Ok(v) = V::from_str(&found) {
260                return Ok(v);
261            }
262        }
263        let out = default.to_string();
264        Self::set_unique_value(conn, table, key, out)?;
265        return Ok(default);
266    }
267    */
268    fn set_unique_value<T, K, V>(
269        conn: &rusqlite::Connection,
270        table: T,
271        key: K,
272        value: V,
273    ) -> io::Result<()>
274    where
275        T: AsRef<str>,
276        K: AsRef<str>,
277        V: ToString,
278    {
279        let changed = conn
280            .execute(
281                &format!(
282                    "INSERT OR REPLACE INTO {} ([key], value) VALUES(?, ?)",
283                    table.as_ref()
284                ),
285                [key.as_ref(), value.to_string().as_str()],
286            )
287            .map_err(other_io_err)?;
288        assert!(
289            changed <= 1,
290            "multiple changes to unique key should not occur"
291        );
292        if changed == 0 {
293            return Err(other_io_err("failed to set unique value"));
294        }
295        Ok(())
296    }
297    fn load_unique_value_blob<T, K>(
298        conn: &rusqlite::Connection,
299        table: T,
300        key: K,
301    ) -> io::Result<Option<Vec<u8>>>
302    where
303        T: AsRef<str>,
304        K: AsRef<str>,
305    {
306        conn.query_row(
307            &format!("SELECT value FROM {} WHERE [key] = ?", table.as_ref()),
308            [key.as_ref()],
309            |row| -> Result<Vec<u8>> { row.get(0) },
310        )
311        .optional()
312        .map_err(other_io_err)
313    }
314    // fn load_unique_value_blob_like<T, K>(
315    // 	conn: &rusqlite::Connection,
316    // 	table: T,
317    // 	like: K,
318    // ) -> io::Result<Option<Vec<u8>>>
319    // where
320    // 	T: AsRef<str>,
321    // 	K: AsRef<str>,
322    // {
323    // 	conn.query_row(
324    // 		&format!(
325    // 			"SELECT value FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
326    // 			table.as_ref()
327    // 		),
328    // 		[like.as_ref()],
329    // 		|row| -> Result<Vec<u8>> { row.get(0) },
330    // 	)
331    // 	.optional()
332    // 	.map_err(other_io_err)
333    // }
334    fn store_unique_value_blob<T, K>(
335        conn: &rusqlite::Connection,
336        table: T,
337        key: K,
338        value: &[u8],
339    ) -> io::Result<()>
340    where
341        T: AsRef<str>,
342        K: AsRef<str>,
343    {
344        let changed = conn
345            .execute(
346                &format!(
347                    "INSERT OR REPLACE INTO {} ([key], value) VALUES(?, ?)",
348                    table.as_ref()
349                ),
350                params![key.as_ref(), value],
351            )
352            .map_err(other_io_err)?;
353        assert!(
354            changed <= 1,
355            "multiple changes to unique key should not occur"
356        );
357        if changed == 0 {
358            return Err(other_io_err("failed to set unique value"));
359        }
360        Ok(())
361    }
362    fn remove_unique_value_blob<T, K>(
363        conn: &rusqlite::Connection,
364        table: T,
365        key: K,
366    ) -> io::Result<()>
367    where
368        T: AsRef<str>,
369        K: AsRef<str>,
370    {
371        let changed = conn
372            .execute(
373                &format!("DELETE FROM {} WHERE [key] = ?", table.as_ref()),
374                [key.as_ref()],
375            )
376            .map_err(other_io_err)?;
377        assert!(
378            changed <= 1,
379            "multiple deletions of unique key should not occur"
380        );
381        if changed == 0 {
382            return Err(other_io_err("failed to remove unique value"));
383        }
384        Ok(())
385    }
386    fn remove_unique_value_blob_like<T, K>(
387        conn: &rusqlite::Connection,
388        table: T,
389        key: K,
390    ) -> io::Result<usize>
391    where
392        T: AsRef<str>,
393        K: AsRef<str>,
394    {
395        let changed = conn
396            .execute(
397                &format!(
398                    "DELETE FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
399                    table.as_ref(),
400                ),
401                [key.as_ref()],
402            )
403            .map_err(other_io_err)?;
404        Ok(changed)
405    }
406    fn open_resize_columns(&self) -> io::Result<()> {
407        self.with_inner(|inner| {
408			// First see if we have a control table with the number of columns
409			inner.connection.execute("CREATE TABLE IF NOT EXISTS control (id INTEGER PRIMARY KEY AUTOINCREMENT, [key] TEXT UNIQUE, value TEXT)", []).map_err(other_io_err)?;
410			// Get column count
411			let on_disk_columns =
412				Self::get_unique_value(&inner.connection, "control", "columns", 0u32)?;
413
414			// If desired column count is less than or equal to current column count, then allow it, but restrict access to columns
415			if inner.config.columns <= on_disk_columns {
416				return Ok(());
417			}
418
419			// Otherwise resize and add other columns
420			for cn in on_disk_columns..inner.config.columns {
421				// Create the column table if we don't have it
422				Self::create_column_table(&inner.connection, cn)?;
423			}
424			Self::set_unique_value(
425				&inner.connection,
426				"control",
427				"columns",
428				inner.config.columns,
429			)?;
430
431			Ok(())
432		})
433    }
434
435    fn stats_read(inner: &mut DatabaseInner, count: usize, bytes: usize) {
436        inner.current_stats.reads += count as u64;
437        inner.overall_stats.reads += count as u64;
438        inner.current_stats.bytes_read += bytes as u64;
439        inner.overall_stats.bytes_read += bytes as u64;
440    }
441    fn stats_write(inner: &mut DatabaseInner, count: usize, bytes: usize) {
442        inner.current_stats.writes += count as u64;
443        inner.overall_stats.writes += count as u64;
444        inner.current_stats.bytes_written += bytes as u64;
445        inner.overall_stats.bytes_written += bytes as u64;
446    }
447    fn stats_transaction(inner: &mut DatabaseInner, count: usize) {
448        inner.current_stats.transactions += count as u64;
449        inner.overall_stats.transactions += count as u64;
450    }
451}
452
453impl KeyValueDB for Database {
454    fn get<'a>(
455        &self,
456        col: u32,
457        key: &'a [u8],
458    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + 'a>> {
459        let this = self.clone();
460        Box::pin(async move {
461            this.with_inner_mut(|inner| {
462                Self::validate_column(col, inner.config.columns)?;
463                match Self::load_unique_value_blob(
464                    &inner.connection,
465                    get_column_table_name(col),
466                    key_to_text(key),
467                ) {
468                    Ok(someval) => {
469                        match &someval {
470                            Some(val) => Self::stats_read(inner, 1, key.len() + val.len()),
471                            None => Self::stats_read(inner, 1, key.len()),
472                        };
473                        Ok(someval)
474                    }
475                    Err(e) => Err(e),
476                }
477            })
478        })
479    }
480
481    /// Remove a value by key, returning the old value
482    fn delete<'a>(
483        &self,
484        col: u32,
485        key: &'a [u8],
486    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + 'a>> {
487        let this = self.clone();
488        Box::pin(async move {
489            this.with_inner_mut(|inner| {
490                Self::validate_column(col, inner.config.columns)?;
491                let someval = Self::load_unique_value_blob(
492                    &inner.connection,
493                    get_column_table_name(col),
494                    key_to_text(key),
495                )?;
496
497                match &someval {
498                    Some(val) => {
499                        Self::stats_read(inner, 1, key.len() + val.len());
500
501                        Self::remove_unique_value_blob(
502                            &inner.connection,
503                            get_column_table_name(col),
504                            key_to_text(key),
505                        )?;
506
507                        Self::stats_write(inner, 1, 0);
508                    }
509                    None => Self::stats_read(inner, 1, key.len()),
510                };
511                Ok(someval)
512            })
513        })
514    }
515
516    fn write(
517        &self,
518        transaction: DBTransaction,
519    ) -> Pin<Box<dyn Future<Output = Result<(), DBTransactionError>> + Send + 'static>> {
520        let this = self.clone();
521        Box::pin(async move {
522            this.with_inner_mut(|inner| {
523                Self::stats_transaction(inner, 1);
524
525                let mut sw = 0usize;
526                let mut sbw = 0usize;
527                let out = {
528                    let tx = inner.connection.transaction().map_err(other_io_err)?;
529                    for op in &transaction.ops {
530                        match op {
531                            DBOp::Insert { col, key, value } => {
532                                Self::validate_column(*col, inner.config.columns)?;
533                                Self::store_unique_value_blob(
534                                    &tx,
535                                    get_column_table_name(*col),
536                                    key_to_text(key),
537                                    value,
538                                )?;
539                                sw += 1;
540                                sbw += key.len() + value.len();
541                            }
542                            DBOp::Delete { col, key } => {
543                                Self::validate_column(*col, inner.config.columns)?;
544                                Self::remove_unique_value_blob(
545                                    &tx,
546                                    get_column_table_name(*col),
547                                    key_to_text(key),
548                                )?;
549                                sw += 1;
550                            }
551                            DBOp::DeletePrefix { col, prefix } => {
552                                Self::validate_column(*col, inner.config.columns)?;
553                                Self::remove_unique_value_blob_like(
554                                    &tx,
555                                    get_column_table_name(*col),
556                                    like_key_to_text(prefix) + "%",
557                                )?;
558                                sw += 1;
559                            }
560                        }
561                    }
562                    tx.commit().map_err(other_io_err)
563                };
564                Self::stats_write(inner, sw, sbw);
565                out
566            })
567            .map_err(|error| DBTransactionError { error, transaction })
568        })
569    }
570
571    fn iter<'a, T: 'a, F: FnMut(DBKeyValueRef) -> io::Result<Option<T>> + Send + Sync + 'a>(
572        &self,
573        col: u32,
574        prefix: Option<&'a [u8]>,
575        mut f: F,
576    ) -> Pin<Box<dyn Future<Output = io::Result<Option<T>>> + Send + 'a>> {
577        let this = self.clone();
578        Box::pin(async move {
579            this.with_inner_mut(|inner| {
580                if col >= inner.config.columns {
581                    return Err(io::Error::from(io::ErrorKind::NotFound));
582                }
583
584                let mut sw = 0usize;
585                let mut sbw = 0usize;
586                let mut out = Ok(None);
587                {
588                    let mut stmt;
589                    let mut rows;
590                    if let Some(p) = prefix {
591                        stmt = match inner.connection.prepare(&format!(
592                            "SELECT key, value FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
593                            get_column_table_name(col)
594                        )) {
595                            Err(e) => {
596                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
597                            }
598                            Ok(v) => v,
599                        };
600                        rows = match stmt.query([like_key_to_text(p) + "%"]) {
601                            Err(e) => {
602                                return Err(other_io_err(format!("SQLite rows error: {:?}", e)));
603                            }
604                            Ok(v) => v,
605                        };
606                    } else {
607                        stmt = match inner.connection.prepare(&format!(
608                            "SELECT key, value FROM {}",
609                            get_column_table_name(col)
610                        )) {
611                            Err(e) => {
612                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
613                            }
614                            Ok(v) => v,
615                        };
616                        rows = match stmt.query([]) {
617                            Err(e) => {
618                                return Err(other_io_err(format!("SQLite query error: {:?}", e)));
619                            }
620                            Ok(v) => v,
621                        };
622                    };
623
624                    loop {
625                        match rows.next() {
626                            // Iterated value
627                            Ok(Some(row)) => {
628                                let kt: String = match row.get(0) {
629                                    Err(e) => {
630                                        out = Err(other_io_err(format!(
631                                            "SQLite row get column 0 error: {:?}",
632                                            e
633                                        )));
634                                        break;
635                                    }
636                                    Ok(v) => v,
637                                };
638                                let v: Vec<u8> = match row.get(1) {
639                                    Err(e) => {
640                                        out = Err(other_io_err(format!(
641                                            "SQLite row get column 1 error: {:?}",
642                                            e
643                                        )));
644                                        break;
645                                    }
646                                    Ok(v) => v,
647                                };
648                                let k: Vec<u8> = match text_to_key(&kt) {
649                                    Err(e) => {
650                                        out = Err(other_io_err(format!(
651                                            "SQLite row get column 0 text convert error: {:?}",
652                                            e
653                                        )));
654                                        break;
655                                    }
656                                    Ok(v) => v,
657                                };
658
659                                sw += 1;
660                                sbw += k.len() + v.len();
661
662                                match f((&k, &v)) {
663                                    Ok(None) => (),
664                                    Ok(Some(v)) => {
665                                        // Callback early termination
666                                        out = Ok(Some(v));
667                                        break;
668                                    }
669                                    Err(e) => {
670                                        // Callback error termination
671                                        out = Err(e);
672                                        break;
673                                    }
674                                }
675                            }
676                            // Natural iterator termination
677                            Ok(None) => {
678                                break;
679                            }
680                            // Error iterator termination
681                            Err(e) => {
682                                out = Err(other_io_err(format!("SQLite rows error: {:?}", e)));
683                                break;
684                            }
685                        }
686                    }
687                }
688                Self::stats_read(inner, sw, sbw);
689                out
690            })
691        })
692    }
693
694    fn iter_keys<'a, T: 'a, F: FnMut(DBKeyRef) -> io::Result<Option<T>> + Send + Sync + 'a>(
695        &self,
696        col: u32,
697        prefix: Option<&'a [u8]>,
698        mut f: F,
699    ) -> Pin<Box<dyn Future<Output = io::Result<Option<T>>> + Send + 'a>> {
700        let this = self.clone();
701        Box::pin(async move {
702            this.with_inner_mut(|inner| {
703                if col >= inner.config.columns {
704                    return Err(io::Error::from(io::ErrorKind::NotFound));
705                }
706
707                let mut sw = 0usize;
708                let mut sbw = 0usize;
709                let mut out = Ok(None);
710                {
711                    let mut stmt;
712                    let mut rows;
713                    if let Some(p) = prefix {
714                        stmt = match inner.connection.prepare(&format!(
715                            "SELECT key FROM {} WHERE [key] LIKE ? ESCAPE '\\'",
716                            get_column_table_name(col)
717                        )) {
718                            Err(e) => {
719                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
720                            }
721                            Ok(v) => v,
722                        };
723                        rows = match stmt.query([like_key_to_text(p) + "%"]) {
724                            Err(e) => {
725                                return Err(other_io_err(format!("SQLite rows error: {:?}", e)));
726                            }
727                            Ok(v) => v,
728                        };
729                    } else {
730                        stmt = match inner
731                            .connection
732                            .prepare(&format!("SELECT key FROM {}", get_column_table_name(col)))
733                        {
734                            Err(e) => {
735                                return Err(other_io_err(format!("SQLite prepare error: {:?}", e)));
736                            }
737                            Ok(v) => v,
738                        };
739                        rows = match stmt.query([]) {
740                            Err(e) => {
741                                return Err(other_io_err(format!("SQLite query error: {:?}", e)));
742                            }
743                            Ok(v) => v,
744                        };
745                    };
746
747                    loop {
748                        match rows.next() {
749                            // Iterated value
750                            Ok(Some(row)) => {
751                                let kt: String = match row.get(0) {
752                                    Err(e) => {
753                                        out = Err(other_io_err(format!(
754                                            "SQLite row get column 0 error: {:?}",
755                                            e
756                                        )));
757                                        break;
758                                    }
759                                    Ok(v) => v,
760                                };
761                                let k: Vec<u8> = match text_to_key(&kt) {
762                                    Err(e) => {
763                                        out = Err(other_io_err(format!(
764                                            "SQLite row get column 0 text convert error: {:?}",
765                                            e
766                                        )));
767                                        break;
768                                    }
769                                    Ok(v) => v,
770                                };
771
772                                sw += 1;
773                                sbw += k.len();
774
775                                match f(&k) {
776                                    Ok(None) => (),
777                                    Ok(Some(v)) => {
778                                        // Callback early termination
779                                        out = Ok(Some(v));
780                                        break;
781                                    }
782                                    Err(e) => {
783                                        // Callback error termination
784                                        out = Err(e);
785                                        break;
786                                    }
787                                }
788                            }
789                            // Natural iterator termination
790                            Ok(None) => {
791                                break;
792                            }
793                            // Error iterator termination
794                            Err(e) => {
795                                out = Err(other_io_err(format!("SQLite rows error: {:?}", e)));
796                                break;
797                            }
798                        }
799                    }
800                }
801                Self::stats_read(inner, sw, sbw);
802                out
803            })
804        })
805    }
806
807    fn io_stats(&self, kind: IoStatsKind) -> IoStats {
808        self.with_inner_mut(|inner| match kind {
809            IoStatsKind::Overall => {
810                let mut stats = inner.overall_stats.clone();
811                stats.span = std::time::SystemTime::now()
812                    .duration_since(stats.started)
813                    .unwrap_or_default();
814                stats
815            }
816            IoStatsKind::SincePrevious => {
817                let mut stats = inner.current_stats.clone();
818                stats.span = std::time::SystemTime::now()
819                    .duration_since(stats.started)
820                    .unwrap_or_default();
821                inner.current_stats = IoStats::empty();
822                stats
823            }
824        })
825    }
826
827    fn num_columns(&self) -> io::Result<u32> {
828        self.with_inner(|inner| {
829            Self::get_unique_value(&inner.connection, "control", "columns", 0u32)
830        })
831    }
832
833    fn num_keys(&self, col: u32) -> Pin<Box<dyn Future<Output = io::Result<u64>> + Send>> {
834        let this = self.clone();
835        Box::pin(async move {
836            this.with_inner(|inner| {
837                if let Ok(found) = inner.connection.query_row(
838                    &format!("SELECT Count(*) FROM {}", get_column_table_name(col)),
839                    [],
840                    |row| -> Result<u64> { row.get(0) },
841                ) {
842                    return Ok(found);
843                }
844                Err(io::Error::from(io::ErrorKind::NotFound))
845            })
846        })
847    }
848
849    fn restore(&self, _new_db: &str) -> io::Result<()> {
850        Err(other_io_err("Attempted to restore sqlite database"))
851    }
852}
853
854#[cfg(test)]
855mod tests {
856
857    use super::*;
858    use keyvaluedb_shared_tests as st;
859    use tempfile::Builder as TempfileBuilder;
860
861    fn create(columns: u32) -> io::Result<Database> {
862        let tempfile = TempfileBuilder::new()
863            .prefix("")
864            .tempfile()?
865            .path()
866            .to_path_buf();
867        let config = DatabaseConfig::with_columns(columns);
868        Database::open(tempfile, config)
869    }
870
871    #[tokio::test]
872    async fn get_fails_with_non_existing_column() -> io::Result<()> {
873        let db = create(1)?;
874        st::test_get_fails_with_non_existing_column(db).await
875    }
876
877    #[tokio::test]
878    async fn num_keys() -> io::Result<()> {
879        let db = create(1)?;
880        st::test_num_keys(db).await
881    }
882
883    #[tokio::test]
884    async fn put_and_get() -> io::Result<()> {
885        let db = create(1)?;
886        st::test_put_and_get(db).await
887    }
888
889    #[tokio::test]
890    async fn delete_and_get() -> io::Result<()> {
891        let db = create(1)?;
892        st::test_delete_and_get(db).await
893    }
894
895    #[tokio::test]
896    async fn delete_and_get_single() -> io::Result<()> {
897        let db = create(1)?;
898        st::test_delete_and_get_single(db).await
899    }
900
901    #[tokio::test]
902    async fn delete_prefix() -> io::Result<()> {
903        let db = create(st::DELETE_PREFIX_NUM_COLUMNS)?;
904        st::test_delete_prefix(db).await
905    }
906
907    #[tokio::test]
908    async fn iter() -> io::Result<()> {
909        let db = create(1)?;
910        st::test_iter(db).await
911    }
912
913    #[tokio::test]
914    async fn iter_keys() -> io::Result<()> {
915        let db = create(1)?;
916        st::test_iter_keys(db).await
917    }
918
919    #[tokio::test]
920    async fn iter_with_prefix() -> io::Result<()> {
921        let db = create(1)?;
922        st::test_iter_with_prefix(db).await
923    }
924
925    #[tokio::test]
926    async fn complex() -> io::Result<()> {
927        let db = create(1)?;
928        st::test_complex(db).await
929    }
930
931    #[tokio::test]
932    async fn stats() -> io::Result<()> {
933        let db = create(st::IO_STATS_NUM_COLUMNS)?;
934        st::test_io_stats(db).await
935    }
936
937    #[tokio::test]
938    #[should_panic]
939    async fn db_config_with_zero_columns() {
940        let _cfg = DatabaseConfig::with_columns(0);
941    }
942
943    #[tokio::test]
944    #[should_panic]
945    async fn open_db_with_zero_columns() {
946        let cfg = DatabaseConfig {
947            columns: 0,
948            ..Default::default()
949        };
950        let _db = Database::open("", cfg);
951    }
952
953    #[tokio::test]
954    async fn add_columns() {
955        let config_1 = DatabaseConfig::default();
956        let config_5 = DatabaseConfig::with_columns(5);
957
958        let tempfile = TempfileBuilder::new()
959            .prefix("")
960            .tempfile()
961            .unwrap()
962            .path()
963            .to_path_buf();
964
965        // open 1, add 4.
966        {
967            let db = Database::open(&tempfile, config_1).unwrap();
968            assert_eq!(db.num_columns().unwrap(), 1);
969
970            for i in 2..=5 {
971                db.add_column().unwrap();
972                assert_eq!(db.num_columns().unwrap(), i);
973            }
974        }
975
976        // reopen as 5.
977        {
978            let db = Database::open(&tempfile, config_5).unwrap();
979            assert_eq!(db.num_columns().unwrap(), 5);
980        }
981    }
982
983    #[tokio::test]
984    async fn remove_columns() {
985        let config_1 = DatabaseConfig::default();
986        let config_5 = DatabaseConfig::with_columns(5);
987
988        let tempfile = TempfileBuilder::new()
989            .prefix("drop_columns")
990            .tempfile()
991            .unwrap()
992            .path()
993            .to_path_buf();
994
995        // open 5, remove 4.
996        {
997            let db = Database::open(&tempfile, config_5).expect("open with 5 columns");
998            assert_eq!(db.num_columns().unwrap(), 5);
999
1000            for i in (1..5).rev() {
1001                db.remove_last_column().unwrap();
1002                assert_eq!(db.num_columns().unwrap(), i);
1003            }
1004        }
1005
1006        // reopen as 1.
1007        {
1008            let db = Database::open(&tempfile, config_1).unwrap();
1009            assert_eq!(db.num_columns().unwrap(), 1);
1010        }
1011    }
1012
1013    #[tokio::test]
1014    async fn test_num_keys() {
1015        let tempfile = TempfileBuilder::new()
1016            .prefix("")
1017            .tempfile()
1018            .unwrap()
1019            .path()
1020            .to_path_buf();
1021        let config = DatabaseConfig::with_columns(1);
1022        let db = Database::open(tempfile, config).unwrap();
1023
1024        assert_eq!(
1025            db.num_keys(0).await.unwrap(),
1026            0,
1027            "database is empty after creation"
1028        );
1029        let key1 = b"beef";
1030        let mut batch = db.transaction();
1031        batch.put(0, key1, key1);
1032        db.write(batch).await.unwrap();
1033        assert_eq!(
1034            db.num_keys(0).await.unwrap(),
1035            1,
1036            "adding a key increases the count"
1037        );
1038    }
1039}