prost_sled/
lib.rs

1//! prost-sled: An integration layer between [prost] and [sled].
2//!
3//! prost-sled makes it easy to store protobufs in a sled database because
4//! it abstracts away the boilerplate of encoding and decoding the protobufs.
5//! If for any reason you wish to interact with the raw bytes instead or
6//! [sled::Db] implements a method that [ProtoDb] doesn't yet, you can simply
7//! use the `from` and `into` methods of the corresponding types as [From] and
8//! [Into] are implemented as a go between, between the two types.
9
10use std::path::Path;
11
12use prost::{bytes::BytesMut, Message};
13use thiserror::Error;
14
15/// Errors that can be returned by this library. It's really a simple
16/// integration layer to encompass the possible errors returned by [sled] and
17/// [prost].
18#[derive(Debug, Error, PartialEq)]
19pub enum Error {
20    /// An error was returned by [sled].
21    #[error(transparent)]
22    SledError(#[from] sled::Error),
23    /// A decoding error ([prost::DecodeError]) occurred in [prost].
24    #[error(transparent)]
25    ProstDecodeError(#[from] prost::DecodeError),
26    /// An encoding error ([prost::EncodeError]) occurred in [prost].
27    #[error(transparent)]
28    ProstEncodeError(#[from] prost::EncodeError),
29}
30
31/// Result of a database action. That is, either some type `T` or an
32/// [enum@Error].
33pub type Result<T> = std::result::Result<T, Error>;
34
35/// Wrapper around [sled::Db] that allows you to use types implementing
36/// [prost::Message] instead of raw bytes.
37#[derive(Clone)]
38pub struct ProtoDb(sled::Db);
39
40/// Escape hatch to get from a [ProtoDb] to a [sled::Db] in-case you need
41/// something more low level.
42impl From<ProtoDb> for sled::Db {
43    fn from(db: ProtoDb) -> Self {
44        db.0
45    }
46}
47
48/// Convenience implementation to convert an existing [sled::Db] to a [ProtoDb].
49impl From<sled::Db> for ProtoDb {
50    fn from(db: sled::Db) -> Self {
51        Self(db)
52    }
53}
54
55/// Wrapper around [sled::Tree] that allows you to use types implementing [prost::Message] instead
56/// of raw bytes.
57#[derive(Clone)]
58pub struct ProtoTree(sled::Tree);
59
60/// Escape hatch to get from a [ProtoTree] to a [sled::Tree] in-case you need
61/// something more low level.
62impl From<ProtoTree> for sled::Tree {
63    fn from(t: ProtoTree) -> Self {
64        t.0
65    }
66}
67
68/// Convenience implementation to convert an existing [sled::Db] to a [ProtoDb].
69impl From<sled::Tree> for ProtoTree {
70    fn from(t: sled::Tree) -> Self {
71        Self(t)
72    }
73}
74
75/// Create a [ProtoDb] and return it.
76pub fn open<P: AsRef<Path>>(path: P) -> Result<ProtoDb> {
77    let db = sled::open(path)?;
78    Ok(db.into())
79}
80
81impl ProtoDb {
82    pub fn open_tree<V: AsRef<[u8]>>(&self, name: V) -> Result<ProtoTree> {
83        Ok(self.0.open_tree(name)?.into())
84    }
85
86    // Ewww. These methods are exactly the same as those in ProtoTree. I'm not sure of a better way
87    // though. sled implements Db as a deref of Tree but we can't also implement it that way.
88    pub fn contains_key<K>(&self, key: K) -> Result<bool>
89    where
90        K: AsRef<[u8]>,
91    {
92        let exists = self.0.contains_key(key)?;
93        Ok(exists)
94    }
95
96    /// Get a value by its key.
97    pub fn get<K, T>(&self, key: K) -> Result<Option<T>>
98    where
99        K: AsRef<[u8]>,
100        T: Message + Default,
101    {
102        let maybe_data = self.0.get(key)?;
103        if let Some(data) = maybe_data {
104            let msg = T::decode(&*data)?;
105            Ok(Some(msg))
106        } else {
107            Ok(None)
108        }
109    }
110
111    /// Atomically retrieve then update a value.
112    pub fn update_and_fetch<K, V, F, T>(&self, key: K, mut f: F) -> Result<Option<T>>
113    where
114        K: AsRef<[u8]>,
115        F: FnMut(Option<T>) -> Option<V>,
116        V: Into<T>,
117        T: Message + Default,
118    {
119        // Escape the scoping of the closure so we can report the error.
120        let mut err: Option<Error> = None;
121        let maybe_data = self.0.update_and_fetch(key, |maybe_data| {
122            let maybe_msg = if let Some(data) = maybe_data {
123                match T::decode(data) {
124                    Ok(value) => Some(value),
125                    Err(e) => {
126                        err = Some(e.into());
127                        None
128                    }
129                }
130            } else {
131                None
132            };
133            if let Some(inserted) = f(maybe_msg) {
134                let mut buf = BytesMut::default();
135                let inserted_msg: T = inserted.into();
136                if let Err(e) = inserted_msg.encode(&mut buf) {
137                    err = Some(e.into());
138                    None
139                } else {
140                    Some(buf.as_bytes())
141                }
142            } else {
143                None
144            }
145        })?;
146
147        if let Some(e) = err {
148            return Err(e);
149        }
150
151        if let Some(data) = maybe_data {
152            let msg = T::decode(&*data)?;
153            Ok(Some(msg))
154        } else {
155            Ok(None)
156        }
157    }
158
159    /// Insert a value into the database.
160    pub fn insert<K, V, T>(&self, key: K, value: V) -> Result<Option<T>>
161    where
162        K: AsRef<[u8]>,
163        V: Into<T>,
164        T: Message + Default,
165    {
166        let mut buf = BytesMut::default();
167        let msg: T = value.into();
168        msg.encode(&mut buf)?;
169        let maybe_inserted = self.0.insert(key, buf.as_bytes())?;
170        if let Some(inserted) = maybe_inserted {
171            let msg = T::decode(&*inserted)?;
172            Ok(Some(msg))
173        } else {
174            Ok(None)
175        }
176    }
177}
178
179impl ProtoTree {
180    pub fn contains_key<K>(&self, key: K) -> Result<bool>
181    where
182        K: AsRef<[u8]>,
183    {
184        let exists = self.0.contains_key(key)?;
185        Ok(exists)
186    }
187
188    /// Get a value by its key.
189    pub fn get<K, T>(&self, key: K) -> Result<Option<T>>
190    where
191        K: AsRef<[u8]>,
192        T: Message + Default,
193    {
194        let maybe_data = self.0.get(key)?;
195        if let Some(data) = maybe_data {
196            let msg = T::decode(&*data)?;
197            Ok(Some(msg))
198        } else {
199            Ok(None)
200        }
201    }
202
203    /// Atomically retrieve then update a value.
204    pub fn update_and_fetch<K, V, F, T>(&self, key: K, mut f: F) -> Result<Option<T>>
205    where
206        K: AsRef<[u8]>,
207        F: FnMut(Option<T>) -> Option<V>,
208        V: Into<T>,
209        T: Message + Default,
210    {
211        // Escape the scoping of the closure so we can report the error.
212        let mut err: Option<Error> = None;
213        let maybe_data = self.0.update_and_fetch(key, |maybe_data| {
214            let maybe_msg = if let Some(data) = maybe_data {
215                match T::decode(data) {
216                    Ok(value) => Some(value),
217                    Err(e) => {
218                        err = Some(e.into());
219                        None
220                    }
221                }
222            } else {
223                None
224            };
225            if let Some(inserted) = f(maybe_msg) {
226                let mut buf = BytesMut::default();
227                let inserted_msg: T = inserted.into();
228                if let Err(e) = inserted_msg.encode(&mut buf) {
229                    err = Some(e.into());
230                    None
231                } else {
232                    Some(buf.as_bytes())
233                }
234            } else {
235                None
236            }
237        })?;
238
239        if let Some(e) = err {
240            return Err(e);
241        }
242
243        if let Some(data) = maybe_data {
244            let msg = T::decode(&*data)?;
245            Ok(Some(msg))
246        } else {
247            Ok(None)
248        }
249    }
250
251    /// Insert a value into the database.
252    pub fn insert<K, V, T>(&self, key: K, value: V) -> Result<Option<T>>
253    where
254        K: AsRef<[u8]>,
255        V: Into<T>,
256        T: Message + Default,
257    {
258        let mut buf = BytesMut::default();
259        let msg: T = value.into();
260        msg.encode(&mut buf)?;
261        let maybe_inserted = self.0.insert(key, buf.as_bytes())?;
262        if let Some(inserted) = maybe_inserted {
263            let msg = T::decode(&*inserted)?;
264            Ok(Some(msg))
265        } else {
266            Ok(None)
267        }
268    }
269}
270
271/// Convenience trait to convert to stdlib bytes. This is intended only for
272/// Convenience trait to convert to stdlib bytes. This is intended only for
273/// internal use (hence not `pub`) and is only implemented for `BytesMut`.
274trait BytesMutAsBytes {
275    fn as_bytes(&self) -> Vec<u8>;
276}
277
278/// Conversion between `BytesMut` and `Vec<u8>`. This is just a convenience as
279/// `sled` works with `Vec<u8>` and `prost` uses `BytesMut`.
280impl BytesMutAsBytes for BytesMut {
281    fn as_bytes(&self) -> Vec<u8> {
282        let bytes: &[u8] = &self;
283        bytes.to_owned()
284    }
285}
286
287#[cfg(test)]
288mod test {
289    use once_cell::sync::OnceCell;
290    use rand::{distributions::Alphanumeric, thread_rng, Rng};
291
292    use super::{ProtoDb, ProtoTree};
293
294    mod messages {
295        include!(concat!(env!("OUT_DIR"), "/messages.rs"));
296    }
297
298    fn random_key() -> String {
299        thread_rng()
300            .sample_iter(&Alphanumeric)
301            .take(5)
302            .map(char::from)
303            .collect()
304    }
305
306    fn proto_db() -> ProtoDb {
307        static DB: OnceCell<sled::Db> = OnceCell::new();
308        let db = DB.get_or_init(|| {
309            sled::open("/tmp/prost-sled-test").expect("failed to open sled DB for test")
310        });
311        ProtoDb(db.to_owned())
312    }
313
314    fn proto_tree() -> ProtoTree {
315        static TREE: OnceCell<sled::Tree> = OnceCell::new();
316        let db = proto_db();
317        let tree = TREE.get_or_init(|| db.0.open_tree("test").expect("failed to create tree"));
318        ProtoTree(tree.to_owned())
319    }
320
321    #[test]
322    fn get_exists() {
323        let db = proto_db();
324        let thing = messages::Thing::default();
325        let key = random_key();
326        let _: Option<messages::Thing> = db.insert(&key, thing.clone()).unwrap();
327        let retrieved = db.get(&key).unwrap().unwrap();
328        assert_eq!(thing, retrieved);
329    }
330
331    #[test]
332    fn get_no_exist() {
333        let db = proto_db();
334        let key = random_key();
335        let retrieved: Option<messages::Thing> = db.get(&key).unwrap();
336        assert!(retrieved.is_none());
337    }
338
339    #[test]
340    fn update_and_fetch_existing() {
341        let db = proto_db();
342        let thing = messages::Thing::default();
343        let key = random_key();
344        let _: Option<messages::Thing> = db.insert(&key, thing).unwrap();
345        let updated = db
346            .update_and_fetch(&key, |maybe_msg| {
347                let mut msg: messages::Thing = maybe_msg.unwrap();
348                msg.i = "test".into();
349                Some(msg)
350            })
351            .unwrap();
352        let expected = messages::Thing {
353            i: "test".into(),
354            ..Default::default()
355        };
356        assert_eq!(updated, Some(expected));
357    }
358
359    #[test]
360    fn update_and_fetch_no_existing() {
361        let db = proto_db();
362        let key = random_key();
363        let updated = db
364            .update_and_fetch(&key, |maybe_msg| {
365                assert!(maybe_msg.is_none());
366                Some(messages::Thing {
367                    i: "test".into(),
368                    ..Default::default()
369                })
370            })
371            .unwrap();
372        let expected = messages::Thing {
373            i: "test".into(),
374            ..Default::default()
375        };
376        assert_eq!(updated, Some(expected));
377    }
378
379    #[test]
380    fn insert_not_existing() {
381        let db = proto_db();
382        let key = random_key();
383        let thing = messages::Thing::default();
384        let previous: Option<messages::Thing> = db.insert(&key, thing.clone()).unwrap();
385        assert!(previous.is_none());
386        let inserted: messages::Thing = db.get(&key).unwrap().unwrap();
387        assert_eq!(inserted, thing);
388    }
389
390    #[test]
391    fn insert_existing() {
392        let db = proto_db();
393        let key = random_key();
394        let first_thing = messages::Thing::default();
395        let previous: Option<messages::Thing> = db.insert(&key, first_thing.clone()).unwrap();
396        assert!(previous.is_none());
397        let inserted: messages::Thing = db.get(&key).unwrap().unwrap();
398        assert_eq!(inserted, first_thing);
399
400        let second_thing = messages::Thing {
401            i: random_key(),
402            ..Default::default()
403        };
404        let second_inserted: messages::Thing =
405            db.insert(&key, second_thing.clone()).unwrap().unwrap();
406        assert_eq!(first_thing, second_inserted);
407        let retieved: messages::Thing = db.get(&key).unwrap().unwrap();
408        assert_eq!(retieved, second_thing);
409    }
410
411    #[test]
412    fn tree_get_exists() {
413        let tree = proto_tree();
414        let thing = messages::Thing::default();
415        let key = random_key();
416        let _: Option<messages::Thing> = tree.insert(&key, thing.clone()).unwrap();
417        let retrieved = tree.get(&key).unwrap().unwrap();
418        assert_eq!(thing, retrieved);
419    }
420
421    #[test]
422    fn tree_get_no_exist() {
423        let tree = proto_tree();
424        let key = random_key();
425        let retrieved: Option<messages::Thing> = tree.get(&key).unwrap();
426        assert!(retrieved.is_none());
427    }
428
429    #[test]
430    fn tree_update_and_fetch_existing() {
431        let tree = proto_tree();
432        let thing = messages::Thing::default();
433        let key = random_key();
434        let _: Option<messages::Thing> = tree.insert(&key, thing).unwrap();
435        let updated = tree
436            .update_and_fetch(&key, |maybe_msg| {
437                let mut msg: messages::Thing = maybe_msg.unwrap();
438                msg.i = "test".into();
439                Some(msg)
440            })
441            .unwrap();
442        let expected = messages::Thing {
443            i: "test".into(),
444            ..Default::default()
445        };
446        assert_eq!(updated, Some(expected));
447    }
448
449    #[test]
450    fn tree_update_and_fetch_no_existing() {
451        let tree = proto_tree();
452        let key = random_key();
453        let updated = tree
454            .update_and_fetch(&key, |maybe_msg| {
455                assert!(maybe_msg.is_none());
456                Some(messages::Thing {
457                    i: "test".into(),
458                    ..Default::default()
459                })
460            })
461            .unwrap();
462        let expected = messages::Thing {
463            i: "test".into(),
464            ..Default::default()
465        };
466        assert_eq!(updated, Some(expected));
467    }
468
469    #[test]
470    fn tree_insert_not_existing() {
471        let tree = proto_tree();
472        let key = random_key();
473        let thing = messages::Thing::default();
474        let previous: Option<messages::Thing> = tree.insert(&key, thing.clone()).unwrap();
475        assert!(previous.is_none());
476        let inserted: messages::Thing = tree.get(&key).unwrap().unwrap();
477        assert_eq!(inserted, thing);
478    }
479
480    #[test]
481    fn tree_insert_existing() {
482        let tree = proto_tree();
483        let key = random_key();
484        let first_thing = messages::Thing::default();
485        let previous: Option<messages::Thing> = tree.insert(&key, first_thing.clone()).unwrap();
486        assert!(previous.is_none());
487        let inserted: messages::Thing = tree.get(&key).unwrap().unwrap();
488        assert_eq!(inserted, first_thing);
489
490        let second_thing = messages::Thing {
491            i: random_key(),
492            ..Default::default()
493        };
494        let second_inserted: messages::Thing =
495            tree.insert(&key, second_thing.clone()).unwrap().unwrap();
496        assert_eq!(first_thing, second_inserted);
497        let retieved: messages::Thing = tree.get(&key).unwrap().unwrap();
498        assert_eq!(retieved, second_thing);
499    }
500}