use crate::error::{Error, Result};
use faer::prelude::*;
use faer::{Mat, MatRef};
#[derive(Debug, Clone)]
pub struct SummingMatrix {
inner: Mat<f64>,
}
impl SummingMatrix {
pub fn new(inner: Mat<f64>) -> Self {
Self { inner }
}
pub fn m(&self) -> usize {
self.inner.nrows()
}
pub fn n(&self) -> usize {
self.inner.ncols()
}
pub fn as_ref(&self) -> MatRef<'_, f64> {
self.inner.as_ref()
}
pub fn simple_star(n_leaves: usize) -> Self {
let mut mat = Mat::<f64>::zeros(n_leaves + 1, n_leaves);
for j in 0..n_leaves {
mat[(0, j)] = 1.0;
}
for j in 0..n_leaves {
mat[(j + 1, j)] = 1.0;
}
Self { inner: mat }
}
}
#[derive(Debug, Clone)]
pub enum ReconciliationMethod {
Ols,
Wls {
weights: Vec<f64>,
},
MinT {
covariance: Mat<f64>,
},
}
pub fn reconcile(
s: &SummingMatrix,
base_forecasts: &Mat<f64>, method: ReconciliationMethod,
) -> Result<Mat<f64>> {
let s_mat = s.as_ref();
let m = s.m();
let n = s.n();
if base_forecasts.nrows() != m {
return Err(Error::ShapeMismatch {
expected: format!("{} rows", m),
actual: format!("{} rows", base_forecasts.nrows()),
});
}
let b = match method {
ReconciliationMethod::Ols => {
let st = s_mat.transpose();
let sts = st * s_mat;
let sty = st * base_forecasts;
sts.full_piv_lu().solve(&sty)
}
ReconciliationMethod::Wls { weights } => {
if weights.len() != m {
return Err(Error::ShapeMismatch {
expected: format!("{} weights", m),
actual: format!("{} weights", weights.len()),
});
}
if weights.iter().any(|&w| w <= 0.0) {
return Err(Error::InvalidParameter {
name: "weights",
message: "all weights must be positive",
});
}
let mut winv_s = Mat::<f64>::zeros(m, n);
let mut winv_y = Mat::<f64>::zeros(m, base_forecasts.ncols());
for i in 0..m {
let w_i_inv = 1.0 / weights[i];
for j in 0..n {
winv_s[(i, j)] = w_i_inv * s_mat[(i, j)];
}
for k in 0..base_forecasts.ncols() {
winv_y[(i, k)] = w_i_inv * base_forecasts[(i, k)];
}
}
let st = s_mat.transpose();
let st_winv_s = st * &winv_s;
let st_winv_y = st * &winv_y;
st_winv_s.full_piv_lu().solve(&st_winv_y)
}
ReconciliationMethod::MinT { covariance } => {
if covariance.nrows() != m || covariance.ncols() != m {
return Err(Error::ShapeMismatch {
expected: format!("{}x{} covariance", m, m),
actual: format!("{}x{} covariance", covariance.nrows(), covariance.ncols()),
});
}
let lu = covariance.full_piv_lu();
let sigmainv_s = lu.solve(s_mat);
let sigmainv_y = lu.solve(base_forecasts);
let st = s_mat.transpose();
let lhs = st * &sigmainv_s;
let rhs = st * &sigmainv_y;
lhs.full_piv_lu().solve(&rhs)
}
};
Ok(s_mat * b)
}