rocksdb_table/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3#![forbid(unsafe_code)]
4//! Some helpers for working with RocksDB databases.
5
6use rocksdb::{DBCompressionType, DBIterator, IteratorMode, Options, DB};
7use std::borrow::Cow;
8use std::marker::PhantomData;
9use std::path::Path;
10use std::sync::Arc;
11
12pub mod error;
13
14/// Marker structs that indicate access mode.
15pub mod mode {
16    /// Indicates whether a database is opened in read-only or write mode.
17    pub trait Mode: 'static {
18        fn is_read_only() -> bool;
19        fn is_secondary() -> bool;
20        #[must_use]
21        fn is_primary() -> bool {
22            !Self::is_read_only() && !Self::is_secondary()
23        }
24    }
25
26    /// Indicates that a database is opened in write mode.
27    pub trait IsWriteable: Mode {}
28
29    /// Indicates that a database is opened in secondary mode and can try to catch up with the primary.
30    pub trait IsSecondary: Mode {}
31
32    /// Indicates that a database is opened in a mode that only requires a single path (i.e. not
33    /// secondary mode).
34    pub trait SinglePath: Mode {}
35
36    #[derive(Clone, Copy)]
37    pub struct ReadOnly;
38
39    #[derive(Clone, Copy)]
40    pub struct Secondary;
41
42    #[derive(Clone, Copy)]
43    pub struct Writeable;
44
45    impl Mode for ReadOnly {
46        fn is_read_only() -> bool {
47            true
48        }
49
50        fn is_secondary() -> bool {
51            false
52        }
53    }
54
55    impl Mode for Secondary {
56        fn is_read_only() -> bool {
57            false
58        }
59
60        fn is_secondary() -> bool {
61            true
62        }
63    }
64
65    impl Mode for Writeable {
66        fn is_read_only() -> bool {
67            false
68        }
69
70        fn is_secondary() -> bool {
71            false
72        }
73    }
74
75    impl IsWriteable for Writeable {}
76    impl IsSecondary for Secondary {}
77    impl SinglePath for ReadOnly {}
78    impl SinglePath for Writeable {}
79}
80
81/// A wrapper for a RocksDB database that maintains type information about whether it was opened
82/// in read-only mode.
83#[derive(Clone)]
84pub struct Database<M> {
85    pub db: Arc<DB>,
86    options: Options,
87    _mode: PhantomData<M>,
88}
89
90/// A database table.
91pub trait Table<M>: Sized {
92    type Counts;
93    type Error: From<error::Error>;
94    type Key;
95    type KeyBytes: AsRef<[u8]>;
96    type Value;
97    type ValueBytes: AsRef<[u8]>;
98    type Index;
99    type IndexBytes: AsRef<[u8]>;
100
101    fn database(&self) -> &Database<M>;
102    fn from_database(database: Database<M>) -> Self;
103    fn get_counts(&self) -> Result<Self::Counts, Self::Error>;
104
105    fn key_to_bytes(key: &Self::Key) -> Result<Self::KeyBytes, Self::Error>;
106    fn value_to_bytes(value: &Self::Value) -> Result<Self::ValueBytes, Self::Error>;
107    fn index_to_bytes(index: &Self::Index) -> Result<Self::IndexBytes, Self::Error>;
108
109    fn bytes_to_key(bytes: Cow<[u8]>) -> Result<Self::Key, Self::Error>;
110    fn bytes_to_value(bytes: Cow<[u8]>) -> Result<Self::Value, Self::Error>;
111
112    #[must_use]
113    fn default_compression_type() -> Option<DBCompressionType> {
114        None
115    }
116
117    fn statistics(&self) -> Option<String> {
118        self.database().options.get_statistics()
119    }
120
121    fn get_estimated_key_count(&self) -> Result<Option<u64>, error::Error> {
122        Ok(self
123            .database()
124            .db
125            .property_int_value("rocksdb.estimate-num-keys")?)
126    }
127
128    fn open_with_defaults<P: AsRef<Path>>(path: P) -> Result<Self, error::Error>
129    where
130        M: mode::SinglePath,
131    {
132        Self::open(path, |mut options| {
133            if let Some(compression_type) = Self::default_compression_type() {
134                options.set_compression_type(compression_type);
135            }
136
137            options
138        })
139    }
140
141    fn open_as_secondary_with_defaults<P: AsRef<Path>, S: AsRef<Path>>(
142        path: P,
143        secondary_path: S,
144    ) -> Result<Self, error::Error>
145    where
146        M: mode::IsSecondary,
147    {
148        Self::open_as_secondary(path, secondary_path, |mut options| {
149            if let Some(compression_type) = Self::default_compression_type() {
150                options.set_compression_type(compression_type);
151            }
152
153            options
154        })
155    }
156
157    fn open<P: AsRef<Path>, F: FnMut(Options) -> Options>(
158        path: P,
159        mut options_init: F,
160    ) -> Result<Self, error::Error>
161    where
162        M: mode::SinglePath,
163    {
164        let mut options = Options::default();
165        options.create_if_missing(true);
166
167        let options = options_init(options);
168
169        let db = if M::is_read_only() {
170            DB::open_for_read_only(&options, path, true)?
171        } else {
172            DB::open(&options, path)?
173        };
174
175        Ok(Self::from_database(Database {
176            db: Arc::new(db),
177            options,
178            _mode: PhantomData,
179        }))
180    }
181
182    fn open_as_secondary<P: AsRef<Path>, S: AsRef<Path>, F: FnMut(Options) -> Options>(
183        path: P,
184        secondary_path: S,
185        mut options_init: F,
186    ) -> Result<Self, error::Error>
187    where
188        M: mode::IsSecondary,
189    {
190        let mut options = Options::default();
191        options.create_if_missing(true);
192
193        let options = options_init(options);
194        let db = DB::open_as_secondary(&options, path.as_ref(), secondary_path.as_ref())?;
195
196        Ok(Self::from_database(Database {
197            db: Arc::new(db),
198            options,
199            _mode: PhantomData,
200        }))
201    }
202
203    fn iter(&self) -> TableIterator<'_, M, Self>
204    where
205        M: mode::Mode,
206    {
207        TableIterator {
208            underlying: self.database().db.iterator(IteratorMode::Start),
209            _mode: PhantomData,
210            _table: PhantomData,
211        }
212    }
213
214    fn iter_selected_values<P: Fn(&Self::Key) -> bool>(
215        &self,
216        pred: P,
217    ) -> SelectedValueTableIterator<'_, M, Self, P>
218    where
219        M: 'static,
220    {
221        SelectedValueTableIterator {
222            underlying: self.database().db.iterator(IteratorMode::Start),
223            pred,
224            _mode: PhantomData,
225            _table: PhantomData,
226        }
227    }
228
229    fn lookup_key(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error> {
230        let key_bytes = Self::key_to_bytes(key)?;
231        self.database()
232            .db
233            .get_pinned(key_bytes)
234            .map_err(error::Error::from)?
235            .map_or(Ok(None), |value_bytes| {
236                Self::bytes_to_value(Cow::from(value_bytes.as_ref())).map(Some)
237            })
238    }
239
240    fn lookup_index(&self, index: &Self::Index) -> IndexIterator<'_, M, Self>
241    where
242        M: 'static,
243    {
244        match Self::index_to_bytes(index) {
245            Ok(index_bytes) => IndexIterator::ValidIndex {
246                underlying: self.database().db.prefix_iterator(&index_bytes),
247                index_bytes,
248                _mode: PhantomData,
249                _table: PhantomData,
250            },
251            Err(error) => IndexIterator::InvalidIndex { error: Some(error) },
252        }
253    }
254
255    fn lookup_index_selected_values<P: Fn(&Self::Key) -> bool>(
256        &self,
257        index: &Self::Index,
258        pred: P,
259    ) -> SelectedValueIndexIterator<'_, M, Self, P>
260    where
261        M: 'static,
262    {
263        match Self::index_to_bytes(index) {
264            Ok(index_bytes) => SelectedValueIndexIterator::ValidIndex {
265                underlying: self.database().db.prefix_iterator(&index_bytes),
266                index_bytes,
267                pred,
268                _mode: PhantomData,
269                _table: PhantomData,
270            },
271            Err(error) => SelectedValueIndexIterator::InvalidIndex { error: Some(error) },
272        }
273    }
274
275    fn put(&self, key: &Self::Key, value: &Self::Value) -> Result<(), Self::Error>
276    where
277        M: mode::IsWriteable,
278    {
279        let key_bytes = Self::key_to_bytes(key)?;
280        let value_bytes = Self::value_to_bytes(value)?;
281        Ok(self
282            .database()
283            .db
284            .put(key_bytes, value_bytes)
285            .map_err(error::Error::from)?)
286    }
287
288    fn catch_up_with_primary(&self) -> Result<(), Self::Error>
289    where
290        M: mode::IsSecondary,
291    {
292        Ok(self
293            .database()
294            .db
295            .try_catch_up_with_primary()
296            .map_err(error::Error::from)?)
297    }
298}
299
300pub struct TableIterator<'a, M, T> {
301    underlying: DBIterator<'a>,
302    _mode: PhantomData<M>,
303    _table: PhantomData<T>,
304}
305
306impl<M: mode::Mode, T: Table<M>> Iterator for TableIterator<'_, M, T> {
307    type Item = Result<(T::Key, T::Value), T::Error>;
308
309    fn next(&mut self) -> Option<Self::Item> {
310        self.underlying.next().map(|result| {
311            result
312                .map_err(|error| T::Error::from(error.into()))
313                .and_then(|(key_bytes, value_bytes)| {
314                    T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
315                        T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
316                            .map(|value| (key, value))
317                    })
318                })
319        })
320    }
321}
322
323/// Allows selection of values to decode (if for example this is expensive).
324pub struct SelectedValueTableIterator<'a, M, T, P> {
325    underlying: DBIterator<'a>,
326    pred: P,
327    _mode: PhantomData<M>,
328    _table: PhantomData<T>,
329}
330
331impl<M: mode::Mode, T: Table<M>, P: Fn(&T::Key) -> bool> Iterator
332    for SelectedValueTableIterator<'_, M, T, P>
333{
334    type Item = Result<(T::Key, Option<T::Value>), T::Error>;
335
336    fn next(&mut self) -> Option<Self::Item> {
337        self.underlying.next().map(|result| {
338            result
339                .map_err(|error| T::Error::from(error.into()))
340                .and_then(|(key_bytes, value_bytes)| {
341                    T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
342                        if (self.pred)(&key) {
343                            T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
344                                .map(|value| (key, Some(value)))
345                        } else {
346                            Ok((key, None))
347                        }
348                    })
349                })
350        })
351    }
352}
353
354pub enum IndexIterator<'a, M, T: Table<M>> {
355    ValidIndex {
356        underlying: DBIterator<'a>,
357        index_bytes: T::IndexBytes,
358        _mode: PhantomData<M>,
359        _table: PhantomData<T>,
360    },
361    InvalidIndex {
362        error: Option<T::Error>,
363    },
364}
365
366impl<M: mode::Mode, T: Table<M>> Iterator for IndexIterator<'_, M, T> {
367    type Item = Result<(T::Key, T::Value), T::Error>;
368
369    fn next(&mut self) -> Option<Self::Item> {
370        match self {
371            IndexIterator::ValidIndex {
372                underlying,
373                index_bytes,
374                ..
375            } => underlying.next().and_then(|result| match result {
376                Ok((key_bytes, value_bytes)) => {
377                    if key_bytes.starts_with(index_bytes.as_ref()) {
378                        Some(
379                            T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
380                                T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
381                                    .map(|value| (key, value))
382                            }),
383                        )
384                    } else {
385                        None
386                    }
387                }
388                Err(error) => Some(Err(T::Error::from(error.into()))),
389            }),
390            IndexIterator::InvalidIndex { error } => error.take().map(Err),
391        }
392    }
393}
394
395/// Allows selection of values to decode (if for example this is expensive).
396pub enum SelectedValueIndexIterator<'a, M, T: Table<M>, P> {
397    ValidIndex {
398        underlying: DBIterator<'a>,
399        index_bytes: T::IndexBytes,
400        pred: P,
401        _mode: PhantomData<M>,
402        _table: PhantomData<T>,
403    },
404    InvalidIndex {
405        error: Option<T::Error>,
406    },
407}
408
409impl<M: mode::Mode, T: Table<M>, P: Fn(&T::Key) -> bool> Iterator
410    for SelectedValueIndexIterator<'_, M, T, P>
411{
412    type Item = Result<(T::Key, Option<T::Value>), T::Error>;
413
414    fn next(&mut self) -> Option<Self::Item> {
415        match self {
416            SelectedValueIndexIterator::ValidIndex {
417                underlying,
418                index_bytes,
419                pred,
420                ..
421            } => underlying.next().and_then(|result| match result {
422                Ok((key_bytes, value_bytes)) => {
423                    if key_bytes.starts_with(index_bytes.as_ref()) {
424                        Some(
425                            T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
426                                if (pred)(&key) {
427                                    T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
428                                        .map(|value| (key, Some(value)))
429                                } else {
430                                    Ok((key, None))
431                                }
432                            }),
433                        )
434                    } else {
435                        None
436                    }
437                }
438                Err(error) => Some(Err(T::Error::from(error.into()))),
439            }),
440            SelectedValueIndexIterator::InvalidIndex { error } => error.take().map(Err),
441        }
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[derive(thiserror::Error, Debug)]
450    pub enum Error {
451        #[error("RocksDb table error")]
452        RocksDbTable(#[from] error::Error),
453        #[error("String encoding error")]
454        Utf8(#[from] std::str::Utf8Error),
455    }
456
457    struct Dictionary<M> {
458        database: Database<M>,
459    }
460
461    impl<M: mode::Mode> Table<M> for Dictionary<M> {
462        type Counts = usize;
463        type Error = Error;
464        type Key = String;
465        type KeyBytes = Vec<u8>;
466        type Value = u64;
467        type ValueBytes = [u8; 8];
468        type Index = String;
469        type IndexBytes = Vec<u8>;
470
471        fn database(&self) -> &Database<M> {
472            &self.database
473        }
474
475        fn from_database(database: Database<M>) -> Self {
476            Self { database }
477        }
478
479        fn key_to_bytes(key: &Self::Key) -> Result<Self::KeyBytes, Self::Error> {
480            Ok(key.as_bytes().to_vec())
481        }
482
483        fn value_to_bytes(value: &Self::Value) -> Result<Self::ValueBytes, Self::Error> {
484            Ok(value.to_be_bytes())
485        }
486
487        fn index_to_bytes(index: &Self::Index) -> Result<Self::IndexBytes, Self::Error> {
488            Ok(index.as_bytes().to_vec())
489        }
490
491        fn bytes_to_key(bytes: Cow<[u8]>) -> Result<Self::Key, Self::Error> {
492            Ok(std::str::from_utf8(bytes.as_ref())?.to_string())
493        }
494
495        fn bytes_to_value(bytes: Cow<[u8]>) -> Result<Self::Value, Self::Error> {
496            Ok(u64::from_be_bytes(
497                bytes.as_ref()[0..8]
498                    .try_into()
499                    .map_err(|_| error::Error::InvalidValue(bytes.as_ref().to_vec()))?,
500            ))
501        }
502
503        fn get_counts(&self) -> Result<Self::Counts, Error> {
504            let mut count = 0;
505
506            for result in self.iter() {
507                result?;
508                count += 1;
509            }
510
511            Ok(count)
512        }
513    }
514
515    fn contents() -> Vec<(String, u64)> {
516        vec![
517            ("bar", 1000),
518            ("baz", 98765),
519            ("foo", 1),
520            ("abc", 23),
521            ("qux", 0),
522        ]
523        .into_iter()
524        .map(|(key, value)| (key.to_string(), value))
525        .collect()
526    }
527
528    #[test]
529    fn lookup_key() {
530        let directory = tempfile::tempdir().unwrap();
531        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
532
533        for (key, value) in contents() {
534            dictionary.put(&key.to_string(), &value).unwrap();
535        }
536
537        assert_eq!(dictionary.lookup_key(&"foo".to_string()).unwrap(), Some(1));
538        assert_eq!(
539            dictionary.lookup_key(&"bar".to_string()).unwrap(),
540            Some(1000)
541        );
542        assert_eq!(dictionary.lookup_key(&"XYZ".to_string()).unwrap(), None);
543    }
544
545    #[test]
546    fn lookup_index() {
547        let directory = tempfile::tempdir().unwrap();
548        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
549
550        for (key, value) in contents() {
551            dictionary.put(&key.to_string(), &value).unwrap();
552        }
553
554        assert_eq!(
555            &dictionary
556                .lookup_index(&"ba".to_string())
557                .collect::<Result<Vec<_>, _>>()
558                .unwrap(),
559            &contents()[0..2].to_vec()
560        );
561    }
562
563    #[test]
564    fn iter() {
565        let directory = tempfile::tempdir().unwrap();
566        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
567
568        for (key, value) in contents() {
569            dictionary.put(&key.to_string(), &value).unwrap();
570        }
571
572        let mut expected = contents();
573        expected.sort();
574
575        assert_eq!(
576            dictionary.iter().collect::<Result<Vec<_>, _>>().unwrap(),
577            expected
578        );
579    }
580
581    #[test]
582    fn get_counts() {
583        let directory = tempfile::tempdir().unwrap();
584        let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
585
586        for (key, value) in contents() {
587            dictionary.put(&key.to_string(), &value).unwrap();
588        }
589
590        assert_eq!(dictionary.get_counts().unwrap(), contents().len());
591    }
592}