Skip to main content

scirs2_stats/
state_space.rs

1//! State Space Models
2//!
3//! This module provides linear and nonlinear state space model implementations:
4//!
5//! - **Kalman Filter** – optimal linear-Gaussian filter with Rauch-Tung-Striebel
6//!   smoother and full log-likelihood computation.
7//! - **Unscented Kalman Filter (UKF)** – sigma-point method for nonlinear systems.
8//! - **Structural Time Series** – local-level (random walk + noise) and local-linear
9//!   trend models with Kalman-filter fitting.
10//!
11//! # Notation
12//!
13//! The standard state-space model is:
14//! ```text
15//! x_{t+1} = F * x_t + w_t,   w_t ~ N(0, Q)   (state transition)
16//! y_t     = H * x_t + v_t,   v_t ~ N(0, R)   (observation)
17//! ```
18//!
19//! where `x_t ∈ ℝ^n` is the latent state, `y_t ∈ ℝ^m` is the observation,
20//! `F ∈ ℝ^{n×n}`, `H ∈ ℝ^{m×n}`, `Q ∈ ℝ^{n×n}`, `R ∈ ℝ^{m×m}`.
21//!
22//! # References
23//! - Kalman, R.E. (1960). "A new approach to linear filtering and prediction problems."
24//!   *J. Basic Engineering* 82(1).
25//! - Rauch, H.E., Tung, F., & Striebel, C.T. (1965). "Maximum likelihood estimates
26//!   of linear dynamic systems." *AIAA J.* 3(8).
27//! - Julier, S.J. & Uhlmann, J.K. (1997). "New extension of the Kalman filter to
28//!   nonlinear systems." *Proc. SPIE* 3068.
29//! - Durbin, J. & Koopman, S.J. (2012). *Time Series Analysis by State Space Methods*.
30
31use crate::error::{StatsError, StatsResult};
32use scirs2_core::ndarray::{Array1, Array2, Axis};
33
34// ─────────────────────────────────────────────────────────────────────────────
35// KalmanState
36// ─────────────────────────────────────────────────────────────────────────────
37
38/// Kalman filter state: mean vector and covariance matrix.
39#[derive(Clone, Debug)]
40pub struct KalmanState {
41    /// State mean `x` of dimension `n`.
42    pub x: Array1<f64>,
43    /// State covariance `P` of shape `n × n`.
44    pub p: Array2<f64>,
45}
46
47impl KalmanState {
48    /// Create a new `KalmanState`.
49    pub fn new(x: Array1<f64>, p: Array2<f64>) -> StatsResult<Self> {
50        let n = x.len();
51        if p.nrows() != n || p.ncols() != n {
52            return Err(StatsError::DimensionMismatch(format!(
53                "KalmanState: x has length {} but P is {}×{}",
54                n,
55                p.nrows(),
56                p.ncols()
57            )));
58        }
59        Ok(Self { x, p })
60    }
61
62    /// State dimension.
63    #[inline]
64    pub fn dim(&self) -> usize {
65        self.x.len()
66    }
67}
68
69// ─────────────────────────────────────────────────────────────────────────────
70// Kalman Filter
71// ─────────────────────────────────────────────────────────────────────────────
72
73/// Linear Kalman filter.
74///
75/// Provides `predict` and `update` steps that can be composed into a full
76/// filtering pass, together with Rauch-Tung-Striebel (RTS) fixed-interval
77/// smoothing.
78pub struct KalmanFilter;
79
80impl KalmanFilter {
81    /// **Predict** step: propagate state through the transition model.
82    ///
83    /// ```text
84    /// x̄_{t+1} = F * x_t
85    /// P̄_{t+1} = F * P_t * F^T + Q
86    /// ```
87    ///
88    /// # Arguments
89    /// * `state` – current filtered state.
90    /// * `f`     – `n × n` state transition matrix.
91    /// * `q`     – `n × n` process noise covariance.
92    pub fn predict(
93        state: &KalmanState,
94        f: &Array2<f64>,
95        q: &Array2<f64>,
96    ) -> StatsResult<KalmanState> {
97        let n = state.dim();
98        check_square(f, n, "F")?;
99        check_square(q, n, "Q")?;
100
101        let x_pred = mat_vec_mul(f, &state.x)?;
102        let fp = mat_mat_mul(f, &state.p)?;
103        let p_pred = mat_mat_mul_at(&fp, f)? + q;
104
105        KalmanState::new(x_pred, p_pred)
106    }
107
108    /// **Update** (correct) step: incorporate a new observation.
109    ///
110    /// ```text
111    /// v   = y - H * x̄            (innovation)
112    /// S   = H * P̄ * H^T + R      (innovation covariance)
113    /// K   = P̄ * H^T * S^{-1}     (Kalman gain)
114    /// x_t = x̄ + K * v
115    /// P_t = (I - K*H) * P̄
116    /// ```
117    ///
118    /// # Arguments
119    /// * `state` – prior (predicted) state.
120    /// * `y`     – observation vector of dimension `m`.
121    /// * `h`     – `m × n` observation matrix.
122    /// * `r`     – `m × m` observation noise covariance.
123    pub fn update(
124        state: &KalmanState,
125        y: &Array1<f64>,
126        h: &Array2<f64>,
127        r: &Array2<f64>,
128    ) -> StatsResult<KalmanState> {
129        let n = state.dim();
130        let m = y.len();
131
132        if h.nrows() != m || h.ncols() != n {
133            return Err(StatsError::DimensionMismatch(format!(
134                "H must be {}×{}, got {}×{}",
135                m,
136                n,
137                h.nrows(),
138                h.ncols()
139            )));
140        }
141        if r.nrows() != m || r.ncols() != m {
142            return Err(StatsError::DimensionMismatch(format!(
143                "R must be {}×{}, got {}×{}",
144                m,
145                m,
146                r.nrows(),
147                r.ncols()
148            )));
149        }
150
151        // Innovation: v = y - H x̄
152        let hx = mat_vec_mul(h, &state.x)?;
153        let innovation = y - &hx;
154
155        // S = H P̄ H^T + R
156        let hp = mat_mat_mul(h, &state.p)?; // m × n
157        let s = mat_mat_mul_at(&hp, h)? + r; // m × m
158
159        // K = P̄ H^T S^{-1}
160        let ph_t = mat_mat_mul_bt(&state.p, h)?; // n × m
161        let s_inv = inv_symmetric(s)?; // m × m
162        let k = mat_mat_mul(&ph_t, &s_inv)?; // n × m
163
164        // x updated
165        let kv = mat_vec_mul(&k, &innovation)?;
166        let x_upd = &state.x + &kv;
167
168        // P updated: (I - K H) P̄  (Joseph form for numerical stability)
169        let kh = mat_mat_mul(&k, h)?; // n × n
170        let i_kh = eye_minus(kh)?; // n × n
171        let p_upd = mat_mat_mul(&i_kh, &state.p)?; // n × n
172
173        KalmanState::new(x_upd, p_upd)
174    }
175
176    /// Run the full Kalman filter over an observation time series.
177    ///
178    /// Observations may be univariate (`y_t ∈ ℝ`) or multivariate
179    /// (`y_t ∈ ℝ^m`). Pass each row of your observation matrix as a
180    /// separate element in `observations`.
181    ///
182    /// # Arguments
183    /// * `observations` – length-T slice of observation vectors.
184    /// * `f`            – `n × n` transition matrix.
185    /// * `h`            – `m × n` observation matrix.
186    /// * `q`            – `n × n` process noise covariance.
187    /// * `r`            – `m × m` observation noise covariance.
188    /// * `x0`           – initial state mean.
189    /// * `p0`           – initial state covariance.
190    ///
191    /// # Returns
192    /// `(filtered_states, log_likelihood)`.
193    pub fn filter_series(
194        observations: &[Array1<f64>],
195        f: &Array2<f64>,
196        h: &Array2<f64>,
197        q: &Array2<f64>,
198        r: &Array2<f64>,
199        x0: Array1<f64>,
200        p0: Array2<f64>,
201    ) -> StatsResult<(Vec<KalmanState>, f64)> {
202        if observations.is_empty() {
203            return Err(StatsError::InsufficientData(
204                "filter_series: observation list is empty".into(),
205            ));
206        }
207
208        let n = x0.len();
209        let m = observations[0].len();
210
211        if p0.nrows() != n || p0.ncols() != n {
212            return Err(StatsError::DimensionMismatch("p0 must be n×n".into()));
213        }
214
215        let log2pi = (2.0 * std::f64::consts::PI).ln();
216        let mut log_lik = 0.0_f64;
217        let mut states = Vec::with_capacity(observations.len());
218        let mut state = KalmanState::new(x0, p0)?;
219
220        for (t, y) in observations.iter().enumerate() {
221            if y.len() != m {
222                return Err(StatsError::DimensionMismatch(format!(
223                    "Observation {} has length {}, expected {}",
224                    t,
225                    y.len(),
226                    m
227                )));
228            }
229
230            // Predict
231            let pred = Self::predict(&state, f, q)?;
232
233            // Compute innovation and its covariance for log-likelihood
234            let hx = mat_vec_mul(h, &pred.x)?;
235            let innovation = y - &hx;
236            let hp = mat_mat_mul(h, &pred.p)?;
237            let s = mat_mat_mul_at(&hp, h)? + r;
238
239            // Log-likelihood contribution: -0.5 * (m*log2pi + log|S| + v^T S^{-1} v)
240            let s_inv = inv_symmetric(s.clone())?;
241            let log_det_s = log_det_posdef(&s)?;
242            let sv = mat_vec_mul(&s_inv, &innovation)?;
243            let quad: f64 = innovation.iter().zip(sv.iter()).map(|(&a, &b)| a * b).sum();
244            log_lik += -0.5 * (m as f64 * log2pi + log_det_s + quad);
245
246            // Update
247            state = Self::update(&pred, y, h, r)?;
248            states.push(state.clone());
249        }
250
251        Ok((states, log_lik))
252    }
253
254    /// Rauch-Tung-Striebel (RTS) fixed-interval smoother.
255    ///
256    /// Given a sequence of *filtered* states (output of [`KalmanFilter::filter_series`])
257    /// and the transition matrix, computes the smoothed state estimates
258    /// E[x_t | y_{1:T}] for all t.
259    ///
260    /// # Arguments
261    /// * `filtered` – filtered states in forward time order.
262    /// * `f`        – `n × n` transition matrix.
263    /// * `q`        – `n × n` process noise covariance.
264    ///
265    /// # Returns
266    /// Smoothed states in forward time order.
267    pub fn smooth(
268        filtered: &[KalmanState],
269        f: &Array2<f64>,
270        q: &Array2<f64>,
271    ) -> StatsResult<Vec<KalmanState>> {
272        let t_len = filtered.len();
273        if t_len == 0 {
274            return Ok(Vec::new());
275        }
276
277        let mut smoothed = filtered.to_vec();
278
279        for t in (0..t_len - 1).rev() {
280            let n = filtered[t].dim();
281            // Predicted state at t+1 (given filtered state at t)
282            let x_pred = mat_vec_mul(f, &filtered[t].x)?;
283            let fp = mat_mat_mul(f, &filtered[t].p)?;
284            let p_pred = mat_mat_mul_at(&fp, f)? + q;
285
286            // Smoother gain: G_t = P_t * F^T * P̄_{t+1}^{-1}
287            let pf_t = mat_mat_mul_bt(&filtered[t].p, f)?; // n × n
288            let p_pred_inv = inv_symmetric(p_pred)?;
289            let g = mat_mat_mul(&pf_t, &p_pred_inv)?; // n × n
290
291            // Smoothed mean: x_s_t = x_t + G_t * (x_s_{t+1} - x̄_{t+1})
292            let diff = &smoothed[t + 1].x - &x_pred;
293            let g_diff = mat_vec_mul(&g, &diff)?;
294            let x_smooth = &filtered[t].x + &g_diff;
295
296            // Smoothed covariance: P_s_t = P_t + G_t * (P_s_{t+1} - P̄_{t+1}) * G_t^T
297            let dp = &smoothed[t + 1].p
298                - &{
299                    // Reconstruct P̄_{t+1} from filtered[t]
300                    let fp2 = mat_mat_mul(f, &filtered[t].p)?;
301                    mat_mat_mul_at(&fp2, f)? + q
302                };
303            let g_dp = mat_mat_mul(&g, &dp)?;
304            let correction = mat_mat_mul_at(&g_dp, &g)?;
305            let p_smooth = &filtered[t].p + &correction;
306
307            smoothed[t] = KalmanState::new(x_smooth, p_smooth)?;
308        }
309
310        Ok(smoothed)
311    }
312}
313
314// ─────────────────────────────────────────────────────────────────────────────
315// Unscented Kalman Filter (UKF)
316// ─────────────────────────────────────────────────────────────────────────────
317
318/// Parameters for the scaled unscented transform.
319#[derive(Clone, Debug)]
320pub struct UkfParams {
321    /// Spread of sigma points around the mean (typically 1e-3).
322    pub alpha: f64,
323    /// Incorporates prior knowledge of the distribution (typically 0 for Gaussian).
324    pub beta: f64,
325    /// Secondary scaling parameter (typically 0).
326    pub kappa: f64,
327}
328
329impl Default for UkfParams {
330    fn default() -> Self {
331        Self {
332            alpha: 1e-3,
333            beta: 2.0,
334            kappa: 0.0,
335        }
336    }
337}
338
339/// Sigma-point weights for mean and covariance estimation.
340#[derive(Clone, Debug)]
341struct SigmaWeights {
342    /// Weights for mean estimation.
343    wm: Vec<f64>,
344    /// Weights for covariance estimation.
345    wc: Vec<f64>,
346    /// Sigma-point scaling factor.
347    lambda: f64,
348}
349
350impl SigmaWeights {
351    fn compute(n: usize, params: &UkfParams) -> Self {
352        let n_f = n as f64;
353        let lambda = params.alpha * params.alpha * (n_f + params.kappa) - n_f;
354        let n_sigma = 2 * n + 1;
355        let mut wm = vec![0.0_f64; n_sigma];
356        let mut wc = vec![0.0_f64; n_sigma];
357
358        wm[0] = lambda / (n_f + lambda);
359        wc[0] = wm[0] + 1.0 - params.alpha * params.alpha + params.beta;
360
361        let w_rest = 1.0 / (2.0 * (n_f + lambda));
362        for i in 1..n_sigma {
363            wm[i] = w_rest;
364            wc[i] = w_rest;
365        }
366
367        Self { wm, wc, lambda }
368    }
369}
370
371/// Generate `2n+1` sigma points from a state (mean + Cholesky of covariance).
372///
373/// Returns a `Vec<Array1<f64>>` of length `2n+1`.
374fn sigma_points(
375    state: &KalmanState,
376    params: &UkfParams,
377    weights: &SigmaWeights,
378) -> StatsResult<Vec<Array1<f64>>> {
379    let n = state.dim();
380    // Scaled covariance matrix: (n + λ) P
381    let scale = n as f64 + weights.lambda;
382    let scaled_p = state.p.mapv(|v| v * scale);
383
384    // Cholesky decomposition of scaled_p
385    let sqrt_p = cholesky_lower(scaled_p)?;
386
387    let mut sigmas = Vec::with_capacity(2 * n + 1);
388    sigmas.push(state.x.clone()); // σ_0 = x̄
389
390    for i in 0..n {
391        let col = sqrt_p.column(i).to_owned();
392        sigmas.push(&state.x + &col); // σ_{i+1}   = x̄ + sqrt_col_i
393        sigmas.push(&state.x - &col); // σ_{n+i+1} = x̄ - sqrt_col_i
394    }
395
396    Ok(sigmas)
397}
398
399/// Unscented Kalman Filter for nonlinear systems.
400pub struct UnscentedKalmanFilter {
401    /// Sigma-point parameters.
402    pub params: UkfParams,
403}
404
405impl UnscentedKalmanFilter {
406    /// Create a new UKF with the given sigma-point parameters.
407    pub fn new(params: UkfParams) -> Self {
408        Self { params }
409    }
410
411    /// Create a new UKF with default sigma-point parameters.
412    pub fn default() -> Self {
413        Self {
414            params: UkfParams::default(),
415        }
416    }
417
418    /// **Predict** step for a nonlinear transition function `f_fn(x) -> x'`.
419    ///
420    /// # Arguments
421    /// * `state`  – current posterior state.
422    /// * `f_fn`   – state transition function `x_t -> x_{t+1}`.
423    /// * `q`      – process noise covariance (n × n).
424    pub fn predict<F>(
425        &self,
426        state: &KalmanState,
427        f_fn: F,
428        q: &Array2<f64>,
429    ) -> StatsResult<KalmanState>
430    where
431        F: Fn(&Array1<f64>) -> Array1<f64>,
432    {
433        let n = state.dim();
434        let weights = SigmaWeights::compute(n, &self.params);
435        let sigmas = sigma_points(state, &self.params, &weights)?;
436
437        // Propagate sigma points through f
438        let propagated: Vec<Array1<f64>> = sigmas.iter().map(|s| f_fn(s)).collect();
439
440        // Compute predicted mean
441        let x_pred = weighted_mean(&propagated, &weights.wm)?;
442
443        // Compute predicted covariance
444        let p_pred = weighted_covariance(&propagated, &x_pred, &weights.wc, Some(q))?;
445
446        KalmanState::new(x_pred, p_pred)
447    }
448
449    /// **Update** step for a nonlinear observation function `h_fn(x) -> y`.
450    ///
451    /// # Arguments
452    /// * `state`  – prior (predicted) state.
453    /// * `y`      – actual observation vector.
454    /// * `h_fn`   – observation function `x -> y`.
455    /// * `r`      – observation noise covariance (m × m).
456    pub fn update<H>(
457        &self,
458        state: &KalmanState,
459        y: &Array1<f64>,
460        h_fn: H,
461        r: &Array2<f64>,
462    ) -> StatsResult<KalmanState>
463    where
464        H: Fn(&Array1<f64>) -> Array1<f64>,
465    {
466        let n = state.dim();
467        let weights = SigmaWeights::compute(n, &self.params);
468        let sigmas = sigma_points(state, &self.params, &weights)?;
469
470        // Propagate sigma points through h
471        let y_sigmas: Vec<Array1<f64>> = sigmas.iter().map(|s| h_fn(s)).collect();
472
473        // Predicted measurement mean
474        let y_pred = weighted_mean(&y_sigmas, &weights.wm)?;
475
476        // Innovation covariance S_yy = Σ wc_i (z_i - ȳ)(z_i - ȳ)^T + R
477        let s_yy = weighted_covariance(&y_sigmas, &y_pred, &weights.wc, Some(r))?;
478
479        // Cross covariance P_xy = Σ wc_i (σ_i - x̄)(z_i - ȳ)^T
480        let p_xy = weighted_cross_covariance(&sigmas, &state.x, &y_sigmas, &y_pred, &weights.wc)?;
481
482        // Kalman gain K = P_xy * S_yy^{-1}
483        let s_inv = inv_symmetric(s_yy)?;
484        let k = mat_mat_mul(&p_xy, &s_inv)?; // n × m
485
486        // Update mean
487        let innovation = y - &y_pred;
488        let kv = mat_vec_mul(&k, &innovation)?;
489        let x_upd = &state.x + &kv;
490
491        // Update covariance P = P̄ - K * S_yy * K^T
492        let ks = mat_mat_mul(&k, &{
493            let s_yy2 = weighted_covariance(&y_sigmas, &y_pred, &weights.wc, Some(r))?;
494            s_yy2
495        })?;
496        let correction = mat_mat_mul_at(&ks, &k)?;
497        let p_upd = &state.p - &correction;
498
499        KalmanState::new(x_upd, p_upd)
500    }
501}
502
503// ─────────────────────────────────────────────────────────────────────────────
504// Structural Time Series
505// ─────────────────────────────────────────────────────────────────────────────
506
507/// Structural Time Series model variant.
508#[derive(Clone, Debug, PartialEq)]
509pub enum StsModel {
510    /// Local level model: random walk + observation noise.
511    LocalLevel {
512        /// Variance of the level innovation σ²_η.
513        level_variance: f64,
514        /// Variance of the observation noise σ²_ε.
515        obs_variance: f64,
516    },
517    /// Local linear trend model: level + slope random walk.
518    LocalLinearTrend {
519        /// Variance of level innovation σ²_η.
520        level_variance: f64,
521        /// Variance of slope innovation σ²_ζ.
522        slope_variance: f64,
523        /// Variance of observation noise σ²_ε.
524        obs_variance: f64,
525    },
526}
527
528/// Result of fitting a Structural Time Series model.
529#[derive(Clone, Debug)]
530pub struct StsFitResult {
531    /// The fitted model variant.
532    pub model: StsModel,
533    /// Smoothed state estimates (one per time point).
534    pub smoothed_states: Vec<KalmanState>,
535    /// Filtered state estimates (one per time point).
536    pub filtered_states: Vec<KalmanState>,
537    /// Log-likelihood of the observations under the model.
538    pub log_likelihood: f64,
539    /// One-step-ahead forecasts (level component).
540    pub fitted_values: Vec<f64>,
541    /// Standardised prediction residuals.
542    pub residuals: Vec<f64>,
543}
544
545/// Structural Time Series builder and fitter.
546pub struct StructuralTimeSeries;
547
548impl StructuralTimeSeries {
549    /// Fit a **local level** model to a univariate time series.
550    ///
551    /// The local level model is:
552    /// ```text
553    /// y_t = μ_t + ε_t,  ε_t ~ N(0, σ²_ε)
554    /// μ_{t+1} = μ_t + η_t,  η_t ~ N(0, σ²_η)
555    /// ```
556    ///
557    /// State `x_t = [μ_t]`.
558    ///
559    /// # Arguments
560    /// * `y`             – univariate time series (length T).
561    /// * `level_var`     – process noise variance σ²_η (must be > 0).
562    /// * `obs_var`       – observation noise variance σ²_ε (must be > 0).
563    /// * `init_level`    – initial level estimate (defaults to `y[0]`).
564    /// * `init_var`      – initial state variance (defaults to 1e6 for diffuse init).
565    pub fn fit_local_level(
566        y: &[f64],
567        level_var: f64,
568        obs_var: f64,
569        init_level: Option<f64>,
570        init_var: Option<f64>,
571    ) -> StatsResult<StsFitResult> {
572        if y.is_empty() {
573            return Err(StatsError::InsufficientData(
574                "Time series must not be empty".into(),
575            ));
576        }
577        if level_var <= 0.0 {
578            return Err(StatsError::DomainError("level_var must be positive".into()));
579        }
580        if obs_var <= 0.0 {
581            return Err(StatsError::DomainError("obs_var must be positive".into()));
582        }
583
584        use scirs2_core::ndarray::{array, Array1, Array2};
585
586        // State: [level]
587        let f = array![[1.0_f64]]; // transition
588        let h = array![[1.0_f64]]; // observation
589        let q = array![[level_var]]; // process noise
590        let r = array![[obs_var]]; // observation noise
591
592        let x0 = Array1::from_elem(1, init_level.unwrap_or(y[0]));
593        let p0 = Array2::from_elem((1, 1), init_var.unwrap_or(1e6));
594
595        let obs_vecs: Vec<Array1<f64>> = y.iter().map(|&yi| Array1::from_elem(1, yi)).collect();
596
597        let (filtered, log_lik) = KalmanFilter::filter_series(&obs_vecs, &f, &h, &q, &r, x0, p0)?;
598
599        let smoothed = KalmanFilter::smooth(&filtered, &f, &q)?;
600
601        let fitted_values: Vec<f64> = filtered.iter().map(|s| s.x[0]).collect();
602        let residuals: Vec<f64> = y
603            .iter()
604            .zip(fitted_values.iter())
605            .map(|(&yi, &fi)| (yi - fi) / obs_var.sqrt())
606            .collect();
607
608        Ok(StsFitResult {
609            model: StsModel::LocalLevel {
610                level_variance: level_var,
611                obs_variance: obs_var,
612            },
613            smoothed_states: smoothed,
614            filtered_states: filtered,
615            log_likelihood: log_lik,
616            fitted_values,
617            residuals,
618        })
619    }
620
621    /// Fit a **local linear trend** model to a univariate time series.
622    ///
623    /// The local linear trend model is:
624    /// ```text
625    /// y_t     = μ_t + ε_t,       ε_t  ~ N(0, σ²_ε)
626    /// μ_{t+1} = μ_t + ν_t + η_t, η_t  ~ N(0, σ²_η)
627    /// ν_{t+1} = ν_t + ζ_t,       ζ_t  ~ N(0, σ²_ζ)
628    /// ```
629    ///
630    /// State `x_t = [μ_t, ν_t]`.
631    ///
632    /// # Arguments
633    /// * `y`           – univariate time series.
634    /// * `level_var`   – level innovation variance σ²_η.
635    /// * `slope_var`   – slope innovation variance σ²_ζ.
636    /// * `obs_var`     – observation noise variance σ²_ε.
637    /// * `init_level`  – initial level (defaults to `y[0]`).
638    /// * `init_slope`  – initial slope (defaults to 0).
639    /// * `init_var`    – initial state variance (defaults to 1e6).
640    pub fn fit_local_linear_trend(
641        y: &[f64],
642        level_var: f64,
643        slope_var: f64,
644        obs_var: f64,
645        init_level: Option<f64>,
646        init_slope: Option<f64>,
647        init_var: Option<f64>,
648    ) -> StatsResult<StsFitResult> {
649        if y.is_empty() {
650            return Err(StatsError::InsufficientData(
651                "Time series must not be empty".into(),
652            ));
653        }
654        if level_var < 0.0 || slope_var < 0.0 {
655            return Err(StatsError::DomainError(
656                "Variance parameters must be non-negative".into(),
657            ));
658        }
659        if obs_var <= 0.0 {
660            return Err(StatsError::DomainError("obs_var must be positive".into()));
661        }
662
663        use scirs2_core::ndarray::{array, Array1, Array2};
664
665        // State: [level, slope]
666        let f = array![[1.0_f64, 1.0], [0.0, 1.0]];
667        let h = array![[1.0_f64, 0.0]];
668        let q = array![[level_var, 0.0], [0.0, slope_var]];
669        let r = array![[obs_var]];
670
671        let iv = init_var.unwrap_or(1e6);
672        let x0 = Array1::from_vec(vec![init_level.unwrap_or(y[0]), init_slope.unwrap_or(0.0)]);
673        let p0 = Array2::from_diag(&Array1::from_vec(vec![iv, iv]));
674
675        let obs_vecs: Vec<Array1<f64>> = y.iter().map(|&yi| Array1::from_elem(1, yi)).collect();
676
677        let (filtered, log_lik) = KalmanFilter::filter_series(&obs_vecs, &f, &h, &q, &r, x0, p0)?;
678
679        let smoothed = KalmanFilter::smooth(&filtered, &f, &q)?;
680
681        let fitted_values: Vec<f64> = filtered.iter().map(|s| s.x[0]).collect();
682        let residuals: Vec<f64> = y
683            .iter()
684            .zip(fitted_values.iter())
685            .map(|(&yi, &fi)| (yi - fi) / obs_var.sqrt())
686            .collect();
687
688        Ok(StsFitResult {
689            model: StsModel::LocalLinearTrend {
690                level_variance: level_var,
691                slope_variance: slope_var,
692                obs_variance: obs_var,
693            },
694            smoothed_states: smoothed,
695            filtered_states: filtered,
696            log_likelihood: log_lik,
697            fitted_values,
698            residuals,
699        })
700    }
701}
702
703// ─────────────────────────────────────────────────────────────────────────────
704// Numerical linear-algebra helpers (private)
705// ─────────────────────────────────────────────────────────────────────────────
706
707/// Matrix-vector product y = A * x.
708fn mat_vec_mul(a: &Array2<f64>, x: &Array1<f64>) -> StatsResult<Array1<f64>> {
709    if a.ncols() != x.len() {
710        return Err(StatsError::DimensionMismatch(format!(
711            "mat_vec_mul: A is {}×{} but x has len {}",
712            a.nrows(),
713            a.ncols(),
714            x.len()
715        )));
716    }
717    let n = a.nrows();
718    let m = a.ncols();
719    let mut y = Array1::<f64>::zeros(n);
720    for i in 0..n {
721        let mut s = 0.0_f64;
722        for k in 0..m {
723            s += a[[i, k]] * x[k];
724        }
725        y[i] = s;
726    }
727    Ok(y)
728}
729
730/// Matrix-matrix product C = A * B.
731fn mat_mat_mul(a: &Array2<f64>, b: &Array2<f64>) -> StatsResult<Array2<f64>> {
732    if a.ncols() != b.nrows() {
733        return Err(StatsError::DimensionMismatch(format!(
734            "mat_mat_mul: A is {}×{} but B is {}×{}",
735            a.nrows(),
736            a.ncols(),
737            b.nrows(),
738            b.ncols()
739        )));
740    }
741    let n = a.nrows();
742    let k = a.ncols();
743    let m = b.ncols();
744    let mut c = Array2::<f64>::zeros((n, m));
745    for i in 0..n {
746        for j in 0..m {
747            let mut s = 0.0_f64;
748            for l in 0..k {
749                s += a[[i, l]] * b[[l, j]];
750            }
751            c[[i, j]] = s;
752        }
753    }
754    Ok(c)
755}
756
757/// Compute C = A * B^T.
758fn mat_mat_mul_bt(a: &Array2<f64>, b: &Array2<f64>) -> StatsResult<Array2<f64>> {
759    if a.ncols() != b.ncols() {
760        return Err(StatsError::DimensionMismatch(format!(
761            "mat_mat_mul_bt: A is {}×{} but B^T is {}×{}",
762            a.nrows(),
763            a.ncols(),
764            b.ncols(),
765            b.nrows()
766        )));
767    }
768    let n = a.nrows();
769    let k = a.ncols();
770    let m = b.nrows();
771    let mut c = Array2::<f64>::zeros((n, m));
772    for i in 0..n {
773        for j in 0..m {
774            let mut s = 0.0_f64;
775            for l in 0..k {
776                s += a[[i, l]] * b[[j, l]];
777            }
778            c[[i, j]] = s;
779        }
780    }
781    Ok(c)
782}
783
784/// Compute C = A * B^T (where A is a result matrix and B is the original).
785/// This is a synonym for mat_mat_mul_bt used in RTS smoother notation.
786#[allow(dead_code)]
787fn mat_mat_mul_at(a: &Array2<f64>, b: &Array2<f64>) -> StatsResult<Array2<f64>> {
788    // C = A * B^T
789    mat_mat_mul_bt(a, b)
790}
791
792/// Compute I - A (identity minus matrix), check square.
793fn eye_minus(a: Array2<f64>) -> StatsResult<Array2<f64>> {
794    let n = a.nrows();
795    if a.ncols() != n {
796        return Err(StatsError::DimensionMismatch(
797            "eye_minus: matrix not square".into(),
798        ));
799    }
800    let mut result = -a;
801    for i in 0..n {
802        result[[i, i]] += 1.0;
803    }
804    Ok(result)
805}
806
807/// Assert that a matrix is square with the expected dimension.
808fn check_square(m: &Array2<f64>, expected: usize, name: &str) -> StatsResult<()> {
809    if m.nrows() != expected || m.ncols() != expected {
810        Err(StatsError::DimensionMismatch(format!(
811            "{} must be {}×{}, got {}×{}",
812            name,
813            expected,
814            expected,
815            m.nrows(),
816            m.ncols()
817        )))
818    } else {
819        Ok(())
820    }
821}
822
823/// Invert a symmetric positive-definite matrix using Cholesky decomposition.
824fn inv_symmetric(a: Array2<f64>) -> StatsResult<Array2<f64>> {
825    let n = a.nrows();
826    if a.ncols() != n {
827        return Err(StatsError::DimensionMismatch(
828            "inv_symmetric: matrix must be square".into(),
829        ));
830    }
831
832    if n == 1 {
833        let val = a[[0, 0]];
834        if val.abs() < 1e-15 {
835            return Err(StatsError::ComputationError(
836                "inv_symmetric: 1×1 matrix is singular".into(),
837            ));
838        }
839        let mut inv = Array2::<f64>::zeros((1, 1));
840        inv[[0, 0]] = 1.0 / val;
841        return Ok(inv);
842    }
843
844    // Cholesky L such that A = L L^T
845    let l = cholesky_lower(a)?;
846
847    // Invert L by forward substitution, then L^{-T} = (L^{-1})^T
848    let l_inv = lower_tri_inv(&l)?;
849
850    // A^{-1} = (L L^T)^{-1} = L^{-T} L^{-1} = (L^{-1})^T * L^{-1}
851    let l_inv_t = l_inv.t().to_owned();
852    mat_mat_mul(&l_inv_t, &l_inv)
853}
854
855/// Compute the lower Cholesky factor L of a positive-definite symmetric matrix A.
856fn cholesky_lower(a: Array2<f64>) -> StatsResult<Array2<f64>> {
857    let n = a.nrows();
858    if a.ncols() != n {
859        return Err(StatsError::DimensionMismatch(
860            "cholesky_lower: matrix must be square".into(),
861        ));
862    }
863    let mut l = Array2::<f64>::zeros((n, n));
864
865    for i in 0..n {
866        for j in 0..=i {
867            let mut s = a[[i, j]];
868            for k in 0..j {
869                s -= l[[i, k]] * l[[j, k]];
870            }
871            if i == j {
872                if s <= 0.0 {
873                    // Apply a small regularisation for near-PSD matrices
874                    let eps = 1e-10_f64.max(s.abs() * 1e-8);
875                    let s_reg = s + eps;
876                    if s_reg <= 0.0 {
877                        return Err(StatsError::ComputationError(format!(
878                            "Cholesky failed at ({},{}): diagonal entry {} is non-positive",
879                            i, j, s
880                        )));
881                    }
882                    l[[i, j]] = s_reg.sqrt();
883                } else {
884                    l[[i, j]] = s.sqrt();
885                }
886            } else {
887                let ljj = l[[j, j]];
888                if ljj.abs() < 1e-15 {
889                    return Err(StatsError::ComputationError(
890                        "Cholesky: near-zero diagonal encountered".into(),
891                    ));
892                }
893                l[[i, j]] = s / ljj;
894            }
895        }
896    }
897
898    Ok(l)
899}
900
901/// Invert a lower-triangular matrix L by forward substitution.
902fn lower_tri_inv(l: &Array2<f64>) -> StatsResult<Array2<f64>> {
903    let n = l.nrows();
904    let mut inv = Array2::<f64>::zeros((n, n));
905
906    for j in 0..n {
907        inv[[j, j]] = 1.0 / l[[j, j]];
908        for i in j + 1..n {
909            let mut s = 0.0_f64;
910            for k in j..i {
911                s += l[[i, k]] * inv[[k, j]];
912            }
913            inv[[i, j]] = -s / l[[i, i]];
914        }
915    }
916
917    Ok(inv)
918}
919
920/// Compute the log-determinant of a symmetric positive-definite matrix.
921fn log_det_posdef(a: &Array2<f64>) -> StatsResult<f64> {
922    let l = cholesky_lower(a.clone())?;
923    let n = l.nrows();
924    let log_det: f64 = (0..n).map(|i| 2.0 * l[[i, i]].ln()).sum();
925    Ok(log_det)
926}
927
928// ─────────────────────────────────────────────────────────────────────────────
929// UKF utilities (private)
930// ─────────────────────────────────────────────────────────────────────────────
931
932/// Compute the weighted mean of a set of vectors.
933fn weighted_mean(vecs: &[Array1<f64>], weights: &[f64]) -> StatsResult<Array1<f64>> {
934    if vecs.is_empty() {
935        return Err(StatsError::InsufficientData(
936            "weighted_mean: empty vector set".into(),
937        ));
938    }
939    if vecs.len() != weights.len() {
940        return Err(StatsError::DimensionMismatch(
941            "weighted_mean: vecs and weights must have same length".into(),
942        ));
943    }
944    let d = vecs[0].len();
945    let mut mean = Array1::<f64>::zeros(d);
946    for (v, &w) in vecs.iter().zip(weights.iter()) {
947        mean = mean + v.mapv(|x| x * w);
948    }
949    Ok(mean)
950}
951
952/// Compute the weighted covariance of a set of vectors around a given mean.
953/// If `additive` is `Some(matrix)`, it is added to the covariance.
954fn weighted_covariance(
955    vecs: &[Array1<f64>],
956    mean: &Array1<f64>,
957    weights: &[f64],
958    additive: Option<&Array2<f64>>,
959) -> StatsResult<Array2<f64>> {
960    let d = mean.len();
961    let mut cov = Array2::<f64>::zeros((d, d));
962    for (v, &w) in vecs.iter().zip(weights.iter()) {
963        let diff = v - mean;
964        for i in 0..d {
965            for j in 0..d {
966                cov[[i, j]] += w * diff[i] * diff[j];
967            }
968        }
969    }
970    if let Some(add) = additive {
971        cov = cov + add;
972    }
973    Ok(cov)
974}
975
976/// Compute the weighted cross-covariance between two sets of vectors.
977fn weighted_cross_covariance(
978    xs: &[Array1<f64>],
979    x_mean: &Array1<f64>,
980    ys: &[Array1<f64>],
981    y_mean: &Array1<f64>,
982    weights: &[f64],
983) -> StatsResult<Array2<f64>> {
984    let dx = x_mean.len();
985    let dy = y_mean.len();
986    if xs.len() != ys.len() || xs.len() != weights.len() {
987        return Err(StatsError::DimensionMismatch(
988            "weighted_cross_covariance: dimension mismatch".into(),
989        ));
990    }
991    let mut cov = Array2::<f64>::zeros((dx, dy));
992    for ((x, y), &w) in xs.iter().zip(ys.iter()).zip(weights.iter()) {
993        let xd = x - x_mean;
994        let yd = y - y_mean;
995        for i in 0..dx {
996            for j in 0..dy {
997                cov[[i, j]] += w * xd[i] * yd[j];
998            }
999        }
1000    }
1001    Ok(cov)
1002}
1003
1004// ─────────────────────────────────────────────────────────────────────────────
1005// Unit tests
1006// ─────────────────────────────────────────────────────────────────────────────
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use scirs2_core::ndarray::{array, Array1, Array2};
1012
1013    // ── Kalman filter ──────────────────────────────────────────────────────
1014
1015    #[test]
1016    fn test_kalman_predict() {
1017        // 1-D constant velocity example
1018        let x = Array1::from_vec(vec![1.0, 0.5]);
1019        let p = Array2::from_diag(&Array1::from_vec(vec![0.1, 0.1]));
1020        let state = KalmanState::new(x, p).expect("state");
1021
1022        let f = array![[1.0, 1.0], [0.0, 1.0]];
1023        let q = array![[0.01, 0.0], [0.0, 0.01]];
1024
1025        let pred = KalmanFilter::predict(&state, &f, &q).expect("predict");
1026        // x_pred = F * [1, 0.5] = [1.5, 0.5]
1027        assert!((pred.x[0] - 1.5).abs() < 1e-12);
1028        assert!((pred.x[1] - 0.5).abs() < 1e-12);
1029    }
1030
1031    #[test]
1032    fn test_kalman_update_reduces_variance() {
1033        let x0 = Array1::from_vec(vec![0.0]);
1034        let p0 = Array2::from_elem((1, 1), 1.0);
1035        let state = KalmanState::new(x0, p0).expect("state");
1036
1037        let h = array![[1.0_f64]];
1038        let r = array![[0.1_f64]];
1039        let y = Array1::from_vec(vec![1.0_f64]);
1040
1041        let upd = KalmanFilter::update(&state, &y, &h, &r).expect("update");
1042        // Posterior variance < prior variance
1043        assert!(upd.p[[0, 0]] < 1.0);
1044        // Posterior mean should shift toward observation
1045        assert!(upd.x[0] > 0.0);
1046    }
1047
1048    #[test]
1049    fn test_filter_series_scalar() {
1050        // Constant signal + noise
1051        let obs: Vec<Array1<f64>> = (0..20)
1052            .map(|i| Array1::from_vec(vec![3.0 + 0.1 * (i as f64 % 3.0 - 1.0)]))
1053            .collect();
1054
1055        let f = array![[1.0_f64]];
1056        let h = array![[1.0_f64]];
1057        let q = array![[0.01_f64]];
1058        let r = array![[0.5_f64]];
1059        let x0 = Array1::from_vec(vec![0.0_f64]);
1060        let p0 = Array2::from_elem((1, 1), 100.0_f64);
1061
1062        let (states, log_lik) =
1063            KalmanFilter::filter_series(&obs, &f, &h, &q, &r, x0, p0).expect("filter_series");
1064        assert_eq!(states.len(), 20);
1065        assert!(log_lik.is_finite());
1066        // After 20 observations the estimate should be close to 3.0
1067        let final_est = states.last().map(|s| s.x[0]).unwrap_or(0.0);
1068        assert!((final_est - 3.0).abs() < 0.5, "final_est = {}", final_est);
1069    }
1070
1071    #[test]
1072    fn test_rts_smoother() {
1073        let obs: Vec<Array1<f64>> = (0..10).map(|_| Array1::from_vec(vec![2.0_f64])).collect();
1074
1075        let f = array![[1.0_f64]];
1076        let h = array![[1.0_f64]];
1077        let q = array![[0.1_f64]];
1078        let r = array![[1.0_f64]];
1079        let x0 = Array1::from_vec(vec![0.0_f64]);
1080        let p0 = Array2::from_elem((1, 1), 10.0_f64);
1081
1082        let (filtered, _) =
1083            KalmanFilter::filter_series(&obs, &f, &h, &q, &r, x0, p0).expect("filter_series");
1084        let smoothed = KalmanFilter::smooth(&filtered, &f, &q).expect("smooth");
1085
1086        assert_eq!(smoothed.len(), 10);
1087        // Smoother should give tighter (or equal) uncertainty
1088        let filtered_var = filtered[5].p[[0, 0]];
1089        let smoothed_var = smoothed[5].p[[0, 0]];
1090        assert!(smoothed_var <= filtered_var + 1e-10);
1091    }
1092
1093    // ── Structural time series ─────────────────────────────────────────────
1094
1095    #[test]
1096    fn test_local_level_fit() {
1097        let y: Vec<f64> = (0..30)
1098            .map(|i| 2.0 + 0.1 * (i as f64 % 5.0 - 2.0))
1099            .collect();
1100        let result = StructuralTimeSeries::fit_local_level(&y, 0.1, 0.5, None, None)
1101            .expect("fit_local_level");
1102        assert_eq!(result.fitted_values.len(), 30);
1103        assert!(result.log_likelihood.is_finite());
1104    }
1105
1106    #[test]
1107    fn test_local_linear_trend_fit() {
1108        let y: Vec<f64> = (0..30)
1109            .map(|i| i as f64 + 0.1 * (i as f64 % 3.0 - 1.0))
1110            .collect();
1111        let result =
1112            StructuralTimeSeries::fit_local_linear_trend(&y, 0.01, 0.001, 0.5, None, None, None)
1113                .expect("fit_local_linear_trend");
1114        assert_eq!(result.fitted_values.len(), 30);
1115        assert!(result.log_likelihood.is_finite());
1116        // Slope component should be approximately 1.0 by the end
1117        let final_slope = result.smoothed_states.last().map(|s| s.x[1]).unwrap_or(0.0);
1118        assert!(final_slope > 0.5, "slope = {}", final_slope);
1119    }
1120
1121    // ── UKF ───────────────────────────────────────────────────────────────
1122
1123    #[test]
1124    fn test_ukf_linear_matches_kf() {
1125        // For a linear system, UKF should closely match the Kalman filter.
1126        let x0 = Array1::from_vec(vec![0.0_f64]);
1127        let p0 = Array2::from_elem((1, 1), 1.0_f64);
1128        let f_mat = array![[1.0_f64]];
1129        let h_mat = array![[1.0_f64]];
1130        let q_mat = array![[0.1_f64]];
1131        let r_mat = array![[0.5_f64]];
1132        let y = Array1::from_vec(vec![1.0_f64]);
1133
1134        let state = KalmanState::new(x0.clone(), p0.clone()).expect("state");
1135
1136        // KF result
1137        let kf_pred = KalmanFilter::predict(&state, &f_mat, &q_mat).expect("kf_pred");
1138        let kf_upd = KalmanFilter::update(&kf_pred, &y, &h_mat, &r_mat).expect("kf_upd");
1139
1140        // UKF result (linear functions)
1141        let ukf = UnscentedKalmanFilter::default();
1142        let f_fn = |x: &Array1<f64>| x.clone(); // identity transition
1143        let h_fn = |x: &Array1<f64>| x.clone(); // identity observation
1144
1145        let ukf_pred = ukf.predict(&state, f_fn, &q_mat).expect("ukf_pred");
1146        let ukf_upd = ukf.update(&ukf_pred, &y, h_fn, &r_mat).expect("ukf_upd");
1147
1148        // Means should match closely
1149        assert!(
1150            (kf_upd.x[0] - ukf_upd.x[0]).abs() < 1e-6,
1151            "KF x={}, UKF x={}",
1152            kf_upd.x[0],
1153            ukf_upd.x[0]
1154        );
1155    }
1156
1157    // ── Numerical helpers ──────────────────────────────────────────────────
1158
1159    #[test]
1160    fn test_cholesky_correctness() {
1161        // A = [[4, 2], [2, 3]] → L = [[2,0],[1, sqrt(2)]]
1162        let a = array![[4.0_f64, 2.0], [2.0, 3.0]];
1163        let l = cholesky_lower(a.clone()).expect("cholesky");
1164        // Verify L L^T ≈ A
1165        let lt = l.t().to_owned();
1166        let a_reconstructed = mat_mat_mul(&l, &lt).expect("mul");
1167        for i in 0..2 {
1168            for j in 0..2 {
1169                assert!(
1170                    (a_reconstructed[[i, j]] - a[[i, j]]).abs() < 1e-12,
1171                    "({},{}) mismatch",
1172                    i,
1173                    j
1174                );
1175            }
1176        }
1177    }
1178
1179    #[test]
1180    fn test_inv_symmetric_2x2() {
1181        let a = array![[2.0_f64, 1.0], [1.0, 3.0]];
1182        let a_inv = inv_symmetric(a.clone()).expect("inv");
1183        // A * A^{-1} should be ~ I
1184        let prod = mat_mat_mul(&a, &a_inv).expect("mul");
1185        assert!((prod[[0, 0]] - 1.0).abs() < 1e-10);
1186        assert!((prod[[1, 1]] - 1.0).abs() < 1e-10);
1187        assert!(prod[[0, 1]].abs() < 1e-10);
1188        assert!(prod[[1, 0]].abs() < 1e-10);
1189    }
1190}