1use std::borrow::Borrow;
31use std::cmp::Ordering;
32use std::collections::HashSet as Inner;
33use std::fmt;
34use std::hash::Hash;
35use std::sync::Arc;
36
37use get_size::GetSize;
38use get_size_derive::*;
39
40pub struct Drain<'a, T> {
42 inner: &'a mut Inner<Arc<T>>,
43 order: std::vec::Drain<'a, Arc<T>>,
44}
45
46impl<'a, T> Iterator for Drain<'a, T>
47where
48 T: Eq + Hash + fmt::Debug,
49{
50 type Item = T;
51
52 fn next(&mut self) -> Option<Self::Item> {
53 let item = self.order.next()?;
54 self.inner.remove(&*item);
55 Some(Arc::try_unwrap(item).expect("item"))
56 }
57
58 fn size_hint(&self) -> (usize, Option<usize>) {
59 self.order.size_hint()
60 }
61}
62
63impl<'a, T> DoubleEndedIterator for Drain<'a, T>
64where
65 T: Eq + Hash + fmt::Debug,
66{
67 fn next_back(&mut self) -> Option<Self::Item> {
68 let item = self.order.next_back()?;
69 self.inner.remove(&*item);
70 Some(Arc::try_unwrap(item).expect("item"))
71 }
72}
73
74pub struct DrainWhile<'a, T, Cond> {
76 inner: &'a mut Inner<Arc<T>>,
77 order: &'a mut Vec<Arc<T>>,
78 cond: Cond,
79}
80
81impl<'a, T, Cond> Iterator for DrainWhile<'a, T, Cond>
82where
83 T: Eq + Hash + fmt::Debug,
84 Cond: Fn(&T) -> bool,
85{
86 type Item = T;
87
88 fn next(&mut self) -> Option<Self::Item> {
89 if (self.cond)(self.order.iter().next()?) {
90 let item = self.order.remove(0);
91 self.inner.remove(&*item);
92 Some(Arc::try_unwrap(item).expect("item"))
93 } else {
94 None
95 }
96 }
97
98 fn size_hint(&self) -> (usize, Option<usize>) {
99 (0, Some(self.inner.len()))
100 }
101}
102
103pub struct IntoIter<T> {
105 inner: std::vec::IntoIter<Arc<T>>,
106}
107
108impl<T: fmt::Debug> Iterator for IntoIter<T> {
109 type Item = T;
110
111 fn next(&mut self) -> Option<Self::Item> {
112 self.inner
113 .next()
114 .map(|item| Arc::try_unwrap(item).expect("item"))
115 }
116
117 fn size_hint(&self) -> (usize, Option<usize>) {
118 self.inner.size_hint()
119 }
120}
121
122impl<T: fmt::Debug> DoubleEndedIterator for IntoIter<T> {
123 fn next_back(&mut self) -> Option<Self::Item> {
124 self.inner
125 .next_back()
126 .map(|item| Arc::try_unwrap(item).expect("item"))
127 }
128}
129
130pub struct Iter<'a, T> {
132 inner: std::slice::Iter<'a, Arc<T>>,
133}
134
135impl<'a, T> Iterator for Iter<'a, T> {
136 type Item = &'a T;
137
138 fn next(&mut self) -> Option<Self::Item> {
139 self.inner.next().map(|item| &**item)
140 }
141
142 fn size_hint(&self) -> (usize, Option<usize>) {
143 self.inner.size_hint()
144 }
145}
146
147impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
148 fn next_back(&mut self) -> Option<Self::Item> {
149 self.inner.next_back().map(|item| &**item)
150 }
151}
152
153#[derive(GetSize)]
155pub struct OrdHashSet<T> {
156 inner: Inner<Arc<T>>,
157 order: Vec<Arc<T>>,
158}
159
160impl<T: Clone + Eq + Hash + Ord + fmt::Debug> Clone for OrdHashSet<T> {
161 fn clone(&self) -> Self {
162 self.iter().cloned().collect()
163 }
164}
165
166impl<T: PartialEq + fmt::Debug> PartialEq for OrdHashSet<T> {
167 fn eq(&self, other: &Self) -> bool {
168 self.order == other.order
169 }
170}
171
172impl<T: Eq + fmt::Debug> Eq for OrdHashSet<T> {}
173
174impl<T> OrdHashSet<T> {
175 pub fn new() -> Self {
177 Self {
178 inner: Inner::new(),
179 order: Vec::new(),
180 }
181 }
182
183 pub fn with_capacity(capacity: usize) -> Self {
185 Self {
186 inner: Inner::with_capacity(capacity),
187 order: Vec::with_capacity(capacity),
188 }
189 }
190
191 pub fn iter(&self) -> Iter<'_, T> {
193 Iter {
194 inner: self.order.iter(),
195 }
196 }
197
198 pub fn is_empty(&self) -> bool {
200 self.inner.is_empty()
201 }
202
203 pub fn len(&self) -> usize {
205 self.inner.len()
206 }
207}
208
209impl<T> Default for OrdHashSet<T> {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl<T: Eq + Hash + Ord> OrdHashSet<T> {
216 fn bisect_hi<Cmp>(&self, cmp: Cmp) -> usize
217 where
218 Cmp: Fn(&T) -> Option<Ordering>,
219 {
220 if self.is_empty() {
221 return 0;
222 } else if cmp(self.order.iter().next_back().expect("tail")).is_some() {
223 return self.len();
224 }
225
226 let mut lo = 0;
227 let mut hi = self.len();
228
229 while lo < hi {
230 let mid = (lo + hi) >> 1;
231 let item = self.order.get(mid).expect("item");
232
233 if cmp(&**item).is_some() {
234 lo = mid + 1;
235 } else {
236 hi = mid;
237 }
238 }
239
240 hi
241 }
242
243 fn bisect_lo<Cmp>(&self, cmp: Cmp) -> usize
244 where
245 Cmp: Fn(&T) -> Option<Ordering>,
246 {
247 if self.is_empty() || cmp(&self.order[0]).is_some() {
248 return 0;
249 }
250
251 let mut lo = 0;
252 let mut hi = 1;
253
254 while lo < hi {
255 let mid = (lo + hi) >> 1;
256 let item = self.order.get(mid).expect("item");
257
258 if cmp(&**item).is_some() {
259 hi = mid;
260 } else {
261 lo = mid + 1;
262 }
263 }
264
265 hi
266 }
267
268 fn bisect_inner<Cmp>(&self, cmp: Cmp, mut lo: usize, mut hi: usize) -> Option<&T>
269 where
270 Cmp: Fn(&T) -> Option<Ordering>,
271 {
272 while lo < hi {
273 let mid = (lo + hi) >> 1;
274 let item = self.order.get(mid).expect("item");
275
276 if let Some(order) = cmp(&**item) {
277 match order {
278 Ordering::Less => hi = mid,
279 Ordering::Equal => return Some(item),
280 Ordering::Greater => lo = mid + 1,
281 }
282 } else {
283 panic!("comparison does not match distribution")
284 }
285 }
286
287 None
288 }
289
290 pub fn bisect<Cmp>(&self, cmp: Cmp) -> Option<&T>
296 where
297 Cmp: Fn(&T) -> Option<Ordering> + Copy,
298 {
299 let lo = self.bisect_lo(cmp);
300 let hi = self.bisect_hi(cmp);
301 self.bisect_inner(cmp, lo, hi)
302 }
303
304 pub fn bisect_and_remove<Cmp>(&mut self, cmp: Cmp) -> Option<T>
310 where
311 Cmp: Fn(&T) -> Option<Ordering> + Copy,
312 T: fmt::Debug,
313 {
314 let mut lo = self.bisect_lo(cmp);
315 let mut hi = self.bisect_hi(cmp);
316
317 let item = loop {
318 if lo >= hi {
319 break None;
320 }
321
322 let mid = (lo + hi) >> 1;
323 let item = self.order.get(mid).expect("item");
324
325 if let Some(order) = cmp(&**item) {
326 match order {
327 Ordering::Less => hi = mid,
328 Ordering::Equal => {
329 lo = mid;
330 break Some(item.clone());
331 }
332 Ordering::Greater => lo = mid + 1,
333 }
334 } else {
335 panic!("comparison does not match distribution")
336 }
337 }?;
338
339 self.order.remove(lo);
340 self.inner.remove(&item);
341
342 Some(Arc::try_unwrap(item).expect("item"))
343 }
344
345 pub fn clear(&mut self) {
347 self.inner.clear();
348 self.order.clear();
349 }
350
351 pub fn contains<Q>(&self, item: &Q) -> bool
353 where
354 Arc<T>: Borrow<Q>,
355 Q: Hash + Eq + ?Sized,
356 {
357 self.inner.contains(item)
358 }
359
360 pub fn drain(&mut self) -> Drain<'_, T> {
362 Drain {
363 inner: &mut self.inner,
364 order: self.order.drain(..),
365 }
366 }
367
368 pub fn drain_while<Cond>(&mut self, cond: Cond) -> DrainWhile<'_, T, Cond>
370 where
371 Cond: Fn(&T) -> bool,
372 {
373 DrainWhile {
374 inner: &mut self.inner,
375 order: &mut self.order,
376 cond,
377 }
378 }
379
380 pub fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
382 for item in iter {
383 self.insert(item);
384 }
385 }
386
387 pub fn first(&self) -> Option<&T> {
389 self.order.first().map(|item| &**item)
390 }
391
392 pub fn insert(&mut self, item: T) -> bool {
394 let new = if self.inner.contains(&item) {
395 false
396 } else {
397 let item = Arc::new(item);
398
399 let index = bisect(&self.order, &item);
400 if index == self.len() {
401 self.order.insert(index, item.clone());
402 } else {
403 let prior = self.order.get(index).expect("item").clone();
404
405 if prior < item {
406 self.order.insert(index + 1, item.clone());
407 } else {
408 self.order.insert(index, item.clone());
409 }
410 }
411
412 self.inner.insert(item)
413 };
414
415 new
416 }
417
418 pub fn last(&self) -> Option<&T> {
420 self.order.iter().next_back().map(|item| &**item)
421 }
422
423 pub fn pop_first(&mut self) -> Option<T>
425 where
426 T: fmt::Debug,
427 {
428 if self.is_empty() {
429 None
430 } else {
431 let item = self.order.remove(0);
432 self.inner.remove(&item);
433 Some(Arc::try_unwrap(item).expect("item"))
434 }
435 }
436
437 pub fn pop_last(&mut self) -> Option<T>
439 where
440 T: fmt::Debug,
441 {
442 if let Some(item) = self.order.pop() {
443 self.inner.remove(&item);
444 Some(Arc::try_unwrap(item).expect("item"))
445 } else {
446 None
447 }
448 }
449
450 pub fn remove<Q>(&mut self, item: &Q) -> bool
455 where
456 Arc<T>: Borrow<Q>,
457 Q: Eq + Hash + Ord,
458 {
459 if self.inner.remove(item) {
460 let index = bisect(&self.order, item);
461 assert!(self.order.remove(index).borrow() == item);
462 true
463 } else {
464 false
465 }
466 }
467
468 pub fn starts_with<'a, I: IntoIterator<Item = &'a T>>(&'a self, other: I) -> bool
470 where
471 T: PartialEq,
472 {
473 let mut this = self.iter();
474 let that = other.into_iter();
475
476 for item in that {
477 if this.next() != Some(item) {
478 return false;
479 }
480 }
481
482 true
483 }
484}
485
486impl<T: Eq + Hash + Ord + fmt::Debug> OrdHashSet<T> {
487 #[allow(unused)]
488 fn is_valid(&self) -> bool {
489 assert_eq!(self.inner.len(), self.order.len());
490
491 if self.is_empty() {
492 return true;
493 }
494
495 let mut item = self.order.first().expect("item");
496 for i in 1..self.len() {
497 let next = self.order.get(i).expect("next");
498 assert!(*item <= *next, "set out of order: {:?}", self);
499 assert!(*next >= *item);
500 item = next;
501 }
502
503 true
504 }
505}
506
507impl<T: Eq + Hash + Ord + fmt::Debug> fmt::Debug for OrdHashSet<T> {
508 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
509 f.write_str("[ ")?;
510
511 for item in self {
512 write!(f, "{:?} ", item)?;
513 }
514
515 f.write_str("]")
516 }
517}
518
519impl<T: Eq + Hash + Ord + fmt::Debug> FromIterator<T> for OrdHashSet<T> {
520 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
521 let iter = iter.into_iter();
522 let mut set = match iter.size_hint() {
523 (_, Some(max)) => Self::with_capacity(max),
524 (min, None) if min > 0 => Self::with_capacity(min),
525 _ => Self::new(),
526 };
527
528 set.extend(iter);
529 set
530 }
531}
532
533impl<T: fmt::Debug> IntoIterator for OrdHashSet<T> {
534 type Item = T;
535 type IntoIter = IntoIter<T>;
536
537 fn into_iter(self) -> Self::IntoIter {
538 IntoIter {
539 inner: self.order.into_iter(),
540 }
541 }
542}
543
544impl<'a, T> IntoIterator for &'a OrdHashSet<T> {
545 type Item = &'a T;
546 type IntoIter = Iter<'a, T>;
547
548 fn into_iter(self) -> Self::IntoIter {
549 OrdHashSet::iter(self)
550 }
551}
552
553#[inline]
554fn bisect<T, Q>(list: &[T], target: &Q) -> usize
555where
556 T: Borrow<Q> + Ord,
557 Q: Ord,
558{
559 if let Some(front) = list.first() {
560 if target < (*front).borrow() {
561 return 0;
562 }
563 }
564
565 if let Some(last) = list.last() {
566 if target > (*last).borrow() {
567 return list.len();
568 }
569 }
570
571 let mut lo = 0;
572 let mut hi = list.len();
573
574 while lo < hi {
575 let mid = (lo + hi) >> 1;
576 let item = list.get(mid).expect("item");
577
578 match item.borrow().cmp(target) {
579 Ordering::Less => lo = mid + 1,
580 Ordering::Greater => hi = mid,
581 Ordering::Equal => return mid,
582 }
583 }
584
585 lo
586}
587
588#[cfg(test)]
589mod tests {
590 use rand::rngs::StdRng;
591 use rand::{Rng, SeedableRng};
592
593 use super::*;
594
595 #[test]
596 fn test_bisect_and_remove() {
597 let mut set = OrdHashSet::<u8>::new();
598
599 assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
600
601 set.insert(8);
602 assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_some());
603 assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
604
605 set.insert(9);
606 assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
607
608 set.insert(7);
609 assert!(set.bisect_and_remove(|item| item.partial_cmp(&8)).is_none());
610 }
611
612 #[test]
613 fn test_into_iter() {
614 let mut set = OrdHashSet::new();
615 assert!(set.insert("d"));
616 assert!(set.insert("a"));
617 assert!(set.insert("c"));
618 assert!(set.insert("b"));
619 assert!(!set.insert("a"));
620 assert_eq!(set.len(), 4);
621
622 assert_eq!(set.into_iter().collect::<Vec<&str>>(), ["a", "b", "c", "d"]);
623 }
624
625 #[test]
626 fn test_drain() {
627 let mut set = OrdHashSet::from_iter(0..10);
628 let expected = (0..10).collect::<Vec<_>>();
629 let actual = set.drain().collect::<Vec<_>>();
630 assert_eq!(expected, actual);
631 }
632
633 #[test]
634 fn test_drain_while() {
635 let mut set = OrdHashSet::from_iter(0..10);
636 let drained = set.drain_while(|x| *x < 5).collect::<Vec<_>>();
637 assert_eq!(drained, vec![0, 1, 2, 3, 4]);
638 assert_eq!(set, OrdHashSet::from_iter(5..10));
639 }
640
641 #[test]
642 fn test_order_invariants_after_ops() {
643 let mut set = OrdHashSet::new();
644 for i in (0..100).rev() {
645 assert!(set.insert(i));
646 }
647
648 let items: Vec<_> = set.iter().cloned().collect();
649 assert_eq!(items, (0..100).collect::<Vec<_>>());
650
651 for i in 0..50 {
652 assert!(set.remove(&i));
653 }
654
655 let items: Vec<_> = set.iter().cloned().collect();
656 assert_eq!(items, (50..100).collect::<Vec<_>>());
657 }
658
659 #[test]
660 fn test_random_ops_invariants() {
661 let mut rng = StdRng::seed_from_u64(0x_d5e4);
662 let mut set = OrdHashSet::new();
663
664 for _ in 0..5_000 {
665 let value = rng.random_range(0..200);
666 if rng.random() {
667 set.insert(value);
668 } else {
669 set.remove(&value);
670 }
671
672 assert!(set.is_valid());
673 }
674 }
675
676 #[test]
677 fn test_bisect_boundaries() {
678 let mut set = OrdHashSet::new();
679 set.insert(10u32);
680 set.insert(20u32);
681
682 assert!(set.bisect(|item| 5u32.partial_cmp(item)).is_none());
683 assert_eq!(set.bisect(|item| 10u32.partial_cmp(item)), Some(&10));
684 assert_eq!(set.bisect(|item| 20u32.partial_cmp(item)), Some(&20));
685 assert!(set.bisect(|item| 25u32.partial_cmp(item)).is_none());
686 }
687
688 #[test]
689 fn test_remove_missing_does_not_mutate() {
690 let mut set = OrdHashSet::new();
691 set.insert(1u32);
692 set.insert(3u32);
693
694 assert!(!set.remove(&2u32));
695 assert_eq!(set.iter().cloned().collect::<Vec<_>>(), vec![1, 3]);
696 }
697}