use std::collections::btree_map::Range;
use std::collections::BTreeMap;
use std::iter::Chain;
use std::ops::Bound::Excluded;
use std::ops::Bound::Included;
use std::ops::Bound::Unbounded;
#[non_exhaustive]
pub struct STree<K, V>(BTreeMap<K, V>);
impl<K: Ord, V> STree<K, V> {
pub fn new() -> Self {
Self(BTreeMap::new())
}
pub fn get(&self, key: K) -> Option<&V> {
self.0.get(&key)
}
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
self.0.insert(key, value)
}
pub fn remove(&mut self, key: K) -> Option<V> {
self.0.remove(&key)
}
}
impl<V> STree<u64, V> {
pub fn find_best_single(&self, target: u64, common_bits: u32) -> Option<(u64, &V)> {
let mut elements_iterator = self.find(target, common_bits);
let mut best_element = match elements_iterator.next() {
Some((&k, v)) => (k, v),
None => {
return None;
}
};
let mut best_element_diff = (best_element.0 as i128 - target as i128).abs();
for (&k, v) in elements_iterator {
let element_diff = (k as i128 - target as i128).abs();
if element_diff < best_element_diff {
best_element = (k, v);
best_element_diff = element_diff;
}
}
Some(best_element)
}
pub fn find_best_sorted(&self, target: u64, common_bits: u32) -> Vec<(u64, &V)> {
let mut sorted_elements: Vec<(i128, u64, &V)> = self
.find(target, common_bits)
.map(|(&k, v)| ((k as i128 - target as i128).abs(), k, v))
.collect();
sorted_elements.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
sorted_elements
.into_iter()
.map(|(_, k, v)| (k, v))
.collect()
}
fn find(&self, target: u64, common_bits: u32) -> Chain<Range<u64, V>, Range<u64, V>> {
let deviation = 2u64.pow(64 - common_bits) / 2;
let start = target.wrapping_sub(deviation);
let end = target.wrapping_add(deviation);
if start > end {
self.0
.range((Included(start), Unbounded))
.chain(self.0.range((Unbounded, Included(end))))
} else {
self.0
.range((Included(start), Excluded(target)))
.chain(self.0.range((Included(target), Included(end))))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_test() {
let mut tree = STree::<u64, ()>::new();
let num_1 = 10u64;
let num_2 = 15u64;
let num_3 = u64::max_value() - 5;
let num_4 = 999u64;
let num_x = u64::max_value() - 999;
let num_y = 2000u64;
let num_z = 12u64;
let num_w = 0u64;
assert!(tree.insert(num_1, ()).is_none());
assert!(tree.insert(num_2, ()).is_none());
assert!(tree.insert(num_3, ()).is_none());
assert!(tree.insert(num_4, ()).is_none());
assert!(tree.get(num_1).is_some());
assert!(tree.get(num_2).is_some());
assert!(tree.get(num_3).is_some());
assert!(tree.get(num_4).is_some());
assert_eq!(tree.find_best_single(num_1, 64 - 1), Some((num_1, &())));
assert_eq!(tree.find_best_single(num_2, 64 - 1), Some((num_2, &())));
assert_eq!(tree.find_best_single(num_x, 64 - 1), None);
assert_eq!(tree.find_best_single(num_x, 64 - 11), Some((num_3, &())));
assert_eq!(tree.find_best_single(num_y, 64 - 1), None);
assert_eq!(tree.find_best_single(num_y, 64 - 11), Some((num_4, &())));
assert_eq!(tree.find_best_single(num_z, 64 - 2), Some((num_1, &())));
assert_eq!(tree.find_best_single(num_w, 64 - 3), None);
assert_eq!(tree.find_best_single(num_w, 64 - 4), Some((num_3, &())));
assert_eq!(tree.find_best_sorted(num_1, 64 - 1).len(), 1);
assert_eq!(tree.find_best_sorted(num_1, 64 - 4).len(), 2);
assert_eq!(
tree.find_best_sorted(num_1, 64 - 4),
vec![(num_1, &()), (num_2, &())]
);
assert_eq!(tree.find_best_sorted(num_1, 64 - 10).len(), 3);
assert_eq!(
tree.find_best_sorted(num_1, 64 - 10),
vec![(num_1, &()), (num_2, &()), (num_3, &())]
);
}
}