use crate::core::{IntervalType, PredictionResult};
use faer::{Col, Mat};
use statrs::distribution::{ContinuousCDF, StudentsT};
#[allow(clippy::too_many_arguments)]
pub fn compute_prediction_intervals(
x_new: &Mat<f64>,
xtx_inv: &Mat<f64>,
predictions: &Col<f64>,
mse: f64,
df: f64,
confidence_level: f64,
interval_type: IntervalType,
has_intercept: bool,
) -> PredictionResult {
let n_new = x_new.nrows();
if df <= 0.0 || mse < 0.0 {
return create_nan_result(predictions, n_new);
}
let t_crit = compute_t_critical(df, confidence_level);
let (se, lower, upper) = compute_all_intervals(
x_new,
xtx_inv,
predictions,
mse,
t_crit,
interval_type,
has_intercept,
);
PredictionResult::with_intervals(predictions.clone(), lower, upper, se)
}
fn create_nan_result(predictions: &Col<f64>, n: usize) -> PredictionResult {
let se = Col::from_fn(n, |_| f64::NAN);
let lower = Col::from_fn(n, |_| f64::NAN);
let upper = Col::from_fn(n, |_| f64::NAN);
PredictionResult::with_intervals(predictions.clone(), lower, upper, se)
}
fn compute_t_critical(df: f64, confidence_level: f64) -> f64 {
let t_dist = StudentsT::new(0.0, 1.0, df).expect("valid t-distribution parameters");
let alpha = 1.0 - confidence_level;
t_dist.inverse_cdf(1.0 - alpha / 2.0)
}
fn compute_all_intervals(
x_new: &Mat<f64>,
xtx_inv: &Mat<f64>,
predictions: &Col<f64>,
mse: f64,
t_crit: f64,
interval_type: IntervalType,
has_intercept: bool,
) -> (Col<f64>, Col<f64>, Col<f64>) {
let n_new = x_new.nrows();
let mut se = Col::zeros(n_new);
let mut lower = Col::zeros(n_new);
let mut upper = Col::zeros(n_new);
for i in 0..n_new {
let (s, l, u) = compute_single_interval(
x_new,
xtx_inv,
predictions[i],
mse,
t_crit,
interval_type,
has_intercept,
i,
);
se[i] = s;
lower[i] = l;
upper[i] = u;
}
(se, lower, upper)
}
#[allow(clippy::too_many_arguments)]
fn compute_single_interval(
x_new: &Mat<f64>,
xtx_inv: &Mat<f64>,
prediction: f64,
mse: f64,
t_crit: f64,
interval_type: IntervalType,
has_intercept: bool,
row: usize,
) -> (f64, f64, f64) {
let x0 = build_observation_vector(x_new, row, has_intercept);
let h = compute_leverage_single(&x0, xtx_inv);
let var = compute_interval_variance(mse, h, interval_type);
let se = if var >= 0.0 { var.sqrt() } else { f64::NAN };
let margin = t_crit * se;
(se, prediction - margin, prediction + margin)
}
fn build_observation_vector(x_new: &Mat<f64>, row: usize, has_intercept: bool) -> Col<f64> {
let n_features = x_new.ncols();
if has_intercept {
let mut x0 = Col::zeros(n_features + 1);
x0[0] = 1.0;
for j in 0..n_features {
x0[j + 1] = x_new[(row, j)];
}
x0
} else {
Col::from_fn(n_features, |j| x_new[(row, j)])
}
}
fn compute_interval_variance(mse: f64, leverage: f64, interval_type: IntervalType) -> f64 {
match interval_type {
IntervalType::Confidence => mse * leverage,
IntervalType::Prediction => mse * (1.0 + leverage),
}
}
fn compute_leverage_single(x0: &Col<f64>, xtx_inv: &Mat<f64>) -> f64 {
let p = x0.nrows();
let mut xtx_inv_x0 = Col::zeros(p);
for i in 0..p {
let mut sum = 0.0;
for j in 0..p {
sum += xtx_inv[(i, j)] * x0[j];
}
xtx_inv_x0[i] = sum;
}
let mut h = 0.0;
for i in 0..p {
h += x0[i] * xtx_inv_x0[i];
}
h
}
pub fn compute_xtx_inverse_augmented(x: &Mat<f64>) -> Result<Mat<f64>, &'static str> {
let n_samples = x.nrows();
let n_features = x.ncols();
let aug_size = n_features + 1;
let mut x_aug = Mat::zeros(n_samples, aug_size);
for i in 0..n_samples {
x_aug[(i, 0)] = 1.0;
for j in 0..n_features {
x_aug[(i, j + 1)] = x[(i, j)];
}
}
let xtx_aug = x_aug.transpose() * &x_aug;
compute_matrix_inverse(&xtx_aug)
}
pub fn compute_xtwx_inverse_augmented(
x: &Mat<f64>,
weights: &Col<f64>,
) -> Result<Mat<f64>, &'static str> {
let n_samples = x.nrows();
let n_features = x.ncols();
let aug_size = n_features + 1;
let mut xtwx_aug: Mat<f64> = Mat::zeros(aug_size, aug_size);
for i in 0..n_samples {
let w = weights[i];
xtwx_aug[(0, 0)] += w;
for j in 0..n_features {
xtwx_aug[(0, j + 1)] += w * x[(i, j)];
xtwx_aug[(j + 1, 0)] += w * x[(i, j)];
}
for j in 0..n_features {
for k in 0..n_features {
xtwx_aug[(j + 1, k + 1)] += w * x[(i, j)] * x[(i, k)];
}
}
}
compute_matrix_inverse(&xtwx_aug)
}
pub fn compute_xtx_inverse(x: &Mat<f64>) -> Result<Mat<f64>, &'static str> {
let xtx = x.transpose() * x;
compute_matrix_inverse(&xtx)
}
pub fn compute_xtx_inverse_augmented_reduced(
x: &Mat<f64>,
aliased: &[bool],
) -> Result<Mat<f64>, &'static str> {
let n_samples = x.nrows();
let n_features = x.ncols();
let non_aliased_cols: Vec<usize> = (0..n_features).filter(|&j| !aliased[j]).collect();
let n_reduced = non_aliased_cols.len();
let aug_size = n_reduced + 1;
let mut x_aug = Mat::zeros(n_samples, aug_size);
for i in 0..n_samples {
x_aug[(i, 0)] = 1.0;
for (k, &j) in non_aliased_cols.iter().enumerate() {
x_aug[(i, k + 1)] = x[(i, j)];
}
}
let xtx_aug = x_aug.transpose() * &x_aug;
compute_matrix_inverse(&xtx_aug)
}
pub fn compute_xtx_inverse_reduced(
x: &Mat<f64>,
aliased: &[bool],
) -> Result<Mat<f64>, &'static str> {
let n_samples = x.nrows();
let n_features = x.ncols();
let non_aliased_cols: Vec<usize> = (0..n_features).filter(|&j| !aliased[j]).collect();
let n_reduced = non_aliased_cols.len();
if n_reduced == 0 {
return Err("All columns are aliased");
}
let mut x_reduced = Mat::zeros(n_samples, n_reduced);
for i in 0..n_samples {
for (k, &j) in non_aliased_cols.iter().enumerate() {
x_reduced[(i, k)] = x[(i, j)];
}
}
let xtx = x_reduced.transpose() * &x_reduced;
compute_matrix_inverse(&xtx)
}
pub fn compute_xtwx_inverse_augmented_reduced(
x: &Mat<f64>,
weights: &Col<f64>,
aliased: &[bool],
) -> Result<Mat<f64>, &'static str> {
let n_samples = x.nrows();
let n_features = x.ncols();
let non_aliased_cols: Vec<usize> = (0..n_features).filter(|&j| !aliased[j]).collect();
let n_reduced = non_aliased_cols.len();
let aug_size = n_reduced + 1;
let mut xtwx_aug: Mat<f64> = Mat::zeros(aug_size, aug_size);
for i in 0..n_samples {
let w = weights[i];
xtwx_aug[(0, 0)] += w;
for (k, &j) in non_aliased_cols.iter().enumerate() {
xtwx_aug[(0, k + 1)] += w * x[(i, j)];
xtwx_aug[(k + 1, 0)] += w * x[(i, j)];
}
for (k, &j) in non_aliased_cols.iter().enumerate() {
for (m, &l) in non_aliased_cols.iter().enumerate() {
xtwx_aug[(k + 1, m + 1)] += w * x[(i, j)] * x[(i, l)];
}
}
}
compute_matrix_inverse(&xtwx_aug)
}
pub fn compute_xtwx_inverse_reduced(
x: &Mat<f64>,
weights: &Col<f64>,
aliased: &[bool],
) -> Result<Mat<f64>, &'static str> {
let n_samples = x.nrows();
let n_features = x.ncols();
let non_aliased_cols: Vec<usize> = (0..n_features).filter(|&j| !aliased[j]).collect();
let n_reduced = non_aliased_cols.len();
if n_reduced == 0 {
return Err("All columns are aliased");
}
let mut xtwx: Mat<f64> = Mat::zeros(n_reduced, n_reduced);
for i in 0..n_samples {
let w = weights[i];
for (k, &j) in non_aliased_cols.iter().enumerate() {
for (m, &l) in non_aliased_cols.iter().enumerate() {
xtwx[(k, m)] += w * x[(i, j)] * x[(i, l)];
}
}
}
compute_matrix_inverse(&xtwx)
}
pub(crate) fn compute_matrix_inverse(matrix: &Mat<f64>) -> Result<Mat<f64>, &'static str> {
let n = matrix.nrows();
let qr: faer::linalg::solvers::Qr<f64> = matrix.qr();
let q = qr.compute_Q();
let r = qr.R();
for i in 0..n {
if r[(i, i)].abs() < 1e-10 {
return Err("Matrix is singular");
}
}
let mut inv = Mat::zeros(n, n);
let qt = q.transpose();
for col in 0..n {
for i in (0..n).rev() {
let mut sum = qt[(i, col)];
for j in (i + 1)..n {
sum -= r[(i, j)] * inv[(j, col)];
}
inv[(i, col)] = sum / r[(i, i)];
}
}
Ok(inv)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leverage_single() {
let x0 = Col::from_fn(2, |i| (i + 1) as f64);
let xtx_inv = Mat::identity(2, 2);
let h = compute_leverage_single(&x0, &xtx_inv);
assert!((h - 5.0).abs() < 1e-10);
}
#[test]
fn test_prediction_interval_wider_than_confidence() {
let x_new = Mat::from_fn(3, 1, |_, _| 1.0);
let xtx_inv = Mat::identity(2, 2); let predictions = Col::from_fn(3, |i| i as f64);
let mse = 1.0;
let df = 10.0;
let ci = compute_prediction_intervals(
&x_new,
&xtx_inv,
&predictions,
mse,
df,
0.95,
IntervalType::Confidence,
true,
);
let pi = compute_prediction_intervals(
&x_new,
&xtx_inv,
&predictions,
mse,
df,
0.95,
IntervalType::Prediction,
true,
);
for i in 0..3 {
let ci_width = ci.upper[i] - ci.lower[i];
let pi_width = pi.upper[i] - pi.lower[i];
assert!(pi_width > ci_width, "PI should be wider than CI");
}
}
#[test]
fn test_compute_matrix_inverse_1x1() {
let x = Mat::from_fn(10, 1, |i, _| (i + 1) as f64);
let xtx = x.transpose() * &x;
assert_eq!(xtx.nrows(), 1);
assert_eq!(xtx.ncols(), 1);
let expected_sum = (1..=10).map(|i| (i * i) as f64).sum::<f64>();
assert!((xtx[(0, 0)] - expected_sum).abs() < 1e-10);
let inv = compute_matrix_inverse(&xtx).expect("Should not fail");
assert_eq!(inv.nrows(), 1);
assert_eq!(inv.ncols(), 1);
let expected_inv = 1.0 / expected_sum;
assert!(
(inv[(0, 0)] - expected_inv).abs() < 1e-10,
"Expected inv = {}, got {}",
expected_inv,
inv[(0, 0)]
);
}
}