1use alloc::vec::Vec;
7use core::cmp::Ordering;
8use core::fmt;
9use core::hash::{Hash, Hasher};
10use core::iter::FusedIterator;
11use core::ops::RangeBounds;
12
13use crate::error::Error;
14use crate::index::external;
15use crate::index::key::Indexable;
16use crate::util::range::range_to_indices;
17
18#[cfg_attr(
44 feature = "rkyv",
45 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
46)]
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48#[cfg_attr(
49 feature = "serde",
50 serde(
51 bound = "T: serde::Serialize + serde::de::DeserializeOwned, T::Key: serde::Serialize + serde::de::DeserializeOwned"
52 )
53)]
54pub struct Set<T: Indexable> {
55 data: Vec<T>,
56 index: Option<external::Static<T>>,
57 epsilon: usize,
58 epsilon_recursive: usize,
59}
60
61impl<T: Indexable + Ord> Set<T>
62where
63 T::Key: Ord,
64{
65 pub fn from_sorted_unique(
71 data: Vec<T>,
72 epsilon: usize,
73 epsilon_recursive: usize,
74 ) -> Result<Self, Error> {
75 debug_assert!(
76 data.windows(2).all(|w| w[0] < w[1]),
77 "data must be sorted and unique"
78 );
79
80 let index = if data.is_empty() {
81 None
82 } else {
83 Some(external::Static::new(&data, epsilon, epsilon_recursive)?)
84 };
85 Ok(Self {
86 data,
87 index,
88 epsilon,
89 epsilon_recursive,
90 })
91 }
92
93 pub fn build<I>(iter: I, epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error>
97 where
98 I: IntoIterator<Item = T>,
99 {
100 let mut data: Vec<T> = iter.into_iter().collect();
101 data.sort();
102 data.dedup();
103
104 Self::from_sorted_unique(data, epsilon, epsilon_recursive)
105 }
106
107 pub fn empty(epsilon: usize, epsilon_recursive: usize) -> Self {
109 Self {
110 data: Vec::new(),
111 index: None,
112 epsilon,
113 epsilon_recursive,
114 }
115 }
116
117 pub fn new(data: Vec<T>) -> Result<Self, Error> {
119 Self::build(data, 64, 4)
120 }
121
122 #[inline]
124 pub fn contains(&self, value: &T) -> bool {
125 self.get(value).is_some()
126 }
127
128 #[inline]
130 pub fn get(&self, value: &T) -> Option<&T> {
131 let index = self.index.as_ref()?;
132 let approx = index.search(value);
133
134 let lo = approx.lo;
135 let hi = approx.hi.min(self.data.len());
136
137 for i in lo..hi {
138 match self.data[i].cmp(value) {
139 Ordering::Equal => return Some(&self.data[i]),
140 Ordering::Greater => return None,
141 Ordering::Less => continue,
142 }
143 }
144 None
145 }
146
147 #[inline]
149 pub fn lower_bound(&self, value: &T) -> usize {
150 match &self.index {
151 Some(index) => index.lower_bound(&self.data, value),
152 None => 0,
153 }
154 }
155
156 #[inline]
158 pub fn upper_bound(&self, value: &T) -> usize {
159 match &self.index {
160 Some(index) => index.upper_bound(&self.data, value),
161 None => 0,
162 }
163 }
164
165 #[inline]
167 pub fn range<R>(&self, range: R) -> impl DoubleEndedIterator<Item = &T>
168 where
169 R: RangeBounds<T>,
170 {
171 let (start, end) = range_to_indices(
172 range,
173 self.data.len(),
174 |v| self.lower_bound(v),
175 |v| self.upper_bound(v),
176 );
177 self.data[start..end].iter()
178 }
179
180 #[inline]
182 pub fn first(&self) -> Option<&T> {
183 self.data.first()
184 }
185
186 #[inline]
188 pub fn last(&self) -> Option<&T> {
189 self.data.last()
190 }
191
192 #[inline]
194 pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + DoubleEndedIterator {
195 self.data.iter()
196 }
197
198 #[inline]
200 pub fn len(&self) -> usize {
201 self.data.len()
202 }
203
204 #[inline]
206 pub fn is_empty(&self) -> bool {
207 self.data.is_empty()
208 }
209
210 #[inline]
212 pub fn segments_count(&self) -> usize {
213 self.index.as_ref().map_or(0, |i| i.segments_count())
214 }
215
216 #[inline]
218 pub fn height(&self) -> usize {
219 self.index.as_ref().map_or(0, |i| i.height())
220 }
221
222 #[inline]
224 pub fn epsilon(&self) -> usize {
225 self.epsilon
226 }
227
228 #[inline]
230 pub fn epsilon_recursive(&self) -> usize {
231 self.epsilon_recursive
232 }
233
234 pub fn size_in_bytes(&self) -> usize {
236 self.index.as_ref().map_or(0, |i| i.size_in_bytes())
237 + self.data.capacity() * core::mem::size_of::<T>()
238 }
239
240 #[inline]
242 pub fn as_slice(&self) -> &[T] {
243 &self.data
244 }
245
246 #[inline]
248 pub fn into_vec(self) -> Vec<T> {
249 self.data
250 }
251
252 #[inline]
254 pub fn index(&self) -> Option<&external::Static<T>> {
255 self.index.as_ref()
256 }
257
258 pub fn insert(&mut self, value: T) -> bool {
266 if self.contains(&value) {
267 return false;
268 }
269
270 let mut data = core::mem::take(&mut self.data);
271 data.push(value);
272 data.sort();
273
274 if let Ok(new_set) = Self::from_sorted_unique(data, self.epsilon, self.epsilon_recursive) {
275 *self = new_set;
276 }
277 true
278 }
279
280 pub fn is_disjoint(&self, other: &Set<T>) -> bool {
282 if self.is_empty() || other.is_empty() {
283 return true;
284 }
285
286 let (smaller, larger) = if self.len() <= other.len() {
287 (self, other)
288 } else {
289 (other, self)
290 };
291
292 for value in smaller.iter() {
293 if larger.contains(value) {
294 return false;
295 }
296 }
297 true
298 }
299
300 pub fn is_subset(&self, other: &Set<T>) -> bool {
302 if self.len() > other.len() {
303 return false;
304 }
305 self.iter().all(|v| other.contains(v))
306 }
307
308 pub fn is_superset(&self, other: &Set<T>) -> bool {
310 other.is_subset(self)
311 }
312
313 pub fn difference<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
315 self.iter().filter(move |v| !other.contains(v))
316 }
317
318 pub fn symmetric_difference<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
320 self.difference(other).chain(other.difference(self))
321 }
322
323 pub fn intersection<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
325 let (smaller, larger) = if self.len() <= other.len() {
326 (self, other)
327 } else {
328 (other, self)
329 };
330 smaller.iter().filter(move |v| larger.contains(v))
331 }
332
333 pub fn union<'a>(&'a self, other: &'a Set<T>) -> impl Iterator<Item = &'a T> {
335 MergeIter::new(self.data.iter(), other.data.iter())
336 }
337}
338
339pub struct MergeIter<'a, T> {
341 a: core::slice::Iter<'a, T>,
342 b: core::slice::Iter<'a, T>,
343 peeked_a: Option<&'a T>,
344 peeked_b: Option<&'a T>,
345}
346
347impl<'a, T: Ord> MergeIter<'a, T> {
348 fn new(mut a: core::slice::Iter<'a, T>, mut b: core::slice::Iter<'a, T>) -> Self {
349 let peeked_a = a.next();
350 let peeked_b = b.next();
351 Self {
352 a,
353 b,
354 peeked_a,
355 peeked_b,
356 }
357 }
358}
359
360impl<'a, T: Ord> Iterator for MergeIter<'a, T> {
361 type Item = &'a T;
362
363 fn next(&mut self) -> Option<Self::Item> {
364 match (self.peeked_a, self.peeked_b) {
365 (Some(a), Some(b)) => match a.cmp(b) {
366 Ordering::Less => {
367 self.peeked_a = self.a.next();
368 Some(a)
369 }
370 Ordering::Greater => {
371 self.peeked_b = self.b.next();
372 Some(b)
373 }
374 Ordering::Equal => {
375 self.peeked_a = self.a.next();
376 self.peeked_b = self.b.next();
377 Some(a)
378 }
379 },
380 (Some(a), None) => {
381 self.peeked_a = self.a.next();
382 Some(a)
383 }
384 (None, Some(b)) => {
385 self.peeked_b = self.b.next();
386 Some(b)
387 }
388 (None, None) => None,
389 }
390 }
391}
392
393impl<T: Ord> FusedIterator for MergeIter<'_, T> {}
394
395impl<T: Indexable + Clone> Clone for Set<T>
398where
399 T::Key: Clone,
400{
401 fn clone(&self) -> Self {
402 Self {
403 data: self.data.clone(),
404 index: self.index.clone(),
405 epsilon: self.epsilon,
406 epsilon_recursive: self.epsilon_recursive,
407 }
408 }
409}
410
411impl<T: Indexable + fmt::Debug> fmt::Debug for Set<T> {
412 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413 f.debug_set().entries(self.data.iter()).finish()
414 }
415}
416
417impl<T: Indexable + Ord + PartialEq> PartialEq for Set<T> {
418 fn eq(&self, other: &Self) -> bool {
419 self.data == other.data
420 }
421}
422
423impl<T: Indexable + Ord + Eq> Eq for Set<T> {}
424
425impl<T: Indexable + Ord + PartialOrd> PartialOrd for Set<T>
426where
427 T::Key: Ord,
428{
429 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
430 Some(self.cmp(other))
431 }
432}
433
434impl<T: Indexable + Ord> Ord for Set<T>
435where
436 T::Key: Ord,
437{
438 fn cmp(&self, other: &Self) -> Ordering {
439 self.data.cmp(&other.data)
440 }
441}
442
443impl<T: Indexable + Hash> Hash for Set<T> {
444 fn hash<H: Hasher>(&self, state: &mut H) {
445 self.data.hash(state);
446 }
447}
448
449impl<T: Indexable + Ord> IntoIterator for Set<T>
450where
451 T::Key: Ord,
452{
453 type Item = T;
454 type IntoIter = alloc::vec::IntoIter<T>;
455
456 fn into_iter(self) -> Self::IntoIter {
457 self.data.into_iter()
458 }
459}
460
461impl<'a, T: Indexable> IntoIterator for &'a Set<T> {
462 type Item = &'a T;
463 type IntoIter = core::slice::Iter<'a, T>;
464
465 fn into_iter(self) -> Self::IntoIter {
466 self.data.iter()
467 }
468}
469
470impl<T: Indexable + Ord> FromIterator<T> for Set<T>
471where
472 T::Key: Ord,
473{
474 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
478 Self::build(iter, 64, 4).unwrap_or_else(|_| Self::empty(64, 4))
479 }
480}
481
482impl<T: Indexable + Ord> core::iter::Extend<T> for Set<T>
483where
484 T::Key: Ord,
485{
486 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
492 let mut data = core::mem::take(&mut self.data);
493 data.extend(iter);
494 data.sort();
495 data.dedup();
496
497 match Self::from_sorted_unique(data, self.epsilon, self.epsilon_recursive) {
498 Ok(new_set) => *self = new_set,
499 Err(_) => {
500 *self = Self::empty(self.epsilon, self.epsilon_recursive);
501 }
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use alloc::string::String;
510 use alloc::vec;
511
512 #[test]
513 fn test_set_numeric() {
514 let data: Vec<u64> = (0..1000).collect();
515 let set = Set::from_sorted_unique(data, 64, 4).unwrap();
516
517 assert_eq!(set.len(), 1000);
518 assert!(set.contains(&500));
519 assert!(!set.contains(&1001));
520 }
521
522 #[test]
523 fn test_set_strings() {
524 let data = vec!["apple", "banana", "cherry", "date"];
525 let set = Set::from_sorted_unique(data, 64, 4).unwrap();
526
527 assert!(set.contains(&"banana"));
528 assert!(set.contains(&"cherry"));
529 assert!(!set.contains(&"elderberry"));
530 }
531
532 #[test]
533 fn test_set_owned_strings() {
534 let data: Vec<String> = vec!["alpha", "beta", "gamma"]
535 .into_iter()
536 .map(String::from)
537 .collect();
538 let set = Set::from_sorted_unique(data, 64, 4).unwrap();
539
540 assert!(set.contains(&String::from("beta")));
541 assert!(!set.contains(&String::from("delta")));
542 }
543
544 #[test]
545 fn test_set_build() {
546 let data = vec![5u64, 3, 1, 4, 1, 5, 9, 2, 6];
547 let set = Set::build(data, 4, 2).unwrap();
548
549 assert_eq!(set.len(), 7);
550 assert!(set.contains(&1));
551 assert!(set.contains(&9));
552
553 let collected: Vec<_> = set.iter().copied().collect();
554 assert_eq!(collected, vec![1, 2, 3, 4, 5, 6, 9]);
555 }
556
557 #[test]
558 fn test_set_first_last() {
559 let data: Vec<u64> = vec![10, 20, 30, 40, 50];
560 let set = Set::from_sorted_unique(data, 4, 2).unwrap();
561
562 assert_eq!(set.first(), Some(&10));
563 assert_eq!(set.last(), Some(&50));
564 }
565
566 #[test]
567 fn test_set_range() {
568 let data: Vec<u64> = (0..100).collect();
569 let set = Set::from_sorted_unique(data, 16, 4).unwrap();
570
571 let range: Vec<_> = set.range(10..20).copied().collect();
572 assert_eq!(range, (10..20).collect::<Vec<_>>());
573 }
574
575 #[test]
576 fn test_set_iter() {
577 let data: Vec<u64> = (0..10).collect();
578 let set = Set::from_sorted_unique(data, 4, 2).unwrap();
579
580 let forward: Vec<_> = set.iter().copied().collect();
581 let backward: Vec<_> = set.iter().rev().copied().collect();
582
583 assert_eq!(forward, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
584 assert_eq!(backward, vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
585 }
586
587 #[test]
588 fn test_set_operations() {
589 let set1 = Set::build(vec![1u64, 2, 3, 4, 5], 4, 2).unwrap();
590 let set2 = Set::build(vec![4u64, 5, 6, 7, 8], 4, 2).unwrap();
591
592 let intersection: Vec<_> = set1.intersection(&set2).copied().collect();
593 assert_eq!(intersection, vec![4, 5]);
594
595 let difference: Vec<_> = set1.difference(&set2).copied().collect();
596 assert_eq!(difference, vec![1, 2, 3]);
597
598 assert!(!set1.is_disjoint(&set2));
599
600 let set3 = Set::build(vec![10u64, 11], 4, 2).unwrap();
601 assert!(set1.is_disjoint(&set3));
602 }
603
604 #[test]
605 fn test_set_collect() {
606 let set: Set<u64> = (0..100).collect();
607 assert_eq!(set.len(), 100);
608 assert!(set.contains(&50));
609 }
610
611 #[test]
612 fn test_set_empty() {
613 let set: Set<u64> = Set::empty(64, 4);
614 assert!(set.is_empty());
615 assert_eq!(set.len(), 0);
616 assert!(!set.contains(&0));
617 assert_eq!(set.first(), None);
618 assert_eq!(set.last(), None);
619 }
620
621 #[test]
622 fn test_set_collect_empty() {
623 let set: Set<u64> = core::iter::empty().collect();
624 assert!(set.is_empty());
625 assert_eq!(set.len(), 0);
626 }
627
628 #[test]
629 fn test_set_insert() {
630 let mut set = Set::build(vec![1u64, 3, 5], 4, 2).unwrap();
631 assert_eq!(set.len(), 3);
632
633 assert!(set.insert(2));
634 assert_eq!(set.len(), 4);
635 assert!(set.contains(&2));
636
637 assert!(!set.insert(2));
638 assert_eq!(set.len(), 4);
639
640 assert!(set.insert(4));
641 let collected: Vec<_> = set.iter().copied().collect();
642 assert_eq!(collected, vec![1, 2, 3, 4, 5]);
643 }
644
645 #[test]
646 fn test_set_insert_into_empty() {
647 let mut set: Set<u64> = Set::empty(64, 4);
648 assert!(set.insert(42));
649 assert_eq!(set.len(), 1);
650 assert!(set.contains(&42));
651 }
652
653 #[test]
654 fn test_set_extend_empty() {
655 let mut set: Set<u64> = Set::empty(64, 4);
656 set.extend(vec![3, 1, 2]);
657 assert_eq!(set.len(), 3);
658 let collected: Vec<_> = set.iter().copied().collect();
659 assert_eq!(collected, vec![1, 2, 3]);
660 }
661}