Skip to main content

linreg_core/prediction_intervals/
mod.rs

1//! Prediction Intervals Module
2//!
3//! Provides prediction interval functionality for OLS regression.
4//! Prediction intervals quantify uncertainty around individual future observations,
5//! as opposed to confidence intervals which quantify uncertainty around the mean response.
6//!
7//! # Formula
8//!
9//! For a prediction at point x₀:
10//! ```text
11//! PI = ŷ₀ ± t(α/2, df) × SE_pred
12//!
13//! SE_pred = √(MSE × (1 + h₀))
14//!
15//! where h₀ = x₀ᵀ(XᵀX)⁻¹x₀ is the leverage of the new point
16//! ```
17
18use crate::core::{compute_leverage, ols_regression, RegressionOutput};
19use crate::distributions::student_t_inverse_cdf;
20use crate::error::{Error, Result};
21use crate::linalg::Matrix;
22use crate::regularized::elastic_net::ElasticNetFit;
23use crate::regularized::lasso::LassoFit;
24use crate::regularized::ridge::RidgeFit;
25use serde::{Deserialize, Serialize};
26
27/// Output from prediction interval computation.
28#[derive(Serialize, Deserialize)]
29pub struct PredictionIntervalOutput {
30    /// Point predictions (fitted values) for new observations
31    pub predicted: Vec<f64>,
32    /// Lower bounds of prediction intervals
33    pub lower_bound: Vec<f64>,
34    /// Upper bounds of prediction intervals
35    pub upper_bound: Vec<f64>,
36    /// Standard errors for predictions (includes both parameter uncertainty and residual variance)
37    pub se_pred: Vec<f64>,
38    /// Leverage values for the new observations
39    pub leverage: Vec<f64>,
40    /// Significance level used (e.g., 0.05 for 95% PI)
41    pub alpha: f64,
42    /// Residual degrees of freedom from the fitted model
43    pub df_residuals: f64,
44}
45
46/// Computes prediction intervals for new observations from raw training data.
47///
48/// Fits an OLS model internally and then computes prediction intervals for the
49/// new observations. This follows the same pattern as diagnostic test functions.
50///
51/// # Arguments
52///
53/// * `y` - Response variable from training data
54/// * `x_vars` - Predictor variables from training data (each inner slice is one variable)
55/// * `new_x` - New predictor values to generate predictions for (each inner slice is one variable)
56/// * `alpha` - Significance level (e.g., 0.05 for 95% prediction interval)
57///
58/// # Returns
59///
60/// A [`PredictionIntervalOutput`] with predictions and interval bounds.
61pub fn prediction_intervals(
62    y: &[f64],
63    x_vars: &[Vec<f64>],
64    new_x: &[&[f64]],
65    alpha: f64,
66) -> Result<PredictionIntervalOutput> {
67    // Build names for OLS
68    let mut names = vec!["Intercept".to_string()];
69    for i in 0..x_vars.len() {
70        names.push(format!("X{}", i + 1));
71    }
72
73    let x_refs: Vec<Vec<f64>> = x_vars.to_vec();
74    let fit = ols_regression(y, &x_refs, &names)?;
75
76    compute_from_fit(&fit, x_vars, new_x, alpha)
77}
78
79/// Computes prediction intervals for new observations using a pre-fitted OLS model.
80///
81/// Requires the original training predictors to reconstruct (X'X)^{-1} for
82/// computing leverage of new points.
83///
84/// # Arguments
85///
86/// * `fit_result` - Reference to the fitted OLS model
87/// * `x_vars` - Original training predictor variables (each inner slice is one variable)
88/// * `new_x` - New predictor values to generate predictions for (each inner slice is one variable)
89/// * `alpha` - Significance level (e.g., 0.05 for 95% prediction interval)
90///
91/// # Returns
92///
93/// A [`PredictionIntervalOutput`] with predictions and interval bounds.
94pub fn compute_from_fit(
95    fit_result: &RegressionOutput,
96    x_vars: &[Vec<f64>],
97    new_x: &[&[f64]],
98    alpha: f64,
99) -> Result<PredictionIntervalOutput> {
100    let n = fit_result.n;
101    let k = fit_result.k;
102    let p = k + 1; // number of coefficients (including intercept)
103
104    // Validate alpha
105    if alpha <= 0.0 || alpha >= 1.0 {
106        return Err(Error::InvalidInput(
107            "alpha must be between 0 and 1 (exclusive)".to_string(),
108        ));
109    }
110
111    // Validate new_x dimensions
112    if new_x.len() != k {
113        return Err(Error::InvalidInput(format!(
114            "new_x has {} variables but model has {} predictors",
115            new_x.len(),
116            k
117        )));
118    }
119
120    if new_x.is_empty() {
121        return Err(Error::InvalidInput("new_x is empty".to_string()));
122    }
123
124    // Get number of new observations and validate consistent lengths
125    let n_new = new_x[0].len();
126    if n_new == 0 {
127        return Err(Error::InvalidInput(
128            "new_x variables have zero observations".to_string(),
129        ));
130    }
131    for (i, var) in new_x.iter().enumerate() {
132        if var.len() != n_new {
133            return Err(Error::InvalidInput(format!(
134                "new_x variable {} has {} observations but variable 0 has {}",
135                i,
136                var.len(),
137                n_new
138            )));
139        }
140        for val in var.iter() {
141            if !val.is_finite() {
142                return Err(Error::InvalidInput(
143                    "new_x contains non-finite values".to_string(),
144                ));
145            }
146        }
147    }
148
149    // Validate x_vars match the model
150    if x_vars.len() != k {
151        return Err(Error::InvalidInput(format!(
152            "x_vars has {} variables but model has {} predictors",
153            x_vars.len(),
154            k
155        )));
156    }
157
158    // Build the training design matrix X (n × p) with intercept column
159    let mut x_data = Vec::with_capacity(n * p);
160    for i in 0..n {
161        x_data.push(1.0); // intercept
162        for var in x_vars.iter() {
163            x_data.push(var[i]);
164        }
165    }
166    let x_matrix = Matrix::new(n, p, x_data);
167
168    // Compute (X'X)^{-1}
169    let xtx = x_matrix.transpose().matmul(&x_matrix);
170    let xtx_inv = match xtx.invert() {
171        Some(inv) => inv,
172        None => {
173            return Err(Error::InvalidInput(
174                "X'X is singular; cannot compute prediction intervals".to_string(),
175            ))
176        }
177    };
178
179    // Build the new observation design matrix (n_new × p)
180    let mut new_x_data = Vec::with_capacity(n_new * p);
181    for i in 0..n_new {
182        new_x_data.push(1.0); // intercept
183        for var in new_x.iter() {
184            new_x_data.push(var[i]);
185        }
186    }
187    let new_x_matrix = Matrix::new(n_new, p, new_x_data);
188
189    // Compute leverage for new points: h₀ = x₀ᵀ(XᵀX)⁻¹x₀
190    let new_leverage = compute_leverage(&new_x_matrix, &xtx_inv);
191
192    // Extract model parameters
193    let df_residuals = fit_result.df as f64;
194    let mse = fit_result.mse;
195    let beta = &fit_result.coefficients;
196
197    // Critical t-value
198    let t_critical = student_t_inverse_cdf(1.0 - alpha / 2.0, df_residuals);
199
200    // Compute predictions and intervals
201    let mut predicted = Vec::with_capacity(n_new);
202    let mut lower_bound = Vec::with_capacity(n_new);
203    let mut upper_bound = Vec::with_capacity(n_new);
204    let mut se_pred = Vec::with_capacity(n_new);
205
206    for i in 0..n_new {
207        // Compute predicted value: ŷ = x₀ᵀβ
208        let mut y_hat = 0.0;
209        for j in 0..p {
210            let x_val = new_x_matrix.get(i, j);
211            let b = beta[j];
212            if !b.is_nan() {
213                y_hat += x_val * b;
214            }
215        }
216        predicted.push(y_hat);
217
218        // Prediction standard error: SE_pred = √(MSE × (1 + h₀))
219        let h = new_leverage[i];
220        let se = (mse * (1.0 + h)).sqrt();
221        se_pred.push(se);
222
223        // Prediction interval bounds
224        let margin = t_critical * se;
225        lower_bound.push(y_hat - margin);
226        upper_bound.push(y_hat + margin);
227    }
228
229    Ok(PredictionIntervalOutput {
230        predicted,
231        lower_bound,
232        upper_bound,
233        se_pred,
234        leverage: new_leverage,
235        alpha,
236        df_residuals,
237    })
238}
239
240/// Shared helper for computing prediction intervals from regularized regression fits.
241///
242/// Uses the conservative approximation: leverage from unpenalized X'X, MSE from the
243/// regularized fit, and effective df from the fit.
244fn compute_regularized_pi(
245    intercept: f64,
246    coefficients: &[f64],
247    mse: f64,
248    df_residual: f64,
249    x_vars: &[Vec<f64>],
250    new_x: &[&[f64]],
251    alpha: f64,
252) -> Result<PredictionIntervalOutput> {
253    let k = x_vars.len(); // number of predictors (excluding intercept)
254
255    // Validate alpha
256    if alpha <= 0.0 || alpha >= 1.0 {
257        return Err(Error::InvalidInput(
258            "alpha must be between 0 and 1 (exclusive)".to_string(),
259        ));
260    }
261
262    // Validate dimensions
263    if new_x.len() != k {
264        return Err(Error::InvalidInput(format!(
265            "new_x has {} variables but model has {} predictors",
266            new_x.len(),
267            k
268        )));
269    }
270    if k == 0 || new_x.is_empty() {
271        return Err(Error::InvalidInput("new_x is empty".to_string()));
272    }
273
274    let n_new = new_x[0].len();
275    if n_new == 0 {
276        return Err(Error::InvalidInput(
277            "new_x variables have zero observations".to_string(),
278        ));
279    }
280    for (i, var) in new_x.iter().enumerate() {
281        if var.len() != n_new {
282            return Err(Error::InvalidInput(format!(
283                "new_x variable {} has {} observations but variable 0 has {}",
284                i,
285                var.len(),
286                n_new
287            )));
288        }
289        for val in var.iter() {
290            if !val.is_finite() {
291                return Err(Error::InvalidInput(
292                    "new_x contains non-finite values".to_string(),
293                ));
294            }
295        }
296    }
297
298    if coefficients.len() != k {
299        return Err(Error::InvalidInput(format!(
300            "coefficients has {} values but model has {} predictors",
301            coefficients.len(),
302            k
303        )));
304    }
305
306    // Validate df_residual
307    if df_residual <= 0.0 {
308        return Err(Error::InvalidInput(
309            "Effective degrees of freedom must be positive".to_string(),
310        ));
311    }
312
313    let n = x_vars[0].len();
314    let p = k + 1; // intercept + predictors
315
316    // Build training design matrix (n × p) with intercept column
317    let mut x_data = Vec::with_capacity(n * p);
318    for i in 0..n {
319        x_data.push(1.0);
320        for var in x_vars.iter() {
321            x_data.push(var[i]);
322        }
323    }
324    let x_matrix = Matrix::new(n, p, x_data);
325
326    // Compute (X'X)^{-1}
327    let xtx = x_matrix.transpose().matmul(&x_matrix);
328    let xtx_inv = match xtx.invert() {
329        Some(inv) => inv,
330        None => {
331            return Err(Error::InvalidInput(
332                "X'X is singular; cannot compute prediction intervals".to_string(),
333            ))
334        }
335    };
336
337    // Build new observation design matrix (n_new × p)
338    let mut new_x_data = Vec::with_capacity(n_new * p);
339    for i in 0..n_new {
340        new_x_data.push(1.0);
341        for var in new_x.iter() {
342            new_x_data.push(var[i]);
343        }
344    }
345    let new_x_matrix = Matrix::new(n_new, p, new_x_data);
346
347    // Compute leverage for new points
348    let new_leverage = compute_leverage(&new_x_matrix, &xtx_inv);
349
350    // Critical t-value
351    let t_critical = student_t_inverse_cdf(1.0 - alpha / 2.0, df_residual);
352
353    // Compute predictions and intervals
354    let mut predicted = Vec::with_capacity(n_new);
355    let mut lower_bound = Vec::with_capacity(n_new);
356    let mut upper_bound = Vec::with_capacity(n_new);
357    let mut se_pred = Vec::with_capacity(n_new);
358
359    for i in 0..n_new {
360        // ŷ = intercept + Σ(coef_j × x_j)
361        let mut y_hat = intercept;
362        for (j, coef) in coefficients.iter().enumerate() {
363            y_hat += coef * new_x[j][i];
364        }
365        predicted.push(y_hat);
366
367        let h = new_leverage[i];
368        let se = (mse * (1.0 + h)).sqrt();
369        se_pred.push(se);
370
371        let margin = t_critical * se;
372        lower_bound.push(y_hat - margin);
373        upper_bound.push(y_hat + margin);
374    }
375
376    Ok(PredictionIntervalOutput {
377        predicted,
378        lower_bound,
379        upper_bound,
380        se_pred,
381        leverage: new_leverage,
382        alpha,
383        df_residuals: df_residual,
384    })
385}
386
387/// Computes approximate prediction intervals for Ridge regression.
388///
389/// Uses the conservative approximation with leverage from unpenalized X'X,
390/// MSE from the ridge fit, and effective degrees of freedom from `fit.df`.
391///
392/// # Arguments
393///
394/// * `fit` - Reference to the fitted Ridge model
395/// * `x_vars` - Original training predictor variables (each inner vec is one variable)
396/// * `new_x` - New predictor values (each inner slice is one variable)
397/// * `alpha` - Significance level (e.g., 0.05 for 95% prediction interval)
398pub fn ridge_prediction_intervals(
399    fit: &RidgeFit,
400    x_vars: &[Vec<f64>],
401    new_x: &[&[f64]],
402    alpha: f64,
403) -> Result<PredictionIntervalOutput> {
404    let n = x_vars.get(0).map_or(0, |v| v.len()) as f64;
405    // df_residual = n - 1 - effective_df (where fit.df is approximate effective df)
406    let df_residual = n - 1.0 - fit.df;
407    compute_regularized_pi(fit.intercept, &fit.coefficients, fit.mse, df_residual, x_vars, new_x, alpha)
408}
409
410/// Computes approximate prediction intervals for Lasso regression.
411///
412/// Uses the conservative approximation with leverage from unpenalized X'X,
413/// MSE from the lasso fit, and `n_nonzero` as the effective degrees of freedom.
414///
415/// # Arguments
416///
417/// * `fit` - Reference to the fitted Lasso model
418/// * `x_vars` - Original training predictor variables (each inner vec is one variable)
419/// * `new_x` - New predictor values (each inner slice is one variable)
420/// * `alpha` - Significance level (e.g., 0.05 for 95% prediction interval)
421pub fn lasso_prediction_intervals(
422    fit: &LassoFit,
423    x_vars: &[Vec<f64>],
424    new_x: &[&[f64]],
425    alpha: f64,
426) -> Result<PredictionIntervalOutput> {
427    let n = x_vars.get(0).map_or(0, |v| v.len()) as f64;
428    let df_residual = n - 1.0 - fit.n_nonzero as f64;
429    compute_regularized_pi(fit.intercept, &fit.coefficients, fit.mse, df_residual, x_vars, new_x, alpha)
430}
431
432/// Computes approximate prediction intervals for Elastic Net regression.
433///
434/// Uses the conservative approximation with leverage from unpenalized X'X,
435/// MSE from the elastic net fit, and `n_nonzero` as the effective degrees of freedom.
436///
437/// # Arguments
438///
439/// * `fit` - Reference to the fitted Elastic Net model
440/// * `x_vars` - Original training predictor variables (each inner vec is one variable)
441/// * `new_x` - New predictor values (each inner slice is one variable)
442/// * `alpha` - Significance level (e.g., 0.05 for 95% prediction interval)
443pub fn elastic_net_prediction_intervals(
444    fit: &ElasticNetFit,
445    x_vars: &[Vec<f64>],
446    new_x: &[&[f64]],
447    alpha: f64,
448) -> Result<PredictionIntervalOutput> {
449    let n = x_vars.get(0).map_or(0, |v| v.len()) as f64;
450    let df_residual = n - 1.0 - fit.n_nonzero as f64;
451    compute_regularized_pi(fit.intercept, &fit.coefficients, fit.mse, df_residual, x_vars, new_x, alpha)
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_prediction_intervals_simple() {
460        // y = 2x + noise
461        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1];
462        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
463
464        let names = vec!["Intercept".to_string(), "X1".to_string()];
465        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
466
467        let new_x1 = [6.0];
468        let result = compute_from_fit(&fit, &[x1], &[&new_x1], 0.05).unwrap();
469
470        assert_eq!(result.predicted.len(), 1);
471        // PI bounds should bracket the prediction
472        assert!(result.lower_bound[0] < result.predicted[0]);
473        assert!(result.upper_bound[0] > result.predicted[0]);
474        assert!(result.se_pred[0] > 0.0);
475        assert!((result.alpha - 0.05).abs() < 1e-10);
476    }
477
478    #[test]
479    fn test_prediction_intervals_multiple_observations() {
480        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1];
481        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
482
483        let names = vec!["Intercept".to_string(), "X1".to_string()];
484        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
485
486        // Predict at multiple new points
487        let new_x1 = [6.0, 7.0, 3.0];
488        let result = compute_from_fit(&fit, &[x1], &[&new_x1], 0.05).unwrap();
489
490        assert_eq!(result.predicted.len(), 3);
491        assert_eq!(result.lower_bound.len(), 3);
492        assert_eq!(result.upper_bound.len(), 3);
493
494        for i in 0..3 {
495            assert!(result.lower_bound[i] < result.predicted[i]);
496            assert!(result.upper_bound[i] > result.predicted[i]);
497        }
498    }
499
500    #[test]
501    fn test_prediction_intervals_multiple_predictors() {
502        let y = vec![3.0, 5.5, 7.0, 9.5, 11.0];
503        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
504        let x2 = vec![2.0, 4.0, 5.0, 6.0, 8.0];
505
506        let names = vec![
507            "Intercept".to_string(),
508            "X1".to_string(),
509            "X2".to_string(),
510        ];
511        let fit = ols_regression(&y, &[x1.clone(), x2.clone()], &names).unwrap();
512
513        let new_x1 = [6.0];
514        let new_x2 = [9.0];
515        let result =
516            compute_from_fit(&fit, &[x1, x2], &[&new_x1, &new_x2], 0.05).unwrap();
517
518        assert_eq!(result.predicted.len(), 1);
519        assert!(result.lower_bound[0] < result.predicted[0]);
520        assert!(result.upper_bound[0] > result.predicted[0]);
521    }
522
523    #[test]
524    fn test_wider_pi_for_lower_alpha() {
525        // Lower alpha (higher confidence) should give wider intervals
526        let y = vec![1.2, 2.1, 2.8, 4.1, 4.9];
527        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
528
529        let names = vec!["Intercept".to_string(), "X1".to_string()];
530        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
531
532        let new_x1 = [3.0];
533
534        let result_95 =
535            compute_from_fit(&fit, &[x1.clone()], &[&new_x1], 0.05).unwrap();
536        let result_99 =
537            compute_from_fit(&fit, &[x1], &[&new_x1], 0.01).unwrap();
538
539        let width_95 = result_95.upper_bound[0] - result_95.lower_bound[0];
540        let width_99 = result_99.upper_bound[0] - result_99.lower_bound[0];
541
542        // 99% PI should be wider than 95% PI
543        assert!(width_99 > width_95);
544    }
545
546    #[test]
547    fn test_extrapolation_has_higher_leverage() {
548        // Points far from the training data center should have higher leverage
549        let y = vec![1.2, 2.1, 2.8, 4.1, 4.9];
550        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
551
552        let names = vec!["Intercept".to_string(), "X1".to_string()];
553        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
554
555        // x=3 is at the center, x=10 is extrapolation
556        let new_center = [3.0];
557        let new_extrap = [10.0];
558
559        let result_center =
560            compute_from_fit(&fit, &[x1.clone()], &[&new_center], 0.05).unwrap();
561        let result_extrap =
562            compute_from_fit(&fit, &[x1], &[&new_extrap], 0.05).unwrap();
563
564        // Extrapolation point should have higher leverage and wider PI
565        assert!(result_extrap.leverage[0] > result_center.leverage[0]);
566        assert!(result_extrap.se_pred[0] > result_center.se_pred[0]);
567
568        let width_center = result_center.upper_bound[0] - result_center.lower_bound[0];
569        let width_extrap = result_extrap.upper_bound[0] - result_extrap.lower_bound[0];
570        assert!(width_extrap > width_center);
571    }
572
573    #[test]
574    fn test_prediction_intervals_convenience_function() {
575        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1];
576        let x_vars = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
577
578        let new_x1 = [6.0];
579        let result = prediction_intervals(&y, &x_vars, &[&new_x1], 0.05).unwrap();
580
581        assert_eq!(result.predicted.len(), 1);
582        assert!(result.lower_bound[0] < result.predicted[0]);
583        assert!(result.upper_bound[0] > result.predicted[0]);
584    }
585
586    #[test]
587    fn test_dimension_mismatch_error() {
588        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
589        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
590
591        let names = vec!["Intercept".to_string(), "X1".to_string()];
592        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
593
594        // Wrong number of predictor variables in new_x
595        let new_x1 = [6.0];
596        let new_x2 = [7.0];
597        let result = compute_from_fit(&fit, &[x1], &[&new_x1, &new_x2], 0.05);
598        assert!(result.is_err());
599    }
600
601    #[test]
602    fn test_invalid_alpha() {
603        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
604        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
605
606        let names = vec!["Intercept".to_string(), "X1".to_string()];
607        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
608
609        let new_x1 = [6.0];
610        assert!(compute_from_fit(&fit, &[x1.clone()], &[&new_x1], 0.0).is_err());
611        assert!(compute_from_fit(&fit, &[x1.clone()], &[&new_x1], 1.0).is_err());
612        assert!(compute_from_fit(&fit, &[x1], &[&new_x1], -0.1).is_err());
613    }
614
615    #[test]
616    fn test_se_pred_includes_residual_variance() {
617        // SE_pred should always be >= sqrt(MSE) since SE_pred = sqrt(MSE * (1 + h))
618        // and h >= 0
619        let y = vec![1.2, 2.1, 2.8, 4.1, 4.9];
620        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
621
622        let names = vec!["Intercept".to_string(), "X1".to_string()];
623        let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
624
625        let new_x1 = [3.0];
626        let result = compute_from_fit(&fit, &[x1], &[&new_x1], 0.05).unwrap();
627
628        let sqrt_mse = fit.mse.sqrt();
629        assert!(result.se_pred[0] >= sqrt_mse);
630    }
631
632    // =========================================================================
633    // Regularized prediction interval tests
634    // =========================================================================
635
636    #[test]
637    fn test_ridge_prediction_intervals_simple() {
638        use crate::regularized::ridge::{ridge_fit, RidgeFitOptions};
639
640        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
641        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
642
643        // Build design matrix with intercept
644        let mut x_data = Vec::new();
645        for i in 0..y.len() {
646            x_data.push(1.0);
647            x_data.push(x1[i]);
648        }
649        let x = Matrix::new(y.len(), 2, x_data);
650
651        let options = RidgeFitOptions {
652            lambda: 0.1,
653            intercept: true,
654            standardize: true,
655            ..Default::default()
656        };
657        let fit = ridge_fit(&x, &y, &options).unwrap();
658
659        let new_x1 = [8.0];
660        let result = ridge_prediction_intervals(&fit, &[x1], &[&new_x1], 0.05).unwrap();
661
662        assert_eq!(result.predicted.len(), 1);
663        assert!(result.lower_bound[0] < result.predicted[0]);
664        assert!(result.upper_bound[0] > result.predicted[0]);
665        assert!(result.se_pred[0] > 0.0);
666        // Prediction should be roughly 2*8 + 1 = 17
667        assert!((result.predicted[0] - 17.0).abs() < 2.0);
668    }
669
670    #[test]
671    fn test_lasso_prediction_intervals_basic() {
672        use crate::regularized::lasso::{lasso_fit, LassoFitOptions};
673
674        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
675        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
676
677        let mut x_data = Vec::new();
678        for i in 0..y.len() {
679            x_data.push(1.0);
680            x_data.push(x1[i]);
681        }
682        let x = Matrix::new(y.len(), 2, x_data);
683
684        let options = LassoFitOptions {
685            lambda: 0.01,
686            intercept: true,
687            standardize: true,
688            ..Default::default()
689        };
690        let fit = lasso_fit(&x, &y, &options).unwrap();
691
692        let new_x1 = [8.0];
693        let result = lasso_prediction_intervals(&fit, &[x1], &[&new_x1], 0.05).unwrap();
694
695        assert_eq!(result.predicted.len(), 1);
696        assert!(result.lower_bound[0] < result.predicted[0]);
697        assert!(result.upper_bound[0] > result.predicted[0]);
698        assert!(result.se_pred[0] > 0.0);
699    }
700
701    #[test]
702    fn test_elastic_net_prediction_intervals_basic() {
703        use crate::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
704
705        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
706        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
707
708        let mut x_data = Vec::new();
709        for i in 0..y.len() {
710            x_data.push(1.0);
711            x_data.push(x1[i]);
712        }
713        let x = Matrix::new(y.len(), 2, x_data);
714
715        let options = ElasticNetOptions {
716            lambda: 0.01,
717            alpha: 0.5,
718            intercept: true,
719            standardize: true,
720            ..Default::default()
721        };
722        let fit = elastic_net_fit(&x, &y, &options).unwrap();
723
724        let new_x1 = [8.0];
725        let result = elastic_net_prediction_intervals(&fit, &[x1], &[&new_x1], 0.05).unwrap();
726
727        assert_eq!(result.predicted.len(), 1);
728        assert!(result.lower_bound[0] < result.predicted[0]);
729        assert!(result.upper_bound[0] > result.predicted[0]);
730    }
731
732    #[test]
733    fn test_regularized_pi_extrapolation_wider() {
734        use crate::regularized::ridge::{ridge_fit, RidgeFitOptions};
735
736        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
737        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
738
739        let mut x_data = Vec::new();
740        for i in 0..y.len() {
741            x_data.push(1.0);
742            x_data.push(x1[i]);
743        }
744        let x = Matrix::new(y.len(), 2, x_data);
745
746        let options = RidgeFitOptions {
747            lambda: 0.1,
748            intercept: true,
749            standardize: true,
750            ..Default::default()
751        };
752        let fit = ridge_fit(&x, &y, &options).unwrap();
753
754        // Center vs far extrapolation
755        let center = [4.0];
756        let extrap = [20.0];
757
758        let result_center = ridge_prediction_intervals(&fit, &[x1.clone()], &[&center], 0.05).unwrap();
759        let result_extrap = ridge_prediction_intervals(&fit, &[x1], &[&extrap], 0.05).unwrap();
760
761        let width_center = result_center.upper_bound[0] - result_center.lower_bound[0];
762        let width_extrap = result_extrap.upper_bound[0] - result_extrap.lower_bound[0];
763
764        assert!(width_extrap > width_center, "Extrapolation PI should be wider");
765    }
766
767    #[test]
768    fn test_regularized_pi_alpha_comparison() {
769        use crate::regularized::ridge::{ridge_fit, RidgeFitOptions};
770
771        let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
772        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
773
774        let mut x_data = Vec::new();
775        for i in 0..y.len() {
776            x_data.push(1.0);
777            x_data.push(x1[i]);
778        }
779        let x = Matrix::new(y.len(), 2, x_data);
780
781        let options = RidgeFitOptions {
782            lambda: 0.1,
783            intercept: true,
784            standardize: true,
785            ..Default::default()
786        };
787        let fit = ridge_fit(&x, &y, &options).unwrap();
788
789        let new_x1 = [8.0];
790        let result_95 = ridge_prediction_intervals(&fit, &[x1.clone()], &[&new_x1], 0.05).unwrap();
791        let result_99 = ridge_prediction_intervals(&fit, &[x1], &[&new_x1], 0.01).unwrap();
792
793        let width_95 = result_95.upper_bound[0] - result_95.lower_bound[0];
794        let width_99 = result_99.upper_bound[0] - result_99.lower_bound[0];
795
796        assert!(width_99 > width_95, "99% PI should be wider than 95% PI");
797    }
798}