anomstream_core/domain/
divector.rs1use alloc::format;
12use alloc::vec;
13use alloc::vec::Vec;
14
15use crate::error::{RcfError, RcfResult};
16
17#[derive(Debug, Clone, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct DiVector {
21 high: Vec<f64>,
23 low: Vec<f64>,
25}
26
27impl DiVector {
28 #[must_use]
39 pub fn zeros(dim: usize) -> Self {
40 Self {
41 high: vec![0.0; dim],
42 low: vec![0.0; dim],
43 }
44 }
45
46 pub fn from_arrays(high: Vec<f64>, low: Vec<f64>) -> RcfResult<Self> {
56 if high.len() != low.len() {
57 return Err(RcfError::DimensionMismatch {
58 expected: high.len(),
59 got: low.len(),
60 });
61 }
62 Ok(Self { high, low })
63 }
64
65 #[must_use]
67 pub fn dim(&self) -> usize {
68 self.high.len()
69 }
70
71 #[must_use]
73 pub fn high(&self) -> &[f64] {
74 &self.high
75 }
76
77 #[must_use]
79 pub fn low(&self) -> &[f64] {
80 &self.low
81 }
82
83 #[must_use]
85 pub fn total(&self) -> f64 {
86 self.high.iter().sum::<f64>() + self.low.iter().sum::<f64>()
87 }
88
89 #[must_use]
95 pub fn per_dim_total(&self, d: usize) -> f64 {
96 self.high[d] + self.low[d]
97 }
98
99 #[must_use]
102 pub fn argmax(&self) -> Option<usize> {
103 if self.dim() == 0 {
104 return None;
105 }
106 let mut best = 0_usize;
107 let mut best_val = self.per_dim_total(0);
108 for d in 1..self.dim() {
109 let v = self.per_dim_total(d);
110 if v > best_val {
111 best = d;
112 best_val = v;
113 }
114 }
115 Some(best)
116 }
117
118 pub fn add_high(&mut self, d: usize, value: f64) -> RcfResult<()> {
124 if d >= self.high.len() {
125 return Err(RcfError::OutOfBounds {
126 index: d,
127 len: self.high.len(),
128 });
129 }
130 self.high[d] += value;
131 Ok(())
132 }
133
134 pub fn add_low(&mut self, d: usize, value: f64) -> RcfResult<()> {
140 if d >= self.low.len() {
141 return Err(RcfError::OutOfBounds {
142 index: d,
143 len: self.low.len(),
144 });
145 }
146 self.low[d] += value;
147 Ok(())
148 }
149
150 pub fn accumulate(&mut self, other: &Self) -> RcfResult<()> {
156 if other.dim() != self.dim() {
157 return Err(RcfError::DimensionMismatch {
158 expected: self.dim(),
159 got: other.dim(),
160 });
161 }
162 for d in 0..self.dim() {
163 self.high[d] += other.high[d];
164 self.low[d] += other.low[d];
165 }
166 Ok(())
167 }
168
169 pub fn scale(&mut self, divisor: f64) -> RcfResult<()> {
178 if divisor == 0.0 || !divisor.is_finite() {
179 return Err(RcfError::InvalidConfig(
180 format!("DiVector::scale divisor must be non-zero and finite, got {divisor}")
181 .into(),
182 ));
183 }
184 for d in 0..self.dim() {
185 self.high[d] /= divisor;
186 self.low[d] /= divisor;
187 }
188 Ok(())
189 }
190}
191
192#[cfg(test)]
193#[allow(clippy::float_cmp)] mod tests {
195 use super::*;
196
197 #[test]
198 fn zeros_creates_dim_sized_vector() {
199 let v = DiVector::zeros(5);
200 assert_eq!(v.dim(), 5);
201 assert_eq!(v.high(), &[0.0; 5]);
202 assert_eq!(v.low(), &[0.0; 5]);
203 assert_eq!(v.total(), 0.0);
204 }
205
206 #[test]
207 fn add_high_and_low_accumulate() {
208 let mut v = DiVector::zeros(3);
209 v.add_high(0, 1.0).unwrap();
210 v.add_high(0, 2.0).unwrap();
211 v.add_low(2, 4.0).unwrap();
212 assert_eq!(v.high(), &[3.0, 0.0, 0.0]);
213 assert_eq!(v.low(), &[0.0, 0.0, 4.0]);
214 assert_eq!(v.total(), 7.0);
215 assert_eq!(v.per_dim_total(0), 3.0);
216 assert_eq!(v.per_dim_total(2), 4.0);
217 }
218
219 #[test]
220 fn add_high_oob() {
221 let mut v = DiVector::zeros(2);
222 let err = v.add_high(3, 1.0).unwrap_err();
223 assert!(matches!(err, RcfError::OutOfBounds { index: 3, len: 2 }));
224 }
225
226 #[test]
227 fn add_low_oob() {
228 let mut v = DiVector::zeros(2);
229 assert!(matches!(
230 v.add_low(99, 1.0).unwrap_err(),
231 RcfError::OutOfBounds { .. }
232 ));
233 }
234
235 #[test]
236 fn accumulate_sums_componentwise() {
237 let mut a = DiVector::zeros(2);
238 a.add_high(0, 1.0).unwrap();
239 a.add_low(1, 2.0).unwrap();
240 let mut b = DiVector::zeros(2);
241 b.add_high(0, 4.0).unwrap();
242 b.add_low(1, 8.0).unwrap();
243 a.accumulate(&b).unwrap();
244 assert_eq!(a.high(), &[5.0, 0.0]);
245 assert_eq!(a.low(), &[0.0, 10.0]);
246 }
247
248 #[test]
249 fn accumulate_rejects_dim_mismatch() {
250 let mut a = DiVector::zeros(2);
251 let b = DiVector::zeros(3);
252 assert!(matches!(
253 a.accumulate(&b).unwrap_err(),
254 RcfError::DimensionMismatch { .. }
255 ));
256 }
257
258 #[test]
259 fn scale_divides_componentwise() {
260 let mut v = DiVector::zeros(2);
261 v.add_high(0, 10.0).unwrap();
262 v.add_low(1, 6.0).unwrap();
263 v.scale(2.0).unwrap();
264 assert_eq!(v.high(), &[5.0, 0.0]);
265 assert_eq!(v.low(), &[0.0, 3.0]);
266 }
267
268 #[test]
269 fn scale_rejects_zero() {
270 let mut v = DiVector::zeros(1);
271 assert!(matches!(
272 v.scale(0.0).unwrap_err(),
273 RcfError::InvalidConfig(_)
274 ));
275 }
276
277 #[test]
278 fn scale_rejects_nan_infinity() {
279 let mut v = DiVector::zeros(1);
280 assert!(v.scale(f64::NAN).is_err());
281 assert!(v.scale(f64::INFINITY).is_err());
282 }
283
284 #[test]
285 fn argmax_picks_largest() {
286 let mut v = DiVector::zeros(4);
287 v.add_high(2, 5.0).unwrap();
288 v.add_low(1, 1.0).unwrap();
289 assert_eq!(v.argmax(), Some(2));
290 }
291
292 #[test]
293 fn argmax_zero_dim_returns_none() {
294 let v = DiVector::zeros(0);
295 assert!(v.argmax().is_none());
296 }
297
298 #[test]
299 fn argmax_ties_returns_first() {
300 let mut v = DiVector::zeros(3);
301 v.add_high(0, 5.0).unwrap();
302 v.add_high(2, 5.0).unwrap();
303 assert_eq!(v.argmax(), Some(0));
304 }
305}