1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3#![forbid(unsafe_code)]
4use rocksdb::{DBCompressionType, DBIterator, IteratorMode, Options, DB};
7use std::borrow::Cow;
8use std::marker::PhantomData;
9use std::path::Path;
10use std::sync::Arc;
11
12pub mod error;
13
14pub mod mode {
16 pub trait Mode: 'static {
18 fn is_read_only() -> bool;
19 fn is_secondary() -> bool;
20 #[must_use]
21 fn is_primary() -> bool {
22 !Self::is_read_only() && !Self::is_secondary()
23 }
24 }
25
26 pub trait IsWriteable: Mode {}
28
29 pub trait IsSecondary: Mode {}
31
32 pub trait SinglePath: Mode {}
35
36 #[derive(Clone, Copy)]
37 pub struct ReadOnly;
38
39 #[derive(Clone, Copy)]
40 pub struct Secondary;
41
42 #[derive(Clone, Copy)]
43 pub struct Writeable;
44
45 impl Mode for ReadOnly {
46 fn is_read_only() -> bool {
47 true
48 }
49
50 fn is_secondary() -> bool {
51 false
52 }
53 }
54
55 impl Mode for Secondary {
56 fn is_read_only() -> bool {
57 false
58 }
59
60 fn is_secondary() -> bool {
61 true
62 }
63 }
64
65 impl Mode for Writeable {
66 fn is_read_only() -> bool {
67 false
68 }
69
70 fn is_secondary() -> bool {
71 false
72 }
73 }
74
75 impl IsWriteable for Writeable {}
76 impl IsSecondary for Secondary {}
77 impl SinglePath for ReadOnly {}
78 impl SinglePath for Writeable {}
79}
80
81#[derive(Clone)]
84pub struct Database<M> {
85 pub db: Arc<DB>,
86 options: Options,
87 _mode: PhantomData<M>,
88}
89
90pub trait Table<M>: Sized {
92 type Counts;
93 type Error: From<error::Error>;
94 type Key;
95 type KeyBytes: AsRef<[u8]>;
96 type Value;
97 type ValueBytes: AsRef<[u8]>;
98 type Index;
99 type IndexBytes: AsRef<[u8]>;
100
101 fn database(&self) -> &Database<M>;
102 fn from_database(database: Database<M>) -> Self;
103 fn get_counts(&self) -> Result<Self::Counts, Self::Error>;
104
105 fn key_to_bytes(key: &Self::Key) -> Result<Self::KeyBytes, Self::Error>;
106 fn value_to_bytes(value: &Self::Value) -> Result<Self::ValueBytes, Self::Error>;
107 fn index_to_bytes(index: &Self::Index) -> Result<Self::IndexBytes, Self::Error>;
108
109 fn bytes_to_key(bytes: Cow<[u8]>) -> Result<Self::Key, Self::Error>;
110 fn bytes_to_value(bytes: Cow<[u8]>) -> Result<Self::Value, Self::Error>;
111
112 #[must_use]
113 fn default_compression_type() -> Option<DBCompressionType> {
114 None
115 }
116
117 fn statistics(&self) -> Option<String> {
118 self.database().options.get_statistics()
119 }
120
121 fn get_estimated_key_count(&self) -> Result<Option<u64>, error::Error> {
122 Ok(self
123 .database()
124 .db
125 .property_int_value("rocksdb.estimate-num-keys")?)
126 }
127
128 fn open_with_defaults<P: AsRef<Path>>(path: P) -> Result<Self, error::Error>
129 where
130 M: mode::SinglePath,
131 {
132 Self::open(path, |mut options| {
133 if let Some(compression_type) = Self::default_compression_type() {
134 options.set_compression_type(compression_type);
135 }
136
137 options
138 })
139 }
140
141 fn open_as_secondary_with_defaults<P: AsRef<Path>, S: AsRef<Path>>(
142 path: P,
143 secondary_path: S,
144 ) -> Result<Self, error::Error>
145 where
146 M: mode::IsSecondary,
147 {
148 Self::open_as_secondary(path, secondary_path, |mut options| {
149 if let Some(compression_type) = Self::default_compression_type() {
150 options.set_compression_type(compression_type);
151 }
152
153 options
154 })
155 }
156
157 fn open<P: AsRef<Path>, F: FnMut(Options) -> Options>(
158 path: P,
159 mut options_init: F,
160 ) -> Result<Self, error::Error>
161 where
162 M: mode::SinglePath,
163 {
164 let mut options = Options::default();
165 options.create_if_missing(true);
166
167 let options = options_init(options);
168
169 let db = if M::is_read_only() {
170 DB::open_for_read_only(&options, path, true)?
171 } else {
172 DB::open(&options, path)?
173 };
174
175 Ok(Self::from_database(Database {
176 db: Arc::new(db),
177 options,
178 _mode: PhantomData,
179 }))
180 }
181
182 fn open_as_secondary<P: AsRef<Path>, S: AsRef<Path>, F: FnMut(Options) -> Options>(
183 path: P,
184 secondary_path: S,
185 mut options_init: F,
186 ) -> Result<Self, error::Error>
187 where
188 M: mode::IsSecondary,
189 {
190 let mut options = Options::default();
191 options.create_if_missing(true);
192
193 let options = options_init(options);
194 let db = DB::open_as_secondary(&options, path.as_ref(), secondary_path.as_ref())?;
195
196 Ok(Self::from_database(Database {
197 db: Arc::new(db),
198 options,
199 _mode: PhantomData,
200 }))
201 }
202
203 fn iter(&self) -> TableIterator<'_, M, Self>
204 where
205 M: mode::Mode,
206 {
207 TableIterator {
208 underlying: self.database().db.iterator(IteratorMode::Start),
209 _mode: PhantomData,
210 _table: PhantomData,
211 }
212 }
213
214 fn iter_selected_values<P: Fn(&Self::Key) -> bool>(
215 &self,
216 pred: P,
217 ) -> SelectedValueTableIterator<'_, M, Self, P>
218 where
219 M: 'static,
220 {
221 SelectedValueTableIterator {
222 underlying: self.database().db.iterator(IteratorMode::Start),
223 pred,
224 _mode: PhantomData,
225 _table: PhantomData,
226 }
227 }
228
229 fn lookup_key(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error> {
230 let key_bytes = Self::key_to_bytes(key)?;
231 self.database()
232 .db
233 .get_pinned(key_bytes)
234 .map_err(error::Error::from)?
235 .map_or(Ok(None), |value_bytes| {
236 Self::bytes_to_value(Cow::from(value_bytes.as_ref())).map(Some)
237 })
238 }
239
240 fn lookup_index(&self, index: &Self::Index) -> IndexIterator<'_, M, Self>
241 where
242 M: 'static,
243 {
244 match Self::index_to_bytes(index) {
245 Ok(index_bytes) => IndexIterator::ValidIndex {
246 underlying: self.database().db.prefix_iterator(&index_bytes),
247 index_bytes,
248 _mode: PhantomData,
249 _table: PhantomData,
250 },
251 Err(error) => IndexIterator::InvalidIndex { error: Some(error) },
252 }
253 }
254
255 fn lookup_index_selected_values<P: Fn(&Self::Key) -> bool>(
256 &self,
257 index: &Self::Index,
258 pred: P,
259 ) -> SelectedValueIndexIterator<'_, M, Self, P>
260 where
261 M: 'static,
262 {
263 match Self::index_to_bytes(index) {
264 Ok(index_bytes) => SelectedValueIndexIterator::ValidIndex {
265 underlying: self.database().db.prefix_iterator(&index_bytes),
266 index_bytes,
267 pred,
268 _mode: PhantomData,
269 _table: PhantomData,
270 },
271 Err(error) => SelectedValueIndexIterator::InvalidIndex { error: Some(error) },
272 }
273 }
274
275 fn put(&self, key: &Self::Key, value: &Self::Value) -> Result<(), Self::Error>
276 where
277 M: mode::IsWriteable,
278 {
279 let key_bytes = Self::key_to_bytes(key)?;
280 let value_bytes = Self::value_to_bytes(value)?;
281 Ok(self
282 .database()
283 .db
284 .put(key_bytes, value_bytes)
285 .map_err(error::Error::from)?)
286 }
287
288 fn catch_up_with_primary(&self) -> Result<(), Self::Error>
289 where
290 M: mode::IsSecondary,
291 {
292 Ok(self
293 .database()
294 .db
295 .try_catch_up_with_primary()
296 .map_err(error::Error::from)?)
297 }
298}
299
300pub struct TableIterator<'a, M, T> {
301 underlying: DBIterator<'a>,
302 _mode: PhantomData<M>,
303 _table: PhantomData<T>,
304}
305
306impl<M: mode::Mode, T: Table<M>> Iterator for TableIterator<'_, M, T> {
307 type Item = Result<(T::Key, T::Value), T::Error>;
308
309 fn next(&mut self) -> Option<Self::Item> {
310 self.underlying.next().map(|result| {
311 result
312 .map_err(|error| T::Error::from(error.into()))
313 .and_then(|(key_bytes, value_bytes)| {
314 T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
315 T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
316 .map(|value| (key, value))
317 })
318 })
319 })
320 }
321}
322
323pub struct SelectedValueTableIterator<'a, M, T, P> {
325 underlying: DBIterator<'a>,
326 pred: P,
327 _mode: PhantomData<M>,
328 _table: PhantomData<T>,
329}
330
331impl<M: mode::Mode, T: Table<M>, P: Fn(&T::Key) -> bool> Iterator
332 for SelectedValueTableIterator<'_, M, T, P>
333{
334 type Item = Result<(T::Key, Option<T::Value>), T::Error>;
335
336 fn next(&mut self) -> Option<Self::Item> {
337 self.underlying.next().map(|result| {
338 result
339 .map_err(|error| T::Error::from(error.into()))
340 .and_then(|(key_bytes, value_bytes)| {
341 T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
342 if (self.pred)(&key) {
343 T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
344 .map(|value| (key, Some(value)))
345 } else {
346 Ok((key, None))
347 }
348 })
349 })
350 })
351 }
352}
353
354pub enum IndexIterator<'a, M, T: Table<M>> {
355 ValidIndex {
356 underlying: DBIterator<'a>,
357 index_bytes: T::IndexBytes,
358 _mode: PhantomData<M>,
359 _table: PhantomData<T>,
360 },
361 InvalidIndex {
362 error: Option<T::Error>,
363 },
364}
365
366impl<M: mode::Mode, T: Table<M>> Iterator for IndexIterator<'_, M, T> {
367 type Item = Result<(T::Key, T::Value), T::Error>;
368
369 fn next(&mut self) -> Option<Self::Item> {
370 match self {
371 IndexIterator::ValidIndex {
372 underlying,
373 index_bytes,
374 ..
375 } => underlying.next().and_then(|result| match result {
376 Ok((key_bytes, value_bytes)) => {
377 if key_bytes.starts_with(index_bytes.as_ref()) {
378 Some(
379 T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
380 T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
381 .map(|value| (key, value))
382 }),
383 )
384 } else {
385 None
386 }
387 }
388 Err(error) => Some(Err(T::Error::from(error.into()))),
389 }),
390 IndexIterator::InvalidIndex { error } => error.take().map(Err),
391 }
392 }
393}
394
395pub enum SelectedValueIndexIterator<'a, M, T: Table<M>, P> {
397 ValidIndex {
398 underlying: DBIterator<'a>,
399 index_bytes: T::IndexBytes,
400 pred: P,
401 _mode: PhantomData<M>,
402 _table: PhantomData<T>,
403 },
404 InvalidIndex {
405 error: Option<T::Error>,
406 },
407}
408
409impl<M: mode::Mode, T: Table<M>, P: Fn(&T::Key) -> bool> Iterator
410 for SelectedValueIndexIterator<'_, M, T, P>
411{
412 type Item = Result<(T::Key, Option<T::Value>), T::Error>;
413
414 fn next(&mut self) -> Option<Self::Item> {
415 match self {
416 SelectedValueIndexIterator::ValidIndex {
417 underlying,
418 index_bytes,
419 pred,
420 ..
421 } => underlying.next().and_then(|result| match result {
422 Ok((key_bytes, value_bytes)) => {
423 if key_bytes.starts_with(index_bytes.as_ref()) {
424 Some(
425 T::bytes_to_key(Cow::from(Vec::from(key_bytes))).and_then(|key| {
426 if (pred)(&key) {
427 T::bytes_to_value(Cow::from(Vec::from(value_bytes)))
428 .map(|value| (key, Some(value)))
429 } else {
430 Ok((key, None))
431 }
432 }),
433 )
434 } else {
435 None
436 }
437 }
438 Err(error) => Some(Err(T::Error::from(error.into()))),
439 }),
440 SelectedValueIndexIterator::InvalidIndex { error } => error.take().map(Err),
441 }
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[derive(thiserror::Error, Debug)]
450 pub enum Error {
451 #[error("RocksDb table error")]
452 RocksDbTable(#[from] error::Error),
453 #[error("String encoding error")]
454 Utf8(#[from] std::str::Utf8Error),
455 }
456
457 struct Dictionary<M> {
458 database: Database<M>,
459 }
460
461 impl<M: mode::Mode> Table<M> for Dictionary<M> {
462 type Counts = usize;
463 type Error = Error;
464 type Key = String;
465 type KeyBytes = Vec<u8>;
466 type Value = u64;
467 type ValueBytes = [u8; 8];
468 type Index = String;
469 type IndexBytes = Vec<u8>;
470
471 fn database(&self) -> &Database<M> {
472 &self.database
473 }
474
475 fn from_database(database: Database<M>) -> Self {
476 Self { database }
477 }
478
479 fn key_to_bytes(key: &Self::Key) -> Result<Self::KeyBytes, Self::Error> {
480 Ok(key.as_bytes().to_vec())
481 }
482
483 fn value_to_bytes(value: &Self::Value) -> Result<Self::ValueBytes, Self::Error> {
484 Ok(value.to_be_bytes())
485 }
486
487 fn index_to_bytes(index: &Self::Index) -> Result<Self::IndexBytes, Self::Error> {
488 Ok(index.as_bytes().to_vec())
489 }
490
491 fn bytes_to_key(bytes: Cow<[u8]>) -> Result<Self::Key, Self::Error> {
492 Ok(std::str::from_utf8(bytes.as_ref())?.to_string())
493 }
494
495 fn bytes_to_value(bytes: Cow<[u8]>) -> Result<Self::Value, Self::Error> {
496 Ok(u64::from_be_bytes(
497 bytes.as_ref()[0..8]
498 .try_into()
499 .map_err(|_| error::Error::InvalidValue(bytes.as_ref().to_vec()))?,
500 ))
501 }
502
503 fn get_counts(&self) -> Result<Self::Counts, Error> {
504 let mut count = 0;
505
506 for result in self.iter() {
507 result?;
508 count += 1;
509 }
510
511 Ok(count)
512 }
513 }
514
515 fn contents() -> Vec<(String, u64)> {
516 vec![
517 ("bar", 1000),
518 ("baz", 98765),
519 ("foo", 1),
520 ("abc", 23),
521 ("qux", 0),
522 ]
523 .into_iter()
524 .map(|(key, value)| (key.to_string(), value))
525 .collect()
526 }
527
528 #[test]
529 fn lookup_key() {
530 let directory = tempfile::tempdir().unwrap();
531 let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
532
533 for (key, value) in contents() {
534 dictionary.put(&key.to_string(), &value).unwrap();
535 }
536
537 assert_eq!(dictionary.lookup_key(&"foo".to_string()).unwrap(), Some(1));
538 assert_eq!(
539 dictionary.lookup_key(&"bar".to_string()).unwrap(),
540 Some(1000)
541 );
542 assert_eq!(dictionary.lookup_key(&"XYZ".to_string()).unwrap(), None);
543 }
544
545 #[test]
546 fn lookup_index() {
547 let directory = tempfile::tempdir().unwrap();
548 let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
549
550 for (key, value) in contents() {
551 dictionary.put(&key.to_string(), &value).unwrap();
552 }
553
554 assert_eq!(
555 &dictionary
556 .lookup_index(&"ba".to_string())
557 .collect::<Result<Vec<_>, _>>()
558 .unwrap(),
559 &contents()[0..2].to_vec()
560 );
561 }
562
563 #[test]
564 fn iter() {
565 let directory = tempfile::tempdir().unwrap();
566 let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
567
568 for (key, value) in contents() {
569 dictionary.put(&key.to_string(), &value).unwrap();
570 }
571
572 let mut expected = contents();
573 expected.sort();
574
575 assert_eq!(
576 dictionary.iter().collect::<Result<Vec<_>, _>>().unwrap(),
577 expected
578 );
579 }
580
581 #[test]
582 fn get_counts() {
583 let directory = tempfile::tempdir().unwrap();
584 let dictionary = Dictionary::<mode::Writeable>::open_with_defaults(directory).unwrap();
585
586 for (key, value) in contents() {
587 dictionary.put(&key.to_string(), &value).unwrap();
588 }
589
590 assert_eq!(dictionary.get_counts().unwrap(), contents().len());
591 }
592}