rocksdb_table/
lib.rs

1//! Some helpers for working with RocksDB databases.
2
3use rocksdb::{DBCompressionType, DBIterator, IteratorMode, Options, DB};
4use std::borrow::Cow;
5use std::marker::PhantomData;
6use std::path::Path;
7use std::sync::Arc;
8
9pub mod error;
10
11/// Marker structs that indicate access mode.
12pub mod mode {
13    /// Indicates whether a database is opened in read-only or write mode.
14    pub trait Mode: 'static {
15        fn is_read_only() -> bool;
16        fn is_secondary() -> bool;
17        fn is_primary() -> bool {
18            !Self::is_read_only() && !Self::is_secondary()
19        }
20    }
21
22    /// Indicates that a database is opened in write mode.
23    pub trait IsWriteable: Mode {}
24
25    /// Indicates that a database is opened in secondary mode and can try to catch up with the primary.
26    pub trait IsSecondary: Mode {}
27
28    /// Indicates that a database is opened in a mode that only requires a single path (i.e. not
29    /// secondary mode).
30    pub trait SinglePath: Mode {}
31
32    #[derive(Clone, Copy)]
33    pub struct ReadOnly;
34
35    #[derive(Clone, Copy)]
36    pub struct Secondary;
37
38    #[derive(Clone, Copy)]
39    pub struct Writeable;
40
41    impl Mode for ReadOnly {
42        fn is_read_only() -> bool {
43            true
44        }
45
46        fn is_secondary() -> bool {
47            false
48        }
49    }
50
51    impl Mode for Secondary {
52        fn is_read_only() -> bool {
53            false
54        }
55
56        fn is_secondary() -> bool {
57            true
58        }
59    }
60
61    impl Mode for Writeable {
62        fn is_read_only() -> bool {
63            false
64        }
65
66        fn is_secondary() -> bool {
67            false
68        }
69    }
70
71    impl IsWriteable for Writeable {}
72    impl IsSecondary for Secondary {}
73    impl SinglePath for ReadOnly {}
74    impl SinglePath for Writeable {}
75}
76
77/// A wrapper for a RocksDB database that maintains type information about whether it was opened
78/// in read-only mode.
79#[derive(Clone)]
80pub struct Database<M> {
81    pub db: Arc<DB>,
82    options: Options,
83    _mode: PhantomData<M>,
84}
85
86/// A database table.
87pub trait Table<M>: Sized {
88    type Counts;
89    type Error: From<error::Error>;
90    type Key;
91    type KeyBytes: AsRef<[u8]>;
92    type Value;
93    type ValueBytes: AsRef<[u8]>;
94    type Index;
95    type IndexBytes: AsRef<[u8]>;
96
97    fn database(&self) -> &Database<M>;
98    fn from_database(database: Database<M>) -> Self;
99    fn get_counts(&self) -> Result<Self::Counts, Self::Error>;
100
101    fn key_to_bytes(key: &Self::Key) -> Result<Self::KeyBytes, Self::Error>;
102    fn value_to_bytes(value: &Self::Value) -> Result<Self::ValueBytes, Self::Error>;
103    fn index_to_bytes(index: &Self::Index) -> Result<Self::IndexBytes, Self::Error>;
104
105    fn bytes_to_key(bytes: Cow<[u8]>) -> Result<Self::Key, Self::Error>;
106    fn bytes_to_value(bytes: Cow<[u8]>) -> Result<Self::Value, Self::Error>;
107
108    fn default_compression_type() -> Option<DBCompressionType> {
109        None
110    }
111
112    fn statistics(&self) -> Option<String> {
113        self.database().options.get_statistics()
114    }
115
116    fn get_estimated_key_count(&self) -> Result<Option<u64>, error::Error> {
117        Ok(self
118            .database()
119            .db
120            .property_int_value("rocksdb.estimate-num-keys")?)
121    }
122
123    fn open_with_defaults<P: AsRef<Path>>(path: P) -> Result<Self, error::Error>
124    where
125        M: mode::SinglePath,
126    {
127        Self::open(path, |mut options| {
128            if let Some(compression_type) = Self::default_compression_type() {
129                options.set_compression_type(compression_type);
130            }
131
132            options
133        })
134    }
135
136    fn open_as_secondary_with_defaults<P: AsRef<Path>, S: AsRef<Path>>(
137        path: P,
138        secondary_path: S,
139    ) -> Result<Self, error::Error>
140    where
141        M: mode::IsSecondary,
142    {
143        Self::open_as_secondary(path, secondary_path, |mut options| {
144            if let Some(compression_type) = Self::default_compression_type() {
145                options.set_compression_type(compression_type);
146            }
147
148            options
149        })
150    }
151
152    fn open<P: AsRef<Path>, F: FnMut(Options) -> Options>(
153        path: P,
154        mut options_init: F,
155    ) -> Result<Self, error::Error>
156    where
157        M: mode::SinglePath,
158    {
159        let mut options = Options::default();
160        options.create_if_missing(true);
161
162        let options = options_init(options);
163
164        let db = if M::is_read_only() {
165            DB::open_for_read_only(&options, path, true)?
166        } else {
167            DB::open(&options, path)?
168        };
169
170        Ok(Self::from_database(Database {
171            db: Arc::new(db),
172            options,
173            _mode: PhantomData,
174        }))
175    }
176
177    fn open_as_secondary<P: AsRef<Path>, S: AsRef<Path>, F: FnMut(Options) -> Options>(
178        path: P,
179        secondary_path: S,
180        mut options_init: F,
181    ) -> Result<Self, error::Error>
182    where
183        M: mode::IsSecondary,
184    {
185        let mut options = Options::default();
186        options.create_if_missing(true);
187
188        let options = options_init(options);
189        let db = DB::open_as_secondary(&options, path.as_ref(), secondary_path.as_ref())?;
190
191        Ok(Self::from_database(Database {
192            db: Arc::new(db),
193            options,
194            _mode: PhantomData,
195        }))
196    }
197
198    fn iter(&self) -> TableIterator<'_, M, Self>
199    where
200        M: 'static,
201    {
202        TableIterator {
203            underlying: self.database().db.iterator(IteratorMode::Start),
204            _mode: PhantomData,
205            _table: PhantomData,
206        }
207    }
208
209    fn iter_selected_values<P: Fn(&Self::Key) -> bool>(
210        &self,
211        pred: P,
212    ) -> SelectedValueTableIterator<'_, M, Self, P>
213    where
214        M: 'static,
215    {
216        SelectedValueTableIterator {
217            underlying: self.database().db.iterator(IteratorMode::Start),
218            pred,
219            _mode: PhantomData,
220            _table: PhantomData,
221        }
222    }
223
224    fn lookup_key(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error> {
225        let key_bytes = Self::key_to_bytes(key)?;
226        self.database()
227            .db
228            .get_pinned(key_bytes)
229            .map_err(error::Error::from)?
230            .map_or(Ok(None), |value_bytes| {
231                Self::bytes_to_value(Cow::from(value_bytes.as_ref())).map(Some)
232            })
233    }
234
235    fn lookup_index(&self, index: &Self::Index) -> IndexIterator<'_, M, Self>
236    where
237        M: 'static,
238    {
239        match Self::index_to_bytes(index) {
240            Ok(index_bytes) => IndexIterator::ValidIndex {
241                underlying: self.database().db.prefix_iterator(&index_bytes),
242                index_bytes,
243                _mode: PhantomData,
244                _table: PhantomData,
245            },
246            Err(error) => IndexIterator::InvalidIndex { error: Some(error) },
247        }
248    }
249
250    fn lookup_index_selected_values<P: Fn(&Self::Key) -> bool>(
251        &self,
252        index: &Self::Index,
253        pred: P,
254    ) -> SelectedValueIndexIterator<'_, M, Self, P>
255    where
256        M: 'static,
257    {
258        match Self::index_to_bytes(index) {
259            Ok(index_bytes) => SelectedValueIndexIterator::ValidIndex {
260                underlying: self.database().db.prefix_iterator(&index_bytes),
261                index_bytes,
262                pred,
263                _mode: PhantomData,
264                _table: PhantomData,
265            },
266            Err(error) => SelectedValueIndexIterator::InvalidIndex { error: Some(error) },
267        }
268    }
269
270    fn put(&self, key: &Self::Key, value: &Self::Value) -> Result<(), Self::Error>
271    where
272        M: mode::IsWriteable,
273    {
274        let key_bytes = Self::key_to_bytes(key)?;
275        let value_bytes = Self::value_to_bytes(value)?;
276        Ok(self
277            .database()
278            .db
279            .put(key_bytes, value_bytes)
280            .map_err(error::Error::from)?)
281    }
282
283    fn catch_up_with_primary(&self) -> Result<(), Self::Error>
284    where
285        M: mode::IsSecondary,
286    {
287        Ok(self
288            .database()
289            .db
290            .try_catch_up_with_primary()
291            .map_err(error::Error::from)?)
292    }
293}
294
295pub struct TableIterator<'a, M, T> {
296    underlying: DBIterator<'a>,
297    _mode: PhantomData<M>,
298    _table: PhantomData<T>,
299}
300
301impl<'a, M: mode::Mode, T: Table<M>> Iterator for TableIterator<'a, M, T> {
302    type Item = Result<(T::Key, T::Value), T::Error>;
303
304    fn next(&mut self) -> Option<Self::Item> {
305        self.underlying.next().map(|result| {
306            result
307                .map_err(|error| T::Error::from(error.into()))
308                .and_then(|(key_bytes, value_bytes)| {
309                    T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
310                        T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
311                            .map(|value| (key, value))
312                    })
313                })
314        })
315    }
316}
317
318/// Allows selection of values to decode (if for example this is expensive).
319pub struct SelectedValueTableIterator<'a, M, T, P> {
320    underlying: DBIterator<'a>,
321    pred: P,
322    _mode: PhantomData<M>,
323    _table: PhantomData<T>,
324}
325
326impl<'a, M: mode::Mode, T: Table<M>, P: Fn(&T::Key) -> bool> Iterator
327    for SelectedValueTableIterator<'a, M, T, P>
328{
329    type Item = Result<(T::Key, Option<T::Value>), T::Error>;
330
331    fn next(&mut self) -> Option<Self::Item> {
332        self.underlying.next().map(|result| {
333            result
334                .map_err(|error| T::Error::from(error.into()))
335                .and_then(|(key_bytes, value_bytes)| {
336                    T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
337                        if (self.pred)(&key) {
338                            T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
339                                .map(|value| (key, Some(value)))
340                        } else {
341                            Ok((key, None))
342                        }
343                    })
344                })
345        })
346    }
347}
348
349pub enum IndexIterator<'a, M, T: Table<M>> {
350    ValidIndex {
351        underlying: DBIterator<'a>,
352        index_bytes: T::IndexBytes,
353        _mode: PhantomData<M>,
354        _table: PhantomData<T>,
355    },
356    InvalidIndex {
357        error: Option<T::Error>,
358    },
359}
360
361impl<'a, M: mode::Mode, T: Table<M>> Iterator for IndexIterator<'a, M, T> {
362    type Item = Result<(T::Key, T::Value), T::Error>;
363
364    fn next(&mut self) -> Option<Self::Item> {
365        match self {
366            IndexIterator::ValidIndex {
367                underlying,
368                index_bytes,
369                ..
370            } => underlying.next().and_then(|result| match result {
371                Ok((key_bytes, value_bytes)) => {
372                    if key_bytes.starts_with(index_bytes.as_ref()) {
373                        Some(
374                            T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
375                                T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
376                                    .map(|value| (key, value))
377                            }),
378                        )
379                    } else {
380                        None
381                    }
382                }
383                Err(error) => Some(Err(T::Error::from(error.into()))),
384            }),
385            IndexIterator::InvalidIndex { error } => error.take().map(Err),
386        }
387    }
388}
389
390/// Allows selection of values to decode (if for example this is expensive).
391pub enum SelectedValueIndexIterator<'a, M, T: Table<M>, P> {
392    ValidIndex {
393        underlying: DBIterator<'a>,
394        index_bytes: T::IndexBytes,
395        pred: P,
396        _mode: PhantomData<M>,
397        _table: PhantomData<T>,
398    },
399    InvalidIndex {
400        error: Option<T::Error>,
401    },
402}
403
404impl<'a, M: mode::Mode, T: Table<M>, P: Fn(&T::Key) -> bool> Iterator
405    for SelectedValueIndexIterator<'a, M, T, P>
406{
407    type Item = Result<(T::Key, Option<T::Value>), T::Error>;
408
409    fn next(&mut self) -> Option<Self::Item> {
410        match self {
411            SelectedValueIndexIterator::ValidIndex {
412                underlying,
413                index_bytes,
414                pred,
415                ..
416            } => underlying.next().and_then(|result| match result {
417                Ok((key_bytes, value_bytes)) => {
418                    if key_bytes.starts_with(index_bytes.as_ref()) {
419                        Some(
420                            T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
421                                if (pred)(&key) {
422                                    T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
423                                        .map(|value| (key, Some(value)))
424                                } else {
425                                    Ok((key, None))
426                                }
427                            }),
428                        )
429                    } else {
430                        None
431                    }
432                }
433                Err(error) => Some(Err(T::Error::from(error.into()))),
434            }),
435            SelectedValueIndexIterator::InvalidIndex { error } => error.take().map(Err),
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[derive(thiserror::Error, Debug)]
445    pub enum Error {
446        #[error("RocksDb table error")]
447        RocksDbTable(#[from] error::Error),
448        #[error("String encoding error")]
449        Utf8(#[from] std::str::Utf8Error),
450    }
451
452    struct Dictionary<M> {
453        database: Database<M>,
454    }
455
456    impl<M: mode::Mode> Table<M> for Dictionary<M> {
457        type Counts = usize;
458        type Error = Error;
459        type Key = String;
460        type KeyBytes = Vec<u8>;
461        type Value = u64;
462        type ValueBytes = [u8; 8];
463        type Index = String;
464        type IndexBytes = Vec<u8>;
465
466        fn database(&self) -> &Database<M> {
467            &self.database
468        }
469
470        fn from_database(database: Database<M>) -> Self {
471            Self { database }
472        }
473
474        fn key_to_bytes(key: &Self::Key) -> Result<Self::KeyBytes, Self::Error> {
475            Ok(key.as_bytes().to_vec())
476        }
477
478        fn value_to_bytes(value: &Self::Value) -> Result<Self::ValueBytes, Self::Error> {
479            Ok(value.to_be_bytes())
480        }
481
482        fn index_to_bytes(index: &Self::Index) -> Result<Self::IndexBytes, Self::Error> {
483            Ok(index.as_bytes().to_vec())
484        }
485
486        fn bytes_to_key(bytes: Cow<[u8]>) -> Result<Self::Key, Self::Error> {
487            Ok(std::str::from_utf8(bytes.as_ref())?.to_string())
488        }
489
490        fn bytes_to_value(bytes: Cow<[u8]>) -> Result<Self::Value, Self::Error> {
491            Ok(u64::from_be_bytes(
492                bytes.as_ref()[0..8]
493                    .try_into()
494                    .map_err(|_| error::Error::InvalidValue(bytes.as_ref().to_vec()))?,
495            ))
496        }
497
498        fn get_counts(&self) -> Result<Self::Counts, Error> {
499            let mut count = 0;
500
501            for result in self.iter() {
502                result?;
503                count += 1;
504            }
505
506            Ok(count)
507        }
508    }
509
510    fn contents() -> Vec<(String, u64)> {
511        vec![
512            ("bar", 1000),
513            ("baz", 98765),
514            ("foo", 1),
515            ("abc", 23),
516            ("qux", 0),
517        ]
518        .into_iter()
519        .map(|(key, value)| (key.to_string(), value))
520        .collect()
521    }
522
523    #[test]
524    fn lookup_key() {
525        let directory = tempfile::tempdir().unwrap();
526        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
527
528        for (key, value) in contents() {
529            dictionary.put(&key.to_string(), &value).unwrap();
530        }
531
532        assert_eq!(dictionary.lookup_key(&"foo".to_string()).unwrap(), Some(1));
533        assert_eq!(
534            dictionary.lookup_key(&"bar".to_string()).unwrap(),
535            Some(1000)
536        );
537        assert_eq!(dictionary.lookup_key(&"XYZ".to_string()).unwrap(), None);
538    }
539
540    #[test]
541    fn lookup_index() {
542        let directory = tempfile::tempdir().unwrap();
543        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
544
545        for (key, value) in contents() {
546            dictionary.put(&key.to_string(), &value).unwrap();
547        }
548
549        assert_eq!(
550            &dictionary
551                .lookup_index(&"ba".to_string())
552                .collect::<Result<Vec<_>, _>>()
553                .unwrap(),
554            &contents()[0..2].to_vec()
555        );
556    }
557
558    #[test]
559    fn iter() {
560        let directory = tempfile::tempdir().unwrap();
561        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
562
563        for (key, value) in contents() {
564            dictionary.put(&key.to_string(), &value).unwrap();
565        }
566
567        let mut expected = contents();
568        expected.sort();
569
570        assert_eq!(
571            dictionary.iter().collect::<Result<Vec<_>, _>>().unwrap(),
572            expected
573        );
574    }
575
576    #[test]
577    fn get_counts() {
578        let directory = tempfile::tempdir().unwrap();
579        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
580
581        for (key, value) in contents() {
582            dictionary.put(&key.to_string(), &value).unwrap();
583        }
584
585        assert_eq!(dictionary.get_counts().unwrap(), contents().len());
586    }
587}