use crate::params::DsfbParams;
use crate::state::DsfbState;
use crate::trust::{calculate_trust_weights, TrustStats};
pub struct DsfbObserver {
params: DsfbParams,
channels: usize,
state: DsfbState,
ema_residuals: Vec<f64>,
trust_stats: Vec<TrustStats>,
}
impl DsfbObserver {
pub fn new(params: DsfbParams, channels: usize) -> Self {
Self {
params,
channels,
state: DsfbState::zero(),
ema_residuals: vec![0.0; channels],
trust_stats: vec![TrustStats::new(); channels],
}
}
pub fn init(&mut self, initial_state: DsfbState) {
self.state = initial_state;
}
pub fn step(&mut self, measurements: &[f64], dt: f64) -> DsfbState {
assert_eq!(
measurements.len(),
self.channels,
"Measurement count mismatch"
);
let phi_pred = self.state.phi + self.state.omega * dt;
let omega_pred = self.state.omega + self.state.alpha * dt;
let alpha_pred = self.state.alpha;
let h_pred = phi_pred;
let residuals: Vec<f64> = measurements.iter().map(|&y| y - h_pred).collect();
let weights = calculate_trust_weights(
&residuals,
&mut self.ema_residuals,
self.params.rho,
self.params.sigma0,
);
for (k, &weight) in weights.iter().enumerate().take(self.channels) {
self.trust_stats[k].residual_ema = self.ema_residuals[k];
self.trust_stats[k].weight = weight;
}
let aggregate_residual: f64 = residuals
.iter()
.zip(weights.iter())
.map(|(&r, &w)| w * r)
.sum();
let phi = phi_pred + self.params.k_phi * aggregate_residual;
let omega = omega_pred + self.params.k_omega * aggregate_residual;
let alpha = alpha_pred + self.params.k_alpha * aggregate_residual;
self.state = DsfbState::new(phi, omega, alpha);
self.state
}
pub fn state(&self) -> DsfbState {
self.state
}
pub fn trust_stats(&self) -> &[TrustStats] {
&self.trust_stats
}
pub fn trust_weight(&self, channel: usize) -> f64 {
self.trust_stats[channel].weight
}
pub fn ema_residual(&self, channel: usize) -> f64 {
self.trust_stats[channel].residual_ema
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_observer_creation() {
let params = DsfbParams::default();
let observer = DsfbObserver::new(params, 2);
assert_eq!(observer.channels, 2);
assert_eq!(observer.state.phi, 0.0);
}
#[test]
fn test_observer_step_no_residual() {
let params = DsfbParams::new(0.5, 0.1, 0.01, 0.9, 0.1);
let mut observer = DsfbObserver::new(params, 2);
observer.init(DsfbState::new(1.0, 0.1, 0.0));
let dt = 0.1;
let measurements = vec![1.01, 1.01]; let state = observer.step(&measurements, dt);
assert!(state.phi > 1.0);
}
#[test]
fn test_observer_trust_weights_sum() {
let params = DsfbParams::default();
let mut observer = DsfbObserver::new(params, 3);
let measurements = vec![0.5, 1.5, 2.5];
observer.step(&measurements, 0.1);
let sum: f64 = (0..3).map(|i| observer.trust_weight(i)).sum();
assert!((sum - 1.0).abs() < 1e-10);
}
}