keyvaluedb-memorydb 0.1.8

A key-value in-memory database that implements the `KeyValueDB` trait
Documentation
//! A key-value in-memory database that implements the `KeyValueDB` trait

#![deny(clippy::all)]

use keyvaluedb::{
    DBKey, DBKeyRef, DBKeyValueRef, DBOp, DBTransaction, DBTransactionError, DBValue, KeyValueDB,
    KeyValueDBPinBoxFuture,
};
use parking_lot::RwLock;
use std::{
    collections::{BTreeMap, HashMap},
    io,
    sync::Arc,
};

/// A key-value database fulfilling the `KeyValueDB` trait, living in memory.
/// This is generally intended for tests and is not particularly optimized.
#[derive(Clone)]
pub struct InMemory {
    columns: Arc<RwLock<HashMap<u32, BTreeMap<DBKey, DBValue>>>>,
}

/// Create an in-memory database with the given number of columns.
/// Columns will be indexable by 0..`num_cols`
pub fn create(num_cols: u32) -> InMemory {
    let mut cols = HashMap::new();

    for idx in 0..num_cols {
        cols.insert(idx, BTreeMap::new());
    }

    InMemory {
        columns: Arc::new(RwLock::new(cols)),
    }
}

impl KeyValueDB for InMemory {
    fn get<'a>(
        &'a self,
        col: u32,
        key: &'a [u8],
    ) -> KeyValueDBPinBoxFuture<'a, io::Result<Option<DBValue>>> {
        let this = self.clone();
        Box::pin(async move {
            let columns = this.columns.read();
            match columns.get(&col) {
                None => Err(io::Error::from(io::ErrorKind::NotFound)),
                Some(map) => Ok(map.get(key).cloned()),
            }
        })
    }

    /// Remove a value by key, returning the old value
    fn delete<'a>(
        &'a self,
        col: u32,
        key: &'a [u8],
    ) -> KeyValueDBPinBoxFuture<'a, io::Result<Option<DBValue>>> {
        let this = self.clone();
        Box::pin(async move {
            let mut columns = this.columns.write();
            match columns.get_mut(&col) {
                None => Err(io::Error::from(io::ErrorKind::NotFound)),
                Some(map) => Ok(map.remove(key)),
            }
        })
    }

    fn write(
        &self,
        transaction: DBTransaction,
    ) -> KeyValueDBPinBoxFuture<'_, Result<(), DBTransactionError>> {
        let this = self.clone();
        Box::pin(async move {
            let mut columns = this.columns.write();
            let ops = transaction.ops;
            for op in ops {
                match op {
                    DBOp::Insert { col, key, value } => {
                        if let Some(col) = columns.get_mut(&col) {
                            col.insert(key, value);
                        }
                    }
                    DBOp::Delete { col, key } => {
                        if let Some(col) = columns.get_mut(&col) {
                            col.remove(&*key);
                        }
                    }
                    DBOp::DeletePrefix { col, prefix } => {
                        if let Some(col) = columns.get_mut(&col) {
                            use std::ops::Bound;
                            if prefix.is_empty() {
                                col.clear();
                            } else {
                                let start_range = Bound::Included(prefix.to_vec());
                                let keys: Vec<_> =
                                    if let Some(end_range) = keyvaluedb::end_prefix(&prefix[..]) {
                                        col.range((start_range, Bound::Excluded(end_range)))
                                            .map(|(k, _)| k.clone())
                                            .collect()
                                    } else {
                                        col.range((start_range, Bound::Unbounded))
                                            .map(|(k, _)| k.clone())
                                            .collect()
                                    };
                                for key in keys.into_iter() {
                                    col.remove(&key[..]);
                                }
                            }
                        }
                    }
                }
            }
            Ok(())
        })
    }

    fn iter<
        'a,
        T: Send + 'static,
        C: Send + 'static,
        F: FnMut(&mut C, DBKeyValueRef) -> io::Result<Option<T>> + Send + Sync + 'static,
    >(
        &'a self,
        col: u32,
        prefix: Option<&'a [u8]>,
        mut context: C,
        mut f: F,
    ) -> KeyValueDBPinBoxFuture<'a, io::Result<(C, Option<T>)>> {
        let this = self.clone();
        Box::pin(async move {
            match this.columns.read().get(&col) {
                Some(map) => {
                    for (k, v) in map {
                        if let Some(p) = prefix {
                            if !k.starts_with(p) {
                                continue;
                            }
                        }
                        match f(&mut context, (k, v)) {
                            Ok(None) => (),
                            Ok(Some(v)) => return Ok((context, Some(v))),
                            Err(e) => return Err(e),
                        }
                    }
                    Ok((context, None))
                }
                None => Err(io::Error::from(io::ErrorKind::NotFound)),
            }
        })
    }

    fn iter_keys<
        'a,
        T: Send + 'static,
        C: Send + 'static,
        F: FnMut(&mut C, DBKeyRef) -> io::Result<Option<T>> + Send + Sync + 'static,
    >(
        &'a self,
        col: u32,
        prefix: Option<&'a [u8]>,
        mut context: C,
        mut f: F,
    ) -> KeyValueDBPinBoxFuture<'a, io::Result<(C, Option<T>)>> {
        let this = self.clone();
        Box::pin(async move {
            match this.columns.read().get(&col) {
                Some(map) => {
                    for k in map.keys() {
                        if let Some(p) = prefix {
                            if !k.starts_with(p) {
                                continue;
                            }
                        }
                        match f(&mut context, k) {
                            Ok(None) => (),
                            Ok(Some(v)) => return Ok((context, Some(v))),
                            Err(e) => return Err(e),
                        }
                    }
                    Ok((context, None))
                }
                None => Err(io::Error::from(io::ErrorKind::NotFound)),
            }
        })
    }

    fn num_columns(&self) -> io::Result<u32> {
        Ok(self.columns.read().len() as u32)
    }

    fn num_keys(&self, col: u32) -> KeyValueDBPinBoxFuture<'_, io::Result<u64>> {
        let this = self.clone();
        Box::pin(async move {
            let c = this.columns.read();
            let Some(column) = c.get(&col) else {
                return Err(io::Error::from(io::ErrorKind::NotFound));
            };
            Ok(column.len() as u64)
        })
    }
}

