para_graph/algorithms/rust/
binary_search.rs

1use num::PrimInt;
2use rayon::prelude::*;
3use std::cmp::Ordering;
4use std::ops::Range;
5
6fn binary_search<T, F, Q>(eval: F, range: Range<T>, query: &Q) -> Option<T>
7where
8    T: PrimInt,
9    F: Fn(&T, &Q) -> Ordering,
10{
11    let mut low = range.start;
12    let mut high = range.end;
13    let mut result: Option<T> = None;
14    while low <= high {
15        let mid = low + (high - low) / T::from(2).unwrap();
16        match eval(&mid, query) {
17            Ordering::Less => low = mid + T::one(),
18            Ordering::Greater | Ordering::Equal => {
19                high = mid - T::one();
20                result = Some(mid);
21            }
22        }
23    }
24    result
25}
26
27pub fn binary_search_serial<T, F, Q>(eval: F, range: Range<T>, queries: &[Q]) -> Vec<Option<T>>
28where
29    T: PrimInt,
30    F: Copy + Fn(&T, &Q) -> Ordering,
31{
32    queries
33        .iter()
34        .map(|query| binary_search(eval, range.clone(), query))
35        .collect()
36}
37
38pub fn binary_search_par_cpu<T, F, Q>(eval: F, range: Range<T>, queries: &[Q]) -> Vec<Option<T>>
39where
40    T: PrimInt + Send + Sync,
41    F: Copy + Fn(&T, &Q) -> Ordering + Send + Sync,
42    Q: Sync,
43{
44    queries
45        .par_iter()
46        .map(|query| binary_search(eval, range.clone(), query))
47        .collect()
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use itertools::Itertools;
54    use num::BigUint;
55    use std::str::FromStr;
56
57    #[test]
58    fn test_serial_search_find_equal() {
59        let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
60        let queries = [3, 4, 7];
61        let results =
62            binary_search_serial(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
63        assert_eq!(results, vec![Some(2), Some(3), Some(6)]);
64    }
65
66    #[test]
67    fn test_serial_search_partial_greater_equal() {
68        let data = [1, 3, 5, 7, 9];
69        let queries = [2, 3, 6, 7];
70        let results =
71            binary_search_serial(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
72        assert_eq!(results, vec![Some(1), Some(1), Some(3), Some(3)]);
73    }
74
75    #[test]
76    fn test_serial_search_factorial() {
77        let factorial = |n: usize| (1..=n).product::<BigUint>();
78        let queries= ["2", "6", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000000", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000001"]
79            .iter()
80            .map(|s| BigUint::from_str(s).unwrap())
81            .collect_vec();
82        let results = binary_search_serial(
83            |mid, query| factorial(*mid).cmp(query),
84            0..80usize,
85            &queries,
86        );
87        assert_eq!(results, vec![Some(2), Some(3), Some(80), None]);
88    }
89
90    #[test]
91    fn test_par_search_find_equal() {
92        let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
93        let queries = [3, 4, 7];
94        let results =
95            binary_search_par_cpu(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
96        assert_eq!(results, vec![Some(2), Some(3), Some(6)]);
97    }
98
99    #[test]
100    fn test_par_search_partial_greater_equal() {
101        let data = [1, 3, 5, 7, 9];
102        let queries = [2, 3, 6, 7];
103        let results =
104            binary_search_par_cpu(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
105        assert_eq!(results, vec![Some(1), Some(1), Some(3), Some(3)]);
106    }
107
108    #[test]
109    fn test_par_search_factorial() {
110        let factorial = |n: usize| (1..=n).product::<BigUint>();
111        let queries= ["2", "6", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000000", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000001"]
112            .iter()
113            .map(|s| BigUint::from_str(s).unwrap())
114            .collect_vec();
115        let results = binary_search_par_cpu(
116            |mid, query| factorial(*mid).cmp(query),
117            0..80usize,
118            &queries,
119        );
120        assert_eq!(results, vec![Some(2), Some(3), Some(80), None]);
121    }
122}