use crate::bal::BalIndex;
use std::vec::Vec;
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BalWrites<T: PartialEq + Clone> {
pub writes: Vec<(BalIndex, T)>,
}
impl<T: PartialEq + Clone> BalWrites<T> {
pub fn new(mut writes: Vec<(BalIndex, T)>) -> Self {
writes.sort_by_key(|(index, _)| *index);
Self { writes }
}
#[inline(never)]
pub fn get_linear_search(&self, bal_index: BalIndex) -> Option<T> {
let mut last_item = None;
for (index, item) in self.writes.iter() {
if index >= &bal_index {
return last_item;
}
last_item = Some(item.clone());
}
last_item
}
pub fn get(&self, bal_index: BalIndex) -> Option<T> {
if self.writes.len() < 5 {
return self.get_linear_search(bal_index);
}
let i = match self
.writes
.binary_search_by_key(&bal_index, |(index, _)| *index)
{
Ok(i) => i,
Err(i) => i,
};
(i != 0).then(|| self.writes[i - 1].1.clone())
}
pub fn extend(&mut self, other: BalWrites<T>) {
self.writes.extend(other.writes);
}
pub fn is_empty(&self) -> bool {
self.writes.is_empty()
}
#[inline]
pub fn force_update(&mut self, index: BalIndex, value: T) {
if let Some(last) = self.writes.last_mut() {
if index == last.0 {
last.1 = value;
return;
}
}
self.writes.push((index, value));
}
pub fn update(&mut self, index: BalIndex, original_value: &T, value: T) {
self.update_with_key(index, original_value, value, |i| i);
}
#[inline]
pub fn update_with_key<K: PartialEq, F>(
&mut self,
index: BalIndex,
original_subvalue: &K,
value: T,
f: F,
) where
F: Fn(&T) -> &K,
{
if let Some(last) = self.writes.last_mut() {
if last.0 != index {
if f(&last.1) != f(&value) {
self.writes.push((index, value));
}
return;
}
}
let (previous, last) = match self.writes.as_mut_slice() {
[.., previous, last] => (f(&previous.1), last),
[last] => (original_subvalue, last),
[] => {
if original_subvalue != f(&value) {
self.writes.push((index, value));
}
return;
}
};
if previous == f(&value) {
self.writes.pop();
return;
}
last.1 = value;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get() {
let bal_writes = BalWrites::new(vec![(0, 1), (1, 2), (2, 3)]);
assert_eq!(bal_writes.get(0), None);
assert_eq!(bal_writes.get(1), Some(1));
assert_eq!(bal_writes.get(2), Some(2));
assert_eq!(bal_writes.get(3), Some(3));
assert_eq!(bal_writes.get(4), Some(3));
}
fn get_binary_search(threshold: BalIndex) {
let entries: Vec<_> = (0..threshold - 1)
.map(|i| (i, i + 1))
.chain(std::iter::once((threshold, threshold + 1)))
.collect();
let bal_writes = BalWrites::new(entries);
assert_eq!(bal_writes.get(0), None);
for i in 1..threshold - 1 {
assert_eq!(bal_writes.get(i), Some(i));
}
assert_eq!(bal_writes.get(threshold), Some(threshold - 1));
assert_eq!(bal_writes.get(threshold + 1), Some(threshold + 1));
}
#[test]
fn test_get_binary_search() {
get_binary_search(4);
get_binary_search(5);
get_binary_search(6);
get_binary_search(7);
}
}