#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(test)]
mod tests {
    use super::create;
    use keyvaluedb_shared_tests as st;
    use std::io;

    #[tokio::test]
    async fn get_fails_with_non_existing_column() -> io::Result<()> {
        let db = create(1);
        st::test_get_fails_with_non_existing_column(db).await
    }

    #[tokio::test]
    async fn put_and_get() -> io::Result<()> {
        let db = create(1);
        st::test_put_and_get(db).await
    }

    #[tokio::test]
    async fn num_keys() -> io::Result<()> {
        let db = create(1);
        st::test_num_keys(db).await
    }

    #[tokio::test]
    async fn delete_and_get() -> io::Result<()> {
        let db = create(1);
        st::test_delete_and_get(db).await
    }

    #[tokio::test]
    async fn delete_prefix() -> io::Result<()> {
        let db = create(st::DELETE_PREFIX_NUM_COLUMNS);
        st::test_delete_prefix(db).await
    }

    #[tokio::test]
    async fn iter() -> io::Result<()> {
        let db = create(1);
        st::test_iter(db).await
    }

    #[tokio::test]
    async fn iter_keys() -> io::Result<()> {
        let db = create(1);
        st::test_iter_keys(db).await
    }

    #[tokio::test]
    async fn iter_with_prefix() -> io::Result<()> {
        let db = create(1);
        st::test_iter_with_prefix(db).await
    }

    #[tokio::test]
    async fn complex() -> io::Result<()> {
        let db = create(1);
        st::test_complex(db).await
    }

    #[tokio::test]
    async fn cleanup() -> io::Result<()> {
        let db = create(1);
        st::test_cleanup(db).await
    }
}