bonerjams_db/
lib.rs

1//! an embedded database using the sled framework
2
3pub mod rpc;
4use serde::Serialize;
5pub mod types;
6use anyhow::{anyhow, Result};
7use bonerjams_config::database::DbOpts;
8use sled::{IVec, Tree};
9use std::sync::Arc;
10
11use self::types::{DbKey, DbTrees};
12
13/// Database is the main embedded database object using the
14/// sled db
15#[derive(Clone)]
16pub struct Database {
17    db: sled::Db,
18}
19
20/// DbTree is a wrapper around the sled::Tree type providing
21/// convenience functions
22#[derive(Clone)]
23pub struct DbTree {
24    pub tree: Tree,
25}
26
27/// DbBatch is a wrapper around the sled::Batch type providing
28/// convenience functions
29#[derive(Default, Clone)]
30pub struct DbBatch {
31    batch: sled::Batch,
32    count: u64,
33}
34
35impl Database {
36    /// returns a new sled database
37    pub fn new(cfg: &DbOpts) -> Result<Arc<Self>> {
38        let sled_config: sled::Config = cfg.into();
39        let db = sled_config.open()?;
40        drop(sled_config);
41        Ok(Arc::new(Database { db }))
42    }
43    /// opens the given database tree
44    pub fn open_tree(self: &Arc<Self>, tree: DbTrees) -> Result<Arc<DbTree>> {
45        DbTree::open(&self.db, tree)
46    }
47    /// opens the given db tree, return a vector of (key, value)
48    pub fn list_values(self: &Arc<Self>, tree: DbTrees) -> Result<Vec<(IVec, IVec)>> {
49        let tree = self.open_tree(tree)?;
50        Ok(tree
51            .iter()
52            .filter_map(|entry| {
53                if let Ok((key, value)) = entry {
54                    Some((key, value))
55                } else {
56                    None
57                }
58            })
59            .collect())
60    }
61    pub async fn flush_async(&self) -> sled::Result<usize> {
62        self.db.flush_async().await
63    }
64    /// flushes teh database
65    pub fn flush(self: &Arc<Self>) -> Result<usize> {
66        Ok(self.db.flush()?)
67    }
68    /// returns a clone of the inner database
69    pub fn inner(self: &Arc<Self>) -> sled::Db {
70        self.db.clone()
71    }
72    /// destroys all trees except the default tree
73    pub fn destroy(self: &Arc<Self>) {
74        const SLED_DEFAULT_TREE: &[u8] = b"__sled__default";
75        self.db
76            .tree_names()
77            .iter()
78            .filter(|tree_name| tree_name.as_ref().ne(SLED_DEFAULT_TREE))
79            .for_each(|tree_name| {
80                if let Err(err) = self.db.drop_tree(tree_name) {
81                    log::error!("failed to drop tree {:?}: {:#?}", tree_name.as_ref(), err);
82                }
83            });
84    }
85    pub fn get<K: AsRef<[u8]>>(&self, key: K) -> sled::Result<Option<sled::IVec>> {
86        self.db.get(key)
87    }
88    pub fn deserialize<K: AsRef<[u8]>, T>(&self, key: K) -> Result<T>
89    where
90        T: serde::de::DeserializeOwned + Clone,
91    {
92        let value = self.get(key)?;
93        if let Some(value) = value {
94            let result = serde_json::from_slice(&value)?;
95            Ok(result)
96        } else {
97            Err(anyhow!("value for key is None"))
98        }
99    }
100    pub fn apply_batch(&self, batch: &mut DbBatch) -> sled::Result<()> {
101        self.db.apply_batch(batch.take_inner())
102    }
103    /// insert raw, untyped bytes
104    pub fn insert_raw(&self, key: &[u8], value: &[u8]) -> Result<Option<sled::IVec>> {
105        Ok(self.db.insert(key, value)?)
106    }
107    /// inserts a value into the default tree
108    pub fn insert<T>(&mut self, value: &T) -> Result<()>
109    where
110        T: Serialize + DbKey,
111    {
112        self.db.insert(value.key()?, serde_json::to_vec(value)?)?;
113        Ok(())
114    }
115    pub fn delete<K: AsRef<[u8]>>(&self, key: K) -> Result<()> {
116        self.db.remove(key)?;
117        Ok(())
118    }
119}
120
121impl DbTree {
122    pub fn open(db: &sled::Db, tree: DbTrees) -> Result<Arc<Self>> {
123        let tree = db.open_tree(&tree.to_string())?;
124        Ok(Arc::new(Self { tree }))
125    }
126    pub fn len(&self) -> usize {
127        self.tree.len()
128    }
129    pub fn is_empty(&self) -> bool {
130        self.tree.is_empty()
131    }
132    pub fn iter(&self) -> sled::Iter {
133        self.tree.iter()
134    }
135    pub fn contains_key<K: AsRef<[u8]>>(&self, key: K) -> sled::Result<bool> {
136        self.tree.contains_key(key)
137    }
138    pub fn flush(&self) -> sled::Result<usize> {
139        self.tree.flush()
140    }
141    pub async fn flush_async(&self) -> sled::Result<usize> {
142        self.tree.flush_async().await
143    }
144    pub fn apply_batch(&self, batch: &mut DbBatch) -> sled::Result<()> {
145        self.tree.apply_batch(batch.take_inner())
146    }
147    pub fn insert<T>(&self, value: &T) -> Result<Option<sled::IVec>>
148    where
149        T: Serialize + DbKey,
150    {
151        Ok(self.tree.insert(value.key()?, serde_json::to_vec(value)?)?)
152    }
153    /// insert raw, untyped bytes
154    pub fn insert_raw(&self, key: &[u8], value: &[u8]) -> Result<Option<sled::IVec>> {
155        Ok(self.tree.insert(key, value)?)
156    }
157    pub fn get<K: AsRef<[u8]>>(&self, key: K) -> sled::Result<Option<sled::IVec>> {
158        self.tree.get(key)
159    }
160    /// currently broken
161    pub fn entries2<K: PartialEq + DbKey, T>(&self, skip_keys: &[K]) -> Result<Vec<T>>
162    where
163        T: serde::de::DeserializeOwned,
164    {
165        Ok(self
166            .tree
167            .into_iter()
168            .keys()
169            .filter_map(|key| match key {
170                Ok(key) => Some(key),
171                Err(_) => None,
172            })
173            .filter(|key| {
174                skip_keys
175                    .iter()
176                    .filter_map(|k| {
177                        if let Ok(key) = k.key() {
178                            Some(key)
179                        } else {
180                            None
181                        }
182                    })
183                    .any(|x| x == key.as_ref().to_vec())
184            })
185            .filter_map(|key| {
186                if let Ok(value) = self.deserialize(key) {
187                    Some(value)
188                } else {
189                    None
190                }
191            })
192            .collect())
193    }
194    pub fn deserialize<K: AsRef<[u8]>, T>(&self, key: K) -> Result<T>
195    where
196        T: serde::de::DeserializeOwned,
197    {
198        let value = self.get(key)?;
199        if let Some(value) = value {
200            Ok(serde_json::from_slice(&value)?)
201        } else {
202            Err(anyhow!("value for key is None"))
203        }
204    }
205    pub fn delete<K: AsRef<[u8]>>(&self, key: K) -> Result<()> {
206        self.tree.remove(key)?;
207        Ok(())
208    }
209}
210
211impl DbBatch {
212    pub fn new() -> DbBatch {
213        DbBatch {
214            batch: Default::default(),
215            count: 0,
216        }
217    }
218    pub fn remove<T>(&mut self, value: &T) -> Result<()>
219    where
220        T: Serialize + DbKey,
221    {
222        self.batch.remove(value.key()?);
223        self.count += 1;
224        Ok(())
225    }
226    /// removes raw, untyped bytes
227    pub fn remove_raw(&mut self, key: &[u8]) -> Result<()> {
228        self.batch.remove(key);
229        self.count += 1;
230        Ok(())
231    }
232    /// inserts raw untyped bytes
233    pub fn insert_raw(&mut self, key: &[u8], value: &[u8]) -> Result<()> {
234        self.batch.insert(key, value);
235        self.count += 1;
236        Ok(())
237    }
238    pub fn insert<T>(&mut self, value: &T) -> Result<()>
239    where
240        T: Serialize + DbKey,
241    {
242        self.batch.insert(value.key()?, serde_json::to_vec(value)?);
243        self.count += 1;
244        Ok(())
245    }
246    /// returns the inner batch, and should only be used when the batch object
247    /// is finished with and the batch needs to be applied, as it replaces the inner
248    /// batch with its default version
249    pub fn take_inner(&mut self) -> sled::Batch {
250        std::mem::take(&mut self.batch)
251    }
252    pub fn inner(&self) -> &sled::Batch {
253        &self.batch
254    }
255    pub fn count(&self) -> u64 {
256        self.count
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263    use serde::Deserialize;
264    use std::fs::remove_dir_all;
265
266    #[derive(Serialize, Deserialize)]
267    pub struct TestData {
268        pub key: String,
269        pub foo: String,
270    }
271
272    impl DbKey for TestData {
273        fn key(&self) -> anyhow::Result<Vec<u8>> {
274            Ok(self.key.as_bytes().to_vec())
275        }
276    }
277
278    // performs very basic database testing
279    #[test]
280    fn test_db_basic() {
281        let db_opts = DbOpts::default();
282
283        let db = Database::new(&db_opts).unwrap();
284        let insert = || {
285            let mut db_batch = DbBatch::new();
286            db_batch
287                .insert(&TestData {
288                    key: "key1".to_string(),
289                    foo: "foo1".to_string(),
290                })
291                .unwrap();
292            {
293                let tree = db.open_tree(DbTrees::Custom("foobar")).unwrap();
294                tree.apply_batch(&mut db_batch).unwrap();
295                assert_eq!(tree.len(), 1);
296            }
297
298            db_batch
299                .insert(&TestData {
300                    key: "key2".to_string(),
301                    foo: "foo2".to_string(),
302                })
303                .unwrap();
304            {
305                let tree = db.open_tree(DbTrees::Custom("foobar")).unwrap();
306                tree.apply_batch(&mut db_batch).unwrap();
307                assert_eq!(tree.len(), 2);
308            }
309
310            db_batch
311                .insert(&TestData {
312                    key: "key3".to_string(),
313                    foo: "foo3".to_string(),
314                })
315                .unwrap();
316            {
317                let tree = db.open_tree(DbTrees::Custom("foobarbaz")).unwrap();
318                tree.apply_batch(&mut db_batch).unwrap();
319                assert_eq!(tree.len(), 1);
320            }
321            db_batch
322                .insert(&TestData {
323                    key: "key4".to_string(),
324                    foo: "foo4".to_string(),
325                })
326                .unwrap();
327            {
328                db.apply_batch(&mut db_batch).unwrap();
329            }
330        };
331        let query = || {
332            let foobar_values = db.list_values(DbTrees::Custom("foobar")).unwrap();
333            assert_eq!(foobar_values.len(), 2);
334            let test_data_one: TestData = db
335                .open_tree(DbTrees::Custom("foobar"))
336                .unwrap()
337                .deserialize(foobar_values[0].0.clone())
338                .unwrap();
339            assert_eq!(test_data_one.key, "key1".to_string());
340            assert_eq!(test_data_one.foo, "foo1".to_string());
341            let test_data_two: TestData = db
342                .open_tree(DbTrees::Custom("foobar"))
343                .unwrap()
344                .deserialize(foobar_values[1].0.clone())
345                .unwrap();
346            assert_eq!(test_data_two.key, "key2".to_string());
347            assert_eq!(test_data_two.foo, "foo2".to_string());
348            let foobarbaz_values = db.list_values(DbTrees::Custom("foobarbaz")).unwrap();
349            assert_eq!(foobarbaz_values.len(), 1);
350            let test_data_three: TestData = db
351                .open_tree(DbTrees::Custom("foobarbaz"))
352                .unwrap()
353                .deserialize(foobarbaz_values[0].0.clone())
354                .unwrap();
355            assert_eq!(test_data_three.key, "key3".to_string());
356            assert_eq!(test_data_three.foo, "foo3".to_string());
357            let default_tree_values = db.list_values(DbTrees::Default).unwrap();
358            assert_eq!(default_tree_values.len(), 1);
359        };
360        insert();
361        query();
362        db.destroy();
363        remove_dir_all("test_infos.db").unwrap();
364    }
365}