use alloc::format;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::{RcfError, RcfResult};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DiVector {
high: Vec<f64>,
low: Vec<f64>,
}
impl DiVector {
#[must_use]
pub fn zeros(dim: usize) -> Self {
Self {
high: vec![0.0; dim],
low: vec![0.0; dim],
}
}
pub fn from_arrays(high: Vec<f64>, low: Vec<f64>) -> RcfResult<Self> {
if high.len() != low.len() {
return Err(RcfError::DimensionMismatch {
expected: high.len(),
got: low.len(),
});
}
Ok(Self { high, low })
}
#[must_use]
pub fn dim(&self) -> usize {
self.high.len()
}
#[must_use]
pub fn high(&self) -> &[f64] {
&self.high
}
#[must_use]
pub fn low(&self) -> &[f64] {
&self.low
}
#[must_use]
pub fn total(&self) -> f64 {
self.high.iter().sum::<f64>() + self.low.iter().sum::<f64>()
}
#[must_use]
pub fn per_dim_total(&self, d: usize) -> f64 {
self.high[d] + self.low[d]
}
#[must_use]
pub fn argmax(&self) -> Option<usize> {
if self.dim() == 0 {
return None;
}
let mut best = 0_usize;
let mut best_val = self.per_dim_total(0);
for d in 1..self.dim() {
let v = self.per_dim_total(d);
if v > best_val {
best = d;
best_val = v;
}
}
Some(best)
}
pub fn add_high(&mut self, d: usize, value: f64) -> RcfResult<()> {
if d >= self.high.len() {
return Err(RcfError::OutOfBounds {
index: d,
len: self.high.len(),
});
}
self.high[d] += value;
Ok(())
}
pub fn add_low(&mut self, d: usize, value: f64) -> RcfResult<()> {
if d >= self.low.len() {
return Err(RcfError::OutOfBounds {
index: d,
len: self.low.len(),
});
}
self.low[d] += value;
Ok(())
}
pub fn accumulate(&mut self, other: &Self) -> RcfResult<()> {
if other.dim() != self.dim() {
return Err(RcfError::DimensionMismatch {
expected: self.dim(),
got: other.dim(),
});
}
for d in 0..self.dim() {
self.high[d] += other.high[d];
self.low[d] += other.low[d];
}
Ok(())
}
pub fn scale(&mut self, divisor: f64) -> RcfResult<()> {
if divisor == 0.0 || !divisor.is_finite() {
return Err(RcfError::InvalidConfig(
format!("DiVector::scale divisor must be non-zero and finite, got {divisor}")
.into(),
));
}
for d in 0..self.dim() {
self.high[d] /= divisor;
self.low[d] /= divisor;
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
#[test]
fn zeros_creates_dim_sized_vector() {
let v = DiVector::zeros(5);
assert_eq!(v.dim(), 5);
assert_eq!(v.high(), &[0.0; 5]);
assert_eq!(v.low(), &[0.0; 5]);
assert_eq!(v.total(), 0.0);
}
#[test]
fn add_high_and_low_accumulate() {
let mut v = DiVector::zeros(3);
v.add_high(0, 1.0).unwrap();
v.add_high(0, 2.0).unwrap();
v.add_low(2, 4.0).unwrap();
assert_eq!(v.high(), &[3.0, 0.0, 0.0]);
assert_eq!(v.low(), &[0.0, 0.0, 4.0]);
assert_eq!(v.total(), 7.0);
assert_eq!(v.per_dim_total(0), 3.0);
assert_eq!(v.per_dim_total(2), 4.0);
}
#[test]
fn add_high_oob() {
let mut v = DiVector::zeros(2);
let err = v.add_high(3, 1.0).unwrap_err();
assert!(matches!(err, RcfError::OutOfBounds { index: 3, len: 2 }));
}
#[test]
fn add_low_oob() {
let mut v = DiVector::zeros(2);
assert!(matches!(
v.add_low(99, 1.0).unwrap_err(),
RcfError::OutOfBounds { .. }
));
}
#[test]
fn accumulate_sums_componentwise() {
let mut a = DiVector::zeros(2);
a.add_high(0, 1.0).unwrap();
a.add_low(1, 2.0).unwrap();
let mut b = DiVector::zeros(2);
b.add_high(0, 4.0).unwrap();
b.add_low(1, 8.0).unwrap();
a.accumulate(&b).unwrap();
assert_eq!(a.high(), &[5.0, 0.0]);
assert_eq!(a.low(), &[0.0, 10.0]);
}
#[test]
fn accumulate_rejects_dim_mismatch() {
let mut a = DiVector::zeros(2);
let b = DiVector::zeros(3);
assert!(matches!(
a.accumulate(&b).unwrap_err(),
RcfError::DimensionMismatch { .. }
));
}
#[test]
fn scale_divides_componentwise() {
let mut v = DiVector::zeros(2);
v.add_high(0, 10.0).unwrap();
v.add_low(1, 6.0).unwrap();
v.scale(2.0).unwrap();
assert_eq!(v.high(), &[5.0, 0.0]);
assert_eq!(v.low(), &[0.0, 3.0]);
}
#[test]
fn scale_rejects_zero() {
let mut v = DiVector::zeros(1);
assert!(matches!(
v.scale(0.0).unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn scale_rejects_nan_infinity() {
let mut v = DiVector::zeros(1);
assert!(v.scale(f64::NAN).is_err());
assert!(v.scale(f64::INFINITY).is_err());
}
#[test]
fn argmax_picks_largest() {
let mut v = DiVector::zeros(4);
v.add_high(2, 5.0).unwrap();
v.add_low(1, 1.0).unwrap();
assert_eq!(v.argmax(), Some(2));
}
#[test]
fn argmax_zero_dim_returns_none() {
let v = DiVector::zeros(0);
assert!(v.argmax().is_none());
}
#[test]
fn argmax_ties_returns_first() {
let mut v = DiVector::zeros(3);
v.add_high(0, 5.0).unwrap();
v.add_high(2, 5.0).unwrap();
assert_eq!(v.argmax(), Some(0));
}
}