try_partialord/
binary_search.rs1use crate::{InvalidOrderError, OrderResult};
2use core::cmp::Ordering;
3
4pub trait TryBinarySearch<T> {
8 #[inline]
10 fn try_binary_search(&self, x: &T) -> OrderResult<Result<usize, usize>>
11 where
12 T: PartialOrd<T>,
13 {
14 self.try_binary_search_by(|a| a.partial_cmp(x))
15 }
16 fn try_binary_search_by<F>(&self, compare: F) -> OrderResult<Result<usize, usize>>
18 where
19 F: FnMut(&T) -> Option<Ordering>;
20 #[inline]
21 fn try_binary_search_by_key<K, F>(&self, b: &K, f: F) -> OrderResult<Result<usize, usize>>
23 where
24 F: FnMut(&T) -> Option<K>,
25 K: PartialOrd<K>,
26 {
27 let mut fk = f;
28 self.try_binary_search_by(|a| fk(a)?.partial_cmp(b))
29 }
30}
31
32impl<T> TryBinarySearch<T> for [T] {
33 #[inline]
34 fn try_binary_search_by<F>(&self, compare: F) -> OrderResult<Result<usize, usize>>
35 where
36 F: FnMut(&T) -> Option<Ordering>,
37 {
38 try_binary_search_by_inner(self, compare).ok_or(InvalidOrderError)
39 }
40}
41
42fn try_binary_search_by_inner<T, F>(slice: &[T], mut compare: F) -> Option<Result<usize, usize>>
43where
44 F: FnMut(&T) -> Option<Ordering>,
45{
46 let mut size = slice.len();
47 let mut left = 0;
48 let mut right = size;
49 while size > 0 {
50 let mid = left + size / 2;
51
52 let cmp = compare(unsafe { slice.get_unchecked(mid) })?;
56
57 if cmp == Ordering::Less {
62 left = mid + 1;
63 } else if cmp == Ordering::Greater {
64 right = mid;
65 } else {
66 return Some(Ok(mid));
69 }
70
71 size = right - left;
72 }
73 Some(Err(left))
74}
75
76#[cfg(test)]
77#[cfg(feature = "std")]
78mod tests {
79 use crate::*;
80 use rand::distributions::Standard;
81 use rand::prelude::*;
82 use std::vec::Vec;
83
84 #[test]
85 fn try_binary_search_ok() {
86 let rng = thread_rng();
87 let mut v: Vec<f32> = Standard.sample_iter(rng).take(100).collect();
88 assert!(v.try_sort().is_ok());
89 let b = random();
90 let i = v.try_binary_search(&b);
91 assert!(i.is_ok());
92 let ik = i.unwrap().unwrap_or_else(|e| e);
93 for sm in v[..ik].iter() {
94 assert!(sm < &b);
96 }
97 for sm in v[ik..].iter() {
98 assert!(sm >= &b);
99 }
100 }
101}