1use super::{Deletable, Gettable, Updatable};
4use crate::qmdb::Error;
5use commonware_codec::CodecShared;
6use commonware_utils::Array;
7use std::{collections::BTreeMap, future::Future};
8
9pub struct Batch<'a, K, V, D>
13where
14 K: Array,
15 V: CodecShared + Clone,
16 D: Gettable<Key = K, Value = V, Error = Error> + Sync,
17{
18 db: &'a D,
20 diff: BTreeMap<K, Option<V>>,
27}
28
29impl<'a, K, V, D> Batch<'a, K, V, D>
30where
31 K: Array,
32 V: CodecShared + Clone,
33 D: Gettable<Key = K, Value = V, Error = Error> + Sync,
34{
35 pub const fn new(db: &'a D) -> Self {
37 Self {
38 db,
39 diff: BTreeMap::new(),
40 }
41 }
42
43 pub async fn delete_unchecked(&mut self, key: K) -> Result<(), Error> {
45 self.diff.insert(key, None);
46
47 Ok(())
48 }
49}
50
51impl<'a, K, V, D> Gettable for Batch<'a, K, V, D>
52where
53 K: Array,
54 V: CodecShared + Clone,
55 D: Gettable<Key = K, Value = V, Error = Error> + Sync,
56{
57 type Key = K;
58 type Value = V;
59 type Error = Error;
60
61 async fn get(&self, key: &K) -> Result<Option<V>, Error> {
64 if let Some(value) = self.diff.get(key) {
65 return Ok(value.clone());
66 }
67
68 self.db.get(key).await
69 }
70}
71
72impl<'a, K, V, D> Updatable for Batch<'a, K, V, D>
73where
74 K: Array,
75 V: CodecShared + Clone,
76 D: Gettable<Key = K, Value = V, Error = Error> + Sync,
77{
78 async fn update(&mut self, key: K, value: V) -> Result<(), Error> {
80 self.diff.insert(key, Some(value));
81
82 Ok(())
83 }
84}
85
86impl<'a, K, V, D> Deletable for Batch<'a, K, V, D>
87where
88 K: Array,
89 V: CodecShared + Clone,
90 D: Gettable<Key = K, Value = V, Error = Error> + Sync,
91{
92 async fn delete(&mut self, key: K) -> Result<bool, Error> {
95 if let Some(entry) = self.diff.get_mut(&key) {
96 match entry {
97 Some(_) => {
98 *entry = None;
99 return Ok(true);
100 }
101 None => return Ok(false),
102 }
103 }
104
105 if self.db.get(&key).await?.is_some() {
106 self.diff.insert(key, None);
107 return Ok(true);
108 }
109
110 Ok(false)
111 }
112}
113
114impl<'a, K, V, D> IntoIterator for Batch<'a, K, V, D>
115where
116 K: Array,
117 V: CodecShared + Clone,
118 D: Gettable<Key = K, Value = V, Error = Error> + Sync,
119{
120 type Item = (K, Option<V>);
121 type IntoIter = std::collections::btree_map::IntoIter<K, Option<V>>;
122
123 fn into_iter(self) -> Self::IntoIter {
124 self.diff.into_iter()
125 }
126}
127
128pub trait Batchable:
130 Gettable<Key: Array, Value: CodecShared + Clone, Error = Error> + Updatable + Deletable
131{
132 fn start_batch(&self) -> Batch<'_, Self::Key, Self::Value, Self>
134 where
135 Self: Sized + Sync,
136 Self::Value: Send + Sync,
137 {
138 Batch {
139 db: self,
140 diff: BTreeMap::new(),
141 }
142 }
143
144 fn write_batch<'a, Iter>(
146 &'a mut self,
147 iter: Iter,
148 ) -> impl Future<Output = Result<(), Error>> + Send + use<'a, Self, Iter>
149 where
150 Self: Send,
151 Iter: Iterator<Item = (Self::Key, Option<Self::Value>)> + Send + 'a,
152 {
153 async move {
154 for (key, value) in iter {
155 if let Some(value) = value {
156 self.update(key, value).await?;
157 } else {
158 self.delete(key).await?;
159 }
160 }
161 Ok(())
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::{
170 kv::tests::{assert_deletable, assert_gettable, assert_send, assert_updatable},
171 qmdb::store::db::Db,
172 translator::TwoCap,
173 };
174 use commonware_cryptography::sha256::Digest;
175 use commonware_runtime::deterministic::Context;
176
177 type TestStore = Db<Context, Digest, Vec<u8>, TwoCap>;
178 type TestBatch<'a> = Batch<'a, Digest, Vec<u8>, TestStore>;
179
180 #[allow(dead_code)]
181 fn assert_batch_futures_are_send(batch: &mut TestBatch<'_>, key: Digest) {
182 assert_gettable(batch, &key);
183 assert_updatable(batch, key, vec![]);
184 assert_deletable(batch, key);
185 }
186
187 #[allow(dead_code)]
188 fn assert_batch_delete_unchecked_is_send(batch: &mut TestBatch<'_>, key: Digest) {
189 assert_send(batch.delete_unchecked(key));
190 }
191}