use std::collections::VecDeque;
use super::c2;
use super::sorted_ref::SortedReferenceState;
#[derive(Debug, Clone, PartialEq)]
pub struct RollingReference {
curves: VecDeque<Vec<f64>>,
capacity: usize,
n_points: usize,
sorted_columns: Vec<Vec<f64>>,
}
impl RollingReference {
pub fn new(capacity: usize, n_points: usize) -> Self {
assert!(capacity >= 1, "capacity must be at least 1");
Self {
curves: VecDeque::with_capacity(capacity),
capacity,
n_points,
sorted_columns: (0..n_points)
.map(|_| Vec::with_capacity(capacity))
.collect(),
}
}
pub fn push(&mut self, curve: &[f64]) -> Option<Vec<f64>> {
assert_eq!(
curve.len(),
self.n_points,
"curve length {} does not match n_points {}",
curve.len(),
self.n_points
);
let evicted = if self.curves.len() == self.capacity {
let old = self
.curves
.pop_front()
.expect("capacity invariant: deque is non-empty");
for t in 0..self.n_points {
let col = &mut self.sorted_columns[t];
let old_val = old[t];
let pos = col.partition_point(|&v| v < old_val);
let mut found = false;
for idx in pos..col.len() {
if col[idx] == old_val {
col.remove(idx);
found = true;
break;
}
if col[idx] > old_val {
break;
}
}
if !found {
for idx in (0..pos).rev() {
if col[idx] == old_val {
col.remove(idx);
break;
}
if col[idx] < old_val {
break;
}
}
}
}
Some(old)
} else {
None
};
let new_curve: Vec<f64> = curve.to_vec();
for t in 0..self.n_points {
let col = &mut self.sorted_columns[t];
let val = new_curve[t];
let pos = col.partition_point(|&v| v < val);
col.insert(pos, val);
}
self.curves.push_back(new_curve);
evicted
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn snapshot(&self) -> SortedReferenceState {
SortedReferenceState {
sorted_columns: self.sorted_columns.clone(),
nori: self.curves.len(),
n_points: self.n_points,
}
}
pub fn mbd_one(&self, curve: &[f64]) -> f64 {
let n = self.curves.len();
if n < 2 || self.n_points == 0 {
return 0.0;
}
assert_eq!(
curve.len(),
self.n_points,
"curve length {} does not match n_points {}",
curve.len(),
self.n_points
);
let cn2 = c2(n);
let mut total = 0usize;
for t in 0..self.n_points {
let col = &self.sorted_columns[t];
let below = col.partition_point(|&v| v < curve[t]);
let at_or_below = col.partition_point(|&v| v <= curve[t]);
let above = n - at_or_below;
total += cn2 - c2(below) - c2(above);
}
total as f64 / (cn2 as f64 * self.n_points as f64)
}
#[inline]
pub fn len(&self) -> usize {
self.curves.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.curves.is_empty()
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
}