1use std::marker::PhantomData;
2
3use sled::transaction::{ConflictableTransactionResult, TransactionResult};
4
5use crate::{deserialize, serialize, Batch, Tree, Key, Value};
6
7pub struct TransactionalTree<'a, K, V>
8where K: Key, V:Value{
9 inner: &'a sled::transaction::TransactionalTree,
10 _key: PhantomData<fn() -> K>,
11 _value: PhantomData<fn() -> V>,
12}
13
14impl<'a, K:Key, V:Value> TransactionalTree<'a, K, V> {
15 pub(crate) fn new(sled: &'a sled::transaction::TransactionalTree) -> Self {
16 Self {
17 inner: sled,
18 _key: PhantomData,
19 _value: PhantomData,
20 }
21 }
22
23 pub fn insert(
24 &self,
25 key: &K,
26 value: &V,
27 ) -> std::result::Result<Option<V>, sled::transaction::UnabortableTransactionError>
28 {
29 self.inner
30 .insert(serialize(key), serialize(value))
31 .map(|opt| opt.map(|v| deserialize(&v)))
32 }
33
34 pub fn remove(
35 &self,
36 key: &K,
37 ) -> std::result::Result<Option<V>, sled::transaction::UnabortableTransactionError>
38 {
39 self.inner
40 .remove(serialize(key))
41 .map(|opt| opt.map(|v| deserialize(&v)))
42 }
43
44 pub fn get(
45 &self,
46 key: &K,
47 ) -> std::result::Result<Option<V>, sled::transaction::UnabortableTransactionError>
48 {
49 self.inner
50 .get(serialize(key))
51 .map(|opt| opt.map(|v| deserialize(&v)))
52 }
53
54 pub fn apply_batch(
55 &self,
56 batch: &Batch<K, V>,
57 ) -> std::result::Result<(), sled::transaction::UnabortableTransactionError> {
58 self.inner.apply_batch(&batch.inner)
59 }
60
61 pub fn flush(&self) {
62 self.inner.flush()
63 }
64
65 pub fn generate_id(&self) -> sled::Result<u64> {
66 self.inner.generate_id()
67 }
68}
69
70pub trait Transactional<E = ()> {
71 type View<'a>;
72
73 fn transaction<F, A>(&self, f: F) -> TransactionResult<A, E>
74 where
75 F: for<'a> Fn(Self::View<'a>) -> ConflictableTransactionResult<A, E>;
76}
77
78macro_rules! impl_transactional {
79 ($($k:ident, $v:ident, $i:tt),+) => {
80 impl<E, $($k:Key, $v:Value),+> Transactional<E> for ($(&Tree<$k, $v>),+) {
81 type View<'a> = (
82 $(TransactionalTree<'a, $k, $v>),+
83 );
84
85 fn transaction<F, A>(&self, f: F) -> TransactionResult<A, E>
86 where
87 F: for<'a> Fn(Self::View<'a>) -> ConflictableTransactionResult<A, E>,
88 {
89 use sled::Transactional;
90
91 ($(&self.$i.inner),+).transaction(|trees| {
92 f((
93 $(TransactionalTree::new(&trees.$i)),+
94 ))
95 })
96 }
97 }
98 };
99}
100
101impl_transactional!(K0, V0, 0, K1, V1, 1);
102impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2);
103impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3);
104impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4);
105impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5);
106impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6);
107impl_transactional!(
108 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7
109);
110impl_transactional!(
111 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7, K8, V8,
112 8
113);
114impl_transactional!(
115 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7, K8, V8,
116 8, K9, V9, 9
117);
118impl_transactional!(
119 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7, K8, V8,
120 8, K9, V9, 9, K10, V10, 10
121);
122
123#[test]
124fn test_multiple_tree_transaction() {
125 let db = sled::Config::new().temporary(true).open().unwrap();
126 let tree0 = Tree::<u32, i32>::open(&db, "tree0");
127 let tree1 = Tree::<u16, i16>::open(&db, "tree1");
128 let tree2 = Tree::<u8, i8>::open(&db, "tree2");
129
130 (&tree0, &tree1, &tree2)
131 .transaction(|trees| {
132 trees.0.insert(&0, &0)?;
133 trees.1.insert(&0, &0)?;
134 trees.2.insert(&0, &0)?;
135 Ok::<(), sled::transaction::ConflictableTransactionError<()>>(())
138 })
139 .unwrap();
140
141 assert_eq!(tree0.get(&0), Ok(Some(0)));
142 assert_eq!(tree1.get(&0), Ok(Some(0)));
143}