use crate::TimeSeries;
use torsh_core::error::Result;
use torsh_tensor::{
creation::{eye, ones, zeros},
Tensor,
};
pub struct KalmanFilter {
state_dim: usize,
obs_dim: usize,
transition: Tensor,
observation: Tensor,
process_noise: Tensor,
measurement_noise: Tensor,
state: Tensor,
covariance: Tensor,
}
impl KalmanFilter {
pub fn new(state_dim: usize, obs_dim: usize) -> Self {
Self {
state_dim,
obs_dim,
transition: eye(state_dim).expect("tensor creation should succeed"),
observation: ones(&[obs_dim, state_dim]).expect("tensor creation should succeed"),
process_noise: eye(state_dim)
.expect("tensor creation should succeed")
.mul_scalar(0.01)
.expect("scalar mul should succeed"),
measurement_noise: eye(obs_dim)
.expect("tensor creation should succeed")
.mul_scalar(0.1)
.expect("scalar mul should succeed"),
state: zeros(&[state_dim, 1]).expect("tensor creation should succeed"), covariance: eye(state_dim).expect("tensor creation should succeed"),
}
}
pub fn with_matrices(
state_dim: usize,
obs_dim: usize,
transition: Tensor,
observation: Tensor,
process_noise: Tensor,
measurement_noise: Tensor,
) -> Self {
Self {
state_dim,
obs_dim,
transition,
observation,
process_noise,
measurement_noise,
state: zeros(&[state_dim, 1]).expect("tensor creation should succeed"), covariance: eye(state_dim).expect("tensor creation should succeed"),
}
}
pub fn dimensions(&self) -> (usize, usize) {
(self.state_dim, self.obs_dim)
}
pub fn set_transition(&mut self, matrix: Tensor) {
self.transition = matrix;
}
pub fn set_observation(&mut self, matrix: Tensor) {
self.observation = matrix;
}
pub fn set_process_noise(&mut self, matrix: Tensor) {
self.process_noise = matrix;
}
pub fn set_measurement_noise(&mut self, matrix: Tensor) {
self.measurement_noise = matrix;
}
pub fn transition_matrix(&self) -> &Tensor {
&self.transition
}
pub fn observation_matrix(&self) -> &Tensor {
&self.observation
}
pub fn state(&self) -> &Tensor {
&self.state
}
pub fn covariance(&self) -> &Tensor {
&self.covariance
}
pub fn set_initial_state(&mut self, state: Tensor, covariance: Tensor) {
self.state = state;
self.covariance = covariance;
}
pub fn predict(&mut self) -> Result<Tensor> {
self.state = self.transition.matmul(&self.state)?;
let f_p = self.transition.matmul(&self.covariance)?;
let f_p_ft = f_p.matmul(&self.transition.transpose(0, 1)?)?;
self.covariance = f_p_ft.add(&self.process_noise)?;
Ok(self.state.clone())
}
pub fn update(&mut self, observation: &Tensor) -> Result<()> {
let obs_reshaped = if observation.ndim() == 1 {
observation.view(&[self.obs_dim as i32, 1])?
} else {
observation.clone()
};
let h_x = self.observation.matmul(&self.state)?;
let innovation = obs_reshaped.add(&h_x.mul_scalar(-1.0)?)?;
let h_p = self.observation.matmul(&self.covariance)?;
let h_p_ht = h_p.matmul(&self.observation.transpose(0, 1)?)?;
let innovation_cov = h_p_ht.add(&self.measurement_noise)?;
let p_ht = self.covariance.matmul(&self.observation.transpose(0, 1)?)?;
let lambda = 1e-6f32;
let reg_eye = eye(self.obs_dim)?.mul_scalar(lambda)?;
let innovation_cov_reg = innovation_cov.add(®_eye)?;
let kalman_gain = if self.obs_dim == 1 {
let s_scalar = innovation_cov_reg.get_item_flat(0)? + 1e-10f32; p_ht.div_scalar(s_scalar)?
} else {
p_ht.div_scalar(innovation_cov_reg.get_item_flat(0)? + 1e-10f32)?
};
let k_times_innovation = kalman_gain.matmul(&innovation)?;
self.state = self.state.add(&k_times_innovation)?;
let k_h = kalman_gain.matmul(&self.observation)?;
let identity = eye(self.state_dim)?;
let i_minus_kh = identity.add(&k_h.mul_scalar(-1.0)?)?;
self.covariance = i_minus_kh.matmul(&self.covariance)?;
Ok(())
}
pub fn filter(&mut self, series: &TimeSeries) -> Result<TimeSeries> {
let mut filtered_states = Vec::new();
for t in 0..series.len() {
self.predict()?;
let obs_value = series.values.get_item_flat(t)?;
let obs = Tensor::from_vec(vec![obs_value], &[1])?;
self.update(&obs)?;
let mut state_vec = Vec::new();
for i in 0..self.state_dim {
let val = self.state.get_item_flat(i)?;
state_vec.push(val);
}
filtered_states.extend(state_vec);
}
let values = Tensor::from_vec(filtered_states, &[series.len(), self.state_dim])?;
Ok(TimeSeries::new(values))
}
pub fn smooth(&mut self, series: &TimeSeries) -> Result<TimeSeries> {
let filtered = self.filter(series)?;
Ok(filtered)
}
pub fn innovation(&self, observation: &Tensor) -> Result<Tensor> {
let obs_reshaped = if observation.ndim() == 1 {
observation.view(&[self.obs_dim as i32, 1])?
} else {
observation.clone()
};
let h_x = self.observation.matmul(&self.state)?;
obs_reshaped.add(&h_x.mul_scalar(-1.0)?)
}
pub fn innovation_covariance(&self) -> Result<Tensor> {
let h_p = self.observation.matmul(&self.covariance)?;
let h_p_ht = h_p.matmul(&self.observation.transpose(0, 1)?)?;
h_p_ht.add(&self.measurement_noise)
}
pub fn kalman_gain(&self) -> Result<Tensor> {
let innovation_cov = self.innovation_covariance()?;
let p_ht = self.covariance.matmul(&self.observation.transpose(0, 1)?)?;
let lambda = 1e-6f32;
let reg_eye = eye(self.obs_dim)?.mul_scalar(lambda)?;
let innovation_cov_reg = innovation_cov.add(®_eye)?;
let kalman_gain = if self.obs_dim == 1 {
let s_scalar = innovation_cov_reg.get_item_flat(0)? + 1e-10f32;
p_ht.div_scalar(s_scalar)?
} else {
p_ht.div_scalar(innovation_cov_reg.get_item_flat(0)? + 1e-10f32)?
};
Ok(kalman_gain)
}
pub fn log_likelihood(&mut self, series: &TimeSeries) -> Result<f32> {
self.reset();
let mut log_likelihood = 0.0f32;
let n = series.len() as f32;
let two_pi = 2.0 * std::f32::consts::PI;
let obs_dim_f32 = self.obs_dim as f32;
let log_normalization = -0.5 * obs_dim_f32 * two_pi.ln();
for t in 0..series.len() {
self.predict()?;
let obs_value = series.values.get_item_flat(t)?;
let obs = Tensor::from_vec(vec![obs_value], &[1])?;
let innovation = self.innovation(&obs)?;
let innovation_cov = self.innovation_covariance()?;
let lambda = 1e-6f32;
let reg_eye = eye(self.obs_dim)?.mul_scalar(lambda)?;
let innovation_cov_reg = innovation_cov.add(®_eye)?;
if self.obs_dim == 1 {
let innovation_val = innovation.get_item_flat(0)?;
let cov_val = innovation_cov_reg.get_item_flat(0)?.max(1e-10f32);
let log_det_term = cov_val.ln();
let quadratic_term = (innovation_val * innovation_val) / cov_val;
log_likelihood += log_normalization - 0.5 * (log_det_term + quadratic_term);
} else {
let innovation_norm_sq: f32 = (0..self.obs_dim)
.map(|i| {
let val = innovation.get_item_flat(i).unwrap_or(0.0);
val * val
})
.sum();
let cov_trace: f32 = (0..self.obs_dim)
.map(|i| {
innovation_cov_reg
.get_item_flat(i * self.obs_dim + i)
.unwrap_or(1.0)
})
.sum();
log_likelihood += log_normalization
- 0.5 * (cov_trace.ln() + innovation_norm_sq / cov_trace.max(1e-10f32));
}
self.update(&obs)?;
}
Ok(log_likelihood / n) }
pub fn reset(&mut self) {
self.state = zeros(&[self.state_dim, 1]).expect("tensor creation should succeed"); self.covariance = eye(self.state_dim).expect("tensor creation should succeed");
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_series() -> TimeSeries {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let tensor = Tensor::from_vec(data, &[5]).expect("Tensor should succeed");
TimeSeries::new(tensor)
}
#[test]
fn test_kalman_filter_creation() {
let kf = KalmanFilter::new(2, 1);
let (state_dim, obs_dim) = kf.dimensions();
assert_eq!(state_dim, 2);
assert_eq!(obs_dim, 1);
}
#[test]
fn test_kalman_filter_with_matrices() {
let transition = eye(2).expect("eye should succeed");
let observation = ones(&[1, 2]).expect("ones should succeed");
let process_noise = eye(2).expect("eye should succeed");
let measurement_noise = eye(1).expect("eye should succeed");
let kf = KalmanFilter::with_matrices(
2,
1,
transition,
observation,
process_noise,
measurement_noise,
);
let (state_dim, obs_dim) = kf.dimensions();
assert_eq!(state_dim, 2);
assert_eq!(obs_dim, 1);
}
#[test]
fn test_kalman_filter_matrices() {
let mut kf = KalmanFilter::new(2, 1);
let new_transition = eye(2).expect("eye should succeed");
kf.set_transition(new_transition);
assert_eq!(kf.transition_matrix().shape().dims(), [2, 2]);
}
#[test]
fn test_kalman_filter_state() {
let mut kf = KalmanFilter::new(2, 1);
let initial_state = zeros(&[2, 1]).expect("zeros should succeed"); let initial_cov = eye(2).expect("eye should succeed");
kf.set_initial_state(initial_state, initial_cov);
assert_eq!(kf.state().shape().dims(), [2, 1]);
assert_eq!(kf.covariance().shape().dims(), [2, 2]);
}
#[test]
fn test_kalman_filter_predict() {
let mut kf = KalmanFilter::new(2, 1);
let prediction = kf.predict().expect("prediction should succeed");
assert_eq!(prediction.shape().dims(), [2, 1]); }
#[test]
fn test_kalman_filter_update() {
let mut kf = KalmanFilter::new(2, 1);
let obs = zeros(&[1]).expect("zeros should succeed");
kf.update(&obs).expect("update operation should succeed");
}
#[test]
fn test_kalman_filter_filter() {
let series = create_test_series();
let mut kf = KalmanFilter::new(1, 1);
let filtered = kf.filter(&series).expect("filter operation should succeed");
assert_eq!(filtered.len(), series.len());
}
#[test]
fn test_kalman_filter_smooth() {
let series = create_test_series();
let mut kf = KalmanFilter::new(1, 1);
let smoothed = kf.smooth(&series).expect("smoothing should succeed");
assert_eq!(smoothed.len(), series.len());
}
#[test]
fn test_kalman_filter_innovation() {
let kf = KalmanFilter::new(1, 1);
let obs = ones(&[1]).expect("ones should succeed");
let innovation = kf
.innovation(&obs)
.expect("innovation computation should succeed");
assert_eq!(innovation.shape().dims(), [1, 1]); }
#[test]
fn test_kalman_filter_log_likelihood() {
let series = create_test_series();
let mut kf = KalmanFilter::new(1, 1);
let ll = kf
.log_likelihood(&series)
.expect("log-likelihood computation should succeed");
assert!(ll < 0.0); assert!(ll.is_finite()); }
#[test]
fn test_kalman_filter_reset() {
let mut kf = KalmanFilter::new(2, 1);
kf.reset();
assert_eq!(kf.state().shape().dims(), [2, 1]); assert_eq!(kf.covariance().shape().dims(), [2, 2]);
}
}