rocksdb_store/
mapper.rs

1use crate::wrapper::Transaction;
2
3use super::wrapper::Db;
4use bincode::serde::OwnedSerdeDecoder;
5use rocksdb::ColumnFamily;
6use std::io::{BufReader, Cursor};
7
8#[derive(thiserror::Error, Debug)]
9pub enum Error {
10    #[error("Unsupported configuration type")]
11    Unsupported,
12    #[error("Invalid RocksDB transaction state")]
13    InvalidTransaction,
14    #[error("Encoding error")]
15    Encoding(#[from] bincode::error::EncodeError),
16    #[error("Decoding error")]
17    Decoding(bincode::error::DecodeError),
18    #[error("Serde error")]
19    Serde(serde::de::value::Error),
20    #[error("RocksDb error")]
21    Db(#[from] rocksdb::Error),
22}
23
24impl serde::ser::Error for Error {
25    fn custom<T: std::fmt::Display>(msg: T) -> Self {
26        Self::Serde(serde::de::value::Error::custom(msg))
27    }
28}
29
30impl serde::de::Error for Error {
31    fn custom<T: std::fmt::Display>(msg: T) -> Self {
32        Self::Serde(serde::de::value::Error::custom(msg))
33    }
34
35    fn duplicate_field(field: &'static str) -> Self {
36        Self::Serde(serde::de::value::Error::duplicate_field(field))
37    }
38
39    fn invalid_length(len: usize, exp: &dyn serde::de::Expected) -> Self {
40        Self::Serde(serde::de::value::Error::invalid_length(len, exp))
41    }
42
43    fn invalid_type(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
44        Self::Serde(serde::de::value::Error::invalid_type(unexp, exp))
45    }
46
47    fn invalid_value(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
48        Self::Serde(serde::de::value::Error::invalid_value(unexp, exp))
49    }
50
51    fn missing_field(field: &'static str) -> Self {
52        Self::Serde(serde::de::value::Error::missing_field(field))
53    }
54
55    fn unknown_field(field: &str, expected: &'static [&'static str]) -> Self {
56        Self::Serde(serde::de::value::Error::unknown_field(field, expected))
57    }
58
59    fn unknown_variant(variant: &str, expected: &'static [&'static str]) -> Self {
60        Self::Serde(serde::de::value::Error::unknown_variant(variant, expected))
61    }
62}
63
64/// Maps a serializable struct onto a column family.
65pub struct TableMapper<'a, const W: bool, C> {
66    db: &'a Db,
67    tx: Option<Transaction<'a>>,
68    cf: &'a ColumnFamily,
69    bincode_config: C,
70}
71
72impl<'a, const W: bool, C> TableMapper<'a, W, C> {
73    pub(super) fn new(db: &'a Db, cf: &'a ColumnFamily, bincode_config: C) -> Self {
74        Self {
75            db,
76            tx: if W {
77                // Safe because we know the wrapper is writeable.
78                Some(db.transaction().unwrap())
79            } else {
80                None
81            },
82            cf,
83            bincode_config,
84        }
85    }
86}
87
88impl<'a, C: bincode::config::Config> serde::ser::SerializeStruct for TableMapper<'a, true, C> {
89    type Ok = ();
90    type Error = Error;
91
92    fn serialize_field<T: ?Sized + serde::Serialize>(
93        &mut self,
94        key: &'static str,
95        value: &T,
96    ) -> Result<(), Self::Error> {
97        let value_bytes = bincode::serde::encode_to_vec(value, self.bincode_config)?;
98
99        self.tx
100            .as_ref()
101            .ok_or(Error::InvalidTransaction)
102            .and_then(|tx| {
103                tx.put(self.cf, key.as_bytes(), value_bytes)
104                    .map_err(Error::from)
105            })
106    }
107
108    fn end(mut self) -> Result<Self::Ok, Self::Error> {
109        self.tx
110            .take()
111            .ok_or(Error::InvalidTransaction)
112            .and_then(|tx| tx.commit().map_err(Error::from))
113    }
114}
115
116impl<'a, C: bincode::config::Config> serde::ser::Serializer for TableMapper<'a, true, C> {
117    type Ok = ();
118    type Error = Error;
119
120    type SerializeSeq = Self;
121    type SerializeTuple = Self;
122    type SerializeTupleStruct = Self;
123    type SerializeTupleVariant = Self;
124    type SerializeMap = Self;
125    type SerializeStruct = Self;
126    type SerializeStructVariant = Self;
127
128    fn serialize_struct(
129        self,
130        _name: &'static str,
131        _len: usize,
132    ) -> Result<Self::SerializeStruct, Self::Error> {
133        Ok(self)
134    }
135
136    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
137        Ok(())
138    }
139
140    fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
141        Err(Error::Unsupported)
142    }
143
144    fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
145        Err(Error::Unsupported)
146    }
147
148    fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
149        Err(Error::Unsupported)
150    }
151
152    fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
153        Err(Error::Unsupported)
154    }
155
156    fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
157        Err(Error::Unsupported)
158    }
159
160    fn serialize_i128(self, _v: i128) -> Result<Self::Ok, Self::Error> {
161        Err(Error::Unsupported)
162    }
163
164    fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
165        Err(Error::Unsupported)
166    }
167
168    fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
169        Err(Error::Unsupported)
170    }
171
172    fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
173        Err(Error::Unsupported)
174    }
175
176    fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
177        Err(Error::Unsupported)
178    }
179
180    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
181        Err(Error::Unsupported)
182    }
183
184    fn serialize_newtype_struct<T: ?Sized + serde::Serialize>(
185        self,
186        _name: &'static str,
187        _value: &T,
188    ) -> Result<Self::Ok, Self::Error> {
189        Err(Error::Unsupported)
190    }
191
192    fn serialize_newtype_variant<T: ?Sized + serde::Serialize>(
193        self,
194        _name: &'static str,
195        _variant_index: u32,
196        _variant: &'static str,
197        _value: &T,
198    ) -> Result<Self::Ok, Self::Error> {
199        Err(Error::Unsupported)
200    }
201
202    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
203        Err(Error::Unsupported)
204    }
205
206    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
207        Err(Error::Unsupported)
208    }
209
210    fn serialize_some<T: ?Sized + serde::Serialize>(
211        self,
212        _value: &T,
213    ) -> Result<Self::Ok, Self::Error> {
214        Err(Error::Unsupported)
215    }
216
217    fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
218        Err(Error::Unsupported)
219    }
220
221    fn serialize_struct_variant(
222        self,
223        _name: &'static str,
224        _variant_index: u32,
225        _variant: &'static str,
226        _len: usize,
227    ) -> Result<Self::SerializeStructVariant, Self::Error> {
228        Err(Error::Unsupported)
229    }
230
231    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
232        Err(Error::Unsupported)
233    }
234
235    fn serialize_tuple_struct(
236        self,
237        _name: &'static str,
238        _len: usize,
239    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
240        Err(Error::Unsupported)
241    }
242
243    fn serialize_tuple_variant(
244        self,
245        _name: &'static str,
246        _variant_index: u32,
247        _variant: &'static str,
248        _len: usize,
249    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
250        Err(Error::Unsupported)
251    }
252
253    fn serialize_u128(self, _v: u128) -> Result<Self::Ok, Self::Error> {
254        Err(Error::Unsupported)
255    }
256
257    fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> {
258        Err(Error::Unsupported)
259    }
260
261    fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
262        Err(Error::Unsupported)
263    }
264
265    fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> {
266        Err(Error::Unsupported)
267    }
268
269    fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
270        Err(Error::Unsupported)
271    }
272
273    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
274        Ok(())
275    }
276
277    fn serialize_unit_variant(
278        self,
279        _name: &'static str,
280        _variant_index: u32,
281        _variant: &'static str,
282    ) -> Result<Self::Ok, Self::Error> {
283        Err(Error::Unsupported)
284    }
285}
286
287impl<'a, C> serde::ser::SerializeMap for TableMapper<'a, true, C> {
288    type Ok = ();
289    type Error = Error;
290
291    fn end(self) -> Result<Self::Ok, Self::Error> {
292        Err(Error::Unsupported)
293    }
294
295    fn serialize_entry<K: ?Sized + serde::Serialize, V: ?Sized + serde::Serialize>(
296        &mut self,
297        _key: &K,
298        _value: &V,
299    ) -> Result<(), Self::Error> {
300        Err(Error::Unsupported)
301    }
302
303    fn serialize_key<T: ?Sized + serde::Serialize>(&mut self, _key: &T) -> Result<(), Self::Error> {
304        Err(Error::Unsupported)
305    }
306
307    fn serialize_value<T: ?Sized + serde::Serialize>(
308        &mut self,
309        _value: &T,
310    ) -> Result<(), Self::Error> {
311        Err(Error::Unsupported)
312    }
313}
314
315impl<'a, C> serde::ser::SerializeSeq for TableMapper<'a, true, C> {
316    type Ok = ();
317    type Error = Error;
318
319    fn end(self) -> Result<Self::Ok, Self::Error> {
320        Err(Error::Unsupported)
321    }
322
323    fn serialize_element<T: ?Sized + serde::Serialize>(
324        &mut self,
325        _value: &T,
326    ) -> Result<(), Self::Error> {
327        Err(Error::Unsupported)
328    }
329}
330
331impl<'a, C> serde::ser::SerializeStructVariant for TableMapper<'a, true, C> {
332    type Ok = ();
333    type Error = Error;
334
335    fn end(self) -> Result<Self::Ok, Self::Error> {
336        Err(Error::Unsupported)
337    }
338
339    fn serialize_field<T: ?Sized + serde::Serialize>(
340        &mut self,
341        _key: &'static str,
342        _value: &T,
343    ) -> Result<(), Self::Error> {
344        Err(Error::Unsupported)
345    }
346
347    fn skip_field(&mut self, _key: &'static str) -> Result<(), Self::Error> {
348        Err(Error::Unsupported)
349    }
350}
351
352impl<'a, C> serde::ser::SerializeTuple for TableMapper<'a, true, C> {
353    type Ok = ();
354    type Error = Error;
355
356    fn end(self) -> Result<Self::Ok, Self::Error> {
357        Err(Error::Unsupported)
358    }
359
360    fn serialize_element<T: ?Sized + serde::Serialize>(
361        &mut self,
362        _value: &T,
363    ) -> Result<(), Self::Error> {
364        Err(Error::Unsupported)
365    }
366}
367
368impl<'a, C> serde::ser::SerializeTupleStruct for TableMapper<'a, true, C> {
369    type Ok = ();
370    type Error = Error;
371
372    fn end(self) -> Result<Self::Ok, Self::Error> {
373        Err(Error::Unsupported)
374    }
375
376    fn serialize_field<T: ?Sized + serde::Serialize>(
377        &mut self,
378        _value: &T,
379    ) -> Result<(), Self::Error> {
380        Err(Error::Unsupported)
381    }
382}
383
384impl<'a, C> serde::ser::SerializeTupleVariant for TableMapper<'a, true, C> {
385    type Ok = ();
386    type Error = Error;
387
388    fn end(self) -> Result<Self::Ok, Self::Error> {
389        Err(Error::Unsupported)
390    }
391
392    fn serialize_field<T: ?Sized + serde::Serialize>(
393        &mut self,
394        _value: &T,
395    ) -> Result<(), Self::Error> {
396        Err(Error::Unsupported)
397    }
398}
399
400impl<'a, 'de: 'a, const W: bool, C: bincode::config::Config> serde::de::Deserializer<'de>
401    for &TableMapper<'a, W, C>
402{
403    type Error = Error;
404
405    fn deserialize_struct<V: serde::de::Visitor<'de>>(
406        self,
407        _name: &'static str,
408        fields: &'static [&'static str],
409        visitor: V,
410    ) -> Result<V::Value, Self::Error> {
411        visitor.visit_map(TableMapperAccess {
412            table: self,
413            fields,
414        })
415    }
416
417    fn is_human_readable(&self) -> bool {
418        false
419    }
420
421    fn deserialize_any<V: serde::de::Visitor<'de>>(
422        self,
423        _visitor: V,
424    ) -> Result<V::Value, Self::Error> {
425        Err(Error::Unsupported)
426    }
427
428    fn deserialize_bool<V: serde::de::Visitor<'de>>(
429        self,
430        _visitor: V,
431    ) -> Result<V::Value, Self::Error> {
432        Err(Error::Unsupported)
433    }
434
435    fn deserialize_byte_buf<V: serde::de::Visitor<'de>>(
436        self,
437        _visitor: V,
438    ) -> Result<V::Value, Self::Error> {
439        Err(Error::Unsupported)
440    }
441
442    fn deserialize_bytes<V: serde::de::Visitor<'de>>(
443        self,
444        _visitor: V,
445    ) -> Result<V::Value, Self::Error> {
446        Err(Error::Unsupported)
447    }
448
449    fn deserialize_char<V: serde::de::Visitor<'de>>(
450        self,
451        _visitor: V,
452    ) -> Result<V::Value, Self::Error> {
453        Err(Error::Unsupported)
454    }
455
456    fn deserialize_enum<V: serde::de::Visitor<'de>>(
457        self,
458        _name: &'static str,
459        _variants: &'static [&'static str],
460        _visitor: V,
461    ) -> Result<V::Value, Self::Error> {
462        Err(Error::Unsupported)
463    }
464
465    fn deserialize_f32<V: serde::de::Visitor<'de>>(
466        self,
467        _visitor: V,
468    ) -> Result<V::Value, Self::Error> {
469        Err(Error::Unsupported)
470    }
471
472    fn deserialize_f64<V: serde::de::Visitor<'de>>(
473        self,
474        _visitor: V,
475    ) -> Result<V::Value, Self::Error> {
476        Err(Error::Unsupported)
477    }
478
479    fn deserialize_i16<V: serde::de::Visitor<'de>>(
480        self,
481        _visitor: V,
482    ) -> Result<V::Value, Self::Error> {
483        Err(Error::Unsupported)
484    }
485
486    fn deserialize_i32<V: serde::de::Visitor<'de>>(
487        self,
488        _visitor: V,
489    ) -> Result<V::Value, Self::Error> {
490        Err(Error::Unsupported)
491    }
492
493    fn deserialize_i64<V: serde::de::Visitor<'de>>(
494        self,
495        _visitor: V,
496    ) -> Result<V::Value, Self::Error> {
497        Err(Error::Unsupported)
498    }
499
500    fn deserialize_i8<V: serde::de::Visitor<'de>>(
501        self,
502        _visitor: V,
503    ) -> Result<V::Value, Self::Error> {
504        Err(Error::Unsupported)
505    }
506
507    fn deserialize_identifier<V: serde::de::Visitor<'de>>(
508        self,
509        _visitor: V,
510    ) -> Result<V::Value, Self::Error> {
511        Err(Error::Unsupported)
512    }
513
514    fn deserialize_ignored_any<V: serde::de::Visitor<'de>>(
515        self,
516        _visitor: V,
517    ) -> Result<V::Value, Self::Error> {
518        Err(Error::Unsupported)
519    }
520
521    fn deserialize_newtype_struct<V: serde::de::Visitor<'de>>(
522        self,
523        _name: &'static str,
524        _visitor: V,
525    ) -> Result<V::Value, Self::Error> {
526        Err(Error::Unsupported)
527    }
528
529    fn deserialize_map<V: serde::de::Visitor<'de>>(
530        self,
531        _visitor: V,
532    ) -> Result<V::Value, Self::Error> {
533        Err(Error::Unsupported)
534    }
535
536    fn deserialize_option<V: serde::de::Visitor<'de>>(
537        self,
538        _visitor: V,
539    ) -> Result<V::Value, Self::Error> {
540        Err(Error::Unsupported)
541    }
542
543    fn deserialize_seq<V: serde::de::Visitor<'de>>(
544        self,
545        _visitor: V,
546    ) -> Result<V::Value, Self::Error> {
547        Err(Error::Unsupported)
548    }
549
550    fn deserialize_str<V: serde::de::Visitor<'de>>(
551        self,
552        _visitor: V,
553    ) -> Result<V::Value, Self::Error> {
554        Err(Error::Unsupported)
555    }
556
557    fn deserialize_string<V: serde::de::Visitor<'de>>(
558        self,
559        _visitor: V,
560    ) -> Result<V::Value, Self::Error> {
561        Err(Error::Unsupported)
562    }
563
564    fn deserialize_tuple<V: serde::de::Visitor<'de>>(
565        self,
566        _len: usize,
567        _visitor: V,
568    ) -> Result<V::Value, Self::Error> {
569        Err(Error::Unsupported)
570    }
571
572    fn deserialize_tuple_struct<V: serde::de::Visitor<'de>>(
573        self,
574        _name: &'static str,
575        _len: usize,
576        _visitor: V,
577    ) -> Result<V::Value, Self::Error> {
578        Err(Error::Unsupported)
579    }
580
581    fn deserialize_u16<V: serde::de::Visitor<'de>>(
582        self,
583        _visitor: V,
584    ) -> Result<V::Value, Self::Error> {
585        Err(Error::Unsupported)
586    }
587
588    fn deserialize_u32<V: serde::de::Visitor<'de>>(
589        self,
590        _visitor: V,
591    ) -> Result<V::Value, Self::Error> {
592        Err(Error::Unsupported)
593    }
594
595    fn deserialize_u64<V: serde::de::Visitor<'de>>(
596        self,
597        _visitor: V,
598    ) -> Result<V::Value, Self::Error> {
599        Err(Error::Unsupported)
600    }
601
602    fn deserialize_u8<V: serde::de::Visitor<'de>>(
603        self,
604        _visitor: V,
605    ) -> Result<V::Value, Self::Error> {
606        Err(Error::Unsupported)
607    }
608
609    fn deserialize_unit<V: serde::de::Visitor<'de>>(
610        self,
611        visitor: V,
612    ) -> Result<V::Value, Self::Error> {
613        visitor.visit_unit()
614    }
615
616    fn deserialize_unit_struct<V: serde::de::Visitor<'de>>(
617        self,
618        _name: &'static str,
619        visitor: V,
620    ) -> Result<V::Value, Self::Error> {
621        visitor.visit_unit()
622    }
623}
624
625struct TableMapperAccess<'a, const W: bool, C> {
626    table: &'a TableMapper<'a, W, C>,
627    fields: &'static [&'static str],
628}
629
630impl<'a, 'de: 'a, const W: bool, C: bincode::config::Config> serde::de::MapAccess<'de>
631    for TableMapperAccess<'a, W, C>
632{
633    type Error = Error;
634
635    fn next_key_seed<K: serde::de::DeserializeSeed<'de>>(
636        &mut self,
637        seed: K,
638    ) -> Result<Option<K::Value>, Self::Error> {
639        if self.fields.is_empty() {
640            Ok(None)
641        } else {
642            let deserializer = serde::de::value::StrDeserializer::new(self.fields[0]);
643
644            seed.deserialize(deserializer).map(Some)
645        }
646    }
647
648    fn next_value_seed<V: serde::de::DeserializeSeed<'de>>(
649        &mut self,
650        seed: V,
651    ) -> Result<V::Value, Self::Error> {
652        // In the case that the field is not found, we return the Bincode representation for `None`.
653        const BINCODE_NONE_BYTES: [u8; 1] = [0];
654
655        let field_name = self.fields[0].as_bytes();
656        self.fields = &self.fields[1..];
657
658        let bytes = self.table.db.get(self.table.cf, field_name)?;
659
660        match bytes {
661            Some(bytes) => {
662                let mut deserializer = OwnedSerdeDecoder::from_reader(
663                    BufReader::new(Cursor::new(bytes)),
664                    self.table.bincode_config,
665                );
666
667                seed.deserialize(deserializer.as_deserializer())
668                    .map_err(Error::Decoding)
669            }
670            None => {
671                let mut deserializer = OwnedSerdeDecoder::from_reader(
672                    BufReader::new(Cursor::new(BINCODE_NONE_BYTES)),
673                    self.table.bincode_config,
674                );
675
676                seed.deserialize(deserializer.as_deserializer())
677                    .map_err(Error::Decoding)
678            }
679        }
680    }
681}
682
683#[cfg(test)]
684mod tests {
685    use quickcheck_arbitrary_derive::QuickCheck;
686    use serde::{de::Deserialize, ser::Serialize};
687
688    #[derive(
689        Clone, Debug, Eq, PartialEq, QuickCheck, serde_derive::Deserialize, serde_derive::Serialize,
690    )]
691    struct Test {
692        foo: String,
693        bar: Vec<Option<u64>>,
694        qux: bool,
695    }
696
697    #[quickcheck_macros::quickcheck]
698    fn round_trip_test(test: Test, new_foo: String) -> bool {
699        let mut options = rocksdb::Options::default();
700        options.create_if_missing(true);
701        options.create_missing_column_families(true);
702
703        let db = rocksdb::OptimisticTransactionDB::open_cf_descriptors(
704            &options,
705            tempfile::tempdir().unwrap(),
706            vec![rocksdb::ColumnFamilyDescriptor::new(
707                "test",
708                rocksdb::Options::default(),
709            )],
710        )
711        .unwrap();
712
713        let wrapper = crate::wrapper::Db::from(db);
714
715        let mapper = super::TableMapper::new(
716            &wrapper,
717            wrapper.handle("test").unwrap(),
718            bincode::config::standard(),
719        );
720
721        test.serialize(mapper).unwrap();
722
723        let mapper = super::TableMapper::<true, _>::new(
724            &wrapper,
725            wrapper.handle("test").unwrap(),
726            bincode::config::standard(),
727        );
728
729        let read_test = Test::deserialize(&mapper).unwrap();
730
731        let mut new_test = read_test.clone();
732        new_test.foo = new_foo;
733
734        let mapper = super::TableMapper::new(
735            &wrapper,
736            wrapper.handle("test").unwrap(),
737            bincode::config::standard(),
738        );
739
740        new_test.serialize(mapper).unwrap();
741
742        let mapper = super::TableMapper::<true, _>::new(
743            &wrapper,
744            wrapper.handle("test").unwrap(),
745            bincode::config::standard(),
746        );
747
748        let new_read_test = Test::deserialize(&mapper).unwrap();
749
750        read_test == test && new_read_test == new_test
751    }
752}