use std::cmp::Ordering;
use std::error::Error;
use std::fmt::Debug;
#[allow(unused)]
#[derive(Debug, Clone, Copy)]
pub(crate) enum Bound {
LeastUpper,
GreatestLower,
}
#[allow(unused)]
#[derive(Debug)]
pub(crate) enum SearchError<T: Error> {
OutOfRange,
KeyFunctionError(T),
}
#[allow(unused)]
pub(crate) fn binary_search_by_key_with_bounds<'a, T, K: Ord + Debug, E: Error>(
values: &'a [T],
key: K,
key_fn: impl Fn(&'a T) -> Result<K, E>,
bound: Bound,
) -> Result<usize, SearchError<E>> {
let (mut lo, mut hi) = (0, values.len());
while lo != hi {
let mid = lo + (hi - lo) / 2;
debug_assert!(lo <= mid && mid < hi);
let mid_key = key_fn(&values[mid]).map_err(SearchError::KeyFunctionError)?;
match (key.cmp(&mid_key), bound) {
(Ordering::Less, _) => hi = mid,
(Ordering::Equal, Bound::LeastUpper) => hi = mid,
(Ordering::Equal, Bound::GreatestLower) => lo = mid + 1,
(Ordering::Greater, _) => lo = mid + 1,
}
}
match bound {
Bound::LeastUpper if hi < values.len() => Ok(hi),
Bound::GreatestLower if lo > 0 => Ok(lo - 1),
_ => Err(SearchError::OutOfRange),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DeltaResult, Error};
fn get_val(x: &i32) -> DeltaResult<i32> {
Ok(*x)
}
#[rstest::rstest]
#[case::exact_least_upper(5, Bound::LeastUpper, 2)]
#[case::exact_greatest_lower(5, Bound::GreatestLower, 2)]
#[case::no_match_least_upper(4, Bound::LeastUpper, 2)]
#[case::no_match_greatest_lower(6, Bound::GreatestLower, 2)]
fn test_binary_search(
#[case] search_key: i32,
#[case] bound: Bound,
#[case] expected_index: usize,
) {
let values = vec![1, 3, 5, 7, 9];
let result = binary_search_by_key_with_bounds(&values, search_key, get_val, bound).unwrap();
assert_eq!(result, expected_index);
}
#[rstest::rstest]
#[case::least_upper_first_occurrence(5, Bound::LeastUpper, 2)]
#[case::greatest_lower_last_occurrence(5, Bound::GreatestLower, 4)]
fn test_duplicate_values(
#[case] search_key: i32,
#[case] bound: Bound,
#[case] expected_index: usize,
) {
let values = vec![1, 3, 5, 5, 5, 7, 9];
let result = binary_search_by_key_with_bounds(&values, search_key, get_val, bound).unwrap();
assert_eq!(result, expected_index);
}
#[test]
fn test_edge_cases() {
let empty: Vec<i32> = vec![];
let result = binary_search_by_key_with_bounds(&empty, 5, get_val, Bound::LeastUpper);
assert!(result.is_err());
let values = vec![5, 7, 9];
let result =
binary_search_by_key_with_bounds(&values, 3, get_val, Bound::LeastUpper).unwrap();
assert_eq!(result, 0);
let result = binary_search_by_key_with_bounds(&values, 3, get_val, Bound::GreatestLower);
assert!(matches!(result, Err(SearchError::OutOfRange)));
let result = binary_search_by_key_with_bounds(&values, 10, get_val, Bound::LeastUpper);
assert!(matches!(result, Err(SearchError::OutOfRange)));
let result =
binary_search_by_key_with_bounds(&values, 10, get_val, Bound::GreatestLower).unwrap();
assert_eq!(result, 2); }
#[test]
fn test_error_propagation() {
let values = vec![1, 3, 5, 7, 9];
let failing_key_fn = |x: &i32| -> DeltaResult<i32> {
if *x == 5 {
Err(Error::generic("Error extracting key"))
} else {
Ok(*x)
}
};
let result =
binary_search_by_key_with_bounds(&values, 7, failing_key_fn, Bound::LeastUpper);
assert!(matches!(
result,
Err(SearchError::KeyFunctionError(crate::Error::Generic(msg))) if msg.contains("Error extracting key")
));
}
}