keyvaluedb_memorydb/
lib.rs

1//! A key-value in-memory database that implements the `KeyValueDB` trait
2
3#![deny(clippy::all)]
4
5use keyvaluedb::{
6    DBKey, DBKeyRef, DBKeyValueRef, DBOp, DBTransaction, DBTransactionError, DBValue, KeyValueDB,
7};
8use parking_lot::RwLock;
9use std::{
10    collections::{BTreeMap, HashMap},
11    future::Future,
12    io,
13    pin::Pin,
14    sync::Arc,
15};
16
17/// A key-value database fulfilling the `KeyValueDB` trait, living in memory.
18/// This is generally intended for tests and is not particularly optimized.
19#[derive(Clone)]
20pub struct InMemory {
21    columns: Arc<RwLock<HashMap<u32, BTreeMap<DBKey, DBValue>>>>,
22}
23
24/// Create an in-memory database with the given number of columns.
25/// Columns will be indexable by 0..`num_cols`
26pub fn create(num_cols: u32) -> InMemory {
27    let mut cols = HashMap::new();
28
29    for idx in 0..num_cols {
30        cols.insert(idx, BTreeMap::new());
31    }
32
33    InMemory {
34        columns: Arc::new(RwLock::new(cols)),
35    }
36}
37
38impl KeyValueDB for InMemory {
39    fn get<'a>(
40        &'a self,
41        col: u32,
42        key: &'a [u8],
43    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + 'a>> {
44        let this = self.clone();
45        Box::pin(async move {
46            let columns = this.columns.read();
47            match columns.get(&col) {
48                None => Err(io::Error::from(io::ErrorKind::NotFound)),
49                Some(map) => Ok(map.get(key).cloned()),
50            }
51        })
52    }
53
54    /// Remove a value by key, returning the old value
55    fn delete<'a>(
56        &'a self,
57        col: u32,
58        key: &'a [u8],
59    ) -> Pin<Box<dyn Future<Output = io::Result<Option<DBValue>>> + Send + 'a>> {
60        let this = self.clone();
61        Box::pin(async move {
62            let mut columns = this.columns.write();
63            match columns.get_mut(&col) {
64                None => Err(io::Error::from(io::ErrorKind::NotFound)),
65                Some(map) => Ok(map.remove(key)),
66            }
67        })
68    }
69
70    fn write(
71        &self,
72        transaction: DBTransaction,
73    ) -> Pin<Box<dyn Future<Output = Result<(), DBTransactionError>> + Send + '_>> {
74        let this = self.clone();
75        Box::pin(async move {
76            let mut columns = this.columns.write();
77            let ops = transaction.ops;
78            for op in ops {
79                match op {
80                    DBOp::Insert { col, key, value } => {
81                        if let Some(col) = columns.get_mut(&col) {
82                            col.insert(key, value);
83                        }
84                    }
85                    DBOp::Delete { col, key } => {
86                        if let Some(col) = columns.get_mut(&col) {
87                            col.remove(&*key);
88                        }
89                    }
90                    DBOp::DeletePrefix { col, prefix } => {
91                        if let Some(col) = columns.get_mut(&col) {
92                            use std::ops::Bound;
93                            if prefix.is_empty() {
94                                col.clear();
95                            } else {
96                                let start_range = Bound::Included(prefix.to_vec());
97                                let keys: Vec<_> =
98                                    if let Some(end_range) = keyvaluedb::end_prefix(&prefix[..]) {
99                                        col.range((start_range, Bound::Excluded(end_range)))
100                                            .map(|(k, _)| k.clone())
101                                            .collect()
102                                    } else {
103                                        col.range((start_range, Bound::Unbounded))
104                                            .map(|(k, _)| k.clone())
105                                            .collect()
106                                    };
107                                for key in keys.into_iter() {
108                                    col.remove(&key[..]);
109                                }
110                            }
111                        }
112                    }
113                }
114            }
115            Ok(())
116        })
117    }
118
119    fn iter<
120        'a,
121        T: Send + 'static,
122        C: Send + 'static,
123        F: FnMut(&mut C, DBKeyValueRef) -> io::Result<Option<T>> + Send + Sync + 'static,
124    >(
125        &'a self,
126        col: u32,
127        prefix: Option<&'a [u8]>,
128        mut context: C,
129        mut f: F,
130    ) -> Pin<Box<dyn Future<Output = io::Result<(C, Option<T>)>> + Send + 'a>> {
131        let this = self.clone();
132        Box::pin(async move {
133            match this.columns.read().get(&col) {
134                Some(map) => {
135                    for (k, v) in map {
136                        if let Some(p) = prefix {
137                            if !k.starts_with(p) {
138                                continue;
139                            }
140                        }
141                        match f(&mut context, (k, v)) {
142                            Ok(None) => (),
143                            Ok(Some(v)) => return Ok((context, Some(v))),
144                            Err(e) => return Err(e),
145                        }
146                    }
147                    Ok((context, None))
148                }
149                None => Err(io::Error::from(io::ErrorKind::NotFound)),
150            }
151        })
152    }
153
154    fn iter_keys<
155        'a,
156        T: Send + 'static,
157        C: Send + 'static,
158        F: FnMut(&mut C, DBKeyRef) -> io::Result<Option<T>> + Send + Sync + 'static,
159    >(
160        &'a self,
161        col: u32,
162        prefix: Option<&'a [u8]>,
163        mut context: C,
164        mut f: F,
165    ) -> Pin<Box<dyn Future<Output = io::Result<(C, Option<T>)>> + Send + 'a>> {
166        let this = self.clone();
167        Box::pin(async move {
168            match this.columns.read().get(&col) {
169                Some(map) => {
170                    for k in map.keys() {
171                        if let Some(p) = prefix {
172                            if !k.starts_with(p) {
173                                continue;
174                            }
175                        }
176                        match f(&mut context, k) {
177                            Ok(None) => (),
178                            Ok(Some(v)) => return Ok((context, Some(v))),
179                            Err(e) => return Err(e),
180                        }
181                    }
182                    Ok((context, None))
183                }
184                None => Err(io::Error::from(io::ErrorKind::NotFound)),
185            }
186        })
187    }
188
189    fn num_columns(&self) -> io::Result<u32> {
190        Ok(self.columns.read().len() as u32)
191    }
192
193    fn num_keys(&self, col: u32) -> Pin<Box<dyn Future<Output = io::Result<u64>> + Send + '_>> {
194        let this = self.clone();
195        Box::pin(async move {
196            let c = this.columns.read();
197            let Some(column) = c.get(&col) else {
198                return Err(io::Error::from(io::ErrorKind::NotFound));
199            };
200            Ok(column.len() as u64)
201        })
202    }
203}
204
205#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
206#[cfg(test)]
207mod tests {
208    use super::create;
209    use keyvaluedb_shared_tests as st;
210    use std::io;
211
212    #[tokio::test]
213    async fn get_fails_with_non_existing_column() -> io::Result<()> {
214        let db = create(1);
215        st::test_get_fails_with_non_existing_column(db).await
216    }
217
218    #[tokio::test]
219    async fn put_and_get() -> io::Result<()> {
220        let db = create(1);
221        st::test_put_and_get(db).await
222    }
223
224    #[tokio::test]
225    async fn num_keys() -> io::Result<()> {
226        let db = create(1);
227        st::test_num_keys(db).await
228    }
229
230    #[tokio::test]
231    async fn delete_and_get() -> io::Result<()> {
232        let db = create(1);
233        st::test_delete_and_get(db).await
234    }
235
236    #[tokio::test]
237    async fn delete_prefix() -> io::Result<()> {
238        let db = create(st::DELETE_PREFIX_NUM_COLUMNS);
239        st::test_delete_prefix(db).await
240    }
241
242    #[tokio::test]
243    async fn iter() -> io::Result<()> {
244        let db = create(1);
245        st::test_iter(db).await
246    }
247
248    #[tokio::test]
249    async fn iter_keys() -> io::Result<()> {
250        let db = create(1);
251        st::test_iter_keys(db).await
252    }
253
254    #[tokio::test]
255    async fn iter_with_prefix() -> io::Result<()> {
256        let db = create(1);
257        st::test_iter_with_prefix(db).await
258    }
259
260    #[tokio::test]
261    async fn complex() -> io::Result<()> {
262        let db = create(1);
263        st::test_complex(db).await
264    }
265
266    #[tokio::test]
267    async fn cleanup() -> io::Result<()> {
268        let db = create(1);
269        st::test_cleanup(db).await
270    }
271}