1use std::collections::{BTreeMap, BTreeSet};
20
21use sled::{
22 transaction::{ConflictableTransactionError, TransactionError},
23 IVec, Transactional,
24};
25
26use crate::{SledTreeOverlay, SledTreeOverlayIter, SledTreeOverlayStateDiff};
27
28#[derive(Debug, Clone)]
30pub struct SledDbOverlayState {
31 pub initial_tree_names: Vec<IVec>,
33 pub new_tree_names: Vec<IVec>,
35 pub caches: BTreeMap<IVec, SledTreeOverlay>,
37 pub dropped_trees: BTreeMap<IVec, SledTreeOverlayStateDiff>,
39 pub protected_tree_names: Vec<IVec>,
42}
43
44impl SledDbOverlayState {
45 pub fn new(initial_tree_names: Vec<IVec>, protected_tree_names: Vec<IVec>) -> Self {
47 Self {
48 initial_tree_names,
49 new_tree_names: vec![],
50 caches: BTreeMap::new(),
51 dropped_trees: BTreeMap::new(),
52 protected_tree_names,
53 }
54 }
55
56 fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
60 let mut trees = vec![];
61 let mut batches = vec![];
62
63 for (key, cache) in self.caches.iter() {
64 if self.dropped_trees.contains_key(key) {
65 return Err(sled::Error::CollectionNotFound(key.into()));
66 }
67
68 if let Some(batch) = cache.aggregate() {
69 trees.push(cache.tree.clone());
70 batches.push(batch);
71 }
72 }
73
74 Ok((trees, batches))
75 }
76
77 pub fn add_diff(
79 &mut self,
80 db: &sled::Db,
81 diff: &SledDbOverlayStateDiff,
82 ) -> Result<(), sled::Error> {
83 self.initial_tree_names
84 .retain(|x| diff.initial_tree_names.contains(x));
85
86 for (k, (cache, drop)) in diff.caches.iter() {
87 if *drop {
88 assert!(!self.protected_tree_names.contains(k));
89 self.new_tree_names.retain(|x| x != k);
90 self.caches.remove(k);
91 self.dropped_trees.insert(k.clone(), cache.clone());
92 continue;
93 }
94
95 let Some(tree_overlay) = self.caches.get_mut(k) else {
96 if !self.initial_tree_names.contains(k) && !self.new_tree_names.contains(k) {
97 self.new_tree_names.push(k.clone());
98 }
99 let mut overlay = SledTreeOverlay::new(&db.open_tree(k)?);
100 overlay.add_diff(cache);
101 self.caches.insert(k.clone(), overlay);
102 continue;
103 };
104
105 tree_overlay.add_diff(cache);
107 }
108
109 for (k, (cache, restored)) in &diff.dropped_trees {
110 if !restored {
112 if self.dropped_trees.contains_key(k) {
113 continue;
114 }
115 self.new_tree_names.retain(|x| x != k);
116 self.caches.remove(k);
117 self.dropped_trees.insert(k.clone(), cache.clone());
118 continue;
119 }
120 assert!(!self.protected_tree_names.contains(k));
121
122 self.initial_tree_names.retain(|x| x != k);
124 if !self.new_tree_names.contains(k) {
125 self.new_tree_names.push(k.clone());
126 }
127
128 let mut overlay = SledTreeOverlay::new(&db.open_tree(k)?);
129 overlay.add_diff(cache);
130 self.caches.insert(k.clone(), overlay);
131 }
132
133 Ok(())
134 }
135
136 pub fn remove_diff(&mut self, diff: &SledDbOverlayStateDiff) {
138 for (k, (cache, drop)) in diff.caches.iter() {
142 assert!(
144 self.initial_tree_names.contains(k)
145 || self.new_tree_names.contains(k)
146 || self.dropped_trees.contains_key(k)
147 );
148 if !self.initial_tree_names.contains(k) {
149 self.initial_tree_names.push(k.clone());
150 }
151 self.new_tree_names.retain(|x| x != k);
152
153 if *drop {
155 assert!(!self.protected_tree_names.contains(k));
156 self.initial_tree_names.retain(|x| x != k);
157 self.new_tree_names.retain(|x| x != k);
158 self.caches.remove(k);
159 self.dropped_trees.remove(k);
160 continue;
161 }
162
163 let Some(tree_overlay) = self.caches.get_mut(k) else {
166 let Some(tree_overlay) = self.dropped_trees.get_mut(k) else {
167 continue;
168 };
169 tree_overlay.update_values(cache);
170 continue;
171 };
172
173 if tree_overlay.state == cache.into() {
175 if self.protected_tree_names.contains(k) {
177 tree_overlay.state.cache = BTreeMap::new();
178 tree_overlay.state.removed = BTreeSet::new();
179 tree_overlay.checkpoint();
180 continue;
181 }
182
183 self.caches.remove(k);
185 continue;
186 }
187
188 tree_overlay.remove_diff(cache);
190 }
191
192 for (k, (cache, restored)) in diff.dropped_trees.iter() {
194 assert!(
196 self.initial_tree_names.contains(k)
197 || self.new_tree_names.contains(k)
198 || self.dropped_trees.contains_key(k)
199 );
200
201 if !restored {
203 assert!(!self.protected_tree_names.contains(k));
204 self.initial_tree_names.retain(|x| x != k);
205 self.new_tree_names.retain(|x| x != k);
206 self.caches.remove(k);
207 self.dropped_trees.remove(k);
208 continue;
209 }
210
211 self.initial_tree_names.retain(|x| x != k);
213 if !self.new_tree_names.contains(k) {
214 self.new_tree_names.push(k.clone());
215 }
216
217 let Some(tree_overlay) = self.caches.get_mut(k) else {
219 continue;
220 };
221
222 if tree_overlay.state == cache.into() {
224 if self.protected_tree_names.contains(k) {
226 tree_overlay.state.cache = BTreeMap::new();
227 tree_overlay.state.removed = BTreeSet::new();
228 tree_overlay.checkpoint();
229 continue;
230 }
231
232 self.caches.remove(k);
234 continue;
235 }
236
237 tree_overlay.remove_diff(cache);
239 }
240 }
241}
242
243impl Default for SledDbOverlayState {
244 fn default() -> Self {
245 Self::new(vec![], vec![])
246 }
247}
248
249#[derive(Debug, Default, Clone, PartialEq)]
251pub struct SledDbOverlayStateDiff {
252 pub initial_tree_names: Vec<IVec>,
254 pub caches: BTreeMap<IVec, (SledTreeOverlayStateDiff, bool)>,
260 pub dropped_trees: BTreeMap<IVec, (SledTreeOverlayStateDiff, bool)>,
265}
266
267impl SledDbOverlayStateDiff {
268 pub fn new(state: &SledDbOverlayState) -> Result<Self, sled::Error> {
271 let mut caches = BTreeMap::new();
272 let mut dropped_trees = BTreeMap::new();
273
274 for (key, cache) in state.caches.iter() {
275 let diff = cache.diff(&[])?;
276
277 if diff.cache.is_empty()
279 && diff.removed.is_empty()
280 && !state.new_tree_names.contains(key)
281 {
282 continue;
283 }
284
285 caches.insert(key.clone(), (diff, false));
286 }
287
288 for (key, cache) in state.dropped_trees.iter() {
289 dropped_trees.insert(key.clone(), (cache.clone(), false));
290 }
291
292 Ok(Self {
293 initial_tree_names: state.initial_tree_names.clone(),
294 caches,
295 dropped_trees,
296 })
297 }
298
299 fn aggregate(
304 &self,
305 state_trees: &BTreeMap<IVec, sled::Tree>,
306 ) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
307 let mut trees = vec![];
308 let mut batches = vec![];
309
310 for (key, (cache, drop)) in self.caches.iter() {
311 if *drop {
312 continue;
313 }
314
315 let Some(tree) = state_trees.get(key) else {
316 return Err(sled::Error::CollectionNotFound(key.into()));
317 };
318
319 if let Some(batch) = cache.aggregate() {
320 trees.push(tree.clone());
321 batches.push(batch);
322 }
323 }
324
325 for (key, (cache, restored)) in self.dropped_trees.iter() {
326 if !restored {
327 continue;
328 }
329
330 let Some(tree) = state_trees.get(key) else {
331 return Err(sled::Error::CollectionNotFound(key.into()));
332 };
333
334 if let Some(batch) = cache.aggregate() {
335 trees.push(tree.clone());
336 batches.push(batch);
337 }
338 }
339
340 Ok((trees, batches))
341 }
342
343 pub fn inverse(&self) -> Self {
346 let mut diff = Self {
347 initial_tree_names: self.initial_tree_names.clone(),
348 ..Default::default()
349 };
350
351 for (key, (cache, drop)) in self.caches.iter() {
352 let inverse = cache.inverse();
353 let drop = if inverse.cache.is_empty()
356 && inverse.removed.is_empty()
357 && !self.initial_tree_names.contains(key)
358 {
359 !drop
360 } else {
361 inverse.cache.is_empty() && !self.initial_tree_names.contains(key)
362 };
363 diff.caches.insert(key.clone(), (inverse, drop));
364 }
365
366 for (key, (cache, restored)) in self.dropped_trees.iter() {
367 if !self.initial_tree_names.contains(key) {
368 continue;
369 }
370 diff.dropped_trees
371 .insert(key.clone(), (cache.clone(), !restored));
372 }
373
374 diff
375 }
376
377 pub fn remove_diff(&mut self, other: &Self) {
379 for initial_tree_name in &other.initial_tree_names {
383 assert!(self.initial_tree_names.contains(initial_tree_name));
384 }
385
386 for (key, cache_pair) in other.caches.iter() {
388 if !self.initial_tree_names.contains(key) {
389 self.initial_tree_names.push(key.clone());
390 }
391
392 let Some(tree_overlay) = self.caches.get_mut(key) else {
395 let Some((tree_overlay, _)) = self.dropped_trees.get_mut(key) else {
396 continue;
397 };
398 tree_overlay.update_values(&cache_pair.0);
399 continue;
400 };
401
402 if tree_overlay == cache_pair {
404 self.caches.remove(key);
406 continue;
407 }
408
409 tree_overlay.0.remove_diff(&cache_pair.0);
411 }
412
413 for (key, (cache, restored)) in other.dropped_trees.iter() {
416 if let Some(tree_overlay) = self.caches.get_mut(key) {
418 assert!(!self.dropped_trees.contains_key(key));
419 tree_overlay.0.remove_diff(cache);
421 continue;
422 }
423 assert!(self.dropped_trees.contains_key(key));
424
425 if *restored {
427 self.caches.insert(key.clone(), (cache.clone(), false));
428 }
429
430 self.initial_tree_names.retain(|x| x != key);
432 self.dropped_trees.remove(key);
433 }
434 }
435
436 pub fn new_trees(&self) -> Vec<IVec> {
438 let mut new_trees: Vec<IVec> = self.caches.keys().cloned().collect();
439 new_trees.retain(|tree| !self.initial_tree_names.contains(tree));
440 new_trees
441 }
442}
443
444#[derive(Clone)]
446pub struct SledDbOverlay {
447 db: sled::Db,
449 pub state: SledDbOverlayState,
451 checkpoint: SledDbOverlayState,
453}
454
455impl SledDbOverlay {
456 pub fn new(db: &sled::Db, protected_tree_names: Vec<&[u8]>) -> Self {
460 let initial_tree_names = db.tree_names();
461 let protected_tree_names: Vec<IVec> = protected_tree_names
462 .into_iter()
463 .map(|tree_name| tree_name.into())
464 .collect();
465 Self {
466 db: db.clone(),
467 state: SledDbOverlayState::new(
468 initial_tree_names.clone(),
469 protected_tree_names.clone(),
470 ),
471 checkpoint: SledDbOverlayState::new(initial_tree_names, protected_tree_names),
472 }
473 }
474
475 pub fn open_tree(&mut self, tree_name: &[u8], protected: bool) -> Result<(), sled::Error> {
482 let tree_key: IVec = tree_name.into();
483
484 if self.state.caches.contains_key(&tree_key) {
486 return Ok(());
487 }
488
489 let tree = self.db.open_tree(&tree_key)?;
491 let mut cache = SledTreeOverlay::new(&tree);
492
493 if let Some(diff) = self.state.dropped_trees.remove(&tree_key) {
495 cache.state = (&diff).into();
496 }
497
498 if !self.state.initial_tree_names.contains(&tree_key) {
501 self.state.new_tree_names.push(tree_key.clone());
502 }
503
504 self.state.caches.insert(tree_key.clone(), cache);
505
506 if protected && !self.state.protected_tree_names.contains(&tree_key) {
508 self.state.protected_tree_names.push(tree_key);
509 }
510
511 Ok(())
512 }
513
514 pub fn drop_tree(&mut self, tree_name: &[u8]) -> Result<(), sled::Error> {
516 let tree_key: IVec = tree_name.into();
517
518 if self.state.protected_tree_names.contains(&tree_key) {
520 return Err(sled::Error::Unsupported(
521 "Protected tree can't be dropped".to_string(),
522 ));
523 }
524
525 if self.state.dropped_trees.contains_key(&tree_key) {
527 return Err(sled::Error::CollectionNotFound(tree_key));
528 }
529
530 if self.state.new_tree_names.contains(&tree_key) {
532 self.state.new_tree_names.retain(|x| *x != tree_key);
533 let tree = match self.get_cache(&tree_key) {
534 Ok(cache) => &cache.tree,
535 _ => &self.db.open_tree(&tree_key)?,
536 };
537 let diff = SledTreeOverlayStateDiff::new_dropped(tree);
538 self.state.caches.remove(&tree_key);
539 self.state.dropped_trees.insert(tree_key, diff);
540
541 return Ok(());
542 }
543
544 if !self.state.initial_tree_names.contains(&tree_key) {
546 return Err(sled::Error::CollectionNotFound(tree_key));
547 }
548
549 let tree = match self.get_cache(&tree_key) {
550 Ok(cache) => &cache.tree,
551 _ => &self.db.open_tree(&tree_key)?,
552 };
553 let diff = SledTreeOverlayStateDiff::new_dropped(tree);
554 self.state.caches.remove(&tree_key);
555 self.state.dropped_trees.insert(tree_key, diff);
556
557 Ok(())
558 }
559
560 pub fn purge_new_trees(&self) -> Result<(), sled::Error> {
564 for i in &self.state.new_tree_names {
565 self.db.drop_tree(i)?;
566 }
567
568 Ok(())
569 }
570
571 fn get_cache(&self, tree_key: &IVec) -> Result<&SledTreeOverlay, sled::Error> {
573 if self.state.dropped_trees.contains_key(tree_key) {
574 return Err(sled::Error::CollectionNotFound(tree_key.into()));
575 }
576
577 if let Some(v) = self.state.caches.get(tree_key) {
578 return Ok(v);
579 }
580
581 Err(sled::Error::CollectionNotFound(tree_key.into()))
582 }
583
584 fn get_cache_mut(&mut self, tree_key: &IVec) -> Result<&mut SledTreeOverlay, sled::Error> {
586 if self.state.dropped_trees.contains_key(tree_key) {
587 return Err(sled::Error::CollectionNotFound(tree_key.into()));
588 }
589
590 if let Some(v) = self.state.caches.get_mut(tree_key) {
591 return Ok(v);
592 }
593 Err(sled::Error::CollectionNotFound(tree_key.clone()))
594 }
595
596 pub fn get_state_trees(&self) -> BTreeMap<IVec, sled::Tree> {
598 let mut state_trees = BTreeMap::new();
600 for (key, cache) in self.state.caches.iter() {
601 state_trees.insert(key.clone(), cache.tree.clone());
602 }
603
604 state_trees
605 }
606
607 pub fn contains_key(&self, tree_key: &[u8], key: &[u8]) -> Result<bool, sled::Error> {
610 let cache = self.get_cache(&tree_key.into())?;
611 cache.contains_key(key)
612 }
613
614 pub fn get(&self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
616 let cache = self.get_cache(&tree_key.into())?;
617 cache.get(key)
618 }
619
620 pub fn is_empty(&self, tree_key: &[u8]) -> Result<bool, sled::Error> {
622 let cache = self.get_cache(&tree_key.into())?;
623 cache.is_empty()
624 }
625
626 pub fn last(&self, tree_key: &[u8]) -> Result<Option<(IVec, IVec)>, sled::Error> {
628 let cache = self.get_cache(&tree_key.into())?;
629 cache.last()
630 }
631
632 pub fn insert(
635 &mut self,
636 tree_key: &[u8],
637 key: &[u8],
638 value: &[u8],
639 ) -> Result<Option<IVec>, sled::Error> {
640 let cache = self.get_cache_mut(&tree_key.into())?;
641 cache.insert(key, value)
642 }
643
644 pub fn remove(&mut self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
646 let cache = self.get_cache_mut(&tree_key.into())?;
647 cache.remove(key)
648 }
649
650 pub fn clear(&mut self, tree_key: &[u8]) -> Result<(), sled::Error> {
653 let cache = self.get_cache_mut(&tree_key.into())?;
654 cache.clear()
655 }
656
657 fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
661 self.state.aggregate()
662 }
663
664 pub fn apply(&mut self) -> Result<(), TransactionError<sled::Error>> {
671 let new_tree_names = self.state.new_tree_names.clone();
673 for tree_key in &new_tree_names {
674 let tree = self.db.open_tree(tree_key)?;
675 let cache = self.get_cache_mut(tree_key)?;
677 cache.tree = tree;
678 }
679
680 for tree in self.state.dropped_trees.keys() {
682 self.db.drop_tree(tree)?;
683 }
684
685 let (trees, batches) = self.aggregate()?;
687 if trees.is_empty() {
688 return Ok(());
689 }
690
691 trees.transaction(|trees| {
694 for (index, tree) in trees.iter().enumerate() {
695 tree.apply_batch(&batches[index])?;
696 }
697
698 Ok::<(), ConflictableTransactionError<sled::Error>>(())
699 })?;
700
701 Ok(())
702 }
703
704 pub fn checkpoint(&mut self) {
706 self.checkpoint = self.state.clone();
707 }
708
709 pub fn revert_to_checkpoint(&mut self) {
712 self.state = self.checkpoint.clone();
713 }
714
715 pub fn diff(
721 &self,
722 sequence: &[SledDbOverlayStateDiff],
723 ) -> Result<SledDbOverlayStateDiff, sled::Error> {
724 let mut current = SledDbOverlayStateDiff::new(&self.state)?;
726
727 for diff in sequence {
729 current.remove_diff(diff);
730 }
731
732 Ok(current)
733 }
734
735 pub fn add_diff(&mut self, diff: &SledDbOverlayStateDiff) -> Result<(), sled::Error> {
737 self.state.add_diff(&self.db, diff)
738 }
739
740 pub fn remove_diff(&mut self, diff: &SledDbOverlayStateDiff) {
742 self.state.remove_diff(diff)
743 }
744
745 pub fn apply_diff(
753 &mut self,
754 diff: &SledDbOverlayStateDiff,
755 ) -> Result<(), TransactionError<sled::Error>> {
756 for tree in diff.dropped_trees.keys() {
758 if self.state.protected_tree_names.contains(tree) {
759 return Err(TransactionError::Storage(sled::Error::Unsupported(
760 "Protected tree can't be dropped".to_string(),
761 )));
762 }
763 }
764 for (tree_key, (_, drop)) in diff.caches.iter() {
765 if *drop && self.state.protected_tree_names.contains(tree_key) {
766 return Err(TransactionError::Storage(sled::Error::Unsupported(
767 "Protected tree can't be dropped".to_string(),
768 )));
769 }
770 }
771
772 let mut state_trees = self.get_state_trees();
774
775 for (tree_key, (_, drop)) in diff.caches.iter() {
777 if !self.state.initial_tree_names.contains(tree_key)
779 && !self.state.new_tree_names.contains(tree_key)
780 {
781 self.state.new_tree_names.push(tree_key.clone());
782 }
783
784 if *drop {
786 self.db.drop_tree(tree_key)?;
787 continue;
788 }
789
790 if !state_trees.contains_key(tree_key) {
791 let tree = self.db.open_tree(tree_key)?;
792 state_trees.insert(tree_key.clone(), tree);
793 }
794 }
795
796 for (tree_key, (_, restored)) in diff.dropped_trees.iter() {
798 if !restored {
799 state_trees.remove(tree_key);
800 self.db.drop_tree(tree_key)?;
801 continue;
802 }
803
804 if !self.state.initial_tree_names.contains(tree_key)
806 && !self.state.new_tree_names.contains(tree_key)
807 {
808 self.state.new_tree_names.push(tree_key.clone());
809 }
810
811 if !state_trees.contains_key(tree_key) {
812 let tree = self.db.open_tree(tree_key)?;
813 state_trees.insert(tree_key.clone(), tree);
814 }
815 }
816
817 let (trees, batches) = diff.aggregate(&state_trees)?;
819 if trees.is_empty() {
820 self.remove_diff(diff);
821 return Ok(());
822 }
823
824 trees.transaction(|trees| {
827 for (index, tree) in trees.iter().enumerate() {
828 tree.apply_batch(&batches[index])?;
829 }
830
831 Ok::<(), ConflictableTransactionError<sled::Error>>(())
832 })?;
833
834 self.remove_diff(diff);
836
837 Ok(())
838 }
839
840 pub fn iter(&self, tree_key: &[u8]) -> Result<SledTreeOverlayIter<'_>, sled::Error> {
842 let cache = self.get_cache(&tree_key.into())?;
843 Ok(cache.iter())
844 }
845}