ds_ext/
set.rs

1//! A hash set ordered using a [`Vec`].
2//!
3//! Example:
4//! ```
5//! use ds_ext::OrdHashSet;
6//!
7//! let mut set1 = OrdHashSet::new();
8//! assert!(set1.insert("d"));
9//! assert!(set1.insert("a"));
10//! assert!(set1.insert("c"));
11//! assert!(set1.insert("b"));
12//! assert!(!set1.insert("a"));
13//! assert_eq!(set1.len(), 4);
14//!
15//! let mut set2 = set1.clone();
16//! assert!(set2.remove(&"d"));
17//! assert_eq!(set2.len(), 3);
18//!
19//! assert_eq!(
20//!     set1.into_iter().collect::<Vec<&str>>(),
21//!     ["a", "b", "c", "d"]
22//! );
23//!
24//! assert_eq!(
25//!     set2.into_iter().rev().collect::<Vec<&str>>(),
26//!     ["c", "b", "a"]
27//! );
28//! ```
29
30use std::borrow::Borrow;
31use std::cmp::Ordering;
32use std::collections::HashSet as Inner;
33use std::fmt;
34use std::hash::Hash;
35use std::sync::Arc;
36
37use get_size::GetSize;
38use get_size_derive::*;
39
40/// An iterator to drain the contents of a [`OrdHashSet`]
41pub struct Drain<'a, T> {
42    inner: &'a mut Inner<Arc<T>>,
43    order: std::vec::Drain<'a, Arc<T>>,
44}
45
46impl<'a, T> Iterator for Drain<'a, T>
47where
48    T: Eq + Hash + fmt::Debug,
49{
50    type Item = T;
51
52    fn next(&mut self) -> Option<Self::Item> {
53        let item = self.order.next()?;
54        self.inner.remove(&*item);
55        Some(Arc::try_unwrap(item).expect("item"))
56    }
57
58    fn size_hint(&self) -> (usize, Option<usize>) {
59        self.order.size_hint()
60    }
61}
62
63impl<'a, T> DoubleEndedIterator for Drain<'a, T>
64where
65    T: Eq + Hash + fmt::Debug,
66{
67    fn next_back(&mut self) -> Option<Self::Item> {
68        let item = self.order.next_back()?;
69        self.inner.remove(&*item);
70        Some(Arc::try_unwrap(item).expect("item"))
71    }
72}
73
74/// An iterator to drain the contents of a [`OrdHashSet`] conditionally
75pub struct DrainWhile<'a, T, Cond> {
76    inner: &'a mut Inner<Arc<T>>,
77    order: &'a mut Vec<Arc<T>>,
78    cond: Cond,
79}
80
81impl<'a, T, Cond> Iterator for DrainWhile<'a, T, Cond>
82where
83    T: Eq + Hash + fmt::Debug,
84    Cond: Fn(&T) -> bool,
85{
86    type Item = T;
87
88    fn next(&mut self) -> Option<Self::Item> {
89        if (self.cond)(self.order.iter().next()?) {
90            let item = self.order.remove(0);
91            self.inner.remove(&*item);
92            Some(Arc::try_unwrap(item).expect("item"))
93        } else {
94            None
95        }
96    }
97
98    fn size_hint(&self) -> (usize, Option<usize>) {
99        (0, Some(self.inner.len()))
100    }
101}
102
103/// An iterator over the contents of a [`OrdHashSet`]
104pub struct IntoIter<T> {
105    inner: std::vec::IntoIter<Arc<T>>,
106}
107
108impl<T: fmt::Debug> Iterator for IntoIter<T> {
109    type Item = T;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        self.inner
113            .next()
114            .map(|item| Arc::try_unwrap(item).expect("item"))
115    }
116
117    fn size_hint(&self) -> (usize, Option<usize>) {
118        self.inner.size_hint()
119    }
120}
121
122impl<T: fmt::Debug> DoubleEndedIterator for IntoIter<T> {
123    fn next_back(&mut self) -> Option<Self::Item> {
124        self.inner
125            .next_back()
126            .map(|item| Arc::try_unwrap(item).expect("item"))
127    }
128}
129
130/// An iterator over the items in a [`OrdHashSet`]
131pub struct Iter<'a, T> {
132    inner: std::slice::Iter<'a, Arc<T>>,
133}
134
135impl<'a, T> Iterator for Iter<'a, T> {
136    type Item = &'a T;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        self.inner.next().map(|item| &**item)
140    }
141
142    fn size_hint(&self) -> (usize, Option<usize>) {
143        self.inner.size_hint()
144    }
145}
146
147impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
148    fn next_back(&mut self) -> Option<Self::Item> {
149        self.inner.next_back().map(|item| &**item)
150    }
151}
152
153/// A [`std::collections::HashSet`] ordered using a [`Vec`].
154#[derive(GetSize)]
155pub struct OrdHashSet<T> {
156    inner: Inner<Arc<T>>,
157    order: Vec<Arc<T>>,
158}
159
160impl<T: Clone + Eq + Hash + Ord + fmt::Debug> Clone for OrdHashSet<T> {
161    fn clone(&self) -> Self {
162        self.iter().cloned().collect()
163    }
164}
165
166impl<T: PartialEq + fmt::Debug> PartialEq for OrdHashSet<T> {
167    fn eq(&self, other: &Self) -> bool {
168        self.order == other.order
169    }
170}
171
172impl<T: Eq + fmt::Debug> Eq for OrdHashSet<T> {}
173
174impl<T> OrdHashSet<T> {
175    /// Construct a new [`OrdHashSet`].
176    pub fn new() -> Self {
177        Self {
178            inner: Inner::new(),
179            order: Vec::new(),
180        }
181    }
182
183    /// Construct a new [`OrdHashSet`] with the given `capacity`.
184    pub fn with_capacity(capacity: usize) -> Self {
185        Self {
186            inner: Inner::with_capacity(capacity),
187            order: Vec::with_capacity(capacity),
188        }
189    }
190
191    /// Construct an iterator over the items in this [`OrdHashSet`].
192    pub fn iter(&self) -> Iter<T> {
193        Iter {
194            inner: self.order.iter(),
195        }
196    }
197
198    /// Return `true` if this [`OrdHashSet`] is empty.
199    pub fn is_empty(&self) -> bool {
200        self.inner.is_empty()
201    }
202
203    /// Return the number of items in this [`OrdHashSet`].
204    pub fn len(&self) -> usize {
205        self.inner.len()
206    }
207}
208
209impl<T: Eq + Hash + Ord> OrdHashSet<T> {
210    fn bisect_hi<Cmp>(&self, cmp: Cmp) -> usize
211    where
212        Cmp: Fn(&T) -> Option<Ordering>,
213    {
214        if self.is_empty() {
215            return 0;
216        } else if cmp(self.order.iter().next_back().expect("tail")).is_some() {
217            return self.len();
218        }
219
220        let mut lo = 0;
221        let mut hi = self.len();
222
223        while lo < hi {
224            let mid = (lo + hi) >> 1;
225            let item = self.order.get(mid).expect("item");
226
227            if cmp(&**item).is_some() {
228                lo = mid + 1;
229            } else {
230                hi = mid;
231            }
232        }
233
234        hi
235    }
236
237    fn bisect_lo<Cmp>(&self, cmp: Cmp) -> usize
238    where
239        Cmp: Fn(&T) -> Option<Ordering>,
240    {
241        if self.is_empty() {
242            return 0;
243        } else if cmp(&self.order[0]).is_some() {
244            return 0;
245        }
246
247        let mut lo = 0;
248        let mut hi = 1;
249
250        while lo < hi {
251            let mid = (lo + hi) >> 1;
252            let item = self.order.get(mid).expect("item");
253
254            if cmp(&**item).is_some() {
255                hi = mid;
256            } else {
257                lo = mid + 1;
258            }
259        }
260
261        hi
262    }
263
264    fn bisect_inner<Cmp>(&self, cmp: Cmp, mut lo: usize, mut hi: usize) -> Option<&T>
265    where
266        Cmp: Fn(&T) -> Option<Ordering>,
267    {
268        while lo < hi {
269            let mid = (lo + hi) >> 1;
270            let item = self.order.get(mid).expect("item");
271
272            if let Some(order) = cmp(&**item) {
273                match order {
274                    Ordering::Less => hi = mid,
275                    Ordering::Equal => return Some(item),
276                    Ordering::Greater => lo = mid + 1,
277                }
278            } else {
279                panic!("comparison does not match distribution")
280            }
281        }
282
283        None
284    }
285
286    /// Bisect this set to match an item using the provided comparison, and return it (if present).
287    ///
288    /// The first item for which the comparison returns `Some(Ordering::Equal)` will be returned.
289    /// This method assumes that any partially-ordered items (where `cmp(item).is_none()`)
290    /// are ordered at the beginning or end of the set.
291    pub fn bisect<Cmp>(&self, cmp: Cmp) -> Option<&T>
292    where
293        Cmp: Fn(&T) -> Option<Ordering> + Copy,
294    {
295        let lo = self.bisect_lo(cmp);
296        let hi = self.bisect_hi(cmp);
297        self.bisect_inner(cmp, lo, hi)
298    }
299
300    /// Bisect this set to match and remove an item using the provided comparison.
301    ///
302    /// The first item for which the comparison returns `Some(Ordering::Equal)` will be returned.
303    /// This method assumes that any partially-ordered items (where `cmp(item).is_none()`)
304    /// are ordered at the beginning and/or end of the set.
305    pub fn bisect_and_remove<Cmp>(&mut self, cmp: Cmp) -> Option<T>
306    where
307        Cmp: Fn(&T) -> Option<Ordering> + Copy,
308        T: fmt::Debug,
309    {
310        let mut lo = self.bisect_lo(cmp);
311        let mut hi = self.bisect_hi(cmp);
312
313        let item = loop {
314            if lo >= hi {
315                break None;
316            }
317
318            let mid = (lo + hi) >> 1;
319            let item = self.order.get(mid).expect("item");
320
321            if let Some(order) = cmp(&**item) {
322                match order {
323                    Ordering::Less => hi = mid,
324                    Ordering::Equal => {
325                        lo = mid;
326                        break Some(item.clone());
327                    }
328                    Ordering::Greater => lo = mid + 1,
329                }
330            } else {
331                panic!("comparison does not match distribution")
332            }
333        }?;
334
335        self.order.remove(lo);
336        self.inner.remove(&item);
337
338        Some(Arc::try_unwrap(item).expect("item"))
339    }
340
341    /// Remove all items from this [`OrdHashSet`].
342    pub fn clear(&mut self) {
343        self.inner.clear();
344        self.order.clear();
345    }
346
347    /// Return `true` if the given item is present in this [`OrdHashSet`].
348    pub fn contains<Q: ?Sized>(&self, item: &Q) -> bool
349    where
350        Arc<T>: Borrow<Q>,
351        Q: Hash + Eq,
352    {
353        self.inner.contains(item)
354    }
355
356    /// Drain all items from this [`OrdHashSet`].
357    pub fn drain(&mut self) -> Drain<T> {
358        Drain {
359            inner: &mut self.inner,
360            order: self.order.drain(..),
361        }
362    }
363
364    /// Drain items from this [`OrdHashSet`] while they match the given `cond`ition.
365    pub fn drain_while<Cond>(&mut self, cond: Cond) -> DrainWhile<T, Cond>
366    where
367        Cond: Fn(&T) -> bool,
368    {
369        DrainWhile {
370            inner: &mut self.inner,
371            order: &mut self.order,
372            cond,
373        }
374    }
375
376    /// Consume the given `iter` and insert all its items into this [`OrdHashSet`]
377    pub fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
378        for item in iter {
379            self.insert(item);
380        }
381    }
382
383    /// Borrow the first item in this [`OrdHashSet`].
384    pub fn first(&self) -> Option<&T> {
385        self.order.iter().next().map(|item| &**item)
386    }
387
388    /// Insert an `item` into this [`OrdHashSet`] and return `false` if it was already present.
389    pub fn insert(&mut self, item: T) -> bool {
390        let new = if self.inner.contains(&item) {
391            false
392        } else {
393            let item = Arc::new(item);
394
395            let index = bisect(&self.order, &item);
396            if index == self.len() {
397                self.order.insert(index, item.clone());
398            } else {
399                let prior = self.order.get(index).expect("item").clone();
400
401                if &prior < &item {
402                    self.order.insert(index + 1, item.clone());
403                } else {
404                    self.order.insert(index, item.clone());
405                }
406            }
407
408            self.inner.insert(item)
409        };
410
411        new
412    }
413
414    /// Borrow the last item in this [`OrdHashSet`].
415    pub fn last(&self) -> Option<&T> {
416        self.order.iter().next_back().map(|item| &**item)
417    }
418
419    /// Remove and return the first item in this [`OrdHashSet`].
420    pub fn pop_first(&mut self) -> Option<T>
421    where
422        T: fmt::Debug,
423    {
424        if self.is_empty() {
425            None
426        } else {
427            let item = self.order.remove(0);
428            self.inner.remove(&item);
429            Some(Arc::try_unwrap(item).expect("item"))
430        }
431    }
432
433    /// Remove and return the last item in this [`OrdHashSet`].
434    pub fn pop_last(&mut self) -> Option<T>
435    where
436        T: fmt::Debug,
437    {
438        if let Some(item) = self.order.pop() {
439            self.inner.remove(&item);
440            Some(Arc::try_unwrap(item).expect("item"))
441        } else {
442            None
443        }
444    }
445
446    /// Remove the given `item` from this [`OrdHashSet`] and return `true` if it was present.
447    ///
448    /// The item may be any borrowed form of `T`,
449    /// but the ordering on the borrowed form **must** match the ordering of `T`.
450    pub fn remove<Q>(&mut self, item: &Q) -> bool
451    where
452        Arc<T>: Borrow<Q>,
453        Q: Eq + Hash + Ord,
454    {
455        if self.inner.remove(item) {
456            let index = bisect(&self.order, item);
457            assert!(self.order.remove(index).borrow() == item);
458            true
459        } else {
460            false
461        }
462    }
463
464    /// Return `true` if the first elements in this set are equal to those in the given `iter`.
465    pub fn starts_with<'a, I: IntoIterator<Item = &'a T>>(&'a self, other: I) -> bool
466    where
467        T: PartialEq,
468    {
469        let mut this = self.iter();
470        let mut that = other.into_iter();
471
472        while let Some(item) = that.next() {
473            if this.next() != Some(item) {
474                return false;
475            }
476        }
477
478        true
479    }
480}
481
482impl<T: Eq + Hash + Ord + fmt::Debug> OrdHashSet<T> {
483    #[allow(unused)]
484    fn is_valid(&self) -> bool {
485        assert_eq!(self.inner.len(), self.order.len());
486
487        if self.is_empty() {
488            return true;
489        }
490
491        let mut item = self.order.get(0).expect("item");
492        for i in 1..self.len() {
493            let next = self.order.get(i).expect("next");
494            assert!(*item <= *next, "set out of order: {:?}", self);
495            assert!(*next >= *item);
496            item = next;
497        }
498
499        true
500    }
501}
502
503impl<T: Eq + Hash + Ord + fmt::Debug> fmt::Debug for OrdHashSet<T> {
504    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
505        f.write_str("[ ")?;
506
507        for item in self {
508            write!(f, "{:?} ", item)?;
509        }
510
511        f.write_str("]")
512    }
513}
514
515impl<T: Eq + Hash + Ord + fmt::Debug> FromIterator<T> for OrdHashSet<T> {
516    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
517        let iter = iter.into_iter();
518        let mut set = match iter.size_hint() {
519            (_, Some(max)) => Self::with_capacity(max),
520            (min, None) if min > 0 => Self::with_capacity(min),
521            _ => Self::new(),
522        };
523
524        set.extend(iter);
525        set
526    }
527}
528
529impl<T: fmt::Debug> IntoIterator for OrdHashSet<T> {
530    type Item = T;
531    type IntoIter = IntoIter<T>;
532
533    fn into_iter(self) -> Self::IntoIter {
534        IntoIter {
535            inner: self.order.into_iter(),
536        }
537    }
538}
539
540impl<'a, T> IntoIterator for &'a OrdHashSet<T> {
541    type Item = &'a T;
542    type IntoIter = Iter<'a, T>;
543
544    fn into_iter(self) -> Self::IntoIter {
545        OrdHashSet::iter(self)
546    }
547}
548
549#[inline]
550fn bisect<T, Q>(list: &Vec<T>, target: &Q) -> usize
551where
552    T: Borrow<Q> + Ord,
553    Q: Ord,
554{
555    if let Some(front) = list.iter().next() {
556        if target < (*front).borrow() {
557            return 0;
558        }
559    }
560
561    if let Some(last) = list.iter().next_back() {
562        if target > (*last).borrow() {
563            return list.len();
564        }
565    }
566
567    let mut lo = 0;
568    let mut hi = list.len();
569
570    while lo < hi {
571        let mid = (lo + hi) >> 1;
572        let item = &*list.get(mid).expect("item");
573
574        match item.borrow().cmp(target) {
575            Ordering::Less => lo = mid + 1,
576            Ordering::Greater => hi = mid,
577            Ordering::Equal => return mid,
578        }
579    }
580
581    lo
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    #[test]
589    fn test_bisect_and_remove() {
590        let mut set = OrdHashSet::<u8>::new();
591
592        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
593
594        set.insert(8);
595        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_some());
596        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
597
598        set.insert(9);
599        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
600
601        set.insert(7);
602        assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
603    }
604
605    #[test]
606    fn test_into_iter() {
607        let mut set = OrdHashSet::new();
608        assert!(set.insert("d"));
609        assert!(set.insert("a"));
610        assert!(set.insert("c"));
611        assert!(set.insert("b"));
612        assert!(!set.insert("a"));
613        assert_eq!(set.len(), 4);
614
615        assert_eq!(set.into_iter().collect::<Vec<&str>>(), ["a", "b", "c", "d"]);
616    }
617
618    #[test]
619    fn test_drain() {
620        let mut set = OrdHashSet::from_iter(0..10);
621        let expected = (0..10).into_iter().collect::<Vec<_>>();
622        let actual = set.drain().collect::<Vec<_>>();
623        assert_eq!(expected, actual);
624    }
625
626    #[test]
627    fn test_drain_while() {
628        let mut set = OrdHashSet::from_iter(0..10);
629        let drained = set.drain_while(|x| *x < 5).collect::<Vec<_>>();
630        assert_eq!(drained, vec![0, 1, 2, 3, 4]);
631        assert_eq!(set, OrdHashSet::from_iter(5..10));
632    }
633}