#![allow(clippy::needless_range_loop)]
use nalgebra::{Const, DimName, OMatrix, OVector};
use super::types::UpdateStats;
#[derive(Clone)]
pub struct InverseQrRls<const N: usize, const P: usize> {
x: OMatrix<f32, Const<N>, Const<P>>,
l: OMatrix<f32, Const<N>, Const<N>>,
lambda: f32,
sqrt_lambda: f32,
samples: u32,
}
impl<const N: usize, const P: usize> InverseQrRls<N, P>
where
Const<N>: DimName,
Const<P>: DimName,
{
pub fn new(gamma: f32, lambda: f32) -> Self {
let inv_sqrt_gamma = 1.0 / libm::sqrtf(gamma);
let mut l = OMatrix::<f32, Const<N>, Const<N>>::zeros();
for i in 0..N {
l[(i, i)] = inv_sqrt_gamma;
}
Self {
x: OMatrix::<f32, Const<N>, Const<P>>::zeros(),
l,
lambda,
sqrt_lambda: libm::sqrtf(lambda),
samples: 0,
}
}
pub fn from_time_constant(gamma: f32, ts: f32, t_char: f32) -> Self {
let lambda = libm::powf(1.0 - core::f32::consts::LN_2, ts / t_char);
Self::new(gamma, lambda)
}
#[inline]
pub fn update(
&mut self,
a: &OVector<f32, Const<N>>,
y: &OVector<f32, Const<P>>,
) -> UpdateStats {
self.samples += 1;
let sqrt_lambda = self.sqrt_lambda;
for col in 0..N {
for row in col..N {
self.l[(row, col)] *= sqrt_lambda;
}
}
let mut a_work = [0.0f32; N];
for i in 0..N {
a_work[i] = a[i];
}
let mut k_bar = [0.0f32; N];
let mut gamma: f32 = 1.0;
for j in 0..N {
let (c, s) = crate::givens::givens(self.l[(j, j)], a_work[j]);
for k in j..N {
let lkj = self.l[(k, j)]; let ak = a_work[k];
self.l[(k, j)] = c * lkj - s * ak;
a_work[k] = s * lkj + c * ak;
}
k_bar[j] = -s * gamma;
gamma *= c;
}
let mut gain = k_bar;
for j in (0..N).rev() {
for k in (j + 1)..N {
gain[j] -= self.l[(k, j)] * gain[k];
}
gain[j] /= self.l[(j, j)];
}
for p in 0..P {
let mut e = y[p];
for i in 0..N {
e -= a[i] * self.x[(i, p)];
}
for i in 0..N {
self.x[(i, p)] += gain[i] * e;
}
}
UpdateStats {
exit_code: super::types::ExitCode::Success,
samples: self.samples,
}
}
#[inline]
pub fn params(&self) -> &OMatrix<f32, Const<N>, Const<P>> {
&self.x
}
#[inline]
pub fn params_mut(&mut self) -> &mut OMatrix<f32, Const<N>, Const<P>> {
&mut self.x
}
#[inline]
pub fn info_factor(&self) -> &OMatrix<f32, Const<N>, Const<N>> {
&self.l
}
#[inline]
pub fn lambda(&self) -> f32 {
self.lambda
}
#[inline]
pub fn set_lambda(&mut self, lambda: f32) {
self.lambda = lambda;
self.sqrt_lambda = libm::sqrtf(lambda);
}
#[inline]
pub fn samples(&self) -> u32 {
self.samples
}
}