1use crate::wrapper::Transaction;
2
3use super::wrapper::Db;
4use bincode::serde::OwnedSerdeDecoder;
5use rocksdb::ColumnFamily;
6use std::io::{BufReader, Cursor};
7
8#[derive(thiserror::Error, Debug)]
9pub enum Error {
10 #[error("Unsupported configuration type")]
11 Unsupported,
12 #[error("Invalid RocksDB transaction state")]
13 InvalidTransaction,
14 #[error("Encoding error")]
15 Encoding(#[from] bincode::error::EncodeError),
16 #[error("Decoding error")]
17 Decoding(bincode::error::DecodeError),
18 #[error("Serde error")]
19 Serde(serde::de::value::Error),
20 #[error("RocksDb error")]
21 Db(#[from] rocksdb::Error),
22}
23
24impl serde::ser::Error for Error {
25 fn custom<T: std::fmt::Display>(msg: T) -> Self {
26 Self::Serde(serde::de::value::Error::custom(msg))
27 }
28}
29
30impl serde::de::Error for Error {
31 fn custom<T: std::fmt::Display>(msg: T) -> Self {
32 Self::Serde(serde::de::value::Error::custom(msg))
33 }
34
35 fn duplicate_field(field: &'static str) -> Self {
36 Self::Serde(serde::de::value::Error::duplicate_field(field))
37 }
38
39 fn invalid_length(len: usize, exp: &dyn serde::de::Expected) -> Self {
40 Self::Serde(serde::de::value::Error::invalid_length(len, exp))
41 }
42
43 fn invalid_type(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
44 Self::Serde(serde::de::value::Error::invalid_type(unexp, exp))
45 }
46
47 fn invalid_value(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
48 Self::Serde(serde::de::value::Error::invalid_value(unexp, exp))
49 }
50
51 fn missing_field(field: &'static str) -> Self {
52 Self::Serde(serde::de::value::Error::missing_field(field))
53 }
54
55 fn unknown_field(field: &str, expected: &'static [&'static str]) -> Self {
56 Self::Serde(serde::de::value::Error::unknown_field(field, expected))
57 }
58
59 fn unknown_variant(variant: &str, expected: &'static [&'static str]) -> Self {
60 Self::Serde(serde::de::value::Error::unknown_variant(variant, expected))
61 }
62}
63
64pub struct TableMapper<'a, const W: bool, C> {
66 db: &'a Db,
67 tx: Option<Transaction<'a>>,
68 cf: &'a ColumnFamily,
69 bincode_config: C,
70}
71
72impl<'a, const W: bool, C> TableMapper<'a, W, C> {
73 pub(super) fn new(db: &'a Db, cf: &'a ColumnFamily, bincode_config: C) -> Self {
74 Self {
75 db,
76 tx: if W {
77 Some(db.transaction().unwrap())
79 } else {
80 None
81 },
82 cf,
83 bincode_config,
84 }
85 }
86}
87
88impl<'a, C: bincode::config::Config> serde::ser::SerializeStruct for TableMapper<'a, true, C> {
89 type Ok = ();
90 type Error = Error;
91
92 fn serialize_field<T: ?Sized + serde::Serialize>(
93 &mut self,
94 key: &'static str,
95 value: &T,
96 ) -> Result<(), Self::Error> {
97 let value_bytes = bincode::serde::encode_to_vec(value, self.bincode_config)?;
98
99 self.tx
100 .as_ref()
101 .ok_or(Error::InvalidTransaction)
102 .and_then(|tx| {
103 tx.put(self.cf, key.as_bytes(), value_bytes)
104 .map_err(Error::from)
105 })
106 }
107
108 fn end(mut self) -> Result<Self::Ok, Self::Error> {
109 self.tx
110 .take()
111 .ok_or(Error::InvalidTransaction)
112 .and_then(|tx| tx.commit().map_err(Error::from))
113 }
114}
115
116impl<'a, C: bincode::config::Config> serde::ser::Serializer for TableMapper<'a, true, C> {
117 type Ok = ();
118 type Error = Error;
119
120 type SerializeSeq = Self;
121 type SerializeTuple = Self;
122 type SerializeTupleStruct = Self;
123 type SerializeTupleVariant = Self;
124 type SerializeMap = Self;
125 type SerializeStruct = Self;
126 type SerializeStructVariant = Self;
127
128 fn serialize_struct(
129 self,
130 _name: &'static str,
131 _len: usize,
132 ) -> Result<Self::SerializeStruct, Self::Error> {
133 Ok(self)
134 }
135
136 fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
137 Ok(())
138 }
139
140 fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
141 Err(Error::Unsupported)
142 }
143
144 fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
145 Err(Error::Unsupported)
146 }
147
148 fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
149 Err(Error::Unsupported)
150 }
151
152 fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
153 Err(Error::Unsupported)
154 }
155
156 fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
157 Err(Error::Unsupported)
158 }
159
160 fn serialize_i128(self, _v: i128) -> Result<Self::Ok, Self::Error> {
161 Err(Error::Unsupported)
162 }
163
164 fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
165 Err(Error::Unsupported)
166 }
167
168 fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
169 Err(Error::Unsupported)
170 }
171
172 fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
173 Err(Error::Unsupported)
174 }
175
176 fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
177 Err(Error::Unsupported)
178 }
179
180 fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
181 Err(Error::Unsupported)
182 }
183
184 fn serialize_newtype_struct<T: ?Sized + serde::Serialize>(
185 self,
186 _name: &'static str,
187 _value: &T,
188 ) -> Result<Self::Ok, Self::Error> {
189 Err(Error::Unsupported)
190 }
191
192 fn serialize_newtype_variant<T: ?Sized + serde::Serialize>(
193 self,
194 _name: &'static str,
195 _variant_index: u32,
196 _variant: &'static str,
197 _value: &T,
198 ) -> Result<Self::Ok, Self::Error> {
199 Err(Error::Unsupported)
200 }
201
202 fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
203 Err(Error::Unsupported)
204 }
205
206 fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
207 Err(Error::Unsupported)
208 }
209
210 fn serialize_some<T: ?Sized + serde::Serialize>(
211 self,
212 _value: &T,
213 ) -> Result<Self::Ok, Self::Error> {
214 Err(Error::Unsupported)
215 }
216
217 fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
218 Err(Error::Unsupported)
219 }
220
221 fn serialize_struct_variant(
222 self,
223 _name: &'static str,
224 _variant_index: u32,
225 _variant: &'static str,
226 _len: usize,
227 ) -> Result<Self::SerializeStructVariant, Self::Error> {
228 Err(Error::Unsupported)
229 }
230
231 fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
232 Err(Error::Unsupported)
233 }
234
235 fn serialize_tuple_struct(
236 self,
237 _name: &'static str,
238 _len: usize,
239 ) -> Result<Self::SerializeTupleStruct, Self::Error> {
240 Err(Error::Unsupported)
241 }
242
243 fn serialize_tuple_variant(
244 self,
245 _name: &'static str,
246 _variant_index: u32,
247 _variant: &'static str,
248 _len: usize,
249 ) -> Result<Self::SerializeTupleVariant, Self::Error> {
250 Err(Error::Unsupported)
251 }
252
253 fn serialize_u128(self, _v: u128) -> Result<Self::Ok, Self::Error> {
254 Err(Error::Unsupported)
255 }
256
257 fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> {
258 Err(Error::Unsupported)
259 }
260
261 fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
262 Err(Error::Unsupported)
263 }
264
265 fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> {
266 Err(Error::Unsupported)
267 }
268
269 fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
270 Err(Error::Unsupported)
271 }
272
273 fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
274 Ok(())
275 }
276
277 fn serialize_unit_variant(
278 self,
279 _name: &'static str,
280 _variant_index: u32,
281 _variant: &'static str,
282 ) -> Result<Self::Ok, Self::Error> {
283 Err(Error::Unsupported)
284 }
285}
286
287impl<'a, C> serde::ser::SerializeMap for TableMapper<'a, true, C> {
288 type Ok = ();
289 type Error = Error;
290
291 fn end(self) -> Result<Self::Ok, Self::Error> {
292 Err(Error::Unsupported)
293 }
294
295 fn serialize_entry<K: ?Sized + serde::Serialize, V: ?Sized + serde::Serialize>(
296 &mut self,
297 _key: &K,
298 _value: &V,
299 ) -> Result<(), Self::Error> {
300 Err(Error::Unsupported)
301 }
302
303 fn serialize_key<T: ?Sized + serde::Serialize>(&mut self, _key: &T) -> Result<(), Self::Error> {
304 Err(Error::Unsupported)
305 }
306
307 fn serialize_value<T: ?Sized + serde::Serialize>(
308 &mut self,
309 _value: &T,
310 ) -> Result<(), Self::Error> {
311 Err(Error::Unsupported)
312 }
313}
314
315impl<'a, C> serde::ser::SerializeSeq for TableMapper<'a, true, C> {
316 type Ok = ();
317 type Error = Error;
318
319 fn end(self) -> Result<Self::Ok, Self::Error> {
320 Err(Error::Unsupported)
321 }
322
323 fn serialize_element<T: ?Sized + serde::Serialize>(
324 &mut self,
325 _value: &T,
326 ) -> Result<(), Self::Error> {
327 Err(Error::Unsupported)
328 }
329}
330
331impl<'a, C> serde::ser::SerializeStructVariant for TableMapper<'a, true, C> {
332 type Ok = ();
333 type Error = Error;
334
335 fn end(self) -> Result<Self::Ok, Self::Error> {
336 Err(Error::Unsupported)
337 }
338
339 fn serialize_field<T: ?Sized + serde::Serialize>(
340 &mut self,
341 _key: &'static str,
342 _value: &T,
343 ) -> Result<(), Self::Error> {
344 Err(Error::Unsupported)
345 }
346
347 fn skip_field(&mut self, _key: &'static str) -> Result<(), Self::Error> {
348 Err(Error::Unsupported)
349 }
350}
351
352impl<'a, C> serde::ser::SerializeTuple for TableMapper<'a, true, C> {
353 type Ok = ();
354 type Error = Error;
355
356 fn end(self) -> Result<Self::Ok, Self::Error> {
357 Err(Error::Unsupported)
358 }
359
360 fn serialize_element<T: ?Sized + serde::Serialize>(
361 &mut self,
362 _value: &T,
363 ) -> Result<(), Self::Error> {
364 Err(Error::Unsupported)
365 }
366}
367
368impl<'a, C> serde::ser::SerializeTupleStruct for TableMapper<'a, true, C> {
369 type Ok = ();
370 type Error = Error;
371
372 fn end(self) -> Result<Self::Ok, Self::Error> {
373 Err(Error::Unsupported)
374 }
375
376 fn serialize_field<T: ?Sized + serde::Serialize>(
377 &mut self,
378 _value: &T,
379 ) -> Result<(), Self::Error> {
380 Err(Error::Unsupported)
381 }
382}
383
384impl<'a, C> serde::ser::SerializeTupleVariant for TableMapper<'a, true, C> {
385 type Ok = ();
386 type Error = Error;
387
388 fn end(self) -> Result<Self::Ok, Self::Error> {
389 Err(Error::Unsupported)
390 }
391
392 fn serialize_field<T: ?Sized + serde::Serialize>(
393 &mut self,
394 _value: &T,
395 ) -> Result<(), Self::Error> {
396 Err(Error::Unsupported)
397 }
398}
399
400impl<'a, 'de: 'a, const W: bool, C: bincode::config::Config> serde::de::Deserializer<'de>
401 for &TableMapper<'a, W, C>
402{
403 type Error = Error;
404
405 fn deserialize_struct<V: serde::de::Visitor<'de>>(
406 self,
407 _name: &'static str,
408 fields: &'static [&'static str],
409 visitor: V,
410 ) -> Result<V::Value, Self::Error> {
411 visitor.visit_map(TableMapperAccess {
412 table: self,
413 fields,
414 })
415 }
416
417 fn is_human_readable(&self) -> bool {
418 false
419 }
420
421 fn deserialize_any<V: serde::de::Visitor<'de>>(
422 self,
423 _visitor: V,
424 ) -> Result<V::Value, Self::Error> {
425 Err(Error::Unsupported)
426 }
427
428 fn deserialize_bool<V: serde::de::Visitor<'de>>(
429 self,
430 _visitor: V,
431 ) -> Result<V::Value, Self::Error> {
432 Err(Error::Unsupported)
433 }
434
435 fn deserialize_byte_buf<V: serde::de::Visitor<'de>>(
436 self,
437 _visitor: V,
438 ) -> Result<V::Value, Self::Error> {
439 Err(Error::Unsupported)
440 }
441
442 fn deserialize_bytes<V: serde::de::Visitor<'de>>(
443 self,
444 _visitor: V,
445 ) -> Result<V::Value, Self::Error> {
446 Err(Error::Unsupported)
447 }
448
449 fn deserialize_char<V: serde::de::Visitor<'de>>(
450 self,
451 _visitor: V,
452 ) -> Result<V::Value, Self::Error> {
453 Err(Error::Unsupported)
454 }
455
456 fn deserialize_enum<V: serde::de::Visitor<'de>>(
457 self,
458 _name: &'static str,
459 _variants: &'static [&'static str],
460 _visitor: V,
461 ) -> Result<V::Value, Self::Error> {
462 Err(Error::Unsupported)
463 }
464
465 fn deserialize_f32<V: serde::de::Visitor<'de>>(
466 self,
467 _visitor: V,
468 ) -> Result<V::Value, Self::Error> {
469 Err(Error::Unsupported)
470 }
471
472 fn deserialize_f64<V: serde::de::Visitor<'de>>(
473 self,
474 _visitor: V,
475 ) -> Result<V::Value, Self::Error> {
476 Err(Error::Unsupported)
477 }
478
479 fn deserialize_i16<V: serde::de::Visitor<'de>>(
480 self,
481 _visitor: V,
482 ) -> Result<V::Value, Self::Error> {
483 Err(Error::Unsupported)
484 }
485
486 fn deserialize_i32<V: serde::de::Visitor<'de>>(
487 self,
488 _visitor: V,
489 ) -> Result<V::Value, Self::Error> {
490 Err(Error::Unsupported)
491 }
492
493 fn deserialize_i64<V: serde::de::Visitor<'de>>(
494 self,
495 _visitor: V,
496 ) -> Result<V::Value, Self::Error> {
497 Err(Error::Unsupported)
498 }
499
500 fn deserialize_i8<V: serde::de::Visitor<'de>>(
501 self,
502 _visitor: V,
503 ) -> Result<V::Value, Self::Error> {
504 Err(Error::Unsupported)
505 }
506
507 fn deserialize_identifier<V: serde::de::Visitor<'de>>(
508 self,
509 _visitor: V,
510 ) -> Result<V::Value, Self::Error> {
511 Err(Error::Unsupported)
512 }
513
514 fn deserialize_ignored_any<V: serde::de::Visitor<'de>>(
515 self,
516 _visitor: V,
517 ) -> Result<V::Value, Self::Error> {
518 Err(Error::Unsupported)
519 }
520
521 fn deserialize_newtype_struct<V: serde::de::Visitor<'de>>(
522 self,
523 _name: &'static str,
524 _visitor: V,
525 ) -> Result<V::Value, Self::Error> {
526 Err(Error::Unsupported)
527 }
528
529 fn deserialize_map<V: serde::de::Visitor<'de>>(
530 self,
531 _visitor: V,
532 ) -> Result<V::Value, Self::Error> {
533 Err(Error::Unsupported)
534 }
535
536 fn deserialize_option<V: serde::de::Visitor<'de>>(
537 self,
538 _visitor: V,
539 ) -> Result<V::Value, Self::Error> {
540 Err(Error::Unsupported)
541 }
542
543 fn deserialize_seq<V: serde::de::Visitor<'de>>(
544 self,
545 _visitor: V,
546 ) -> Result<V::Value, Self::Error> {
547 Err(Error::Unsupported)
548 }
549
550 fn deserialize_str<V: serde::de::Visitor<'de>>(
551 self,
552 _visitor: V,
553 ) -> Result<V::Value, Self::Error> {
554 Err(Error::Unsupported)
555 }
556
557 fn deserialize_string<V: serde::de::Visitor<'de>>(
558 self,
559 _visitor: V,
560 ) -> Result<V::Value, Self::Error> {
561 Err(Error::Unsupported)
562 }
563
564 fn deserialize_tuple<V: serde::de::Visitor<'de>>(
565 self,
566 _len: usize,
567 _visitor: V,
568 ) -> Result<V::Value, Self::Error> {
569 Err(Error::Unsupported)
570 }
571
572 fn deserialize_tuple_struct<V: serde::de::Visitor<'de>>(
573 self,
574 _name: &'static str,
575 _len: usize,
576 _visitor: V,
577 ) -> Result<V::Value, Self::Error> {
578 Err(Error::Unsupported)
579 }
580
581 fn deserialize_u16<V: serde::de::Visitor<'de>>(
582 self,
583 _visitor: V,
584 ) -> Result<V::Value, Self::Error> {
585 Err(Error::Unsupported)
586 }
587
588 fn deserialize_u32<V: serde::de::Visitor<'de>>(
589 self,
590 _visitor: V,
591 ) -> Result<V::Value, Self::Error> {
592 Err(Error::Unsupported)
593 }
594
595 fn deserialize_u64<V: serde::de::Visitor<'de>>(
596 self,
597 _visitor: V,
598 ) -> Result<V::Value, Self::Error> {
599 Err(Error::Unsupported)
600 }
601
602 fn deserialize_u8<V: serde::de::Visitor<'de>>(
603 self,
604 _visitor: V,
605 ) -> Result<V::Value, Self::Error> {
606 Err(Error::Unsupported)
607 }
608
609 fn deserialize_unit<V: serde::de::Visitor<'de>>(
610 self,
611 visitor: V,
612 ) -> Result<V::Value, Self::Error> {
613 visitor.visit_unit()
614 }
615
616 fn deserialize_unit_struct<V: serde::de::Visitor<'de>>(
617 self,
618 _name: &'static str,
619 visitor: V,
620 ) -> Result<V::Value, Self::Error> {
621 visitor.visit_unit()
622 }
623}
624
625struct TableMapperAccess<'a, const W: bool, C> {
626 table: &'a TableMapper<'a, W, C>,
627 fields: &'static [&'static str],
628}
629
630impl<'a, 'de: 'a, const W: bool, C: bincode::config::Config> serde::de::MapAccess<'de>
631 for TableMapperAccess<'a, W, C>
632{
633 type Error = Error;
634
635 fn next_key_seed<K: serde::de::DeserializeSeed<'de>>(
636 &mut self,
637 seed: K,
638 ) -> Result<Option<K::Value>, Self::Error> {
639 if self.fields.is_empty() {
640 Ok(None)
641 } else {
642 let deserializer = serde::de::value::StrDeserializer::new(self.fields[0]);
643
644 seed.deserialize(deserializer).map(Some)
645 }
646 }
647
648 fn next_value_seed<V: serde::de::DeserializeSeed<'de>>(
649 &mut self,
650 seed: V,
651 ) -> Result<V::Value, Self::Error> {
652 const BINCODE_NONE_BYTES: [u8; 1] = [0];
654
655 let field_name = self.fields[0].as_bytes();
656 self.fields = &self.fields[1..];
657
658 let bytes = self.table.db.get(self.table.cf, field_name)?;
659
660 match bytes {
661 Some(bytes) => {
662 let mut deserializer = OwnedSerdeDecoder::from_reader(
663 BufReader::new(Cursor::new(bytes)),
664 self.table.bincode_config,
665 );
666
667 seed.deserialize(deserializer.as_deserializer())
668 .map_err(Error::Decoding)
669 }
670 None => {
671 let mut deserializer = OwnedSerdeDecoder::from_reader(
672 BufReader::new(Cursor::new(BINCODE_NONE_BYTES)),
673 self.table.bincode_config,
674 );
675
676 seed.deserialize(deserializer.as_deserializer())
677 .map_err(Error::Decoding)
678 }
679 }
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use quickcheck_arbitrary_derive::QuickCheck;
686 use serde::{de::Deserialize, ser::Serialize};
687
688 #[derive(
689 Clone, Debug, Eq, PartialEq, QuickCheck, serde_derive::Deserialize, serde_derive::Serialize,
690 )]
691 struct Test {
692 foo: String,
693 bar: Vec<Option<u64>>,
694 qux: bool,
695 }
696
697 #[quickcheck_macros::quickcheck]
698 fn round_trip_test(test: Test, new_foo: String) -> bool {
699 let mut options = rocksdb::Options::default();
700 options.create_if_missing(true);
701 options.create_missing_column_families(true);
702
703 let db = rocksdb::OptimisticTransactionDB::open_cf_descriptors(
704 &options,
705 tempfile::tempdir().unwrap(),
706 vec![rocksdb::ColumnFamilyDescriptor::new(
707 "test",
708 rocksdb::Options::default(),
709 )],
710 )
711 .unwrap();
712
713 let wrapper = crate::wrapper::Db::from(db);
714
715 let mapper = super::TableMapper::new(
716 &wrapper,
717 wrapper.handle("test").unwrap(),
718 bincode::config::standard(),
719 );
720
721 test.serialize(mapper).unwrap();
722
723 let mapper = super::TableMapper::<true, _>::new(
724 &wrapper,
725 wrapper.handle("test").unwrap(),
726 bincode::config::standard(),
727 );
728
729 let read_test = Test::deserialize(&mapper).unwrap();
730
731 let mut new_test = read_test.clone();
732 new_test.foo = new_foo;
733
734 let mapper = super::TableMapper::new(
735 &wrapper,
736 wrapper.handle("test").unwrap(),
737 bincode::config::standard(),
738 );
739
740 new_test.serialize(mapper).unwrap();
741
742 let mapper = super::TableMapper::<true, _>::new(
743 &wrapper,
744 wrapper.handle("test").unwrap(),
745 bincode::config::standard(),
746 );
747
748 let new_read_test = Test::deserialize(&mapper).unwrap();
749
750 read_test == test && new_read_test == new_test
751 }
752}