Skip to main content

ds_ext/
set.rs

1//! A hash set ordered using a [`Vec`].
2//!
3//! Example:
4//! ```
5//! use ds_ext::OrdHashSet;
6//!
7//! let mut set1 = OrdHashSet::new();
8//! assert!(set1.insert("d"));
9//! assert!(set1.insert("a"));
10//! assert!(set1.insert("c"));
11//! assert!(set1.insert("b"));
12//! assert!(!set1.insert("a"));
13//! assert_eq!(set1.len(), 4);
14//!
15//! let mut set2 = set1.clone();
16//! assert!(set2.remove(&"d"));
17//! assert_eq!(set2.len(), 3);
18//!
19//! assert_eq!(
20//!     set1.into_iter().collect::<Vec<&str>>(),
21//!     ["a", "b", "c", "d"]
22//! );
23//!
24//! assert_eq!(
25//!     set2.into_iter().rev().collect::<Vec<&str>>(),
26//!     ["c", "b", "a"]
27//! );
28//! ```
29
30use std::borrow::Borrow;
31use std::cmp::Ordering;
32use std::collections::HashSet as Inner;
33use std::fmt;
34use std::hash::Hash;
35use std::sync::Arc;
36
37use get_size::GetSize;
38use get_size_derive::*;
39
40/// An iterator to drain the contents of a [`OrdHashSet`]
41pub struct Drain<'a, T> {
42    inner: &'a mut Inner<Arc<T>>,
43    order: std::vec::Drain<'a, Arc<T>>,
44}
45
46impl<'a, T> Iterator for Drain<'a, T>
47where
48    T: Eq + Hash + fmt::Debug,
49{
50    type Item = T;
51
52    fn next(&mut self) -> Option<Self::Item> {
53        let item = self.order.next()?;
54        self.inner.remove(&*item);
55        Some(Arc::try_unwrap(item).expect("item"))
56    }
57
58    fn size_hint(&self) -> (usize, Option<usize>) {
59        self.order.size_hint()
60    }
61}
62
63impl<'a, T> DoubleEndedIterator for Drain<'a, T>
64where
65    T: Eq + Hash + fmt::Debug,
66{
67    fn next_back(&mut self) -> Option<Self::Item> {
68        let item = self.order.next_back()?;
69        self.inner.remove(&*item);
70        Some(Arc::try_unwrap(item).expect("item"))
71    }
72}
73
74/// An iterator to drain the contents of a [`OrdHashSet`] conditionally
75pub struct DrainWhile<'a, T, Cond> {
76    inner: &'a mut Inner<Arc<T>>,
77    order: &'a mut Vec<Arc<T>>,
78    cond: Cond,
79}
80
81impl<'a, T, Cond> Iterator for DrainWhile<'a, T, Cond>
82where
83    T: Eq + Hash + fmt::Debug,
84    Cond: Fn(&T) -> bool,
85{
86    type Item = T;
87
88    fn next(&mut self) -> Option<Self::Item> {
89        if (self.cond)(self.order.iter().next()?) {
90            let item = self.order.remove(0);
91            self.inner.remove(&*item);
92            Some(Arc::try_unwrap(item).expect("item"))
93        } else {
94            None
95        }
96    }
97
98    fn size_hint(&self) -> (usize, Option<usize>) {
99        (0, Some(self.inner.len()))
100    }
101}
102
103/// An iterator over the contents of a [`OrdHashSet`]
104pub struct IntoIter<T> {
105    inner: std::vec::IntoIter<Arc<T>>,
106}
107
108impl<T: fmt::Debug> Iterator for IntoIter<T> {
109    type Item = T;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        self.inner
113            .next()
114            .map(|item| Arc::try_unwrap(item).expect("item"))
115    }
116
117    fn size_hint(&self) -> (usize, Option<usize>) {
118        self.inner.size_hint()
119    }
120}
121
122impl<T: fmt::Debug> DoubleEndedIterator for IntoIter<T> {
123    fn next_back(&mut self) -> Option<Self::Item> {
124        self.inner
125            .next_back()
126            .map(|item| Arc::try_unwrap(item).expect("item"))
127    }
128}
129
130/// An iterator over the items in a [`OrdHashSet`]
131pub struct Iter<'a, T> {
132    inner: std::slice::Iter<'a, Arc<T>>,
133}
134
135impl<'a, T> Iterator for Iter<'a, T> {
136    type Item = &'a T;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        self.inner.next().map(|item| &**item)
140    }
141
142    fn size_hint(&self) -> (usize, Option<usize>) {
143        self.inner.size_hint()
144    }
145}
146
147impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
148    fn next_back(&mut self) -> Option<Self::Item> {
149        self.inner.next_back().map(|item| &**item)
150    }
151}
152
153/// A [`std::collections::HashSet`] ordered using a [`Vec`].
154#[derive(GetSize)]
155pub struct OrdHashSet<T> {
156    inner: Inner<Arc<T>>,
157    order: Vec<Arc<T>>,
158}
159
160impl<T: Clone + Eq + Hash + Ord + fmt::Debug> Clone for OrdHashSet<T> {
161    fn clone(&self) -> Self {
162        self.iter().cloned().collect()
163    }
164}
165
166impl<T: PartialEq + fmt::Debug> PartialEq for OrdHashSet<T> {
167    fn eq(&self, other: &Self) -> bool {
168        self.order == other.order
169    }
170}
171
172impl<T: Eq + fmt::Debug> Eq for OrdHashSet<T> {}
173
174impl<T> OrdHashSet<T> {
175    /// Construct a new [`OrdHashSet`].
176    pub fn new() -> Self {
177        Self {
178            inner: Inner::new(),
179            order: Vec::new(),
180        }
181    }
182
183    /// Construct a new [`OrdHashSet`] with the given `capacity`.
184    pub fn with_capacity(capacity: usize) -> Self {
185        Self {
186            inner: Inner::with_capacity(capacity),
187            order: Vec::with_capacity(capacity),
188        }
189    }
190
191    /// Construct an iterator over the items in this [`OrdHashSet`].
192    pub fn iter(&self) -> Iter<'_, T> {
193        Iter {
194            inner: self.order.iter(),
195        }
196    }
197
198    /// Return `true` if this [`OrdHashSet`] is empty.
199    pub fn is_empty(&self) -> bool {
200        self.inner.is_empty()
201    }
202
203    /// Return the number of items in this [`OrdHashSet`].
204    pub fn len(&self) -> usize {
205        self.inner.len()
206    }
207}
208
209impl<T> Default for OrdHashSet<T> {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215impl<T: Eq + Hash + Ord> OrdHashSet<T> {
216    fn bisect_hi<Cmp>(&self, cmp: Cmp) -> usize
217    where
218        Cmp: Fn(&T) -> Option<Ordering>,
219    {
220        if self.is_empty() {
221            return 0;
222        } else if cmp(self.order.iter().next_back().expect("tail")).is_some() {
223            return self.len();
224        }
225
226        let mut lo = 0;
227        let mut hi = self.len();
228
229        while lo < hi {
230            let mid = (lo + hi) >> 1;
231            let item = self.order.get(mid).expect("item");
232
233            if cmp(&**item).is_some() {
234                lo = mid + 1;
235            } else {
236                hi = mid;
237            }
238        }
239
240        hi
241    }
242
243    fn bisect_lo<Cmp>(&self, cmp: Cmp) -> usize
244    where
245        Cmp: Fn(&T) -> Option<Ordering>,
246    {
247        if self.is_empty() || cmp(&self.order[0]).is_some() {
248            return 0;
249        }
250
251        let mut lo = 0;
252        let mut hi = 1;
253
254        while lo < hi {
255            let mid = (lo + hi) >> 1;
256            let item = self.order.get(mid).expect("item");
257
258            if cmp(&**item).is_some() {
259                hi = mid;
260            } else {
261                lo = mid + 1;
262            }
263        }
264
265        hi
266    }
267
268    fn bisect_inner<Cmp>(&self, cmp: Cmp, mut lo: usize, mut hi: usize) -> Option<&T>
269    where
270        Cmp: Fn(&T) -> Option<Ordering>,
271    {
272        while lo < hi {
273            let mid = (lo + hi) >> 1;
274            let item = self.order.get(mid).expect("item");
275
276            if let Some(order) = cmp(&**item) {
277                match order {
278                    Ordering::Less => hi = mid,
279                    Ordering::Equal => return Some(item),
280                    Ordering::Greater => lo = mid + 1,
281                }
282            } else {
283                panic!("comparison does not match distribution")
284            }
285        }
286
287        None
288    }
289
290    /// Bisect this set to match an item using the provided comparison, and return it (if present).
291    ///
292    /// The first item for which the comparison returns `Some(Ordering::Equal)` will be returned.
293    /// This method assumes that any partially-ordered items (where `cmp(item).is_none()`)
294    /// are ordered at the beginning or end of the set.
295    pub fn bisect<Cmp>(&self, cmp: Cmp) -> Option<&T>
296    where
297        Cmp: Fn(&T) -> Option<Ordering> + Copy,
298    {
299        let lo = self.bisect_lo(cmp);
300        let hi = self.bisect_hi(cmp);
301        self.bisect_inner(cmp, lo, hi)
302    }
303
304    /// Bisect this set to match and remove an item using the provided comparison.
305    ///
306    /// The first item for which the comparison returns `Some(Ordering::Equal)` will be returned.
307    /// This method assumes that any partially-ordered items (where `cmp(item).is_none()`)
308    /// are ordered at the beginning and/or end of the set.
309    pub fn bisect_and_remove<Cmp>(&mut self, cmp: Cmp) -> Option<T>
310    where
311        Cmp: Fn(&T) -> Option<Ordering> + Copy,
312        T: fmt::Debug,
313    {
314        let mut lo = self.bisect_lo(cmp);
315        let mut hi = self.bisect_hi(cmp);
316
317        let item = loop {
318            if lo >= hi {
319                break None;
320            }
321
322            let mid = (lo + hi) >> 1;
323            let item = self.order.get(mid).expect("item");
324
325            if let Some(order) = cmp(&**item) {
326                match order {
327                    Ordering::Less => hi = mid,
328                    Ordering::Equal => {
329                        lo = mid;
330                        break Some(item.clone());
331                    }
332                    Ordering::Greater => lo = mid + 1,
333                }
334            } else {
335                panic!("comparison does not match distribution")
336            }
337        }?;
338
339        self.order.remove(lo);
340        self.inner.remove(&item);
341
342        Some(Arc::try_unwrap(item).expect("item"))
343    }
344
345    /// Remove all items from this [`OrdHashSet`].
346    pub fn clear(&mut self) {
347        self.inner.clear();
348        self.order.clear();
349    }
350
351    /// Return `true` if the given item is present in this [`OrdHashSet`].
352    pub fn contains<Q>(&self, item: &Q) -> bool
353    where
354        Arc<T>: Borrow<Q>,
355        Q: Hash + Eq + ?Sized,
356    {
357        self.inner.contains(item)
358    }
359
360    /// Drain all items from this [`OrdHashSet`].
361    pub fn drain(&mut self) -> Drain<'_, T> {
362        Drain {
363            inner: &mut self.inner,
364            order: self.order.drain(..),
365        }
366    }
367
368    /// Drain items from this [`OrdHashSet`] while they match the given `cond`ition.
369    pub fn drain_while<Cond>(&mut self, cond: Cond) -> DrainWhile<'_, T, Cond>
370    where
371        Cond: Fn(&T) -> bool,
372    {
373        DrainWhile {
374            inner: &mut self.inner,
375            order: &mut self.order,
376            cond,
377        }
378    }
379
380    /// Consume the given `iter` and insert all its items into this [`OrdHashSet`]
381    pub fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
382        for item in iter {
383            self.insert(item);
384        }
385    }
386
387    /// Borrow the first item in this [`OrdHashSet`].
388    pub fn first(&self) -> Option<&T> {
389        self.order.first().map(|item| &**item)
390    }
391
392    /// Insert an `item` into this [`OrdHashSet`] and return `false` if it was already present.
393    pub fn insert(&mut self, item: T) -> bool {
394        let new = if self.inner.contains(&item) {
395            false
396        } else {
397            let item = Arc::new(item);
398
399            let index = bisect(&self.order, &item);
400            if index == self.len() {
401                self.order.insert(index, item.clone());
402            } else {
403                let prior = self.order.get(index).expect("item").clone();
404
405                if prior < item {
406                    self.order.insert(index + 1, item.clone());
407                } else {
408                    self.order.insert(index, item.clone());
409                }
410            }
411
412            self.inner.insert(item)
413        };
414
415        new
416    }
417
418    /// Borrow the last item in this [`OrdHashSet`].
419    pub fn last(&self) -> Option<&T> {
420        self.order.iter().next_back().map(|item| &**item)
421    }
422
423    /// Remove and return the first item in this [`OrdHashSet`].
424    pub fn pop_first(&mut self) -> Option<T>
425    where
426        T: fmt::Debug,
427    {
428        if self.is_empty() {
429            None
430        } else {
431            let item = self.order.remove(0);
432            self.inner.remove(&item);
433            Some(Arc::try_unwrap(item).expect("item"))
434        }
435    }
436
437    /// Remove and return the last item in this [`OrdHashSet`].
438    pub fn pop_last(&mut self) -> Option<T>
439    where
440        T: fmt::Debug,
441    {
442        if let Some(item) = self.order.pop() {
443            self.inner.remove(&item);
444            Some(Arc::try_unwrap(item).expect("item"))
445        } else {
446            None
447        }
448    }
449
450    /// Remove the given `item` from this [`OrdHashSet`] and return `true` if it was present.
451    ///
452    /// The item may be any borrowed form of `T`,
453    /// but the ordering on the borrowed form **must** match the ordering of `T`.
454    pub fn remove<Q>(&mut self, item: &Q) -> bool
455    where
456        Arc<T>: Borrow<Q>,
457        Q: Eq + Hash + Ord,
458    {
459        if self.inner.remove(item) {
460            let index = bisect(&self.order, item);
461            assert!(self.order.remove(index).borrow() == item);
462            true
463        } else {
464            false
465        }
466    }
467
468    /// Return `true` if the first elements in this set are equal to those in the given `iter`.
469    pub fn starts_with<'a, I: IntoIterator<Item = &'a T>>(&'a self, other: I) -> bool
470    where
471        T: PartialEq,
472    {
473        let mut this = self.iter();
474        let that = other.into_iter();
475
476        for item in that {
477            if this.next() != Some(item) {
478                return false;
479            }
480        }
481
482        true
483    }
484}
485
486impl<T: Eq + Hash + Ord + fmt::Debug> OrdHashSet<T> {
487    #[allow(unused)]
488    fn is_valid(&self) -> bool {
489        assert_eq!(self.inner.len(), self.order.len());
490
491        if self.is_empty() {
492            return true;
493        }
494
495        let mut item = self.order.first().expect("item");
496        for i in 1..self.len() {
497            let next = self.order.get(i).expect("next");
498            assert!(*item <= *next, "set out of order: {:?}", self);
499            assert!(*next >= *item);
500            item = next;
501        }
502
503        true
504    }
505}
506
507impl<T: Eq + Hash + Ord + fmt::Debug> fmt::Debug for OrdHashSet<T> {
508    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
509        f.write_str("[ ")?;
510
511        for item in self {
512            write!(f, "{:?} ", item)?;
513        }
514
515        f.write_str("]")
516    }
517}
518
519impl<T: Eq + Hash + Ord + fmt::Debug> FromIterator<T> for OrdHashSet<T> {
520    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
521        let iter = iter.into_iter();
522        let mut set = match iter.size_hint() {
523            (_, Some(max)) => Self::with_capacity(max),
524            (min, None) if min > 0 => Self::with_capacity(min),
525            _ => Self::new(),
526        };
527
528        set.extend(iter);
529        set
530    }
531}
532
533impl<T: fmt::Debug> IntoIterator for OrdHashSet<T> {
534    type Item = T;
535    type IntoIter = IntoIter<T>;
536
537    fn into_iter(self) -> Self::IntoIter {
538        IntoIter {
539            inner: self.order.into_iter(),
540        }
541    }
542}
543
544impl<'a, T> IntoIterator for &'a OrdHashSet<T> {
545    type Item = &'a T;
546    type IntoIter = Iter<'a, T>;
547
548    fn into_iter(self) -> Self::IntoIter {
549        OrdHashSet::iter(self)
550    }
551}
552
553#[inline]
554fn bisect<T, Q>(list: &[T], target: &Q) -> usize
555where
556    T: Borrow<Q> + Ord,
557    Q: Ord,
558{
559    if let Some(front) = list.first() {
560        if target < (*front).borrow() {
561            return 0;
562        }
563    }
564
565    if let Some(last) = list.last() {
566        if target > (*last).borrow() {
567            return list.len();
568        }
569    }
570
571    let mut lo = 0;
572    let mut hi = list.len();
573
574    while lo < hi {
575        let mid = (lo + hi) >> 1;
576        let item = list.get(mid).expect("item");
577
578        match item.borrow().cmp(target) {
579            Ordering::Less => lo = mid + 1,
580            Ordering::Greater => hi = mid,
581            Ordering::Equal => return mid,
582        }
583    }
584
585    lo
586}
587
588#[cfg(test)]
589mod tests {
590    use rand::rngs::StdRng;
591    use rand::{Rng, SeedableRng};
592
593    use super::*;
594
595    #[test]
596    fn test_bisect_and_remove() {
597        let mut set = OrdHashSet::<u8>::new();
598
599        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
600
601        set.insert(8);
602        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_some());
603        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
604
605        set.insert(9);
606        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
607
608        set.insert(7);
609        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
610    }
611
612    #[test]
613    fn test_into_iter() {
614        let mut set = OrdHashSet::new();
615        assert!(set.insert("d"));
616        assert!(set.insert("a"));
617        assert!(set.insert("c"));
618        assert!(set.insert("b"));
619        assert!(!set.insert("a"));
620        assert_eq!(set.len(), 4);
621
622        assert_eq!(set.into_iter().collect::<Vec<&str>>(), ["a", "b", "c", "d"]);
623    }
624
625    #[test]
626    fn test_drain() {
627        let mut set = OrdHashSet::from_iter(0..10);
628        let expected = (0..10).collect::<Vec<_>>();
629        let actual = set.drain().collect::<Vec<_>>();
630        assert_eq!(expected, actual);
631    }
632
633    #[test]
634    fn test_drain_while() {
635        let mut set = OrdHashSet::from_iter(0..10);
636        let drained = set.drain_while(|x| *x < 5).collect::<Vec<_>>();
637        assert_eq!(drained, vec![0, 1, 2, 3, 4]);
638        assert_eq!(set, OrdHashSet::from_iter(5..10));
639    }
640
641    #[test]
642    fn test_order_invariants_after_ops() {
643        let mut set = OrdHashSet::new();
644        for i in (0..100).rev() {
645            assert!(set.insert(i));
646        }
647
648        let items: Vec<_> = set.iter().cloned().collect();
649        assert_eq!(items, (0..100).collect::<Vec<_>>());
650
651        for i in 0..50 {
652            assert!(set.remove(&i));
653        }
654
655        let items: Vec<_> = set.iter().cloned().collect();
656        assert_eq!(items, (50..100).collect::<Vec<_>>());
657    }
658
659    #[test]
660    fn test_random_ops_invariants() {
661        let mut rng = StdRng::seed_from_u64(0x_d5e4);
662        let mut set = OrdHashSet::new();
663
664        for _ in 0..5_000 {
665            let value = rng.random_range(0..200);
666            if rng.random() {
667                set.insert(value);
668            } else {
669                set.remove(&value);
670            }
671
672            assert!(set.is_valid());
673        }
674    }
675
676    #[test]
677    fn test_bisect_boundaries() {
678        let mut set = OrdHashSet::new();
679        set.insert(10u32);
680        set.insert(20u32);
681
682        assert!(set.bisect(|item| 5u32.partial_cmp(item)).is_none());
683        assert_eq!(set.bisect(|item| 10u32.partial_cmp(item)), Some(&10));
684        assert_eq!(set.bisect(|item| 20u32.partial_cmp(item)), Some(&20));
685        assert!(set.bisect(|item| 25u32.partial_cmp(item)).is_none());
686    }
687
688    #[test]
689    fn test_remove_missing_does_not_mutate() {
690        let mut set = OrdHashSet::new();
691        set.insert(1u32);
692        set.insert(3u32);
693
694        assert!(!set.remove(&2u32));
695        assert_eq!(set.iter().cloned().collect::<Vec<_>>(), vec![1, 3]);
696    }
697}