1#![cfg_attr(not(any(test, feature = "std")), no_std)]
2#![doc(html_root_url = "https://docs.rs/kindness/0.5.0")]
3#![deny(missing_docs)]
4#![allow(warnings, dead_code, unused_imports, unused_mut)]
5#![warn(clippy::pedantic)]
6#![warn(clippy::nursery)]
7#![warn(clippy::cargo)]
8
9mod coin_flipper;
45mod unique;
46pub mod uniform;
48
49use coin_flipper::CoinFlipper;
50use core::cmp::Ordering;
51use core::hash::{BuildHasher, Hash};
52use rand::Rng;
53
54impl<T: Iterator + Sized> Kindness for T {}
55
56fn choose_best_by_key<
57 I: Iterator + Sized,
58 B: Ord,
59 R: Rng,
60 F: FnMut(&I::Item) -> B,
61 const MAX: bool,
62>(
63 mut iterator: I,
64 rng: &mut R,
65 mut f: F,
66) -> Option<I::Item> {
67 let Some(first) = iterator.next() else {
68 return None;
69 };
70
71 let mut current_key = f(&first);
72 let mut current = first;
73 let mut coin_flipper = coin_flipper::CoinFlipper::new(rng);
74 let mut consumed = 1;
75
76 for item in iterator {
77 let item_key = f(&item);
78 match item_key.cmp(¤t_key) {
79 core::cmp::Ordering::Equal => {
80 consumed += 1;
81 if coin_flipper.gen_ratio_one_over(consumed) {
83 current = item;
84 }
85 }
86 ordering => {
87 if MAX == (ordering == core::cmp::Ordering::Greater) {
88 current_key = item_key; current = item;
90 consumed = 1;
91 }
92 }
93 }
94 }
95
96 Some(current)
97}
98
99fn choose_best_by<
100 I: Iterator + Sized,
101 R: Rng,
102 F: FnMut(&I::Item, &I::Item) -> Ordering,
103 const MAX: bool,
104>(
105 mut iterator: I,
106 rng: &mut R,
107 mut compare: F,
108) -> Option<I::Item>
109where
110 I::Item: Ord,
111{
112 let Some(first) = iterator.next() else {
113 return None;
114 };
115
116 let mut current = first;
117 let mut coin_flipper = coin_flipper::CoinFlipper::new(rng);
118 let mut consumed = 1;
119
120 for item in iterator {
121 match compare(&item, ¤t) {
122 core::cmp::Ordering::Equal => {
123 consumed += 1;
124 if coin_flipper.gen_ratio_one_over(consumed) {
125 current = item;
126 }
127 }
128 ordering => {
129 if MAX == (ordering == core::cmp::Ordering::Greater) {
130 current = item; consumed = 1;
132 }
133 }
134 }
135 }
136
137 Some(current)
138}
139
140pub trait Kindness: Iterator
143where
144 Self: Sized,
145{
146 #[inline]
151 fn choose_item<R: Rng>(mut self, rng: &mut R) -> Option<Self::Item> {
152 let (mut lower, mut upper) = self.size_hint();
153 let mut result = None;
154
155 if upper == Some(lower) {
159 return if lower == 0 {
160 None
161 } else {
162 self.nth(gen_index(rng, lower))
163 };
164 }
165
166 let mut coin_flipper = CoinFlipper::new(rng);
168 let mut consumed = 0;
169
170 loop {
172 if lower > 1 {
173 let ix = gen_index(coin_flipper.rng, lower + consumed);
174 let skip = if ix < lower {
175 result = self.nth(ix);
176 lower - (ix + 1)
177 } else {
178 lower
179 };
180 if upper == Some(lower) {
181 return result;
182 }
183 consumed += lower;
184 if skip > 0 {
185 self.nth(skip - 1);
186 }
187 } else {
188 consumed += 1;
189 let skip = coin_flipper.try_skip(consumed as u32) as usize;
190 let elem = self.nth(skip);
191 if elem.is_none() {
192 return result;
193 }
194 consumed += skip;
195
196 if coin_flipper.gen_ratio_one_over(consumed) {
197 result = elem;
198 }
199 }
200
201 let hint = self.size_hint();
202 lower = hint.0;
203 upper = hint.1;
204 }
205 }
206
207 fn choose_max<R: Rng>(self, rng: &mut R) -> Option<Self::Item>
256 where
257 Self::Item: Ord,
258 {
259 self.choose_max_by(rng, Ord::cmp)
260 }
261
262 fn choose_max_by_key<B: Ord, R: Rng, F: FnMut(&Self::Item) -> B>(
267 mut self,
268 rng: &mut R,
269 mut f: F,
270 ) -> Option<Self::Item> {
271 choose_best_by_key::<Self, B, R, F, true>(self, rng, f)
272 }
273
274 fn choose_max_by<R: Rng, F: FnMut(&Self::Item, &Self::Item) -> Ordering>(
279 mut self,
280 rng: &mut R,
281 mut compare: F,
282 ) -> Option<Self::Item>
283 where
284 Self::Item: Ord,
285 {
286 choose_best_by::<Self, R, F, true>(self, rng, compare)
287 }
288
289 fn choose_min<R: Rng>(self, rng: &mut R) -> Option<Self::Item>
293 where
294 Self::Item: Ord,
295 {
296 self.choose_min_by(rng, Ord::cmp)
297 }
298
299 fn choose_min_by_key<B: Ord, R: Rng, F: FnMut(&Self::Item) -> B>(
304 mut self,
305 rng: &mut R,
306 mut f: F,
307 ) -> Option<Self::Item> {
308 choose_best_by_key::<Self, B, R, F, false>(self, rng, f)
309 }
310
311 fn choose_min_by<R: Rng, F: FnMut(&Self::Item, &Self::Item) -> Ordering>(
316 mut self,
317 rng: &mut R,
318 mut compare: F,
319 ) -> Option<Self::Item>
320 where
321 Self::Item: Ord,
322 {
323 choose_best_by::<Self, R, F, false>(self, rng, compare)
324 }
325
326 #[cfg(any(test, all(feature = "hashbrown", feature = "std")))]
331 fn choose_unique<R: Rng>(
332 mut self,
333 rng: &mut R,
334 ) -> unique::iterators::Unique<Self::Item, allocator_api2::alloc::Global>
335 where
336 Self::Item: Hash + Eq,
337 {
338 let hash_builder = std::collections::hash_map::RandomState::new();
339 let alloc = allocator_api2::alloc::Global;
340 self.choose_unique_with_hasher_in(rng, hash_builder, alloc)
341 }
342
343 #[cfg(any(test, all(feature = "hashbrown")))]
350 fn choose_unique_with_hasher_in<
351 R: Rng,
352 S: BuildHasher,
353 A: allocator_api2::alloc::Allocator + Clone,
354 >(
355 mut self,
356 rng: &mut R,
357 hash_builder: S,
358 alloc: A,
359 ) -> unique::iterators::Unique<Self::Item, A>
360 where
361 Self::Item: Hash + Eq,
362 {
363 use hashbrown::{hash_map::Entry, HashTable};
364 let mut table: HashTable<(Self::Item, usize), A> =
365 HashTable::new_in(alloc);
366 let mut coin_flipper = CoinFlipper::new(rng);
367 for item in self {
368
369 let hash = hash_builder.hash_one(&item);
370
371 let entry = table.entry(hash, |(other, _)| item.eq(other), |(i,_)| hash_builder.hash_one(i));
372
373 match entry{
374 hashbrown::hash_table::Entry::Occupied(mut occupied_entry) => {
375 let new_count = occupied_entry.get().1+ 1;
376 occupied_entry.get_mut().1 = new_count;
377
378 if coin_flipper.gen_ratio_one_over(new_count) {
379 occupied_entry.get_mut().0 = item;
381 }
382 },
383 hashbrown::hash_table::Entry::Vacant(vacant_entry) => {
384 vacant_entry.insert((item, 1));
385 },
386 }
387 }
388
389 let iter = table.into_iter();
390 unique::iterators::Unique::new(iter)
391
392
393 }
394 #[cfg(any(test, all(feature = "hashbrown", feature = "std")))]
399 fn choose_unique_by_key<R: Rng, K: Eq + Hash, F: FnMut(&Self::Item) -> K>(
400 mut self,
401 rng: &mut R,
402 mut get_key: F,
403 ) -> unique::iterators::UniqueByKey<K, Self::Item, allocator_api2::alloc::Global> {
404 let hash_builder = std::collections::hash_map::RandomState::new();
405 let alloc = allocator_api2::alloc::Global;
406 self.choose_unique_by_key_with_hasher_in(rng, get_key, hash_builder, alloc)
407 }
408
409 #[cfg(any(test, feature = "hashbrown"))]
414 fn choose_unique_by_key_with_hasher_in<
415 R: Rng,
416 K: Eq + Hash,
417 F: FnMut(&Self::Item) -> K,
418 S: BuildHasher,
419 A: allocator_api2::alloc::Allocator + Clone,
420 >(
421 mut self,
422 rng: &mut R,
423 mut get_key: F,
424 hash_builder: S,
425 alloc: A,
426 ) -> unique::iterators::UniqueByKey<K, Self::Item, A> {
427 use hashbrown::{hash_map::Entry, HashMap};
428 let mut map: HashMap<K, (Self::Item, usize), S, A> =
429 HashMap::with_hasher_in(hash_builder, alloc);
430 let mut coin_flipper = CoinFlipper::new(rng);
431 for element in self {
432 let v = get_key(&element);
433 let entry = map.entry(v).and_modify(|(e, c)| *c += 1);
434
435 match entry {
436 Entry::Occupied(mut occupied) => {
437 let (previous, new_count) = occupied.get_mut();
438 if coin_flipper.gen_ratio_one_over(*new_count) {
439 *previous = element;
440 }
441 }
442 Entry::Vacant(vacant) => {
443 vacant.insert((element, 1));
444 }
445 }
446 }
447
448
449
450 unique::iterators::UniqueByKey::new(map.into_values())
451 }
452}
453
454#[inline]
458fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize {
459 if ubound <= (core::u32::MAX as usize) {
460 rng.gen_range(0..ubound as u32) as usize
461 } else {
462 rng.gen_range(0..ubound)
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use core::{hash::Hash, ops::Range};
469
470 use crate::Kindness;
471 use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
472
473 const RUNS: usize = 10000;
474 const LENGTH: usize = 100;
475 const LOWER_TOLERANCE: usize = 60;
476 const UPPER_TOLERANCE: usize = 140;
477
478 #[test]
479 fn test_choose_max_empty(){
480 let vec: Vec<u8> = vec![];
481 let mut rng = get_rng();
482 let r = vec.into_iter().choose_max(&mut rng);
483 assert_eq!(r, None)
484 }
485
486 #[test]
487 fn test_choose_max_by_key_empty(){
488 let vec: Vec<u8> = vec![];
489 let mut rng = get_rng();
490 let r = vec.into_iter().choose_max_by_key(&mut rng, |x|*x);
491 assert_eq!(r, None)
492 }
493
494 #[test]
495 fn test_choose_unique() {
496 let mut counts: [usize; LENGTH] = [0; LENGTH];
497 let mut rng = get_rng();
498
499 for _ in 0..RUNS {
500 let range = (0..LENGTH).map(RoughNumber);
501 let elements = range.choose_unique(&mut rng);
502
503 for x in elements {
504 counts[x.0] += 1;
505 }
506 }
507
508 insta::assert_debug_snapshot!(counts);
509 for x in counts {
510 assert!(x > LOWER_TOLERANCE * 10);
511 assert!(x < UPPER_TOLERANCE * 10);
512 }
513 }
514
515 #[test]
516 fn test_choose_unique_by_key() {
517 let mut counts: [usize; LENGTH] = [0; LENGTH];
518 let mut rng = get_rng();
519
520 for _ in 0..RUNS {
521 let range = (0..LENGTH);
522 let elements = range.choose_unique_by_key(&mut rng, |x| x / 10);
523
524 for x in elements {
525 counts[x] += 1;
526 }
527 }
528
529 insta::assert_debug_snapshot!(counts);
530 for x in counts {
531 assert!(x > LOWER_TOLERANCE * 10);
532 assert!(x < UPPER_TOLERANCE * 10);
533 }
534 }
535
536 #[test]
537 fn test_choose_item_empty() {
538 let vec: Vec<usize> = vec![];
539 let mut rng = get_rng();
540
541 let item = vec.into_iter().choose_item(&mut rng);
542 assert_eq!(item, None);
543 }
544
545 #[test]
546 fn test_random_element_with_size_hint() {
547 let mut counts: [usize; LENGTH] = [0; LENGTH];
548 let mut rng = get_rng();
549
550 for _ in 0..RUNS {
551 let range = 0..LENGTH;
552 assert_eq!((LENGTH, Some(LENGTH)), range.size_hint());
553 let element = range.choose_item(&mut rng).unwrap();
554 counts[element] += 1;
555 }
556
557 insta::assert_debug_snapshot!(counts);
558 for x in counts {
559 assert!(x > LOWER_TOLERANCE);
560 assert!(x < UPPER_TOLERANCE);
561 }
562
563 assert_contains(RUNS..(RUNS * 2), &rng.count); }
565
566 #[test]
567 fn test_random_element_unhinted() {
568 let mut counts: [usize; LENGTH] = [0; LENGTH];
569 let mut rng = get_rng();
570
571 for _ in 0..RUNS {
572 let range = UnhintedIterator(0..LENGTH);
573 assert_eq!((0, None), range.size_hint());
574 let element = range.choose_item(&mut rng).unwrap();
575 counts[element] += 1;
576 }
577
578 insta::assert_debug_snapshot!(counts);
579
580 for x in counts {
581 assert!(x > LOWER_TOLERANCE);
582 assert!(x < UPPER_TOLERANCE);
583 }
584
585 }
589
590 #[test]
591 fn test_random_element_windowed() {
592 let mut counts: [usize; LENGTH] = [0; LENGTH];
593 let mut rng = get_rng();
594
595 for _ in 0..RUNS {
596 let range = UnhintedIterator(0..LENGTH);
597 assert_eq!((0, None), range.size_hint());
598 let element = range.choose_item(&mut rng).unwrap();
599 counts[element] += 1;
600 }
601
602 insta::assert_debug_snapshot!(counts);
603
604 for x in counts {
605 assert!(x > LOWER_TOLERANCE);
606 assert!(x < UPPER_TOLERANCE);
607 }
608
609 }
611
612 #[test]
613 fn test_random_max() {
614 let mut counts: [usize; LENGTH] = [0; LENGTH];
615 let mut rng = get_rng();
616
617 for _ in 0..RUNS {
618 let range = (0..LENGTH).map(RoughNumber);
619 let max = range.choose_max(&mut rng).unwrap();
620 counts[max.0] += 1;
621 }
622
623 insta::assert_debug_snapshot!(counts);
624
625 for (i, &x) in counts.iter().enumerate() {
626 if i < 90 {
627 assert!(x == 0)
628 } else {
629 assert!(x > LOWER_TOLERANCE * 10);
630 assert!(x < UPPER_TOLERANCE * 10);
631 }
632 }
633
634 }
636
637 #[test]
638 fn test_random_max_by() {
639 let mut counts: [usize; LENGTH] = [0; LENGTH];
640 let mut rng = get_rng();
641
642 for _ in 0..RUNS {
643 let range = 0..LENGTH;
644 let max = range
645 .choose_max_by(&mut rng, |&a, &b| (a / 10).cmp(&(b / 10)))
646 .unwrap();
647 counts[max] += 1;
648 }
649
650 insta::assert_debug_snapshot!(counts);
651
652 for (i, &x) in counts.iter().enumerate() {
653 if i < 90 {
654 assert!(x == 0)
655 } else {
656 assert!(x > LOWER_TOLERANCE * 10);
657 assert!(x < UPPER_TOLERANCE * 10);
658 }
659 }
660
661 }
663
664 #[test]
665 fn test_random_max_by_key() {
666 let mut counts: [usize; LENGTH] = [0; LENGTH];
667 let mut rng = get_rng();
668
669 for _ in 0..RUNS {
670 let range = 0..LENGTH;
671 let max = range
672 .choose_max_by_key(&mut rng, |x| RoughNumber(*x))
673 .unwrap();
674 counts[max] += 1;
675 }
676
677 insta::assert_debug_snapshot!(counts);
678
679 for (i, &x) in counts.iter().enumerate() {
680 if i < 90 {
681 assert!(x == 0)
682 } else {
683 assert!(x > LOWER_TOLERANCE * 10);
684 assert!(x < UPPER_TOLERANCE * 10);
685 }
686 }
687
688 }
690
691 #[test]
692 fn test_random_min() {
693 let mut counts: [usize; LENGTH] = [0; LENGTH];
694 let mut rng = get_rng();
695
696 for _ in 0..RUNS {
697 let range = (0..LENGTH).map(RoughNumber);
698 let min = range.choose_min(&mut rng).unwrap();
699 counts[min.0] += 1;
700 }
701
702 insta::assert_debug_snapshot!(counts);
703
704 for (i, &x) in counts.iter().enumerate() {
705 if i >= 10 {
706 assert!(x == 0)
707 } else {
708 assert!(x > LOWER_TOLERANCE * 10);
709 assert!(x < UPPER_TOLERANCE * 10);
710 }
711 }
712
713 }
715
716 #[test]
717 fn test_random_min_by() {
718 let mut counts: [usize; LENGTH] = [0; LENGTH];
719 let mut rng = get_rng();
720
721 for _ in 0..RUNS {
722 let range = 0..LENGTH;
723 let max = range
724 .choose_min_by(&mut rng, |&a, &b| (a / 10).cmp(&(b / 10)))
725 .unwrap();
726 counts[max] += 1;
727 }
728
729 insta::assert_debug_snapshot!(counts);
730
731 for (i, &x) in counts.iter().enumerate() {
732 if i >= 10 {
733 assert!(x == 0)
734 } else {
735 assert!(x > LOWER_TOLERANCE * 10);
736 assert!(x < UPPER_TOLERANCE * 10);
737 }
738 }
739
740 }
742
743 #[test]
744 fn test_random_min_by_key() {
745 let mut counts: [usize; LENGTH] = [0; LENGTH];
746 let mut rng = get_rng();
747
748 for _ in 0..RUNS {
749 let range = 0..LENGTH;
750 let max = range
751 .choose_min_by_key(&mut rng, |x| RoughNumber(*x))
752 .unwrap();
753 counts[max] += 1;
754 }
755
756 insta::assert_debug_snapshot!(counts);
757
758 for (i, &x) in counts.iter().enumerate() {
759 if i >= 10 {
760 assert!(x == 0)
761 } else {
762 assert!(x > LOWER_TOLERANCE * 10);
763 assert!(x < UPPER_TOLERANCE * 10);
764 }
765 }
766
767 }
769
770 #[derive(Clone)]
771 struct UnhintedIterator<I: Iterator + Clone>(I);
772 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
773 type Item = I::Item;
774
775 fn next(&mut self) -> Option<Self::Item> {
776 self.0.next()
777 }
778 }
779
780 #[derive(Clone)]
781 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone>(I, usize);
782
783 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
784 type Item = I::Item;
785
786 fn next(&mut self) -> Option<Self::Item> {
787 self.0.next()
788 }
789
790 fn size_hint(&self) -> (usize, Option<usize>) {
791 (core::cmp::min(self.0.len(), self.1), None)
792 }
793 }
794
795 #[derive(Debug, Copy, Clone)]
797 struct RoughNumber(pub usize);
798
799 impl Hash for RoughNumber {
800 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
801 (self.0 / 10).hash(state);
802 }
803 }
804
805 impl Eq for RoughNumber {}
806
807 impl PartialEq for RoughNumber {
808 fn eq(&self, other: &Self) -> bool {
809 (self.0 / 10) == (other.0 / 10)
810 }
811 }
812
813 impl Ord for RoughNumber {
814 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
815 (self.0 / 10).cmp(&(other.0 / 10))
816 }
817 }
818
819 impl PartialOrd for RoughNumber {
820 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
821 (self.0 / 10).partial_cmp(&(other.0 / 10))
822 }
823 }
824
825 fn assert_contains(range: Range<usize>, n: &usize) {
826 if !range.contains(n) {
827 panic!("The range {:?} does not contain {n}", range)
828 }
829 }
830
831 fn get_rng() -> CountingRng<StdRng> {
832 let inner = StdRng::seed_from_u64(123);
833 CountingRng {
834 rng: inner,
835 count: 0,
836 }
837 }
838
839 struct CountingRng<Inner: Rng> {
840 pub rng: Inner,
841 pub count: usize,
842 }
843
844 impl<Inner: Rng> RngCore for CountingRng<Inner> {
845 fn next_u32(&mut self) -> u32 {
846 self.count += 1;
847 self.rng.next_u32()
848 }
849
850 fn next_u64(&mut self) -> u64 {
851 self.count += 1;
852 self.rng.next_u64()
853 }
854
855 fn fill_bytes(&mut self, dest: &mut [u8]) {
856 self.count += 1;
857 self.rng.fill_bytes(dest)
858 }
859 }
860}