1use std::collections::{BTreeMap, BTreeSet};
20
21use sled::{
22 transaction::{ConflictableTransactionError, TransactionError},
23 IVec, Transactional,
24};
25
26use crate::{SledTreeOverlay, SledTreeOverlayIter, SledTreeOverlayStateDiff};
27
28#[derive(Debug, Clone)]
30pub struct SledDbOverlayState {
31 pub initial_tree_names: Vec<IVec>,
33 pub new_tree_names: Vec<IVec>,
35 pub caches: BTreeMap<IVec, SledTreeOverlay>,
37 pub dropped_trees: BTreeMap<IVec, SledTreeOverlayStateDiff>,
39 pub protected_tree_names: Vec<IVec>,
42}
43
44impl SledDbOverlayState {
45 pub fn new(initial_tree_names: Vec<IVec>, protected_tree_names: Vec<IVec>) -> Self {
47 Self {
48 initial_tree_names,
49 new_tree_names: vec![],
50 caches: BTreeMap::new(),
51 dropped_trees: BTreeMap::new(),
52 protected_tree_names,
53 }
54 }
55
56 fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
60 let mut trees = vec![];
61 let mut batches = vec![];
62
63 for (key, cache) in self.caches.iter() {
64 if self.dropped_trees.contains_key(key) {
65 return Err(sled::Error::CollectionNotFound(key.into()));
66 }
67
68 if let Some(batch) = cache.aggregate() {
69 trees.push(cache.tree.clone());
70 batches.push(batch);
71 }
72 }
73
74 Ok((trees, batches))
75 }
76
77 pub fn add_diff(
79 &mut self,
80 db: &sled::Db,
81 diff: &SledDbOverlayStateDiff,
82 ) -> Result<(), sled::Error> {
83 self.initial_tree_names
84 .retain(|x| diff.initial_tree_names.contains(x));
85
86 for (k, (cache, drop)) in diff.caches.iter() {
87 if *drop {
88 assert!(!self.protected_tree_names.contains(k));
89 self.new_tree_names.retain(|x| x != k);
90 self.caches.remove(k);
91 self.dropped_trees.insert(k.clone(), cache.clone());
92 continue;
93 }
94
95 let Some(tree_overlay) = self.caches.get_mut(k) else {
96 if !self.initial_tree_names.contains(k) && !self.new_tree_names.contains(k) {
97 self.new_tree_names.push(k.clone());
98 }
99 let mut overlay = SledTreeOverlay::new(&db.open_tree(k)?);
100 overlay.add_diff(cache);
101 self.caches.insert(k.clone(), overlay);
102 continue;
103 };
104
105 tree_overlay.add_diff(cache);
107 }
108
109 for (k, (cache, restored)) in &diff.dropped_trees {
110 if !restored {
112 if self.dropped_trees.contains_key(k) {
113 continue;
114 }
115 self.new_tree_names.retain(|x| x != k);
116 self.caches.remove(k);
117 self.dropped_trees.insert(k.clone(), cache.clone());
118 continue;
119 }
120 assert!(!self.protected_tree_names.contains(k));
121
122 self.initial_tree_names.retain(|x| x != k);
124 if !self.new_tree_names.contains(k) {
125 self.new_tree_names.push(k.clone());
126 }
127
128 let mut overlay = SledTreeOverlay::new(&db.open_tree(k)?);
129 overlay.add_diff(cache);
130 self.caches.insert(k.clone(), overlay);
131 }
132
133 Ok(())
134 }
135
136 pub fn remove_diff(&mut self, diff: &SledDbOverlayStateDiff) {
138 for (k, (cache, drop)) in diff.caches.iter() {
142 assert!(
144 self.initial_tree_names.contains(k)
145 || self.new_tree_names.contains(k)
146 || self.dropped_trees.contains_key(k)
147 );
148 if !self.initial_tree_names.contains(k) {
149 self.initial_tree_names.push(k.clone());
150 }
151 self.new_tree_names.retain(|x| x != k);
152
153 if *drop {
155 assert!(!self.protected_tree_names.contains(k));
156 self.initial_tree_names.retain(|x| x != k);
157 self.new_tree_names.retain(|x| x != k);
158 self.caches.remove(k);
159 self.dropped_trees.remove(k);
160 continue;
161 }
162
163 let Some(tree_overlay) = self.caches.get_mut(k) else {
166 let Some(tree_overlay) = self.dropped_trees.get_mut(k) else {
167 continue;
168 };
169 tree_overlay.update_values(cache);
170 continue;
171 };
172
173 if tree_overlay.state == cache.into() {
175 if self.protected_tree_names.contains(k) {
177 tree_overlay.state.cache = BTreeMap::new();
178 tree_overlay.state.removed = BTreeSet::new();
179 tree_overlay.checkpoint();
180 continue;
181 }
182
183 self.caches.remove(k);
185 continue;
186 }
187
188 tree_overlay.remove_diff(cache);
190 }
191
192 for (k, (cache, restored)) in diff.dropped_trees.iter() {
194 assert!(
196 self.initial_tree_names.contains(k)
197 || self.new_tree_names.contains(k)
198 || self.dropped_trees.contains_key(k)
199 );
200
201 if !restored {
203 assert!(!self.protected_tree_names.contains(k));
204 self.initial_tree_names.retain(|x| x != k);
205 self.new_tree_names.retain(|x| x != k);
206 self.caches.remove(k);
207 self.dropped_trees.remove(k);
208 continue;
209 }
210
211 self.initial_tree_names.retain(|x| x != k);
213 if !self.new_tree_names.contains(k) {
214 self.new_tree_names.push(k.clone());
215 }
216
217 let Some(tree_overlay) = self.caches.get_mut(k) else {
219 continue;
220 };
221
222 if tree_overlay.state == cache.into() {
224 if self.protected_tree_names.contains(k) {
226 tree_overlay.state.cache = BTreeMap::new();
227 tree_overlay.state.removed = BTreeSet::new();
228 tree_overlay.checkpoint();
229 continue;
230 }
231
232 self.caches.remove(k);
234 continue;
235 }
236
237 tree_overlay.remove_diff(cache);
239 }
240 }
241}
242
243impl Default for SledDbOverlayState {
244 fn default() -> Self {
245 Self::new(vec![], vec![])
246 }
247}
248
249#[derive(Debug, Default, Clone, PartialEq)]
251pub struct SledDbOverlayStateDiff {
252 pub initial_tree_names: Vec<IVec>,
254 pub caches: BTreeMap<IVec, (SledTreeOverlayStateDiff, bool)>,
260 pub dropped_trees: BTreeMap<IVec, (SledTreeOverlayStateDiff, bool)>,
265}
266
267impl SledDbOverlayStateDiff {
268 pub fn new(state: &SledDbOverlayState) -> Result<Self, sled::Error> {
272 let mut caches = BTreeMap::new();
273 let mut dropped_trees = BTreeMap::new();
274
275 for (key, cache) in state.caches.iter() {
276 let mut diff = cache.diff(&[])?;
277
278 if diff.cache.is_empty()
280 && diff.removed.is_empty()
281 && !state.new_tree_names.contains(key)
282 {
283 continue;
284 }
285
286 if state.new_tree_names.contains(key) {
288 diff.removed = BTreeMap::new();
289 }
290
291 caches.insert(key.clone(), (diff, false));
292 }
293
294 for (key, cache) in state.dropped_trees.iter() {
295 dropped_trees.insert(key.clone(), (cache.clone(), false));
296 }
297
298 Ok(Self {
299 initial_tree_names: state.initial_tree_names.clone(),
300 caches,
301 dropped_trees,
302 })
303 }
304
305 fn aggregate(
310 &self,
311 state_trees: &BTreeMap<IVec, sled::Tree>,
312 ) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
313 let mut trees = vec![];
314 let mut batches = vec![];
315
316 for (key, (cache, drop)) in self.caches.iter() {
317 if *drop {
318 continue;
319 }
320
321 let Some(tree) = state_trees.get(key) else {
322 return Err(sled::Error::CollectionNotFound(key.into()));
323 };
324
325 if let Some(batch) = cache.aggregate() {
326 trees.push(tree.clone());
327 batches.push(batch);
328 }
329 }
330
331 for (key, (cache, restored)) in self.dropped_trees.iter() {
332 if !restored {
333 continue;
334 }
335
336 let Some(tree) = state_trees.get(key) else {
337 return Err(sled::Error::CollectionNotFound(key.into()));
338 };
339
340 if let Some(batch) = cache.aggregate() {
341 trees.push(tree.clone());
342 batches.push(batch);
343 }
344 }
345
346 Ok((trees, batches))
347 }
348
349 pub fn inverse(&self) -> Self {
352 let mut diff = Self {
353 initial_tree_names: self.initial_tree_names.clone(),
354 ..Default::default()
355 };
356
357 for (key, (cache, drop)) in self.caches.iter() {
358 let inverse = cache.inverse();
359 let drop = if inverse.cache.is_empty()
362 && inverse.removed.is_empty()
363 && !self.initial_tree_names.contains(key)
364 {
365 !drop
366 } else {
367 inverse.cache.is_empty() && !self.initial_tree_names.contains(key)
368 };
369 diff.caches.insert(key.clone(), (inverse, drop));
370 }
371
372 for (key, (cache, restored)) in self.dropped_trees.iter() {
373 if !self.initial_tree_names.contains(key) {
374 continue;
375 }
376 diff.dropped_trees
377 .insert(key.clone(), (cache.clone(), !restored));
378 }
379
380 diff
381 }
382
383 pub fn remove_diff(&mut self, other: &Self) {
385 for initial_tree_name in &other.initial_tree_names {
389 assert!(self.initial_tree_names.contains(initial_tree_name));
390 }
391
392 for (key, cache_pair) in other.caches.iter() {
394 if !self.initial_tree_names.contains(key) {
395 self.initial_tree_names.push(key.clone());
396 }
397
398 let Some(tree_overlay) = self.caches.get_mut(key) else {
401 let Some((tree_overlay, _)) = self.dropped_trees.get_mut(key) else {
402 continue;
403 };
404 tree_overlay.update_values(&cache_pair.0);
405 continue;
406 };
407
408 if tree_overlay == cache_pair {
410 self.caches.remove(key);
412 continue;
413 }
414
415 tree_overlay.0.remove_diff(&cache_pair.0);
417 }
418
419 for (key, (cache, restored)) in other.dropped_trees.iter() {
422 if let Some(tree_overlay) = self.caches.get_mut(key) {
424 assert!(!self.dropped_trees.contains_key(key));
425 tree_overlay.0.remove_diff(cache);
427 continue;
428 }
429 assert!(self.dropped_trees.contains_key(key));
430
431 if *restored {
433 self.caches.insert(key.clone(), (cache.clone(), false));
434 }
435
436 self.initial_tree_names.retain(|x| x != key);
438 self.dropped_trees.remove(key);
439 }
440 }
441}
442
443#[derive(Clone)]
445pub struct SledDbOverlay {
446 db: sled::Db,
448 pub state: SledDbOverlayState,
450 checkpoint: SledDbOverlayState,
452}
453
454impl SledDbOverlay {
455 pub fn new(db: &sled::Db, protected_tree_names: Vec<&[u8]>) -> Self {
459 let initial_tree_names = db.tree_names();
460 let protected_tree_names: Vec<IVec> = protected_tree_names
461 .into_iter()
462 .map(|tree_name| tree_name.into())
463 .collect();
464 Self {
465 db: db.clone(),
466 state: SledDbOverlayState::new(
467 initial_tree_names.clone(),
468 protected_tree_names.clone(),
469 ),
470 checkpoint: SledDbOverlayState::new(initial_tree_names, protected_tree_names),
471 }
472 }
473
474 pub fn open_tree(&mut self, tree_name: &[u8], protected: bool) -> Result<(), sled::Error> {
481 let tree_key: IVec = tree_name.into();
482
483 if self.state.caches.contains_key(&tree_key) {
485 return Ok(());
486 }
487
488 let tree = self.db.open_tree(&tree_key)?;
490 let mut cache = SledTreeOverlay::new(&tree);
491
492 if let Some(diff) = self.state.dropped_trees.remove(&tree_key) {
494 cache.state = (&diff).into();
495 }
496
497 if !self.state.initial_tree_names.contains(&tree_key) {
500 self.state.new_tree_names.push(tree_key.clone());
501 }
502
503 self.state.caches.insert(tree_key.clone(), cache);
504
505 if protected && !self.state.protected_tree_names.contains(&tree_key) {
507 self.state.protected_tree_names.push(tree_key);
508 }
509
510 Ok(())
511 }
512
513 pub fn drop_tree(&mut self, tree_name: &[u8]) -> Result<(), sled::Error> {
515 let tree_key: IVec = tree_name.into();
516
517 if self.state.protected_tree_names.contains(&tree_key) {
519 return Err(sled::Error::Unsupported(
520 "Protected tree can't be dropped".to_string(),
521 ));
522 }
523
524 if self.state.dropped_trees.contains_key(&tree_key) {
526 return Err(sled::Error::CollectionNotFound(tree_key));
527 }
528
529 if self.state.new_tree_names.contains(&tree_key) {
531 self.state.new_tree_names.retain(|x| *x != tree_key);
532 let tree = match self.get_cache(&tree_key) {
533 Ok(cache) => &cache.tree,
534 _ => &self.db.open_tree(&tree_key)?,
535 };
536 let diff = SledTreeOverlayStateDiff::new_dropped(tree);
537 self.state.caches.remove(&tree_key);
538 self.state.dropped_trees.insert(tree_key, diff);
539
540 return Ok(());
541 }
542
543 if !self.state.initial_tree_names.contains(&tree_key) {
545 return Err(sled::Error::CollectionNotFound(tree_key));
546 }
547
548 let tree = match self.get_cache(&tree_key) {
549 Ok(cache) => &cache.tree,
550 _ => &self.db.open_tree(&tree_key)?,
551 };
552 let diff = SledTreeOverlayStateDiff::new_dropped(tree);
553 self.state.caches.remove(&tree_key);
554 self.state.dropped_trees.insert(tree_key, diff);
555
556 Ok(())
557 }
558
559 pub fn purge_new_trees(&self) -> Result<(), sled::Error> {
563 for i in &self.state.new_tree_names {
564 self.db.drop_tree(i)?;
565 }
566
567 Ok(())
568 }
569
570 fn get_cache(&self, tree_key: &IVec) -> Result<&SledTreeOverlay, sled::Error> {
572 if self.state.dropped_trees.contains_key(tree_key) {
573 return Err(sled::Error::CollectionNotFound(tree_key.into()));
574 }
575
576 if let Some(v) = self.state.caches.get(tree_key) {
577 return Ok(v);
578 }
579
580 Err(sled::Error::CollectionNotFound(tree_key.into()))
581 }
582
583 fn get_cache_mut(&mut self, tree_key: &IVec) -> Result<&mut SledTreeOverlay, sled::Error> {
585 if self.state.dropped_trees.contains_key(tree_key) {
586 return Err(sled::Error::CollectionNotFound(tree_key.into()));
587 }
588
589 if let Some(v) = self.state.caches.get_mut(tree_key) {
590 return Ok(v);
591 }
592 Err(sled::Error::CollectionNotFound(tree_key.clone()))
593 }
594
595 pub fn get_state_trees(&self) -> BTreeMap<IVec, sled::Tree> {
597 let mut state_trees = BTreeMap::new();
599 for (key, cache) in self.state.caches.iter() {
600 state_trees.insert(key.clone(), cache.tree.clone());
601 }
602
603 state_trees
604 }
605
606 pub fn contains_key(&self, tree_key: &[u8], key: &[u8]) -> Result<bool, sled::Error> {
609 let cache = self.get_cache(&tree_key.into())?;
610 cache.contains_key(key)
611 }
612
613 pub fn get(&self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
615 let cache = self.get_cache(&tree_key.into())?;
616 cache.get(key)
617 }
618
619 pub fn is_empty(&self, tree_key: &[u8]) -> Result<bool, sled::Error> {
621 let cache = self.get_cache(&tree_key.into())?;
622 Ok(cache.is_empty())
623 }
624
625 pub fn last(&self, tree_key: &[u8]) -> Result<Option<(IVec, IVec)>, sled::Error> {
627 let cache = self.get_cache(&tree_key.into())?;
628 cache.last()
629 }
630
631 pub fn insert(
634 &mut self,
635 tree_key: &[u8],
636 key: &[u8],
637 value: &[u8],
638 ) -> Result<Option<IVec>, sled::Error> {
639 let cache = self.get_cache_mut(&tree_key.into())?;
640 cache.insert(key, value)
641 }
642
643 pub fn remove(&mut self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
645 let cache = self.get_cache_mut(&tree_key.into())?;
646 cache.remove(key)
647 }
648
649 pub fn clear(&mut self, tree_key: &[u8]) -> Result<(), sled::Error> {
652 let cache = self.get_cache_mut(&tree_key.into())?;
653 cache.clear()
654 }
655
656 fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
660 self.state.aggregate()
661 }
662
663 pub fn apply(&mut self) -> Result<(), TransactionError<sled::Error>> {
670 let new_tree_names = self.state.new_tree_names.clone();
672 for tree_key in &new_tree_names {
673 let tree = self.db.open_tree(tree_key)?;
674 let cache = self.get_cache_mut(tree_key)?;
676 cache.tree = tree;
677 }
678
679 for tree in self.state.dropped_trees.keys() {
681 self.db.drop_tree(tree)?;
682 }
683
684 let (trees, batches) = self.aggregate()?;
686 if trees.is_empty() {
687 return Ok(());
688 }
689
690 trees.transaction(|trees| {
693 for (index, tree) in trees.iter().enumerate() {
694 tree.apply_batch(&batches[index])?;
695 }
696
697 Ok::<(), ConflictableTransactionError<sled::Error>>(())
698 })?;
699
700 Ok(())
701 }
702
703 pub fn checkpoint(&mut self) {
705 self.checkpoint = self.state.clone();
706 }
707
708 pub fn revert_to_checkpoint(&mut self) -> Result<(), sled::Error> {
710 let new_trees: Vec<_> = self
712 .state
713 .new_tree_names
714 .iter()
715 .filter(|tree| !self.checkpoint.new_tree_names.contains(tree))
716 .collect();
717 for tree in &new_trees {
718 self.db.drop_tree(tree)?;
719 }
720
721 self.state = self.checkpoint.clone();
722
723 Ok(())
724 }
725
726 pub fn diff(
732 &self,
733 sequence: &[SledDbOverlayStateDiff],
734 ) -> Result<SledDbOverlayStateDiff, sled::Error> {
735 let mut current = SledDbOverlayStateDiff::new(&self.state)?;
737
738 for diff in sequence {
740 current.remove_diff(diff);
741 }
742
743 Ok(current)
744 }
745
746 pub fn add_diff(&mut self, diff: &SledDbOverlayStateDiff) -> Result<(), sled::Error> {
748 self.state.add_diff(&self.db, diff)
749 }
750
751 pub fn remove_diff(&mut self, diff: &SledDbOverlayStateDiff) {
753 self.state.remove_diff(diff)
754 }
755
756 pub fn apply_diff(
764 &mut self,
765 diff: &SledDbOverlayStateDiff,
766 ) -> Result<(), TransactionError<sled::Error>> {
767 for tree in diff.dropped_trees.keys() {
769 if self.state.protected_tree_names.contains(tree) {
770 return Err(TransactionError::Storage(sled::Error::Unsupported(
771 "Protected tree can't be dropped".to_string(),
772 )));
773 }
774 }
775 for (tree_key, (_, drop)) in diff.caches.iter() {
776 if *drop && self.state.protected_tree_names.contains(tree_key) {
777 return Err(TransactionError::Storage(sled::Error::Unsupported(
778 "Protected tree can't be dropped".to_string(),
779 )));
780 }
781 }
782
783 let mut state_trees = self.get_state_trees();
785
786 for (tree_key, (_, drop)) in diff.caches.iter() {
788 if !self.state.initial_tree_names.contains(tree_key)
790 && !self.state.new_tree_names.contains(tree_key)
791 {
792 self.state.new_tree_names.push(tree_key.clone());
793 }
794
795 if *drop {
797 self.db.drop_tree(tree_key)?;
798 continue;
799 }
800
801 if !state_trees.contains_key(tree_key) {
802 let tree = self.db.open_tree(tree_key)?;
803 state_trees.insert(tree_key.clone(), tree);
804 }
805 }
806
807 for (tree_key, (_, restored)) in diff.dropped_trees.iter() {
809 if !restored {
810 state_trees.remove(tree_key);
811 self.db.drop_tree(tree_key)?;
812 continue;
813 }
814
815 if !self.state.initial_tree_names.contains(tree_key)
817 && !self.state.new_tree_names.contains(tree_key)
818 {
819 self.state.new_tree_names.push(tree_key.clone());
820 }
821
822 if !state_trees.contains_key(tree_key) {
823 let tree = self.db.open_tree(tree_key)?;
824 state_trees.insert(tree_key.clone(), tree);
825 }
826 }
827
828 let (trees, batches) = diff.aggregate(&state_trees)?;
830 if trees.is_empty() {
831 self.remove_diff(diff);
832 return Ok(());
833 }
834
835 trees.transaction(|trees| {
838 for (index, tree) in trees.iter().enumerate() {
839 tree.apply_batch(&batches[index])?;
840 }
841
842 Ok::<(), ConflictableTransactionError<sled::Error>>(())
843 })?;
844
845 self.remove_diff(diff);
847
848 Ok(())
849 }
850
851 pub fn iter(&self, tree_key: &[u8]) -> Result<SledTreeOverlayIter<'_>, sled::Error> {
853 let cache = self.get_cache(&tree_key.into())?;
854 Ok(cache.iter())
855 }
856}