pgm_extra/collections/
set.rs

1//! An owned set optimized for read-heavy workloads backed by PGM-Index.
2//!
3//! `Set` is a drop-in replacement for `BTreeSet` in read-heavy workloads
4//! where you build the set once and perform many lookups.
5
6use alloc::vec::Vec;
7use core::cmp::Ordering;
8use core::fmt;
9use core::hash::{Hash, Hasher};
10use core::iter::FusedIterator;
11use core::ops::RangeBounds;
12
13use crate::error::Error;
14use crate::index::external;
15use crate::index::key::Indexable;
16use crate::util::range::range_to_indices;
17
18/// An owned set optimized for read-heavy workloads, backed by a PGM-index.
19///
20/// This is a BTreeSet-like container that owns its data and provides
21/// efficient lookups using a learned index. Mutations are supported but
22/// trigger O(n) index rebuilds; for frequent updates use [`crate::Dynamic`].
23///
24/// Works with any type that implements [`Indexable`]:
25/// - Numeric types (u64, i32, etc.) are indexed directly
26/// - String/bytes types use prefix extraction
27///
28/// # Example
29///
30/// ```
31/// use pgm_extra::Set;
32///
33/// // Numeric set
34/// let nums: Vec<u64> = (0..10000).collect();
35/// let set = Set::from_sorted_unique(nums, 64, 4).unwrap();
36/// assert!(set.contains(&5000));
37///
38/// // String set
39/// let words = vec!["apple", "banana", "cherry"];
40/// let set = Set::from_sorted_unique(words, 64, 4).unwrap();
41/// assert!(set.contains(&"banana"));
42/// ```
43#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
44#[cfg_attr(
45    feature = "serde",
46    serde(
47        bound = "T: serde::Serialize + serde::de::DeserializeOwned, T::Key: serde::Serialize + serde::de::DeserializeOwned"
48    )
49)]
50pub struct Set<T: Indexable> {
51    data: Vec<T>,
52    index: Option<external::Static<T>>,
53    epsilon: usize,
54    epsilon_recursive: usize,
55}
56
57impl<T: Indexable + Ord> Set<T>
58where
59    T::Key: Ord,
60{
61    /// Create a set from pre-sorted, unique values.
62    ///
63    /// # Panics
64    ///
65    /// Debug builds will panic if values are not sorted or contain duplicates.
66    pub fn from_sorted_unique(
67        data: Vec<T>,
68        epsilon: usize,
69        epsilon_recursive: usize,
70    ) -> Result<Self, Error> {
71        debug_assert!(
72            data.windows(2).all(|w| w[0] < w[1]),
73            "data must be sorted and unique"
74        );
75
76        let index = if data.is_empty() {
77            None
78        } else {
79            Some(external::Static::new(&data, epsilon, epsilon_recursive)?)
80        };
81        Ok(Self {
82            data,
83            index,
84            epsilon,
85            epsilon_recursive,
86        })
87    }
88
89    /// Create a set from an unsorted iterator.
90    ///
91    /// Values are sorted and deduplicated (like `BTreeSet::from_iter`).
92    pub fn build<I>(iter: I, epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error>
93    where
94        I: IntoIterator<Item = T>,
95    {
96        let mut data: Vec<T> = iter.into_iter().collect();
97        data.sort();
98        data.dedup();
99
100        Self::from_sorted_unique(data, epsilon, epsilon_recursive)
101    }
102
103    /// Create an empty set with the given epsilon values.
104    pub fn empty(epsilon: usize, epsilon_recursive: usize) -> Self {
105        Self {
106            data: Vec::new(),
107            index: None,
108            epsilon,
109            epsilon_recursive,
110        }
111    }
112
113    /// Create a set with default epsilon values (64, 4).
114    pub fn new(data: Vec<T>) -> Result<Self, Error> {
115        Self::build(data, 64, 4)
116    }
117
118    /// Check if the set contains the value.
119    #[inline]
120    pub fn contains(&self, value: &T) -> bool {
121        self.get(value).is_some()
122    }
123
124    /// Get a reference to the value if it exists.
125    #[inline]
126    pub fn get(&self, value: &T) -> Option<&T> {
127        let index = self.index.as_ref()?;
128        let approx = index.search(value);
129
130        let lo = approx.lo;
131        let hi = approx.hi.min(self.data.len());
132
133        for i in lo..hi {
134            match self.data[i].cmp(value) {
135                Ordering::Equal => return Some(&self.data[i]),
136                Ordering::Greater => return None,
137                Ordering::Less => continue,
138            }
139        }
140        None
141    }
142
143    /// Find the index of the first value >= the given value.
144    #[inline]
145    pub fn lower_bound(&self, value: &T) -> usize {
146        match &self.index {
147            Some(index) => index.lower_bound(&self.data, value),
148            None => 0,
149        }
150    }
151
152    /// Find the index of the first value > the given value.
153    #[inline]
154    pub fn upper_bound(&self, value: &T) -> usize {
155        match &self.index {
156            Some(index) => index.upper_bound(&self.data, value),
157            None => 0,
158        }
159    }
160
161    /// Returns an iterator over values in the given range.
162    #[inline]
163    pub fn range<R>(&self, range: R) -> impl DoubleEndedIterator<Item = &T>
164    where
165        R: RangeBounds<T>,
166    {
167        let (start, end) = range_to_indices(
168            range,
169            self.data.len(),
170            |v| self.lower_bound(v),
171            |v| self.upper_bound(v),
172        );
173        self.data[start..end].iter()
174    }
175
176    /// Get the first (smallest) value.
177    #[inline]
178    pub fn first(&self) -> Option<&T> {
179        self.data.first()
180    }
181
182    /// Get the last (largest) value.
183    #[inline]
184    pub fn last(&self) -> Option<&T> {
185        self.data.last()
186    }
187
188    /// Iterate over all values in sorted order.
189    #[inline]
190    pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + DoubleEndedIterator {
191        self.data.iter()
192    }
193
194    /// Get the number of values.
195    #[inline]
196    pub fn len(&self) -> usize {
197        self.data.len()
198    }
199
200    /// Check if the set is empty.
201    #[inline]
202    pub fn is_empty(&self) -> bool {
203        self.data.is_empty()
204    }
205
206    /// Get the number of segments in the underlying index.
207    #[inline]
208    pub fn segments_count(&self) -> usize {
209        self.index.as_ref().map_or(0, |i| i.segments_count())
210    }
211
212    /// Get the height of the underlying index.
213    #[inline]
214    pub fn height(&self) -> usize {
215        self.index.as_ref().map_or(0, |i| i.height())
216    }
217
218    /// Get the epsilon value.
219    #[inline]
220    pub fn epsilon(&self) -> usize {
221        self.epsilon
222    }
223
224    /// Get the epsilon_recursive value.
225    #[inline]
226    pub fn epsilon_recursive(&self) -> usize {
227        self.epsilon_recursive
228    }
229
230    /// Approximate memory usage in bytes.
231    pub fn size_in_bytes(&self) -> usize {
232        self.index.as_ref().map_or(0, |i| i.size_in_bytes())
233            + self.data.capacity() * core::mem::size_of::<T>()
234    }
235
236    /// Get a reference to the underlying data slice.
237    #[inline]
238    pub fn as_slice(&self) -> &[T] {
239        &self.data
240    }
241
242    /// Consume the set and return the underlying data.
243    #[inline]
244    pub fn into_vec(self) -> Vec<T> {
245        self.data
246    }
247
248    /// Get a reference to the underlying index.
249    #[inline]
250    pub fn index(&self) -> Option<&external::Static<T>> {
251        self.index.as_ref()
252    }
253
254    /// Insert a value into the set.
255    ///
256    /// Returns `true` if the value was newly inserted, `false` if it already existed.
257    ///
258    /// **Note**: This rebuilds the entire index, making it O(n) per insertion.
259    /// For bulk insertions, prefer collecting into a new set or using `extend`.
260    /// For frequent mutations, consider using `index::owned::Dynamic` directly.
261    pub fn insert(&mut self, value: T) -> bool {
262        if self.contains(&value) {
263            return false;
264        }
265
266        let mut data = core::mem::take(&mut self.data);
267        data.push(value);
268        data.sort();
269
270        if let Ok(new_set) = Self::from_sorted_unique(data, self.epsilon, self.epsilon_recursive) {
271            *self = new_set;
272        }
273        true
274    }
275
276    /// Returns `true` if `self` has no elements in common with `other`.
277    pub fn is_disjoint(&self, other: &Set<T>) -> bool {
278        if self.is_empty() || other.is_empty() {
279            return true;
280        }
281
282        let (smaller, larger) = if self.len() <= other.len() {
283            (self, other)
284        } else {
285            (other, self)
286        };
287
288        for value in smaller.iter() {
289            if larger.contains(value) {
290                return false;
291            }
292        }
293        true
294    }
295
296    /// Returns `true` if `self` is a subset of `other`.
297    pub fn is_subset(&self, other: &Set<T>) -> bool {
298        if self.len() > other.len() {
299            return false;
300        }
301        self.iter().all(|v| other.contains(v))
302    }
303
304    /// Returns `true` if `self` is a superset of `other`.
305    pub fn is_superset(&self, other: &Set<T>) -> bool {
306        other.is_subset(self)
307    }
308
309    /// Returns an iterator over values in `self` but not in `other`.
310    pub fn difference<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
311        self.iter().filter(move |v| !other.contains(v))
312    }
313
314    /// Returns an iterator over values in `self` or `other` but not both.
315    pub fn symmetric_difference<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
316        self.difference(other).chain(other.difference(self))
317    }
318
319    /// Returns an iterator over values in both `self` and `other`.
320    pub fn intersection<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
321        let (smaller, larger) = if self.len() <= other.len() {
322            (self, other)
323        } else {
324            (other, self)
325        };
326        smaller.iter().filter(move |v| larger.contains(v))
327    }
328
329    /// Returns an iterator over values in either `self` or `other`.
330    pub fn union<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
331        MergeIter::new(self.data.iter(), other.data.iter())
332    }
333}
334
335/// Iterator that merges two sorted iterators, yielding unique elements.
336pub struct MergeIter<'a, T> {
337    a: core::slice::Iter<'a, T>,
338    b: core::slice::Iter<'a, T>,
339    peeked_a: Option<&'a T>,
340    peeked_b: Option<&'a T>,
341}
342
343impl<'a, T: Ord> MergeIter<'a, T> {
344    fn new(mut a: core::slice::Iter<'a, T>, mut b: core::slice::Iter<'a, T>) -> Self {
345        let peeked_a = a.next();
346        let peeked_b = b.next();
347        Self {
348            a,
349            b,
350            peeked_a,
351            peeked_b,
352        }
353    }
354}
355
356impl<'a, T: Ord> Iterator for MergeIter<'a, T> {
357    type Item = &'a T;
358
359    fn next(&mut self) -> Option<Self::Item> {
360        match (self.peeked_a, self.peeked_b) {
361            (Some(a), Some(b)) => match a.cmp(b) {
362                Ordering::Less => {
363                    self.peeked_a = self.a.next();
364                    Some(a)
365                }
366                Ordering::Greater => {
367                    self.peeked_b = self.b.next();
368                    Some(b)
369                }
370                Ordering::Equal => {
371                    self.peeked_a = self.a.next();
372                    self.peeked_b = self.b.next();
373                    Some(a)
374                }
375            },
376            (Some(a), None) => {
377                self.peeked_a = self.a.next();
378                Some(a)
379            }
380            (None, Some(b)) => {
381                self.peeked_b = self.b.next();
382                Some(b)
383            }
384            (None, None) => None,
385        }
386    }
387}
388
389impl<T: Ord> FusedIterator for MergeIter<'_, T> {}
390
391// Standard trait implementations
392
393impl<T: Indexable + Clone> Clone for Set<T>
394where
395    T::Key: Clone,
396{
397    fn clone(&self) -> Self {
398        Self {
399            data: self.data.clone(),
400            index: self.index.clone(),
401            epsilon: self.epsilon,
402            epsilon_recursive: self.epsilon_recursive,
403        }
404    }
405}
406
407impl<T: Indexable + fmt::Debug> fmt::Debug for Set<T> {
408    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409        f.debug_set().entries(self.data.iter()).finish()
410    }
411}
412
413impl<T: Indexable + Ord + PartialEq> PartialEq for Set<T> {
414    fn eq(&self, other: &Self) -> bool {
415        self.data == other.data
416    }
417}
418
419impl<T: Indexable + Ord + Eq> Eq for Set<T> {}
420
421impl<T: Indexable + Ord + PartialOrd> PartialOrd for Set<T>
422where
423    T::Key: Ord,
424{
425    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
426        Some(self.cmp(other))
427    }
428}
429
430impl<T: Indexable + Ord> Ord for Set<T>
431where
432    T::Key: Ord,
433{
434    fn cmp(&self, other: &Self) -> Ordering {
435        self.data.cmp(&other.data)
436    }
437}
438
439impl<T: Indexable + Hash> Hash for Set<T> {
440    fn hash<H: Hasher>(&self, state: &mut H) {
441        self.data.hash(state);
442    }
443}
444
445impl<T: Indexable + Ord> IntoIterator for Set<T>
446where
447    T::Key: Ord,
448{
449    type Item = T;
450    type IntoIter = alloc::vec::IntoIter<T>;
451
452    fn into_iter(self) -> Self::IntoIter {
453        self.data.into_iter()
454    }
455}
456
457impl<'a, T: Indexable> IntoIterator for &'a Set<T> {
458    type Item = &'a T;
459    type IntoIter = core::slice::Iter<'a, T>;
460
461    fn into_iter(self) -> Self::IntoIter {
462        self.data.iter()
463    }
464}
465
466impl<T: Indexable + Ord> FromIterator<T> for Set<T>
467where
468    T::Key: Ord,
469{
470    /// Creates a Set from an iterator.
471    ///
472    /// Returns an empty set if the iterator is empty.
473    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
474        Self::build(iter, 64, 4).unwrap_or_else(|_| Self::empty(64, 4))
475    }
476}
477
478impl<T: Indexable + Ord> core::iter::Extend<T> for Set<T>
479where
480    T::Key: Ord,
481{
482    /// Extends the set with elements from an iterator.
483    ///
484    /// **Note**: This rebuilds the entire index, making it O(n) per call.
485    /// For bulk insertions, prefer collecting into a new set.
486    /// For frequent mutations, consider using `index::owned::Dynamic` directly.
487    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
488        let mut data = core::mem::take(&mut self.data);
489        data.extend(iter);
490        data.sort();
491        data.dedup();
492
493        match Self::from_sorted_unique(data, self.epsilon, self.epsilon_recursive) {
494            Ok(new_set) => *self = new_set,
495            Err(_) => {
496                *self = Self::empty(self.epsilon, self.epsilon_recursive);
497            }
498        }
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use alloc::string::String;
506    use alloc::vec;
507
508    #[test]
509    fn test_set_numeric() {
510        let data: Vec<u64> = (0..1000).collect();
511        let set = Set::from_sorted_unique(data, 64, 4).unwrap();
512
513        assert_eq!(set.len(), 1000);
514        assert!(set.contains(&500));
515        assert!(!set.contains(&1001));
516    }
517
518    #[test]
519    fn test_set_strings() {
520        let data = vec!["apple", "banana", "cherry", "date"];
521        let set = Set::from_sorted_unique(data, 64, 4).unwrap();
522
523        assert!(set.contains(&"banana"));
524        assert!(set.contains(&"cherry"));
525        assert!(!set.contains(&"elderberry"));
526    }
527
528    #[test]
529    fn test_set_owned_strings() {
530        let data: Vec<String> = vec!["alpha", "beta", "gamma"]
531            .into_iter()
532            .map(String::from)
533            .collect();
534        let set = Set::from_sorted_unique(data, 64, 4).unwrap();
535
536        assert!(set.contains(&String::from("beta")));
537        assert!(!set.contains(&String::from("delta")));
538    }
539
540    #[test]
541    fn test_set_build() {
542        let data = vec![5u64, 3, 1, 4, 1, 5, 9, 2, 6];
543        let set = Set::build(data, 4, 2).unwrap();
544
545        assert_eq!(set.len(), 7);
546        assert!(set.contains(&1));
547        assert!(set.contains(&9));
548
549        let collected: Vec<_> = set.iter().copied().collect();
550        assert_eq!(collected, vec![1, 2, 3, 4, 5, 6, 9]);
551    }
552
553    #[test]
554    fn test_set_first_last() {
555        let data: Vec<u64> = vec![10, 20, 30, 40, 50];
556        let set = Set::from_sorted_unique(data, 4, 2).unwrap();
557
558        assert_eq!(set.first(), Some(&10));
559        assert_eq!(set.last(), Some(&50));
560    }
561
562    #[test]
563    fn test_set_range() {
564        let data: Vec<u64> = (0..100).collect();
565        let set = Set::from_sorted_unique(data, 16, 4).unwrap();
566
567        let range: Vec<_> = set.range(10..20).copied().collect();
568        assert_eq!(range, (10..20).collect::<Vec<_>>());
569    }
570
571    #[test]
572    fn test_set_iter() {
573        let data: Vec<u64> = (0..10).collect();
574        let set = Set::from_sorted_unique(data, 4, 2).unwrap();
575
576        let forward: Vec<_> = set.iter().copied().collect();
577        let backward: Vec<_> = set.iter().rev().copied().collect();
578
579        assert_eq!(forward, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
580        assert_eq!(backward, vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
581    }
582
583    #[test]
584    fn test_set_operations() {
585        let set1 = Set::build(vec![1u64, 2, 3, 4, 5], 4, 2).unwrap();
586        let set2 = Set::build(vec![4u64, 5, 6, 7, 8], 4, 2).unwrap();
587
588        let intersection: Vec<_> = set1.intersection(&set2).copied().collect();
589        assert_eq!(intersection, vec![4, 5]);
590
591        let difference: Vec<_> = set1.difference(&set2).copied().collect();
592        assert_eq!(difference, vec![1, 2, 3]);
593
594        assert!(!set1.is_disjoint(&set2));
595
596        let set3 = Set::build(vec![10u64, 11], 4, 2).unwrap();
597        assert!(set1.is_disjoint(&set3));
598    }
599
600    #[test]
601    fn test_set_collect() {
602        let set: Set<u64> = (0..100).collect();
603        assert_eq!(set.len(), 100);
604        assert!(set.contains(&50));
605    }
606
607    #[test]
608    fn test_set_empty() {
609        let set: Set<u64> = Set::empty(64, 4);
610        assert!(set.is_empty());
611        assert_eq!(set.len(), 0);
612        assert!(!set.contains(&0));
613        assert_eq!(set.first(), None);
614        assert_eq!(set.last(), None);
615    }
616
617    #[test]
618    fn test_set_collect_empty() {
619        let set: Set<u64> = core::iter::empty().collect();
620        assert!(set.is_empty());
621        assert_eq!(set.len(), 0);
622    }
623
624    #[test]
625    fn test_set_insert() {
626        let mut set = Set::build(vec![1u64, 3, 5], 4, 2).unwrap();
627        assert_eq!(set.len(), 3);
628
629        assert!(set.insert(2));
630        assert_eq!(set.len(), 4);
631        assert!(set.contains(&2));
632
633        assert!(!set.insert(2));
634        assert_eq!(set.len(), 4);
635
636        assert!(set.insert(4));
637        let collected: Vec<_> = set.iter().copied().collect();
638        assert_eq!(collected, vec![1, 2, 3, 4, 5]);
639    }
640
641    #[test]
642    fn test_set_insert_into_empty() {
643        let mut set: Set<u64> = Set::empty(64, 4);
644        assert!(set.insert(42));
645        assert_eq!(set.len(), 1);
646        assert!(set.contains(&42));
647    }
648
649    #[test]
650    fn test_set_extend_empty() {
651        let mut set: Set<u64> = Set::empty(64, 4);
652        set.extend(vec![3, 1, 2]);
653        assert_eq!(set.len(), 3);
654        let collected: Vec<_> = set.iter().copied().collect();
655        assert_eq!(collected, vec![1, 2, 3]);
656    }
657}