para_graph/algorithms/rust/
binary_search.rs1use num::PrimInt;
2use rayon::prelude::*;
3use std::cmp::Ordering;
4use std::ops::Range;
5
6fn binary_search<T, F, Q>(eval: F, range: Range<T>, query: &Q) -> Option<T>
7where
8 T: PrimInt,
9 F: Fn(&T, &Q) -> Ordering,
10{
11 let mut low = range.start;
12 let mut high = range.end;
13 let mut result: Option<T> = None;
14 while low <= high {
15 let mid = low + (high - low) / T::from(2).unwrap();
16 match eval(&mid, query) {
17 Ordering::Less => low = mid + T::one(),
18 Ordering::Greater | Ordering::Equal => {
19 high = mid - T::one();
20 result = Some(mid);
21 }
22 }
23 }
24 result
25}
26
27pub fn binary_search_serial<T, F, Q>(eval: F, range: Range<T>, queries: &[Q]) -> Vec<Option<T>>
28where
29 T: PrimInt,
30 F: Copy + Fn(&T, &Q) -> Ordering,
31{
32 queries
33 .iter()
34 .map(|query| binary_search(eval, range.clone(), query))
35 .collect()
36}
37
38pub fn binary_search_par_cpu<T, F, Q>(eval: F, range: Range<T>, queries: &[Q]) -> Vec<Option<T>>
39where
40 T: PrimInt + Send + Sync,
41 F: Copy + Fn(&T, &Q) -> Ordering + Send + Sync,
42 Q: Sync,
43{
44 queries
45 .par_iter()
46 .map(|query| binary_search(eval, range.clone(), query))
47 .collect()
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use itertools::Itertools;
54 use num::BigUint;
55 use std::str::FromStr;
56
57 #[test]
58 fn test_serial_search_find_equal() {
59 let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
60 let queries = [3, 4, 7];
61 let results =
62 binary_search_serial(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
63 assert_eq!(results, vec![Some(2), Some(3), Some(6)]);
64 }
65
66 #[test]
67 fn test_serial_search_partial_greater_equal() {
68 let data = [1, 3, 5, 7, 9];
69 let queries = [2, 3, 6, 7];
70 let results =
71 binary_search_serial(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
72 assert_eq!(results, vec![Some(1), Some(1), Some(3), Some(3)]);
73 }
74
75 #[test]
76 fn test_serial_search_factorial() {
77 let factorial = |n: usize| (1..=n).product::<BigUint>();
78 let queries= ["2", "6", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000000", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000001"]
79 .iter()
80 .map(|s| BigUint::from_str(s).unwrap())
81 .collect_vec();
82 let results = binary_search_serial(
83 |mid, query| factorial(*mid).cmp(query),
84 0..80usize,
85 &queries,
86 );
87 assert_eq!(results, vec![Some(2), Some(3), Some(80), None]);
88 }
89
90 #[test]
91 fn test_par_search_find_equal() {
92 let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
93 let queries = [3, 4, 7];
94 let results =
95 binary_search_par_cpu(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
96 assert_eq!(results, vec![Some(2), Some(3), Some(6)]);
97 }
98
99 #[test]
100 fn test_par_search_partial_greater_equal() {
101 let data = [1, 3, 5, 7, 9];
102 let queries = [2, 3, 6, 7];
103 let results =
104 binary_search_par_cpu(|mid, query| data[*mid].cmp(query), 0..data.len(), &queries);
105 assert_eq!(results, vec![Some(1), Some(1), Some(3), Some(3)]);
106 }
107
108 #[test]
109 fn test_par_search_factorial() {
110 let factorial = |n: usize| (1..=n).product::<BigUint>();
111 let queries= ["2", "6", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000000", "71569457046263802294811533723186532165584657342365752577109445058227039255480148842668944867280814080000000000000000001"]
112 .iter()
113 .map(|s| BigUint::from_str(s).unwrap())
114 .collect_vec();
115 let results = binary_search_par_cpu(
116 |mid, query| factorial(*mid).cmp(query),
117 0..80usize,
118 &queries,
119 );
120 assert_eq!(results, vec![Some(2), Some(3), Some(80), None]);
121 }
122}