Skip to main content

cranpose_core/
snapshot_id_set.rs

1/// An optimized bit-set implementation for tracking snapshot IDs.
2///
3/// This is based on Jetpack Compose's SnapshotIdSet, optimized for:
4/// - O(1) access for the most recent 128 snapshot IDs
5/// - O(log N) access for older snapshots
6/// - Immutable copy-on-write semantics
7///
8/// The set maintains:
9/// - `lower_set`: 64 bits for IDs in range [lower_bound, lower_bound+63]
10/// - `upper_set`: 64 bits for IDs in range [lower_bound+64, lower_bound+127]
11/// - `below_bound`: sorted array for IDs below lower_bound
12///
13/// This structure is highly biased toward recent snapshots being set,
14/// with older snapshots mostly or completely clear.
15use std::fmt;
16
17pub type SnapshotId = usize;
18
19const BITS_PER_SET: usize = 64;
20const SNAPSHOT_ID_SIZE: usize = 64;
21
22#[derive(Clone, PartialEq, Eq)]
23pub struct SnapshotIdSet {
24    /// Bit set from (lower_bound + 64) to (lower_bound + 127)
25    upper_set: u64,
26    /// Bit set from lower_bound to (lower_bound + 63)
27    lower_set: u64,
28    /// Lower bound of the bit set. All values above lower_bound+127 are clear.
29    lower_bound: SnapshotId,
30    /// Sorted array of snapshot IDs below lower_bound
31    below_bound: Option<Box<[SnapshotId]>>,
32}
33
34impl SnapshotIdSet {
35    /// Empty snapshot ID set.
36    pub const EMPTY: SnapshotIdSet = SnapshotIdSet {
37        upper_set: 0,
38        lower_set: 0,
39        lower_bound: 0,
40        below_bound: None,
41    };
42
43    /// Create a new empty snapshot ID set.
44    pub fn new() -> Self {
45        Self::EMPTY
46    }
47
48    /// Check if a snapshot ID is in the set.
49    pub fn get(&self, id: SnapshotId) -> bool {
50        let offset = id.wrapping_sub(self.lower_bound);
51
52        if offset < BITS_PER_SET {
53            // In lower_set range
54            let mask = 1u64 << offset;
55            (self.lower_set & mask) != 0
56        } else if offset < BITS_PER_SET * 2 {
57            // In upper_set range
58            let mask = 1u64 << (offset - BITS_PER_SET);
59            (self.upper_set & mask) != 0
60        } else if id > self.lower_bound {
61            // Above our tracked range
62            false
63        } else {
64            // Below lower_bound, check the array
65            self.below_bound
66                .as_ref()
67                .map(|arr| arr.binary_search(&id).is_ok())
68                .unwrap_or(false)
69        }
70    }
71
72    /// Add a snapshot ID to the set (returns a new set if modified).
73    pub fn set(&self, id: SnapshotId) -> Self {
74        if id < self.lower_bound {
75            if let Some(ref arr) = self.below_bound {
76                match arr.binary_search(&id) {
77                    Ok(_) => {
78                        // Already present
79                        return self.clone();
80                    }
81                    Err(insert_pos) => {
82                        // Insert at position
83                        let mut new_arr = Vec::with_capacity(arr.len() + 1);
84                        new_arr.extend_from_slice(&arr[..insert_pos]);
85                        new_arr.push(id);
86                        new_arr.extend_from_slice(&arr[insert_pos..]);
87                        return Self {
88                            upper_set: self.upper_set,
89                            lower_set: self.lower_set,
90                            lower_bound: self.lower_bound,
91                            below_bound: Some(new_arr.into_boxed_slice()),
92                        };
93                    }
94                }
95            } else {
96                // First element below bound
97                return Self {
98                    upper_set: self.upper_set,
99                    lower_set: self.lower_set,
100                    lower_bound: self.lower_bound,
101                    below_bound: Some(vec![id].into_boxed_slice()),
102                };
103            }
104        }
105
106        let offset = id - self.lower_bound;
107
108        if offset < BITS_PER_SET {
109            // In lower_set range
110            let mask = 1u64 << offset;
111            if (self.lower_set & mask) == 0 {
112                return Self {
113                    upper_set: self.upper_set,
114                    lower_set: self.lower_set | mask,
115                    lower_bound: self.lower_bound,
116                    below_bound: self.below_bound.clone(),
117                };
118            }
119        } else if offset < BITS_PER_SET * 2 {
120            // In upper_set range
121            let mask = 1u64 << (offset - BITS_PER_SET);
122            if (self.upper_set & mask) == 0 {
123                return Self {
124                    upper_set: self.upper_set | mask,
125                    lower_set: self.lower_set,
126                    lower_bound: self.lower_bound,
127                    below_bound: self.below_bound.clone(),
128                };
129            }
130        } else if offset >= BITS_PER_SET * 2 {
131            // Need to shift the bit arrays
132            if !self.get(id) {
133                return self.shift_and_set(id);
134            }
135        }
136
137        // No change needed
138        self.clone()
139    }
140
141    /// Remove a snapshot ID from the set (returns a new set if modified).
142    pub fn clear(&self, id: SnapshotId) -> Self {
143        let offset = id.wrapping_sub(self.lower_bound);
144
145        if offset < BITS_PER_SET {
146            // In lower_set range
147            let mask = 1u64 << offset;
148            if (self.lower_set & mask) != 0 {
149                return Self {
150                    upper_set: self.upper_set,
151                    lower_set: self.lower_set & !mask,
152                    lower_bound: self.lower_bound,
153                    below_bound: self.below_bound.clone(),
154                };
155            }
156        } else if offset < BITS_PER_SET * 2 {
157            // In upper_set range
158            let mask = 1u64 << (offset - BITS_PER_SET);
159            if (self.upper_set & mask) != 0 {
160                return Self {
161                    upper_set: self.upper_set & !mask,
162                    lower_set: self.lower_set,
163                    lower_bound: self.lower_bound,
164                    below_bound: self.below_bound.clone(),
165                };
166            }
167        } else if id < self.lower_bound {
168            // Below lower_bound
169            if let Some(ref arr) = self.below_bound {
170                if let Ok(pos) = arr.binary_search(&id) {
171                    let mut new_arr = Vec::with_capacity(arr.len() - 1);
172                    new_arr.extend_from_slice(&arr[..pos]);
173                    new_arr.extend_from_slice(&arr[pos + 1..]);
174                    return Self {
175                        upper_set: self.upper_set,
176                        lower_set: self.lower_set,
177                        lower_bound: self.lower_bound,
178                        below_bound: if new_arr.is_empty() {
179                            None
180                        } else {
181                            Some(new_arr.into_boxed_slice())
182                        },
183                    };
184                }
185            }
186        }
187
188        // No change needed
189        self.clone()
190    }
191
192    /// Remove all IDs in `other` from this set (a & ~b).
193    pub fn and_not(&self, other: &Self) -> Self {
194        if other.is_empty() {
195            return self.clone();
196        }
197        if self.is_empty() {
198            return Self::EMPTY;
199        }
200
201        // Fast path: if both have same lower_bound and below_bound, can do bitwise ops
202        if self.lower_bound == other.lower_bound && self.below_bound_equals(&other.below_bound) {
203            return Self {
204                upper_set: self.upper_set & !other.upper_set,
205                lower_set: self.lower_set & !other.lower_set,
206                lower_bound: self.lower_bound,
207                below_bound: self.below_bound.clone(),
208            };
209        }
210
211        // Slow path: iterate and clear each ID
212        let mut result = self.clone();
213        for id in other.iter() {
214            result = result.clear(id);
215        }
216        result
217    }
218
219    /// Union this set with another (a | b).
220    pub fn or(&self, other: &Self) -> Self {
221        if other.is_empty() {
222            return self.clone();
223        }
224        if self.is_empty() {
225            return other.clone();
226        }
227
228        // Fast path: if both have same lower_bound and below_bound
229        if self.lower_bound == other.lower_bound && self.below_bound_equals(&other.below_bound) {
230            return Self {
231                upper_set: self.upper_set | other.upper_set,
232                lower_set: self.lower_set | other.lower_set,
233                lower_bound: self.lower_bound,
234                below_bound: self.below_bound.clone(),
235            };
236        }
237
238        // Slow path: iterate and set each ID
239        let mut result = self.clone();
240        for id in other.iter() {
241            result = result.set(id);
242        }
243        result
244    }
245
246    /// Find the lowest snapshot ID in the set that is <= upper.
247    pub fn lowest(&self, upper: SnapshotId) -> SnapshotId {
248        // Check below_bound array first
249        if let Some(ref arr) = self.below_bound {
250            if let Some(&lowest) = arr.first() {
251                if lowest <= upper {
252                    return lowest;
253                }
254            }
255        }
256
257        // Check lower_set
258        if self.lower_set != 0 {
259            let lowest_in_lower = self.lower_bound + self.lower_set.trailing_zeros() as usize;
260            if lowest_in_lower <= upper {
261                return lowest_in_lower;
262            }
263        }
264
265        // Check upper_set
266        if self.upper_set != 0 {
267            let lowest_in_upper =
268                self.lower_bound + BITS_PER_SET + self.upper_set.trailing_zeros() as usize;
269            if lowest_in_upper <= upper {
270                return lowest_in_upper;
271            }
272        }
273
274        // Nothing found, return upper
275        upper
276    }
277
278    /// Check if the set is empty.
279    pub fn is_empty(&self) -> bool {
280        self.lower_set == 0 && self.upper_set == 0 && self.below_bound.is_none()
281    }
282
283    /// Iterate over all snapshot IDs in the set.
284    pub fn iter(&self) -> SnapshotIdSetIter<'_> {
285        SnapshotIdSetIter::new(self)
286    }
287
288    /// Convert to a Vec of snapshot IDs (for testing/debugging).
289    pub fn to_list(&self) -> Vec<SnapshotId> {
290        self.iter().collect()
291    }
292
293    /// Add a contiguous range of IDs [from, until) to the set.
294    /// Mirrors AndroidX SnapshotIdSet.addRange semantics used by Snapshot.kt.
295    pub fn add_range(&self, from: SnapshotId, until: SnapshotId) -> Self {
296        if from >= until {
297            return self.clone();
298        }
299        let mut result = self.clone();
300        let mut id = from;
301        while id < until {
302            result = result.set(id);
303            id += 1;
304        }
305        result
306    }
307
308    // Helper: check if two below_bound arrays are equal
309    fn below_bound_equals(&self, other: &Option<Box<[SnapshotId]>>) -> bool {
310        match (&self.below_bound, other) {
311            (None, None) => true,
312            (Some(a), Some(b)) => a == b,
313            _ => false,
314        }
315    }
316
317    // Helper: shift the bit arrays and set a new ID
318    fn shift_and_set(&self, id: SnapshotId) -> Self {
319        let target_lower_bound = (id / SNAPSHOT_ID_SIZE) * SNAPSHOT_ID_SIZE;
320
321        let mut new_upper_set = self.upper_set;
322        let mut new_lower_set = self.lower_set;
323        let mut new_lower_bound = self.lower_bound;
324        let mut new_below_bound: Vec<SnapshotId> = if let Some(ref arr) = self.below_bound {
325            arr.to_vec()
326        } else {
327            Vec::new()
328        };
329
330        while new_lower_bound < target_lower_bound {
331            // Shift lower_set into below_bound array
332            if new_lower_set != 0 {
333                for bit_offset in 0..BITS_PER_SET {
334                    if (new_lower_set & (1u64 << bit_offset)) != 0 {
335                        let id_to_add = new_lower_bound + bit_offset;
336                        // Insert in sorted order
337                        match new_below_bound.binary_search(&id_to_add) {
338                            Ok(_) => {} // Already present (shouldn't happen)
339                            Err(pos) => new_below_bound.insert(pos, id_to_add),
340                        }
341                    }
342                }
343            }
344
345            // Shift upper_set down to lower_set
346            if new_upper_set == 0 {
347                new_lower_bound = target_lower_bound;
348                new_lower_set = 0;
349                break;
350            }
351
352            new_lower_set = new_upper_set;
353            new_upper_set = 0;
354            new_lower_bound += BITS_PER_SET;
355        }
356
357        let result = Self {
358            upper_set: new_upper_set,
359            lower_set: new_lower_set,
360            lower_bound: new_lower_bound,
361            below_bound: if new_below_bound.is_empty() {
362                None
363            } else {
364                Some(new_below_bound.into_boxed_slice())
365            },
366        };
367
368        // Now set the ID
369        result.set(id)
370    }
371}
372
373impl Default for SnapshotIdSet {
374    fn default() -> Self {
375        Self::EMPTY
376    }
377}
378
379impl fmt::Debug for SnapshotIdSet {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        write!(f, "SnapshotIdSet{{")?;
382        let ids: Vec<_> = self.iter().collect();
383        for (i, id) in ids.iter().enumerate() {
384            if i > 0 {
385                write!(f, ", ")?;
386            }
387            write!(f, "{}", id)?;
388        }
389        write!(f, "}}")
390    }
391}
392
393/// Iterator over snapshot IDs in a set.
394pub struct SnapshotIdSetIter<'a> {
395    set: &'a SnapshotIdSet,
396    below_index: usize,
397    lower_set: u64,
398    upper_set: u64,
399    current_offset: usize,
400}
401
402impl<'a> SnapshotIdSetIter<'a> {
403    fn new(set: &'a SnapshotIdSet) -> Self {
404        Self {
405            set,
406            below_index: 0,
407            lower_set: set.lower_set,
408            upper_set: set.upper_set,
409            current_offset: 0,
410        }
411    }
412}
413
414impl<'a> Iterator for SnapshotIdSetIter<'a> {
415    type Item = SnapshotId;
416
417    fn next(&mut self) -> Option<Self::Item> {
418        // First, yield from below_bound array
419        if let Some(ref arr) = self.set.below_bound {
420            if self.below_index < arr.len() {
421                let id = arr[self.below_index];
422                self.below_index += 1;
423                return Some(id);
424            }
425        }
426
427        // Then yield from lower_set
428        while self.current_offset < BITS_PER_SET {
429            if (self.lower_set & (1u64 << self.current_offset)) != 0 {
430                let id = self.set.lower_bound + self.current_offset;
431                self.current_offset += 1;
432                return Some(id);
433            }
434            self.current_offset += 1;
435        }
436
437        // Finally yield from upper_set
438        while self.current_offset < BITS_PER_SET * 2 {
439            let bit_offset = self.current_offset - BITS_PER_SET;
440            if (self.upper_set & (1u64 << bit_offset)) != 0 {
441                let id = self.set.lower_bound + self.current_offset;
442                self.current_offset += 1;
443                return Some(id);
444            }
445            self.current_offset += 1;
446        }
447
448        None
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_empty_set() {
458        let set = SnapshotIdSet::EMPTY;
459        assert!(set.is_empty());
460        assert!(!set.get(0));
461        assert!(!set.get(100));
462    }
463
464    #[test]
465    fn test_set_and_get_lower_range() {
466        let set = SnapshotIdSet::new();
467        let set = set.set(0);
468        assert!(set.get(0));
469        assert!(!set.get(1));
470
471        let set = set.set(63);
472        assert!(set.get(0));
473        assert!(set.get(63));
474        assert!(!set.get(64));
475    }
476
477    #[test]
478    fn test_set_and_get_upper_range() {
479        let set = SnapshotIdSet::new();
480        let set = set.set(64);
481        assert!(set.get(64));
482        assert!(!set.get(63));
483        assert!(!set.get(128));
484
485        let set = set.set(127);
486        assert!(set.get(64));
487        assert!(set.get(127));
488        assert!(!set.get(128));
489    }
490
491    #[test]
492    fn test_set_idempotent() {
493        let set = SnapshotIdSet::new();
494        let set1 = set.set(10);
495        let set2 = set1.set(10);
496        assert_eq!(set1, set2);
497    }
498
499    #[test]
500    fn test_clear() {
501        let set = SnapshotIdSet::new().set(10).set(20).set(30);
502        assert!(set.get(10));
503        assert!(set.get(20));
504        assert!(set.get(30));
505
506        let set = set.clear(20);
507        assert!(set.get(10));
508        assert!(!set.get(20));
509        assert!(set.get(30));
510    }
511
512    #[test]
513    fn test_clear_idempotent() {
514        let set = SnapshotIdSet::new().set(10);
515        let set1 = set.clear(10);
516        let set2 = set1.clear(10);
517        assert_eq!(set1, set2);
518    }
519
520    #[test]
521    fn test_below_bound_insertion() {
522        let mut set = SnapshotIdSet::new();
523        // Set lower_bound to 100
524        set = set.set(100);
525        assert_eq!(set.lower_bound, 0);
526
527        // Now insert something below lower_bound
528        set = set.set(50);
529        assert!(set.get(50));
530        assert!(set.get(100));
531
532        set = set.set(25);
533        set = set.set(75);
534        assert!(set.get(25));
535        assert!(set.get(50));
536        assert!(set.get(75));
537        assert!(set.get(100));
538
539        // Check that below_bound is sorted
540        let list = set.to_list();
541        assert_eq!(list, vec![25, 50, 75, 100]);
542    }
543
544    #[test]
545    fn test_below_bound_removal() {
546        // Build incrementally to avoid stack overflow from large shifts
547        let set = SnapshotIdSet::new();
548        let set = set.set(25);
549        let set = set.set(50);
550        let set = set.set(75);
551        let set = set.set(200);
552
553        let set = set.clear(50);
554        assert!(set.get(25));
555        assert!(!set.get(50));
556        assert!(set.get(75));
557        assert!(set.get(200));
558
559        let list = set.to_list();
560        assert_eq!(list, vec![25, 75, 200]);
561    }
562
563    #[test]
564    fn test_shift_and_set() {
565        let set = SnapshotIdSet::new();
566        let set = set.set(10);
567        assert_eq!(set.lower_bound, 0);
568
569        // Setting a value way above should shift the arrays
570        let set = set.set(200);
571        assert!(set.get(10));
572        assert!(set.get(200));
573
574        // 10 should now be in below_bound
575        assert!(set.below_bound.is_some());
576    }
577
578    #[test]
579    fn test_shift_and_set_boundary_values() {
580        let mut set = SnapshotIdSet::new();
581        let boundary = SNAPSHOT_ID_SIZE * 12 - 1;
582        set = set.set(boundary);
583        assert!(set.get(boundary));
584
585        set = set.set(boundary + 1);
586        assert!(set.get(boundary));
587        assert!(set.get(boundary + 1));
588    }
589
590    #[test]
591    fn test_set_below_lower_bound_inserts() {
592        let set = SnapshotIdSet::new().set(200);
593        let lower_bound = set.lower_bound;
594        assert!(lower_bound > 0);
595
596        let below = lower_bound - 1;
597        let set = set.set(below);
598        assert!(set.get(below));
599        assert!(set.get(200));
600    }
601
602    #[test]
603    fn test_and_not_fast_path() {
604        let set1 = SnapshotIdSet::new().set(10).set(20).set(30);
605        let set2 = SnapshotIdSet::new().set(20).set(40);
606
607        let result = set1.and_not(&set2);
608        assert!(result.get(10));
609        assert!(!result.get(20));
610        assert!(result.get(30));
611        assert!(!result.get(40));
612    }
613
614    #[test]
615    fn test_and_not_slow_path() {
616        let set1 = SnapshotIdSet::new().set(10).set(20).set(30);
617        // Create set2 with different lower_bound by setting high value first
618        let set2 = SnapshotIdSet::new().set(100).set(20);
619
620        let result = set1.and_not(&set2);
621        assert!(result.get(10));
622        assert!(!result.get(20));
623        assert!(result.get(30));
624    }
625
626    #[test]
627    fn test_or_fast_path() {
628        let set1 = SnapshotIdSet::new().set(10).set(20);
629        let set2 = SnapshotIdSet::new().set(20).set(30);
630
631        let result = set1.or(&set2);
632        assert!(result.get(10));
633        assert!(result.get(20));
634        assert!(result.get(30));
635    }
636
637    #[test]
638    fn test_or_slow_path() {
639        let set1 = SnapshotIdSet::new().set(10).set(20);
640        let set2 = SnapshotIdSet::new().set(100).set(30);
641
642        let result = set1.or(&set2);
643        assert!(result.get(10));
644        assert!(result.get(20));
645        assert!(result.get(30));
646        assert!(result.get(100));
647    }
648
649    #[test]
650    fn test_lowest_in_below_bound() {
651        // Build incrementally to avoid deep recursion
652        let set = SnapshotIdSet::new();
653        let set = set.set(25);
654        let set = set.set(50);
655        let set = set.set(200);
656        assert_eq!(set.lowest(1000), 25);
657        assert_eq!(set.lowest(100), 25);
658        assert_eq!(set.lowest(30), 25);
659    }
660
661    #[test]
662    fn test_lowest_in_lower_set() {
663        let set = SnapshotIdSet::new().set(10).set(20).set(30);
664        assert_eq!(set.lowest(1000), 10);
665        assert_eq!(set.lowest(25), 10);
666    }
667
668    #[test]
669    fn test_lowest_in_upper_set() {
670        let set = SnapshotIdSet::new().set(70).set(80).set(90);
671        assert_eq!(set.lowest(1000), 70);
672    }
673
674    #[test]
675    fn test_lowest_returns_upper_if_none_found() {
676        let set = SnapshotIdSet::new().set(100);
677        assert_eq!(set.lowest(50), 50);
678    }
679
680    #[test]
681    fn test_iterator() {
682        let set = SnapshotIdSet::new().set(10).set(20).set(5).set(30);
683        let list: Vec<_> = set.iter().collect();
684        // Should be in sorted order
685        assert_eq!(list, vec![5, 10, 20, 30]);
686    }
687
688    #[test]
689    fn test_iterator_empty() {
690        let set = SnapshotIdSet::new();
691        let list: Vec<_> = set.iter().collect();
692        assert_eq!(list, Vec::<SnapshotId>::new());
693    }
694
695    #[test]
696    fn test_iterator_all_ranges() {
697        let set = SnapshotIdSet::new()
698            .set(5) // below_bound (after shift)
699            .set(10) // lower_set (after shift)
700            .set(70) // upper_set (after shift)
701            .set(200); // causes shift
702
703        let list: Vec<_> = set.iter().collect();
704        assert_eq!(list, vec![5, 10, 70, 200]);
705    }
706
707    #[test]
708    fn test_to_list() {
709        let set = SnapshotIdSet::new().set(10).set(20).set(30);
710        assert_eq!(set.to_list(), vec![10, 20, 30]);
711    }
712
713    #[test]
714    fn test_debug_format() {
715        let set = SnapshotIdSet::new().set(10).set(20);
716        let debug_str = format!("{:?}", set);
717        assert_eq!(debug_str, "SnapshotIdSet{10, 20}");
718    }
719
720    #[test]
721    fn test_large_snapshot_ids() {
722        // Build incrementally to avoid deep recursion
723        let set = SnapshotIdSet::new();
724        let set = set.set(500);
725        let set = set.set(1000);
726        let set = set.set(2000);
727
728        assert!(set.get(500));
729        assert!(set.get(1000));
730        assert!(set.get(2000));
731        assert!(!set.get(1500));
732    }
733
734    #[test]
735    fn test_boundary_transitions() {
736        let set = SnapshotIdSet::new();
737
738        // Test transition from lower to upper
739        let set = set.set(63);
740        let set = set.set(64);
741        assert!(set.get(63));
742        assert!(set.get(64));
743
744        // Test transition from upper to above
745        let set = set.set(127);
746        let set = set.set(128);
747        assert!(set.get(127));
748        assert!(set.get(128));
749    }
750}