1use crate::{
4 qmdb::Error,
5 store::{Store, StoreDeletable},
6};
7use commonware_codec::Codec;
8use commonware_utils::Array;
9use core::future::Future;
10use std::collections::BTreeMap;
11
12pub trait Getter<K, V> {
14 fn get(&self, key: &K) -> impl Future<Output = Result<Option<V>, Error>>;
16}
17
18impl<D> Getter<D::Key, D::Value> for D
20where
21 D: Store<Error = Error>,
22 D::Key: Array,
23 D::Value: Codec + Clone,
24{
25 async fn get(&self, key: &D::Key) -> Result<Option<D::Value>, Error> {
26 Store::get(self, key).await
27 }
28}
29
30pub struct Batch<'a, K, V, D>
34where
35 K: Array,
36 V: Codec + Clone,
37 D: Getter<K, V>,
38{
39 db: &'a D,
41 diff: BTreeMap<K, Option<V>>,
48}
49
50impl<'a, K, V, D> Batch<'a, K, V, D>
51where
52 K: Array,
53 V: Codec + Clone,
54 D: Getter<K, V>,
55{
56 pub const fn new(db: &'a D) -> Self {
58 Self {
59 db,
60 diff: BTreeMap::new(),
61 }
62 }
63
64 pub async fn get(&self, key: &K) -> Result<Option<V>, Error> {
67 if let Some(value) = self.diff.get(key) {
68 return Ok(value.clone());
69 }
70
71 self.db.get(key).await
72 }
73
74 pub async fn create(&mut self, key: K, value: V) -> Result<bool, Error> {
77 if let Some(value_opt) = self.diff.get_mut(&key) {
78 match value_opt {
79 Some(_) => return Ok(false),
80 None => {
81 *value_opt = Some(value);
82 return Ok(true);
83 }
84 }
85 }
86
87 if self.db.get(&key).await?.is_some() {
88 return Ok(false);
89 }
90
91 self.diff.insert(key, Some(value));
92 Ok(true)
93 }
94
95 pub async fn update(&mut self, key: K, value: V) -> Result<(), Error> {
97 self.diff.insert(key, Some(value));
98
99 Ok(())
100 }
101
102 pub async fn delete(&mut self, key: K) -> Result<bool, Error> {
105 if let Some(entry) = self.diff.get_mut(&key) {
106 match entry {
107 Some(_) => {
108 *entry = None;
109 return Ok(true);
110 }
111 None => return Ok(false),
112 }
113 }
114
115 if self.db.get(&key).await?.is_some() {
116 self.diff.insert(key, None);
117 return Ok(true);
118 }
119
120 Ok(false)
121 }
122
123 pub async fn delete_unchecked(&mut self, key: K) -> Result<(), Error> {
125 self.diff.insert(key, None);
126
127 Ok(())
128 }
129}
130
131impl<'a, K, V, D> IntoIterator for Batch<'a, K, V, D>
132where
133 K: Array,
134 V: Codec + Clone,
135 D: Getter<K, V>,
136{
137 type Item = (K, Option<V>);
138 type IntoIter = std::collections::btree_map::IntoIter<K, Option<V>>;
139
140 fn into_iter(self) -> Self::IntoIter {
141 self.diff.into_iter()
142 }
143}
144
145pub trait Batchable: StoreDeletable<Key: Array, Value: Codec + Clone, Error = Error> {
147 fn start_batch(&self) -> Batch<'_, Self::Key, Self::Value, Self>
149 where
150 Self: Sized,
151 {
152 Batch {
153 db: self,
154 diff: BTreeMap::new(),
155 }
156 }
157
158 fn write_batch(
160 &mut self,
161 iter: impl Iterator<Item = (Self::Key, Option<Self::Value>)>,
162 ) -> impl Future<Output = Result<(), Error>> {
163 async {
164 for (key, value) in iter {
165 if let Some(value) = value {
166 self.update(key, value).await?;
167 } else {
168 self.delete(key).await?;
169 }
170 }
171 Ok(())
172 }
173 }
174}
175
176#[cfg(test)]
177pub mod tests {
178 use super::*;
179 use crate::store::StorePersistable;
180 use commonware_cryptography::{blake3, sha256};
181 use commonware_runtime::{
182 deterministic::{self, Context},
183 Runner as _,
184 };
185 use core::{fmt::Debug, future::Future};
186 use rand::{rngs::StdRng, Rng, SeedableRng};
187 use std::collections::HashSet;
188
189 pub trait TestKey: Array {
190 fn from_seed(seed: u8) -> Self;
191 }
192
193 pub trait TestValue: Codec + Clone + PartialEq + Debug {
194 fn from_seed(seed: u8) -> Self;
195 }
196
197 pub fn test_batch<D, F, Fut>(mut new_db: F)
200 where
201 F: FnMut(Context) -> Fut + Clone,
202 Fut: Future<Output = D>,
203 D: Batchable + StorePersistable,
204 D::Key: TestKey,
205 D::Value: TestValue,
206 {
207 let executor = deterministic::Runner::default();
208 let mut new_db_clone = new_db.clone();
209 let state1 = executor.start(|context| async move {
210 let ctx = context.clone();
211 run_batch_tests::<D, _, Fut>(&mut || new_db_clone(ctx.clone()))
212 .await
213 .unwrap();
214 ctx.auditor().state()
215 });
216
217 let executor = deterministic::Runner::default();
218 let state2 = executor.start(|context| async move {
219 let ctx = context.clone();
220 run_batch_tests::<D, _, Fut>(&mut || new_db(ctx.clone()))
221 .await
222 .unwrap();
223 ctx.auditor().state()
224 });
225
226 assert_eq!(state1, state2);
227 }
228
229 pub async fn run_batch_tests<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
231 where
232 F: FnMut() -> Fut,
233 Fut: Future<Output = D>,
234 D: Batchable + StorePersistable,
235 D::Key: TestKey,
236 D::Value: TestValue,
237 {
238 test_overlay_reads(new_db).await?;
239 test_create(new_db).await?;
240 test_delete(new_db).await?;
241 test_delete_unchecked(new_db).await?;
242 test_write_batch_from_to_empty(new_db).await?;
243 test_write_batch(new_db).await?;
244 test_update_delete_update(new_db).await?;
245 Ok(())
246 }
247
248 async fn test_overlay_reads<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
249 where
250 F: FnMut() -> Fut,
251 Fut: Future<Output = D>,
252 D: Batchable + StorePersistable,
253 D::Key: TestKey,
254 D::Value: TestValue,
255 {
256 let mut db = new_db().await;
257 let key = D::Key::from_seed(1);
258 db.update(key.clone(), D::Value::from_seed(1)).await?;
259
260 let mut batch = db.start_batch();
261 assert_eq!(batch.get(&key).await?, Some(D::Value::from_seed(1)));
262
263 batch.update(key.clone(), D::Value::from_seed(9)).await?;
264 assert_eq!(batch.get(&key).await?, Some(D::Value::from_seed(9)));
265
266 db.destroy().await?;
267 Ok(())
268 }
269
270 async fn test_create<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
271 where
272 F: FnMut() -> Fut,
273 Fut: Future<Output = D>,
274 D: Batchable + StorePersistable,
275 D::Key: TestKey,
276 D::Value: TestValue,
277 {
278 let mut db = new_db().await;
279 let mut batch = db.start_batch();
280 let key = D::Key::from_seed(2);
281 assert!(batch.create(key.clone(), D::Value::from_seed(1)).await?);
282 assert!(!batch.create(key.clone(), D::Value::from_seed(2)).await?);
283
284 batch.delete_unchecked(key.clone()).await?;
285 assert!(batch.create(key.clone(), D::Value::from_seed(3)).await?);
286 assert_eq!(batch.get(&key).await?, Some(D::Value::from_seed(3)));
287
288 let existing = D::Key::from_seed(3);
289 db.update(existing.clone(), D::Value::from_seed(4)).await?;
290 let mut batch = db.start_batch();
291 assert!(
292 !batch
293 .create(existing.clone(), D::Value::from_seed(5))
294 .await?
295 );
296
297 db.destroy().await?;
298 Ok(())
299 }
300
301 async fn test_delete<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
302 where
303 F: FnMut() -> Fut,
304 Fut: Future<Output = D>,
305 D: Batchable + StorePersistable,
306 D::Key: TestKey,
307 D::Value: TestValue,
308 {
309 let mut db = new_db().await;
310 let base_key = D::Key::from_seed(4);
311 db.update(base_key.clone(), D::Value::from_seed(10)).await?;
312 let mut batch = db.start_batch();
313 assert!(batch.delete(base_key.clone()).await?);
314 assert_eq!(batch.get(&base_key).await?, None);
315 assert!(!batch.delete(base_key.clone()).await?);
316
317 let mut batch = db.start_batch();
318 let overlay_key = D::Key::from_seed(5);
319 batch
320 .update(overlay_key.clone(), D::Value::from_seed(11))
321 .await?;
322 assert!(batch.delete(overlay_key.clone()).await?);
323 assert_eq!(batch.get(&overlay_key).await?, None);
324 assert!(!batch.delete(overlay_key).await?);
325
326 db.destroy().await?;
327 Ok(())
328 }
329
330 async fn test_delete_unchecked<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
331 where
332 F: FnMut() -> Fut,
333 Fut: Future<Output = D>,
334 D: Batchable + StorePersistable,
335 D::Key: TestKey,
336 D::Value: TestValue,
337 {
338 let mut db = new_db().await;
339 let key = D::Key::from_seed(6);
340
341 let mut batch = db.start_batch();
342 batch.update(key.clone(), D::Value::from_seed(12)).await?;
343 batch.delete_unchecked(key.clone()).await?;
344 assert_eq!(batch.get(&key).await?, None);
345
346 db.update(key.clone(), D::Value::from_seed(13)).await?;
347 let mut batch = db.start_batch();
348 batch.delete_unchecked(key.clone()).await?;
349 assert_eq!(batch.get(&key).await?, None);
350
351 db.destroy().await?;
352 Ok(())
353 }
354
355 async fn test_update_delete_update<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
358 where
359 F: FnMut() -> Fut,
360 Fut: Future<Output = D>,
361 D: Batchable + StorePersistable,
362 D::Key: TestKey,
363 D::Value: TestValue,
364 {
365 let mut db = new_db().await;
366 for i in 0..100 {
368 assert!(
369 db.create(D::Key::from_seed(i), D::Value::from_seed(i))
370 .await?
371 );
372 }
373 db.commit().await?;
374
375 let mut rng = StdRng::seed_from_u64(1337);
377 let mut deleted = HashSet::new();
378 let mut batch = db.start_batch();
379 for i in 0..100 {
380 if rng.gen_bool(0.5) {
381 deleted.insert(i);
382 assert!(batch.delete(D::Key::from_seed(i)).await?);
383 }
384 }
385 batch.delete_unchecked(D::Key::from_seed(255)).await?;
387
388 db.write_batch(batch.into_iter()).await?;
390 db.commit().await?;
391 for i in 0..100 {
392 if deleted.contains(&i) {
393 assert_eq!(Store::get(&db, &D::Key::from_seed(i)).await?, None);
394 } else {
395 assert_eq!(
396 Store::get(&db, &D::Key::from_seed(i)).await?,
397 Some(D::Value::from_seed(i))
398 );
399 }
400 }
401
402 let mut batch = db.start_batch();
404 for i in 0..100 {
405 if deleted.contains(&i) {
406 batch
407 .create(D::Key::from_seed(i), D::Value::from_seed(i))
408 .await?;
409 }
410 }
411
412 db.write_batch(batch.into_iter()).await?;
414 db.commit().await?;
415
416 for i in 0..100 {
417 assert_eq!(
418 Store::get(&db, &D::Key::from_seed(i)).await?,
419 Some(D::Value::from_seed(i))
420 );
421 }
422
423 db.destroy().await?;
424
425 Ok(())
426 }
427
428 async fn test_write_batch_from_to_empty<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
431 where
432 F: FnMut() -> Fut,
433 Fut: Future<Output = D>,
434 D: Batchable + StorePersistable,
435 D::Key: TestKey,
436 D::Value: TestValue,
437 {
438 let mut db = new_db().await;
440 let created1 = D::Key::from_seed(1);
441 let created2 = D::Key::from_seed(2);
442 let mut batch = db.start_batch();
443 batch
444 .create(created1.clone(), D::Value::from_seed(1))
445 .await?;
446 batch
447 .create(created2.clone(), D::Value::from_seed(2))
448 .await?;
449 batch
450 .update(created1.clone(), D::Value::from_seed(3))
451 .await?;
452 db.write_batch(batch.into_iter()).await?;
453
454 assert_eq!(
455 Store::get(&db, &created1).await?,
456 Some(D::Value::from_seed(3))
457 );
458 assert_eq!(
459 Store::get(&db, &created2).await?,
460 Some(D::Value::from_seed(2))
461 );
462
463 let mut delete_batch = db.start_batch();
464 delete_batch.delete(created1.clone()).await?;
465 delete_batch.delete(created2.clone()).await?;
466 db.write_batch(delete_batch.into_iter()).await?;
467 assert_eq!(Store::get(&db, &created1).await?, None);
468 assert_eq!(Store::get(&db, &created2).await?, None);
469
470 db.destroy().await?;
471
472 let mut db = new_db().await;
474 let created1 = D::Key::from_seed(1);
475 let mut batch = db.start_batch();
476 batch
477 .create(created1.clone(), D::Value::from_seed(1))
478 .await?;
479 db.write_batch(batch.into_iter()).await?;
480 assert_eq!(
481 Store::get(&db, &created1).await?,
482 Some(D::Value::from_seed(1))
483 );
484 let mut delete_batch = db.start_batch();
485 delete_batch.delete(created1.clone()).await?;
486 db.write_batch(delete_batch.into_iter()).await?;
487 assert_eq!(Store::get(&db, &created1).await?, None);
488
489 db.destroy().await?;
490
491 Ok(())
492 }
493
494 async fn test_write_batch<D, F, Fut>(new_db: &mut F) -> Result<(), Error>
495 where
496 F: FnMut() -> Fut,
497 Fut: Future<Output = D>,
498 D: Batchable + StorePersistable,
499 D::Key: TestKey,
500 D::Value: TestValue,
501 {
502 let mut db = new_db().await;
503 let existing = D::Key::from_seed(7);
504 db.update(existing.clone(), D::Value::from_seed(0)).await?;
505
506 let created = D::Key::from_seed(8);
507 let mut batch = db.start_batch();
508 batch
509 .update(existing.clone(), D::Value::from_seed(8))
510 .await?;
511 batch
512 .create(created.clone(), D::Value::from_seed(9))
513 .await?;
514 db.write_batch(batch.into_iter()).await?;
515
516 assert_eq!(
517 Store::get(&db, &existing).await?,
518 Some(D::Value::from_seed(8))
519 );
520 assert_eq!(
521 Store::get(&db, &created).await?,
522 Some(D::Value::from_seed(9))
523 );
524
525 let mut delete_batch = db.start_batch();
526 delete_batch.delete(existing.clone()).await?;
527 db.write_batch(delete_batch.into_iter()).await?;
528 assert_eq!(Store::get(&db, &existing).await?, None);
529
530 db.destroy().await?;
531 Ok(())
532 }
533
534 fn seed_bytes(seed: u8) -> [u8; 32] {
535 let mut bytes = [0u8; 32];
536 bytes[0] = seed;
537 bytes
538 }
539
540 impl TestKey for blake3::Digest {
541 fn from_seed(seed: u8) -> Self {
542 Self::from(seed_bytes(seed))
543 }
544 }
545
546 impl TestKey for sha256::Digest {
547 fn from_seed(seed: u8) -> Self {
548 Self::from(seed_bytes(seed))
549 }
550 }
551
552 impl TestValue for Vec<u8> {
553 fn from_seed(seed: u8) -> Self {
554 vec![seed]
555 }
556 }
557
558 impl TestValue for blake3::Digest {
559 fn from_seed(seed: u8) -> Self {
560 Self::from(seed_bytes(seed))
561 }
562 }
563
564 impl TestValue for sha256::Digest {
565 fn from_seed(seed: u8) -> Self {
566 Self::from(seed_bytes(seed))
567 }
568 }
569}