pgm_extra/collections/
set.rs

1//! An owned set optimized for read-heavy workloads backed by PGM-Index.
2//!
3//! `Set` is a drop-in replacement for `BTreeSet` in read-heavy workloads
4//! where you build the set once and perform many lookups.
5
6use alloc::vec::Vec;
7use core::cmp::Ordering;
8use core::fmt;
9use core::hash::{Hash, Hasher};
10use core::iter::FusedIterator;
11use core::ops::RangeBounds;
12
13use crate::error::Error;
14use crate::index::external;
15use crate::index::key::Indexable;
16use crate::util::range::range_to_indices;
17
18/// An owned set optimized for read-heavy workloads, backed by a PGM-index.
19///
20/// This is a BTreeSet-like container that owns its data and provides
21/// efficient lookups using a learned index. Mutations are supported but
22/// trigger O(n) index rebuilds; for frequent updates use [`crate::Dynamic`].
23///
24/// Works with any type that implements [`Indexable`]:
25/// - Numeric types (u64, i32, etc.) are indexed directly
26/// - String/bytes types use prefix extraction
27///
28/// # Example
29///
30/// ```
31/// use pgm_extra::Set;
32///
33/// // Numeric set
34/// let nums: Vec<u64> = (0..10000).collect();
35/// let set = Set::from_sorted_unique(nums, 64, 4).unwrap();
36/// assert!(set.contains(&5000));
37///
38/// // String set
39/// let words = vec!["apple", "banana", "cherry"];
40/// let set = Set::from_sorted_unique(words, 64, 4).unwrap();
41/// assert!(set.contains(&"banana"));
42/// ```
43#[cfg_attr(
44    feature = "rkyv",
45    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
46)]
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48#[cfg_attr(
49    feature = "serde",
50    serde(
51        bound = "T: serde::Serialize + serde::de::DeserializeOwned, T::Key: serde::Serialize + serde::de::DeserializeOwned"
52    )
53)]
54pub struct Set<T: Indexable> {
55    data: Vec<T>,
56    index: Option<external::Static<T>>,
57    epsilon: usize,
58    epsilon_recursive: usize,
59}
60
61impl<T: Indexable + Ord> Set<T>
62where
63    T::Key: Ord,
64{
65    /// Create a set from pre-sorted, unique values.
66    ///
67    /// # Panics
68    ///
69    /// Debug builds will panic if values are not sorted or contain duplicates.
70    pub fn from_sorted_unique(
71        data: Vec<T>,
72        epsilon: usize,
73        epsilon_recursive: usize,
74    ) -> Result<Self, Error> {
75        debug_assert!(
76            data.windows(2).all(|w| w[0] < w[1]),
77            "data must be sorted and unique"
78        );
79
80        let index = if data.is_empty() {
81            None
82        } else {
83            Some(external::Static::new(&data, epsilon, epsilon_recursive)?)
84        };
85        Ok(Self {
86            data,
87            index,
88            epsilon,
89            epsilon_recursive,
90        })
91    }
92
93    /// Create a set from an unsorted iterator.
94    ///
95    /// Values are sorted and deduplicated (like `BTreeSet::from_iter`).
96    pub fn build<I>(iter: I, epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error>
97    where
98        I: IntoIterator<Item = T>,
99    {
100        let mut data: Vec<T> = iter.into_iter().collect();
101        data.sort();
102        data.dedup();
103
104        Self::from_sorted_unique(data, epsilon, epsilon_recursive)
105    }
106
107    /// Create an empty set with the given epsilon values.
108    pub fn empty(epsilon: usize, epsilon_recursive: usize) -> Self {
109        Self {
110            data: Vec::new(),
111            index: None,
112            epsilon,
113            epsilon_recursive,
114        }
115    }
116
117    /// Create a set with default epsilon values (64, 4).
118    pub fn new(data: Vec<T>) -> Result<Self, Error> {
119        Self::build(data, 64, 4)
120    }
121
122    /// Check if the set contains the value.
123    #[inline]
124    pub fn contains(&self, value: &T) -> bool {
125        self.get(value).is_some()
126    }
127
128    /// Get a reference to the value if it exists.
129    #[inline]
130    pub fn get(&self, value: &T) -> Option<&T> {
131        let index = self.index.as_ref()?;
132        let approx = index.search(value);
133
134        let lo = approx.lo;
135        let hi = approx.hi.min(self.data.len());
136
137        for i in lo..hi {
138            match self.data[i].cmp(value) {
139                Ordering::Equal => return Some(&self.data[i]),
140                Ordering::Greater => return None,
141                Ordering::Less => continue,
142            }
143        }
144        None
145    }
146
147    /// Find the index of the first value >= the given value.
148    #[inline]
149    pub fn lower_bound(&self, value: &T) -> usize {
150        match &self.index {
151            Some(index) => index.lower_bound(&self.data, value),
152            None => 0,
153        }
154    }
155
156    /// Find the index of the first value > the given value.
157    #[inline]
158    pub fn upper_bound(&self, value: &T) -> usize {
159        match &self.index {
160            Some(index) => index.upper_bound(&self.data, value),
161            None => 0,
162        }
163    }
164
165    /// Returns an iterator over values in the given range.
166    #[inline]
167    pub fn range<R>(&self, range: R) -> impl DoubleEndedIterator<Item = &T>
168    where
169        R: RangeBounds<T>,
170    {
171        let (start, end) = range_to_indices(
172            range,
173            self.data.len(),
174            |v| self.lower_bound(v),
175            |v| self.upper_bound(v),
176        );
177        self.data[start..end].iter()
178    }
179
180    /// Get the first (smallest) value.
181    #[inline]
182    pub fn first(&self) -> Option<&T> {
183        self.data.first()
184    }
185
186    /// Get the last (largest) value.
187    #[inline]
188    pub fn last(&self) -> Option<&T> {
189        self.data.last()
190    }
191
192    /// Iterate over all values in sorted order.
193    #[inline]
194    pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + DoubleEndedIterator {
195        self.data.iter()
196    }
197
198    /// Get the number of values.
199    #[inline]
200    pub fn len(&self) -> usize {
201        self.data.len()
202    }
203
204    /// Check if the set is empty.
205    #[inline]
206    pub fn is_empty(&self) -> bool {
207        self.data.is_empty()
208    }
209
210    /// Get the number of segments in the underlying index.
211    #[inline]
212    pub fn segments_count(&self) -> usize {
213        self.index.as_ref().map_or(0, |i| i.segments_count())
214    }
215
216    /// Get the height of the underlying index.
217    #[inline]
218    pub fn height(&self) -> usize {
219        self.index.as_ref().map_or(0, |i| i.height())
220    }
221
222    /// Get the epsilon value.
223    #[inline]
224    pub fn epsilon(&self) -> usize {
225        self.epsilon
226    }
227
228    /// Get the epsilon_recursive value.
229    #[inline]
230    pub fn epsilon_recursive(&self) -> usize {
231        self.epsilon_recursive
232    }
233
234    /// Approximate memory usage in bytes.
235    pub fn size_in_bytes(&self) -> usize {
236        self.index.as_ref().map_or(0, |i| i.size_in_bytes())
237            + self.data.capacity() * core::mem::size_of::<T>()
238    }
239
240    /// Get a reference to the underlying data slice.
241    #[inline]
242    pub fn as_slice(&self) -> &[T] {
243        &self.data
244    }
245
246    /// Consume the set and return the underlying data.
247    #[inline]
248    pub fn into_vec(self) -> Vec<T> {
249        self.data
250    }
251
252    /// Get a reference to the underlying index.
253    #[inline]
254    pub fn index(&self) -> Option<&external::Static<T>> {
255        self.index.as_ref()
256    }
257
258    /// Insert a value into the set.
259    ///
260    /// Returns `true` if the value was newly inserted, `false` if it already existed.
261    ///
262    /// **Note**: This rebuilds the entire index, making it O(n) per insertion.
263    /// For bulk insertions, prefer collecting into a new set or using `extend`.
264    /// For frequent mutations, consider using `index::owned::Dynamic` directly.
265    pub fn insert(&mut self, value: T) -> bool {
266        if self.contains(&value) {
267            return false;
268        }
269
270        let mut data = core::mem::take(&mut self.data);
271        data.push(value);
272        data.sort();
273
274        if let Ok(new_set) = Self::from_sorted_unique(data, self.epsilon, self.epsilon_recursive) {
275            *self = new_set;
276        }
277        true
278    }
279
280    /// Returns `true` if `self` has no elements in common with `other`.
281    pub fn is_disjoint(&self, other: &Set<T>) -> bool {
282        if self.is_empty() || other.is_empty() {
283            return true;
284        }
285
286        let (smaller, larger) = if self.len() <= other.len() {
287            (self, other)
288        } else {
289            (other, self)
290        };
291
292        for value in smaller.iter() {
293            if larger.contains(value) {
294                return false;
295            }
296        }
297        true
298    }
299
300    /// Returns `true` if `self` is a subset of `other`.
301    pub fn is_subset(&self, other: &Set<T>) -> bool {
302        if self.len() > other.len() {
303            return false;
304        }
305        self.iter().all(|v| other.contains(v))
306    }
307
308    /// Returns `true` if `self` is a superset of `other`.
309    pub fn is_superset(&self, other: &Set<T>) -> bool {
310        other.is_subset(self)
311    }
312
313    /// Returns an iterator over values in `self` but not in `other`.
314    pub fn difference<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
315        self.iter().filter(move |v| !other.contains(v))
316    }
317
318    /// Returns an iterator over values in `self` or `other` but not both.
319    pub fn symmetric_difference<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
320        self.difference(other).chain(other.difference(self))
321    }
322
323    /// Returns an iterator over values in both `self` and `other`.
324    pub fn intersection<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
325        let (smaller, larger) = if self.len() <= other.len() {
326            (self, other)
327        } else {
328            (other, self)
329        };
330        smaller.iter().filter(move |v| larger.contains(v))
331    }
332
333    /// Returns an iterator over values in either `self` or `other`.
334    pub fn union<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
335        MergeIter::new(self.data.iter(), other.data.iter())
336    }
337}
338
339/// Iterator that merges two sorted iterators, yielding unique elements.
340pub struct MergeIter<'a, T> {
341    a: core::slice::Iter<'a, T>,
342    b: core::slice::Iter<'a, T>,
343    peeked_a: Option<&'a T>,
344    peeked_b: Option<&'a T>,
345}
346
347impl<'a, T: Ord> MergeIter<'a, T> {
348    fn new(mut a: core::slice::Iter<'a, T>, mut b: core::slice::Iter<'a, T>) -> Self {
349        let peeked_a = a.next();
350        let peeked_b = b.next();
351        Self {
352            a,
353            b,
354            peeked_a,
355            peeked_b,
356        }
357    }
358}
359
360impl<'a, T: Ord> Iterator for MergeIter<'a, T> {
361    type Item = &'a T;
362
363    fn next(&mut self) -> Option<Self::Item> {
364        match (self.peeked_a, self.peeked_b) {
365            (Some(a), Some(b)) => match a.cmp(b) {
366                Ordering::Less => {
367                    self.peeked_a = self.a.next();
368                    Some(a)
369                }
370                Ordering::Greater => {
371                    self.peeked_b = self.b.next();
372                    Some(b)
373                }
374                Ordering::Equal => {
375                    self.peeked_a = self.a.next();
376                    self.peeked_b = self.b.next();
377                    Some(a)
378                }
379            },
380            (Some(a), None) => {
381                self.peeked_a = self.a.next();
382                Some(a)
383            }
384            (None, Some(b)) => {
385                self.peeked_b = self.b.next();
386                Some(b)
387            }
388            (None, None) => None,
389        }
390    }
391}
392
393impl<T: Ord> FusedIterator for MergeIter<'_, T> {}
394
395// Standard trait implementations
396
397impl<T: Indexable + Clone> Clone for Set<T>
398where
399    T::Key: Clone,
400{
401    fn clone(&self) -> Self {
402        Self {
403            data: self.data.clone(),
404            index: self.index.clone(),
405            epsilon: self.epsilon,
406            epsilon_recursive: self.epsilon_recursive,
407        }
408    }
409}
410
411impl<T: Indexable + fmt::Debug> fmt::Debug for Set<T> {
412    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413        f.debug_set().entries(self.data.iter()).finish()
414    }
415}
416
417impl<T: Indexable + Ord + PartialEq> PartialEq for Set<T> {
418    fn eq(&self, other: &Self) -> bool {
419        self.data == other.data
420    }
421}
422
423impl<T: Indexable + Ord + Eq> Eq for Set<T> {}
424
425impl<T: Indexable + Ord + PartialOrd> PartialOrd for Set<T>
426where
427    T::Key: Ord,
428{
429    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
430        Some(self.cmp(other))
431    }
432}
433
434impl<T: Indexable + Ord> Ord for Set<T>
435where
436    T::Key: Ord,
437{
438    fn cmp(&self, other: &Self) -> Ordering {
439        self.data.cmp(&other.data)
440    }
441}
442
443impl<T: Indexable + Hash> Hash for Set<T> {
444    fn hash<H: Hasher>(&self, state: &mut H) {
445        self.data.hash(state);
446    }
447}
448
449impl<T: Indexable + Ord> IntoIterator for Set<T>
450where
451    T::Key: Ord,
452{
453    type Item = T;
454    type IntoIter = alloc::vec::IntoIter<T>;
455
456    fn into_iter(self) -> Self::IntoIter {
457        self.data.into_iter()
458    }
459}
460
461impl<'a, T: Indexable> IntoIterator for &'a Set<T> {
462    type Item = &'a T;
463    type IntoIter = core::slice::Iter<'a, T>;
464
465    fn into_iter(self) -> Self::IntoIter {
466        self.data.iter()
467    }
468}
469
470impl<T: Indexable + Ord> FromIterator<T> for Set<T>
471where
472    T::Key: Ord,
473{
474    /// Creates a Set from an iterator.
475    ///
476    /// Returns an empty set if the iterator is empty.
477    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
478        Self::build(iter, 64, 4).unwrap_or_else(|_| Self::empty(64, 4))
479    }
480}
481
482impl<T: Indexable + Ord> core::iter::Extend<T> for Set<T>
483where
484    T::Key: Ord,
485{
486    /// Extends the set with elements from an iterator.
487    ///
488    /// **Note**: This rebuilds the entire index, making it O(n) per call.
489    /// For bulk insertions, prefer collecting into a new set.
490    /// For frequent mutations, consider using `index::owned::Dynamic` directly.
491    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
492        let mut data = core::mem::take(&mut self.data);
493        data.extend(iter);
494        data.sort();
495        data.dedup();
496
497        match Self::from_sorted_unique(data, self.epsilon, self.epsilon_recursive) {
498            Ok(new_set) => *self = new_set,
499            Err(_) => {
500                *self = Self::empty(self.epsilon, self.epsilon_recursive);
501            }
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use alloc::string::String;
510    use alloc::vec;
511
512    #[test]
513    fn test_set_numeric() {
514        let data: Vec<u64> = (0..1000).collect();
515        let set = Set::from_sorted_unique(data, 64, 4).unwrap();
516
517        assert_eq!(set.len(), 1000);
518        assert!(set.contains(&500));
519        assert!(!set.contains(&1001));
520    }
521
522    #[test]
523    fn test_set_strings() {
524        let data = vec!["apple", "banana", "cherry", "date"];
525        let set = Set::from_sorted_unique(data, 64, 4).unwrap();
526
527        assert!(set.contains(&"banana"));
528        assert!(set.contains(&"cherry"));
529        assert!(!set.contains(&"elderberry"));
530    }
531
532    #[test]
533    fn test_set_owned_strings() {
534        let data: Vec<String> = vec!["alpha", "beta", "gamma"]
535            .into_iter()
536            .map(String::from)
537            .collect();
538        let set = Set::from_sorted_unique(data, 64, 4).unwrap();
539
540        assert!(set.contains(&String::from("beta")));
541        assert!(!set.contains(&String::from("delta")));
542    }
543
544    #[test]
545    fn test_set_build() {
546        let data = vec![5u64, 3, 1, 4, 1, 5, 9, 2, 6];
547        let set = Set::build(data, 4, 2).unwrap();
548
549        assert_eq!(set.len(), 7);
550        assert!(set.contains(&1));
551        assert!(set.contains(&9));
552
553        let collected: Vec<_> = set.iter().copied().collect();
554        assert_eq!(collected, vec![1, 2, 3, 4, 5, 6, 9]);
555    }
556
557    #[test]
558    fn test_set_first_last() {
559        let data: Vec<u64> = vec![10, 20, 30, 40, 50];
560        let set = Set::from_sorted_unique(data, 4, 2).unwrap();
561
562        assert_eq!(set.first(), Some(&10));
563        assert_eq!(set.last(), Some(&50));
564    }
565
566    #[test]
567    fn test_set_range() {
568        let data: Vec<u64> = (0..100).collect();
569        let set = Set::from_sorted_unique(data, 16, 4).unwrap();
570
571        let range: Vec<_> = set.range(10..20).copied().collect();
572        assert_eq!(range, (10..20).collect::<Vec<_>>());
573    }
574
575    #[test]
576    fn test_set_iter() {
577        let data: Vec<u64> = (0..10).collect();
578        let set = Set::from_sorted_unique(data, 4, 2).unwrap();
579
580        let forward: Vec<_> = set.iter().copied().collect();
581        let backward: Vec<_> = set.iter().rev().copied().collect();
582
583        assert_eq!(forward, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
584        assert_eq!(backward, vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
585    }
586
587    #[test]
588    fn test_set_operations() {
589        let set1 = Set::build(vec![1u64, 2, 3, 4, 5], 4, 2).unwrap();
590        let set2 = Set::build(vec![4u64, 5, 6, 7, 8], 4, 2).unwrap();
591
592        let intersection: Vec<_> = set1.intersection(&set2).copied().collect();
593        assert_eq!(intersection, vec![4, 5]);
594
595        let difference: Vec<_> = set1.difference(&set2).copied().collect();
596        assert_eq!(difference, vec![1, 2, 3]);
597
598        assert!(!set1.is_disjoint(&set2));
599
600        let set3 = Set::build(vec![10u64, 11], 4, 2).unwrap();
601        assert!(set1.is_disjoint(&set3));
602    }
603
604    #[test]
605    fn test_set_collect() {
606        let set: Set<u64> = (0..100).collect();
607        assert_eq!(set.len(), 100);
608        assert!(set.contains(&50));
609    }
610
611    #[test]
612    fn test_set_empty() {
613        let set: Set<u64> = Set::empty(64, 4);
614        assert!(set.is_empty());
615        assert_eq!(set.len(), 0);
616        assert!(!set.contains(&0));
617        assert_eq!(set.first(), None);
618        assert_eq!(set.last(), None);
619    }
620
621    #[test]
622    fn test_set_collect_empty() {
623        let set: Set<u64> = core::iter::empty().collect();
624        assert!(set.is_empty());
625        assert_eq!(set.len(), 0);
626    }
627
628    #[test]
629    fn test_set_insert() {
630        let mut set = Set::build(vec![1u64, 3, 5], 4, 2).unwrap();
631        assert_eq!(set.len(), 3);
632
633        assert!(set.insert(2));
634        assert_eq!(set.len(), 4);
635        assert!(set.contains(&2));
636
637        assert!(!set.insert(2));
638        assert_eq!(set.len(), 4);
639
640        assert!(set.insert(4));
641        let collected: Vec<_> = set.iter().copied().collect();
642        assert_eq!(collected, vec![1, 2, 3, 4, 5]);
643    }
644
645    #[test]
646    fn test_set_insert_into_empty() {
647        let mut set: Set<u64> = Set::empty(64, 4);
648        assert!(set.insert(42));
649        assert_eq!(set.len(), 1);
650        assert!(set.contains(&42));
651    }
652
653    #[test]
654    fn test_set_extend_empty() {
655        let mut set: Set<u64> = Set::empty(64, 4);
656        set.extend(vec![3, 1, 2]);
657        assert_eq!(set.len(), 3);
658        let collected: Vec<_> = set.iter().copied().collect();
659        assert_eq!(collected, vec![1, 2, 3]);
660    }
661}