1use ndarray::{Array1, Array2, Array3};
30use serde::{Deserialize, Serialize};
31use so_core::error::Result;
32use so_linalg;
33
34#[derive(Debug, Clone)]
36pub struct StateSpaceModel {
37 pub observation_matrix: Array2<f64>,
39 pub transition_matrix: Array2<f64>,
41 pub selection_matrix: Array2<f64>,
43 pub observation_cov: Array2<f64>,
45 pub state_cov: Array2<f64>,
47 pub initial_state_mean: Array1<f64>,
49 pub initial_state_cov: Array2<f64>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct KalmanFilterResults {
56 pub filtered_state_means: Array2<f64>,
58 pub filtered_state_covs: Array3<f64>,
60 pub predicted_state_means: Array2<f64>,
62 pub predicted_state_covs: Array3<f64>,
64 pub innovations: Array1<f64>,
66 pub innovation_variances: Array1<f64>,
68 pub kalman_gains: Array3<f64>,
70 pub log_likelihood: f64,
72}
73
74impl StateSpaceModel {
75 pub fn local_level(obs_var: f64, level_var: f64) -> Self {
77 let observation_matrix = ndarray::array![[1.0]];
81 let transition_matrix = ndarray::array![[1.0]];
82 let selection_matrix = ndarray::array![[1.0]];
83 let observation_cov = ndarray::array![[obs_var]];
84 let state_cov = ndarray::array![[level_var]];
85 let initial_state_mean = ndarray::array![0.0];
86 let initial_state_cov = ndarray::array![[1e6]]; Self {
89 observation_matrix,
90 transition_matrix,
91 selection_matrix,
92 observation_cov,
93 state_cov,
94 initial_state_mean,
95 initial_state_cov,
96 }
97 }
98
99 pub fn local_linear_trend(obs_var: f64, level_var: f64, slope_var: f64) -> Self {
101 let observation_matrix = ndarray::array![[1.0, 0.0]];
106 let transition_matrix = ndarray::array![[1.0, 1.0], [0.0, 1.0]];
107 let selection_matrix = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
108 let observation_cov = ndarray::array![[obs_var]];
109 let state_cov = ndarray::array![[level_var, 0.0], [0.0, slope_var]];
110 let initial_state_mean = ndarray::array![0.0, 0.0];
111 let initial_state_cov = ndarray::array![[1e6, 0.0], [0.0, 1e6]];
112
113 Self {
114 observation_matrix,
115 transition_matrix,
116 selection_matrix,
117 observation_cov,
118 state_cov,
119 initial_state_mean,
120 initial_state_cov,
121 }
122 }
123
124 pub fn arma(ar_coef: &[f64], ma_coef: &[f64], sigma2: f64) -> Self {
126 let p = ar_coef.len();
127 let q = ma_coef.len();
128 let r = p.max(q + 1);
129 let n_states = r;
130
131 let mut transition = Array2::zeros((n_states, n_states));
133
134 if p > 0 {
135 for j in 0..p {
137 transition[(0, j)] = ar_coef[j];
138 }
139 }
140
141 for i in 1..n_states {
143 transition[(i, i - 1)] = 1.0;
144 }
145
146 let mut observation = Array1::zeros(n_states);
148 observation[0] = 1.0;
149 if q > 0 {
150 for j in 0..q.min(n_states - 1) {
152 observation[j + 1] = ma_coef[j];
153 }
154 }
155 let observation_matrix = observation.insert_axis(ndarray::Axis(0));
156
157 let mut selection = Array2::zeros((n_states, 1));
159 selection[(0, 0)] = 1.0;
160 for j in 1..q.min(n_states - 1) {
161 selection[(j, 0)] = ma_coef[j - 1];
162 }
163
164 let observation_cov = ndarray::array![[0.0]]; let state_cov = ndarray::array![[sigma2]];
167
168 let initial_state_mean = Array1::zeros(n_states);
170 let mut initial_state_cov = Array2::zeros((n_states, n_states));
171 for i in 0..n_states {
172 initial_state_cov[(i, i)] = 1e6;
173 }
174
175 Self {
176 observation_matrix,
177 transition_matrix: transition,
178 selection_matrix: selection,
179 observation_cov,
180 state_cov,
181 initial_state_mean,
182 initial_state_cov,
183 }
184 }
185
186 pub fn filter(&self, y: &Array1<f64>) -> Result<KalmanFilterResults> {
188 let n = y.len();
189 let n_states = self.observation_matrix.ncols();
190 let n_obs = self.observation_matrix.nrows();
191
192 let mut filtered_means = Array2::zeros((n, n_states));
194 let mut filtered_covs = Array3::zeros((n, n_states, n_states));
195 let mut predicted_means = Array2::zeros((n, n_states));
196 let mut predicted_covs = Array3::zeros((n, n_states, n_states));
197 let mut innovations = Array1::zeros(n);
198 let mut innovation_variances = Array1::zeros(n);
199 let mut kalman_gains = Array3::zeros((n, n_states, n_obs));
200
201 let mut log_likelihood = 0.0;
202
203 let mut pred_mean = self.initial_state_mean.clone();
205 let mut pred_cov = self.initial_state_cov.clone();
206
207 for t in 0..n {
208 predicted_means.row_mut(t).assign(&pred_mean);
210 predicted_covs
211 .slice_mut(ndarray::s![t, .., ..])
212 .assign(&pred_cov);
213
214 let obs_pred = self.observation_matrix.dot(&pred_mean);
216 let innovation = y[t] - obs_pred[0];
217 innovations[t] = innovation;
218
219 let innovation_var = self
221 .observation_matrix
222 .dot(&pred_cov.dot(&self.observation_matrix.t()))
223 + &self.observation_cov;
224 let innovation_var_scalar = innovation_var[(0, 0)];
225 innovation_variances[t] = innovation_var_scalar;
226
227 if innovation_var_scalar > 0.0 {
229 log_likelihood += -0.5 * innovation_var_scalar.ln()
230 - 0.5 * innovation.powi(2) / innovation_var_scalar;
231 }
232
233 let kalman_gain = if innovation_var_scalar > 0.0 {
235 pred_cov.dot(&self.observation_matrix.t()) / innovation_var_scalar
236 } else {
237 Array2::zeros((n_states, n_obs))
238 };
239
240 kalman_gains
241 .slice_mut(ndarray::s![t, .., ..])
242 .assign(&kalman_gain);
243
244 let filtered_mean = &pred_mean + kalman_gain.dot(&ndarray::array![innovation]);
246 let filtered_cov = &pred_cov - kalman_gain.dot(&self.observation_matrix.dot(&pred_cov));
247
248 filtered_means.row_mut(t).assign(&filtered_mean);
250 filtered_covs
251 .slice_mut(ndarray::s![t, .., ..])
252 .assign(&filtered_cov);
253
254 if t < n - 1 {
256 pred_mean = self.transition_matrix.dot(&filtered_mean);
257 pred_cov = self
258 .transition_matrix
259 .dot(&filtered_cov.dot(&self.transition_matrix.t()))
260 + self
261 .selection_matrix
262 .dot(&self.state_cov.dot(&self.selection_matrix.t()));
263 }
264 }
265
266 Ok(KalmanFilterResults {
267 filtered_state_means: filtered_means,
268 filtered_state_covs: filtered_covs,
269 predicted_state_means: predicted_means,
270 predicted_state_covs: predicted_covs,
271 innovations,
272 innovation_variances,
273 kalman_gains,
274 log_likelihood,
275 })
276 }
277
278 pub fn smooth(&self, filter_results: &KalmanFilterResults) -> KalmanFilterResults {
280 let n = filter_results.filtered_state_means.nrows();
281 let n_states = self.observation_matrix.ncols();
282
283 let mut smoothed_means = filter_results.filtered_state_means.clone();
285 let mut smoothed_covs = filter_results.filtered_state_covs.clone();
286
287 let mut smoother_gain = Array2::zeros((n_states, n_states));
289
290 for t in (0..n - 1).rev() {
291 let pred_cov = filter_results
293 .predicted_state_covs
294 .slice(ndarray::s![t + 1, .., ..]);
295 let filtered_cov = filter_results
296 .filtered_state_covs
297 .slice(ndarray::s![t, .., ..]);
298
299 let pred_cov_inv =
300 so_linalg::inv(&pred_cov.to_owned()).unwrap_or_else(|_| pred_cov.to_owned());
301 smoother_gain.assign(
302 &filtered_cov
303 .dot(&self.transition_matrix.t())
304 .dot(&pred_cov_inv),
305 );
306
307 let filtered_mean = filter_results.filtered_state_means.row(t);
309 let next_smoothed_mean = smoothed_means.row(t + 1);
310 let next_pred_mean = filter_results.predicted_state_means.row(t + 1);
311
312 let mut smoothed_mean = filtered_mean.to_owned();
313 smoothed_mean += &smoother_gain.dot(&(&next_smoothed_mean - &next_pred_mean));
314
315 let next_smoothed_cov = smoothed_covs.slice(ndarray::s![t + 1, .., ..]);
317 let next_pred_cov =
318 filter_results
319 .predicted_state_covs
320 .slice(ndarray::s![t + 1, .., ..]);
321
322 let mut smoothed_cov = filtered_cov.to_owned();
323 let diff_cov = &next_smoothed_cov - &next_pred_cov;
324 smoothed_cov += &smoother_gain.dot(&diff_cov.dot(&smoother_gain.t()));
325
326 smoothed_means.row_mut(t).assign(&smoothed_mean);
328 smoothed_covs
329 .slice_mut(ndarray::s![t, .., ..])
330 .assign(&smoothed_cov);
331 }
332
333 KalmanFilterResults {
334 filtered_state_means: smoothed_means,
335 filtered_state_covs: smoothed_covs,
336 ..filter_results.clone()
337 }
338 }
339
340 pub fn forecast(
342 &self,
343 filter_results: &KalmanFilterResults,
344 steps: usize,
345 ) -> (Array2<f64>, Array3<f64>) {
346 let n = filter_results.filtered_state_means.nrows();
347 let n_states = self.observation_matrix.ncols();
348
349 let mut forecast_means = Array2::zeros((steps, n_states));
350 let mut forecast_covs = Array3::zeros((steps, n_states, n_states));
351
352 let mut current_mean = filter_results.filtered_state_means.row(n - 1).to_owned();
354 let mut current_cov = filter_results
355 .filtered_state_covs
356 .slice(ndarray::s![n - 1, .., ..])
357 .to_owned();
358
359 for h in 0..steps {
360 current_mean = self.transition_matrix.dot(¤t_mean);
362 current_cov = self
363 .transition_matrix
364 .dot(¤t_cov.dot(&self.transition_matrix.t()))
365 + self
366 .selection_matrix
367 .dot(&self.state_cov.dot(&self.selection_matrix.t()));
368
369 forecast_means.row_mut(h).assign(¤t_mean);
371 forecast_covs
372 .slice_mut(ndarray::s![h, .., ..])
373 .assign(¤t_cov);
374 }
375
376 (forecast_means, forecast_covs)
377 }
378
379 pub fn log_likelihood(&self, y: &Array1<f64>) -> Result<f64> {
381 let results = self.filter(y)?;
382 Ok(results.log_likelihood)
383 }
384
385 pub fn estimate(&mut self, _y: &Array1<f64>) -> Result<()> {
387 Ok(())
390 }
391}
392
393pub struct KalmanFilter;
395
396impl KalmanFilter {
397 pub fn new() -> Self {
399 Self
400 }
401
402 pub fn filter(&self, model: &StateSpaceModel, y: &Array1<f64>) -> Result<KalmanFilterResults> {
404 model.filter(y)
405 }
406
407 pub fn smooth(
409 &self,
410 model: &StateSpaceModel,
411 results: &KalmanFilterResults,
412 ) -> KalmanFilterResults {
413 model.smooth(results)
414 }
415
416 pub fn filter_smooth(
418 &self,
419 model: &StateSpaceModel,
420 y: &Array1<f64>,
421 ) -> Result<KalmanFilterResults> {
422 let filtered = model.filter(y)?;
423 Ok(model.smooth(&filtered))
424 }
425
426 pub fn forecast(
428 &self,
429 model: &StateSpaceModel,
430 results: &KalmanFilterResults,
431 steps: usize,
432 ) -> (Array1<f64>, Array1<f64>) {
433 let (state_means, state_covs) = model.forecast(results, steps);
434
435 let mut forecast_means = Array1::zeros(steps);
436 let mut forecast_variances = Array1::zeros(steps);
437
438 for h in 0..steps {
439 let state_mean = state_means.row(h);
440 let state_cov = state_covs.slice(ndarray::s![h, .., ..]);
441
442 let obs_mean = model.observation_matrix.dot(&state_mean);
444 forecast_means[h] = obs_mean[0];
445
446 let obs_var = model
448 .observation_matrix
449 .dot(&state_cov.dot(&model.observation_matrix.t()))
450 + &model.observation_cov;
451 forecast_variances[h] = obs_var[(0, 0)];
452 }
453
454 (forecast_means, forecast_variances)
455 }
456}