use crate::error::InterpolateError;
#[inline]
fn inv_multiquadric(r: f64, eps: f64) -> f64 {
1.0 / (1.0 + (eps * r) * (eps * r)).sqrt()
}
#[derive(Debug, Clone)]
pub struct StreamingRbfConfig {
pub window_size: usize,
pub shape_param: f64,
pub forget_factor: f64,
}
impl Default for StreamingRbfConfig {
fn default() -> Self {
Self {
window_size: 200,
shape_param: 1.0,
forget_factor: 0.99,
}
}
}
#[derive(Debug, Clone)]
pub struct StreamingRbf {
config: StreamingRbfConfig,
points: Vec<Vec<f64>>,
values: Vec<f64>,
l: Vec<Vec<f64>>,
coeffs: Vec<f64>,
coeffs_dirty: bool,
}
impl StreamingRbf {
pub fn new(config: StreamingRbfConfig) -> Self {
Self {
config,
points: Vec::new(),
values: Vec::new(),
l: Vec::new(),
coeffs: Vec::new(),
coeffs_dirty: false,
}
}
pub fn n_points(&self) -> usize {
self.points.len()
}
fn dist(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| (ai - bi) * (ai - bi))
.sum::<f64>()
.sqrt()
}
fn phi(&self, a: &[f64], b: &[f64]) -> f64 {
inv_multiquadric(Self::dist(a, b), self.config.shape_param)
}
fn full_cholesky(g: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, InterpolateError> {
let n = g.len();
let ridge = 1e-12;
let mut l = vec![vec![0.0_f64; n]; n];
for i in 0..n {
for j in 0..=i {
let mut s = g[i][j] + if i == j { ridge } else { 0.0 };
for k in 0..j {
s -= l[i][k] * l[j][k];
}
if i == j {
if s <= 0.0 {
s = ridge;
}
l[i][j] = s.sqrt();
} else {
l[i][j] = s / l[j][j];
}
}
}
Ok(l)
}
fn forward_sub(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let n = l.len();
let mut y = vec![0.0; n];
for i in 0..n {
let mut s = b[i];
for k in 0..i {
s -= l[i][k] * y[k];
}
y[i] = s / l[i][i];
}
y
}
fn back_sub(l: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
let n = l.len();
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut s = y[i];
for k in (i + 1)..n {
s -= l[k][i] * x[k];
}
x[i] = s / l[i][i];
}
x
}
fn solve_cholesky(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let y = Self::forward_sub(l, b);
Self::back_sub(l, &y)
}
pub fn cholesky_rank1_update(l: &mut Vec<Vec<f64>>, v: &[f64]) -> Result<(), InterpolateError> {
let old_n = l.len();
let new_n = old_n + 1;
if v.len() != new_n {
return Err(InterpolateError::DimensionMismatch(format!(
"cholesky_rank1_update: v has length {} but expected {}",
v.len(),
new_n
)));
}
for row in l.iter_mut() {
row.push(0.0);
}
l.push(vec![0.0; new_n]);
let w = if old_n > 0 {
let v_sub = &v[..old_n];
let l_old: Vec<Vec<f64>> = l[..old_n].iter().map(|r| r[..old_n].to_vec()).collect();
Self::forward_sub(&l_old, v_sub)
} else {
vec![]
};
for j in 0..old_n {
l[old_n][j] = w[j];
}
let w_norm2: f64 = w.iter().map(|&wi| wi * wi).sum();
let diag2 = v[old_n] - w_norm2;
let diag = if diag2 <= 0.0 {
1e-10_f64
} else {
diag2.sqrt()
};
l[old_n][old_n] = diag;
Ok(())
}
#[allow(dead_code)]
fn remove_first_point(&mut self) -> Result<(), InterpolateError> {
self.points.remove(0);
self.values.remove(0);
let n = self.points.len();
if n == 0 {
self.l.clear();
self.coeffs.clear();
self.coeffs_dirty = false;
return Ok(());
}
let mut g = vec![vec![0.0_f64; n]; n];
for i in 0..n {
for j in 0..n {
let r = Self::dist(&self.points[i], &self.points[j]);
g[i][j] = inv_multiquadric(r, self.config.shape_param);
}
}
self.l = Self::full_cholesky(&g)?;
self.coeffs_dirty = true;
Ok(())
}
fn refresh_coeffs(&mut self) {
if !self.coeffs_dirty || self.points.is_empty() {
return;
}
self.coeffs = Self::solve_cholesky(&self.l, &self.values);
self.coeffs_dirty = false;
}
pub fn update(&mut self, x: Vec<f64>, y: f64) -> Result<(), InterpolateError> {
let gamma = self.config.forget_factor;
if self.points.len() >= self.config.window_size {
self.points.remove(0);
self.values.remove(0);
}
if gamma < 1.0 {
for v in &mut self.values {
*v *= gamma;
}
}
self.points.push(x);
self.values.push(y);
let n = self.points.len();
let eps = self.config.shape_param;
let mut g = vec![vec![0.0_f64; n]; n];
for i in 0..n {
for j in 0..n {
let r = Self::dist(&self.points[i], &self.points[j]);
g[i][j] = inv_multiquadric(r, eps);
}
}
self.l = Self::full_cholesky(&g)?;
self.coeffs_dirty = true;
Ok(())
}
pub fn predict(&mut self, x: &[f64]) -> Result<f64, InterpolateError> {
if self.points.is_empty() {
return Err(InterpolateError::InvalidState(
"StreamingRbf has no data yet".to_string(),
));
}
self.refresh_coeffs();
let val: f64 = self
.points
.iter()
.zip(self.coeffs.iter())
.map(|(pt, &alpha)| {
let r = Self::dist(x, pt);
alpha * inv_multiquadric(r, self.config.shape_param)
})
.sum();
Ok(val)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_streaming_sin_prediction() {
let mut rbf = StreamingRbf::new(StreamingRbfConfig {
window_size: 20,
shape_param: 1.0,
forget_factor: 1.0, });
for i in 0..10 {
let xi = (i as f64) * PI / 9.0;
rbf.update(vec![xi], xi.sin())
.expect("update should succeed");
}
assert_eq!(rbf.n_points(), 10);
let test_x = PI / 4.0;
let pred = rbf.predict(&[test_x]).expect("predict should succeed");
let expected = test_x.sin();
assert!(
(pred - expected).abs() < 0.1,
"sin prediction off: pred={:.4}, expected={:.4}",
pred,
expected
);
}
#[test]
fn test_window_eviction() {
let window = 5;
let mut rbf = StreamingRbf::new(StreamingRbfConfig {
window_size: window,
shape_param: 1.0,
forget_factor: 1.0,
});
for i in 0..10 {
let xi = i as f64 * 0.1;
rbf.update(vec![xi], xi * xi).expect("update");
}
assert_eq!(
rbf.n_points(),
window,
"window should be capped at {window}"
);
}
#[test]
fn test_predict_empty_returns_error() {
let mut rbf = StreamingRbf::new(StreamingRbfConfig::default());
let result = rbf.predict(&[0.5]);
assert!(result.is_err());
}
#[test]
fn test_forget_factor_effect() {
let mut rbf = StreamingRbf::new(StreamingRbfConfig {
window_size: 100,
shape_param: 1.0,
forget_factor: 0.5,
});
for i in 0..5 {
rbf.update(vec![i as f64 * 0.2], 0.0).expect("update old");
}
for i in 0..5 {
rbf.update(vec![i as f64 * 0.2 + 0.05], 1.0)
.expect("update new");
}
let pred = rbf.predict(&[0.5]).expect("predict");
assert!(
pred > 0.0,
"prediction with forgetting should lean towards recent data, got {pred}"
);
}
#[test]
fn test_cholesky_rank1_update_basic() {
let g00 = 4.0_f64;
let g01 = 1.0_f64;
let g11 = 4.0_f64;
let mut l = vec![vec![g00.sqrt()]];
let v = vec![g01, g11];
StreamingRbf::cholesky_rank1_update(&mut l, &v).expect("rank-1 update");
assert_eq!(l.len(), 2);
let rec_g00 = l[0][0] * l[0][0];
let rec_g10 = l[1][0] * l[0][0];
let rec_g11 = l[1][0] * l[1][0] + l[1][1] * l[1][1];
assert!(
(rec_g00 - g00).abs() < 1e-10,
"G[0][0] mismatch: {} vs {}",
rec_g00,
g00
);
assert!(
(rec_g10 - g01).abs() < 1e-10,
"G[1][0] mismatch: {} vs {}",
rec_g10,
g01
);
assert!(
(rec_g11 - g11).abs() < 1e-10,
"G[1][1] mismatch: {} vs {}",
rec_g11,
g11
);
}
}