#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use itertools::izip;
use ndarray::{Array2, ArrayView1, s};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{RcfError, Result};
fn lookahead_offset(dim: usize, input_dim: usize, look_ahead: usize) -> usize {
dim - input_dim * (1 + look_ahead)
}
fn l1_distance_slices(query: &[f32], stored: &[f32]) -> f64 {
query
.iter()
.zip(stored)
.map(|(a, b)| (a - b).abs() as f64)
.sum()
}
fn l1_distance_slices_ignore_missing(query: &[f32], stored: &[f32], missing: &[bool]) -> f64 {
izip!(query, stored, missing)
.filter(|(_, _, m)| !*m)
.map(|(a, b, _)| (a - b).abs() as f64)
.sum()
}
fn checked_grown_capacity(capacity: usize) -> Result<usize> {
capacity
.checked_mul(2)
.and_then(|v| v.checked_add(4))
.ok_or_else(|| RcfError::Overflow("point store capacity growth overflows usize".into()))
}
type PointMatrix = Array2<f32>;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub(super) struct PointStore {
dim: usize,
input_dim: usize,
#[allow(dead_code)]
shingle_size: usize,
#[allow(dead_code)]
internal_shingling: bool,
store: PointMatrix,
occupied: Vec<bool>,
ref_count: Vec<usize>,
next_free: usize,
free_list: Vec<usize>,
size: usize,
capacity: usize,
shingle_buf: Vec<f32>,
entries_seen: u64,
}
impl PointStore {
pub(super) fn new(
input_dim: usize,
shingle_size: usize,
capacity: usize,
internal_shingling: bool,
) -> Self {
let dim = input_dim * shingle_size;
PointStore {
dim,
input_dim,
shingle_size,
internal_shingling,
store: Array2::zeros((capacity, dim)),
occupied: vec![false; capacity],
ref_count: vec![0usize; capacity],
next_free: 0,
free_list: Vec::new(),
size: 0,
capacity,
shingle_buf: vec![0.0f32; dim],
entries_seen: 0,
}
}
#[cfg(test)]
fn shingled_point(&mut self, base: &[f32]) -> Result<Vec<f32>> {
if self.internal_shingling {
self.advance_shingle(base)?;
Ok(self.shingle_buf.clone())
} else {
self.validate_full_point(base)?;
Ok(base.to_vec())
}
}
pub(super) fn advance_shingle(&mut self, base: &[f32]) -> Result<()> {
if base.len() != self.input_dim {
return Err(RcfError::DimensionMismatch {
expected: self.input_dim,
got: base.len(),
});
}
self.shingle_buf.copy_within(self.input_dim.., 0);
let start = self.dim - self.input_dim;
self.shingle_buf[start..].copy_from_slice(base);
Ok(())
}
pub(super) fn current_shingled(&self) -> &[f32] {
&self.shingle_buf
}
#[cfg(test)]
pub(super) fn add(&mut self, point: &[f32]) -> Result<usize> {
self.validate_full_point(point)?;
self.add_validated(point)
}
pub(super) fn add_validated(&mut self, point: &[f32]) -> Result<usize> {
debug_assert_eq!(point.len(), self.dim);
self.store_point(point)
}
pub(super) fn add_current_shingled(&mut self) -> Result<usize> {
let idx = self.allocate_slot()?;
self.store
.row_mut(idx)
.assign(&ArrayView1::from(&self.shingle_buf[..]));
self.finish_add(idx);
Ok(idx)
}
pub(super) fn validate_full_point(&self, point: &[f32]) -> Result<()> {
if point.len() != self.dim {
return Err(RcfError::DimensionMismatch {
expected: self.dim,
got: point.len(),
});
}
Ok(())
}
pub(super) fn ensure_can_allocate_slot(&self) -> Result<()> {
if !self.free_list.is_empty() || self.next_free < self.capacity {
return Ok(());
}
checked_grown_capacity(self.capacity).map(|_| ())
}
#[cfg(test)]
pub(super) fn force_next_allocation_to_overflow(&mut self) {
self.free_list.clear();
self.next_free = usize::MAX;
self.capacity = usize::MAX;
}
fn store_point(&mut self, point: &[f32]) -> Result<usize> {
let idx = self.allocate_slot()?;
self.store.row_mut(idx).assign(&ArrayView1::from(point));
self.finish_add(idx);
Ok(idx)
}
fn finish_add(&mut self, idx: usize) {
self.occupied[idx] = true;
self.ref_count[idx] = 0;
self.size += 1;
self.entries_seen += 1;
}
pub(super) fn record_logical_add_without_storage(&mut self) {
self.entries_seen += 1;
}
fn allocate_slot(&mut self) -> Result<usize> {
if let Some(idx) = self.free_list.pop() {
return Ok(idx);
}
if self.next_free < self.capacity {
let idx = self.next_free;
self.next_free += 1;
return Ok(idx);
}
let new_cap = checked_grown_capacity(self.capacity)?;
let mut new_store = Array2::zeros((new_cap, self.dim));
new_store
.slice_mut(s![..self.capacity, ..])
.assign(&self.store);
self.store = new_store;
self.occupied.resize(new_cap, false);
self.ref_count.resize(new_cap, 0);
let idx = self.next_free;
self.next_free += 1;
self.capacity = new_cap;
Ok(idx)
}
pub(super) fn inc_ref(&mut self, idx: usize) {
self.ref_count[idx] += 1;
}
pub(super) fn dec_ref(&mut self, idx: usize) {
if self.ref_count[idx] > 0 {
self.ref_count[idx] -= 1;
}
if self.ref_count[idx] == 0 && self.occupied[idx] {
self.occupied[idx] = false;
self.size -= 1;
self.free_list.push(idx);
}
}
pub(super) fn get(&self, idx: usize) -> &[f32] {
debug_assert!(self.occupied[idx], "accessing unoccupied slot {idx}");
self.store
.row(idx)
.to_slice()
.expect("store must be contiguous")
}
pub(super) fn is_equal(&self, point: &[f32], idx: usize) -> bool {
self.store.row(idx) == ArrayView1::from(point)
}
pub(super) fn l1_distance(&self, query: &[f32], idx: usize) -> f64 {
let stored = self.get(idx);
l1_distance_slices(query, stored)
}
pub(super) fn l1_distance_ignore_missing(
&self,
query: &[f32],
idx: usize,
missing: &[bool],
) -> f64 {
let stored = self.get(idx);
l1_distance_slices_ignore_missing(query, stored, missing)
}
pub(super) fn copy_point(&self, idx: usize) -> Vec<f32> {
self.get(idx).to_vec()
}
#[cfg(test)]
fn next_indices(&self, look_ahead: usize) -> Vec<usize> {
let offset = lookahead_offset(self.dim, self.input_dim, look_ahead);
(0..self.input_dim).map(|i| offset + i).collect()
}
pub(super) fn next_indices_into(&self, look_ahead: usize, indices: &mut Vec<usize>) {
indices.clear();
let offset = lookahead_offset(self.dim, self.input_dim, look_ahead);
indices.extend((0..self.input_dim).map(|i| offset + i));
}
#[cfg(test)]
fn missing_indices_with_lookahead(
&self,
look_ahead: usize,
missing_base: &[usize],
) -> Vec<usize> {
let offset = lookahead_offset(self.dim, self.input_dim, look_ahead);
missing_base.iter().map(|&i| offset + i).collect()
}
#[cfg(test)]
pub(super) fn num_points(&self) -> usize {
self.size
}
#[cfg(test)]
pub(super) fn ref_count(&self, idx: usize) -> usize {
self.ref_count[idx]
}
#[cfg(test)]
fn dim(&self) -> usize {
self.dim
}
#[cfg(test)]
fn input_dim(&self) -> usize {
self.input_dim
}
#[cfg(test)]
pub(super) fn entries_seen(&self) -> u64 {
self.entries_seen
}
}
#[cfg(test)]
mod tests {
use approx::assert_abs_diff_eq;
use rstest::rstest;
use super::*;
#[test]
fn add_and_get() {
let mut ps = PointStore::new(2, 1, 8, false);
let idx = ps.add(&[1.0, 2.0]).unwrap();
assert_eq!(ps.get(idx), &[1.0f32, 2.0]);
}
#[test]
fn add_validated_stores_full_point() {
let mut ps = PointStore::new(2, 2, 8, false);
let idx = ps.add_validated(&[1.0, 2.0, 3.0, 4.0]).unwrap();
assert_eq!(idx, 0);
assert_eq!(ps.get(idx), &[1.0f32, 2.0, 3.0, 4.0]);
assert_eq!(ps.num_points(), 1);
assert_eq!(ps.entries_seen(), 1);
}
#[rstest]
#[case::small_capacity(8, 20)]
#[case::zero_capacity(0, 4)]
fn checked_grown_capacity_matches_growth_formula(
#[case] capacity: usize,
#[case] expected: usize,
) {
assert_eq!(checked_grown_capacity(capacity).unwrap(), expected);
}
#[test]
fn checked_grown_capacity_rejects_overflow() {
let err = checked_grown_capacity(usize::MAX).unwrap_err();
assert!(
matches!(err, RcfError::Overflow(ref msg) if msg.contains("capacity growth")),
"unexpected error variant: {err:?}"
);
}
#[test]
fn ensure_can_allocate_slot_rejects_growth_overflow() {
let mut ps = PointStore::new(2, 1, 1, false);
ps.force_next_allocation_to_overflow();
let err = ps.ensure_can_allocate_slot().unwrap_err();
assert!(
matches!(err, RcfError::Overflow(ref msg) if msg.contains("capacity growth")),
"unexpected error variant: {err:?}"
);
}
#[test]
fn skipped_add_preserves_logical_add_count_without_storing_point() {
let mut ps = PointStore::new(2, 1, 8, false);
ps.record_logical_add_without_storage();
assert_eq!(ps.num_points(), 0);
assert_eq!(ps.entries_seen(), 1);
}
#[test]
fn ref_count_frees_slot() {
let mut ps = PointStore::new(2, 1, 8, false);
let idx = ps.add(&[1.0, 2.0]).unwrap();
ps.inc_ref(idx);
assert_eq!(ps.size, 1);
ps.dec_ref(idx);
assert_eq!(ps.size, 0);
}
#[rstest]
#[case::window_2(vec![1.0, 2.0], vec![1.0, 2.0])]
#[case::window_3(vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0])]
#[case::window_4(vec![1.0, 2.0, 3.0, 4.0], vec![1.0, 2.0, 3.0, 4.0])]
fn shingling_shifts_buffer(#[case] series: Vec<f32>, #[case] expected: Vec<f32>) {
let shingle_size = series.len();
let mut ps = PointStore::new(1, shingle_size, 8, true);
for point in &series[..shingle_size - 1] {
let _ = ps.shingled_point(&[*point]).unwrap();
}
let full = ps.shingled_point(&[series[shingle_size - 1]]).unwrap();
assert_eq!(full, expected);
}
#[test]
fn is_equal_works() {
let mut ps = PointStore::new(2, 1, 8, false);
let idx = ps.add(&[3.0, 4.0]).unwrap();
assert!(ps.is_equal(&[3.0, 4.0], idx));
assert!(!ps.is_equal(&[3.0, 5.0], idx));
}
#[rstest]
#[case::no_missing([false, false, false, false], 8.0)]
#[case::alternate_missing([false, true, false, true], 3.0)]
#[case::all_missing([true, true, true, true], 0.0)]
fn l1_helpers_match_expected(#[case] missing: [bool; 4], #[case] expected_partial: f64) {
let q = [1.0f32, -1.0, 3.5, 0.0];
let s = [0.0f32, 2.0, 1.5, -2.0];
let full = l1_distance_slices(&q, &s);
let partial = l1_distance_slices_ignore_missing(&q, &s, &missing);
assert_abs_diff_eq!(full, 8.0, epsilon = 1e-12);
assert_abs_diff_eq!(partial, expected_partial, epsilon = 1e-12);
}
#[rstest]
#[case::lookahead_0(0, vec![4, 5])]
#[case::lookahead_1(1, vec![2, 3])]
#[case::lookahead_2(2, vec![0, 1])]
fn lookahead_offset_and_indices_are_consistent(
#[case] look_ahead: usize,
#[case] expected_indices: Vec<usize>,
) {
let ps = PointStore::new(2, 3, 8, true);
let offset = lookahead_offset(ps.dim(), ps.input_dim(), look_ahead);
assert_eq!(offset, expected_indices[0]);
assert_eq!(ps.next_indices(look_ahead), expected_indices);
}
#[test]
fn missing_indices_with_lookahead_maps_base_indices() {
let ps = PointStore::new(2, 3, 8, true);
assert_eq!(ps.missing_indices_with_lookahead(1, &[0, 1]), vec![2, 3]);
}
#[cfg(feature = "std")]
mod proptest_tests {
use super::*;
use approx::abs_diff_eq;
use proptest::prelude::*;
proptest! {
#[test]
fn l1_distance_symmetric(
a0 in -100f32..100f32,
a1 in -100f32..100f32,
b0 in -100f32..100f32,
b1 in -100f32..100f32,
) {
let a = [a0, a1];
let b = [b0, b1];
let d_ab = l1_distance_slices(&a, &b);
let d_ba = l1_distance_slices(&b, &a);
prop_assert!(
abs_diff_eq!(d_ab, d_ba, epsilon = 1e-9),
"d_ab={d_ab} d_ba={d_ba}"
);
}
#[test]
fn l1_distance_non_negative(
a0 in -100f32..100f32,
a1 in -100f32..100f32,
b0 in -100f32..100f32,
b1 in -100f32..100f32,
) {
let a = [a0, a1];
let b = [b0, b1];
let d = l1_distance_slices(&a, &b);
prop_assert!(d >= 0.0, "distance={d}");
}
#[test]
fn l1_missing_leq_full(
a0 in -100f32..100f32,
a1 in -100f32..100f32,
b0 in -100f32..100f32,
b1 in -100f32..100f32,
m0 in any::<bool>(),
m1 in any::<bool>(),
) {
let a = [a0, a1];
let b = [b0, b1];
let missing = [m0, m1];
let full = l1_distance_slices(&a, &b);
let partial = l1_distance_slices_ignore_missing(&a, &b, &missing);
prop_assert!(partial <= full + 1e-9, "partial={partial} > full={full}");
}
#[test]
fn l1_all_missing_is_zero(
a0 in -100f32..100f32,
a1 in -100f32..100f32,
b0 in -100f32..100f32,
b1 in -100f32..100f32,
) {
let a = [a0, a1];
let b = [b0, b1];
let d = l1_distance_slices_ignore_missing(&a, &b, &[true, true]);
prop_assert_eq!(d, 0.0);
}
}
}
}