1use crate::core::{compute_leverage, ols_regression, RegressionOutput};
19use crate::distributions::student_t_inverse_cdf;
20use crate::error::{Error, Result};
21use crate::linalg::Matrix;
22use crate::regularized::elastic_net::ElasticNetFit;
23use crate::regularized::lasso::LassoFit;
24use crate::regularized::ridge::RidgeFit;
25use serde::{Deserialize, Serialize};
26
27#[derive(Serialize, Deserialize)]
29pub struct PredictionIntervalOutput {
30 pub predicted: Vec<f64>,
32 pub lower_bound: Vec<f64>,
34 pub upper_bound: Vec<f64>,
36 pub se_pred: Vec<f64>,
38 pub leverage: Vec<f64>,
40 pub alpha: f64,
42 pub df_residuals: f64,
44}
45
46pub fn prediction_intervals(
62 y: &[f64],
63 x_vars: &[Vec<f64>],
64 new_x: &[&[f64]],
65 alpha: f64,
66) -> Result<PredictionIntervalOutput> {
67 let mut names = vec!["Intercept".to_string()];
69 for i in 0..x_vars.len() {
70 names.push(format!("X{}", i + 1));
71 }
72
73 let x_refs: Vec<Vec<f64>> = x_vars.to_vec();
74 let fit = ols_regression(y, &x_refs, &names)?;
75
76 compute_from_fit(&fit, x_vars, new_x, alpha)
77}
78
79pub fn compute_from_fit(
95 fit_result: &RegressionOutput,
96 x_vars: &[Vec<f64>],
97 new_x: &[&[f64]],
98 alpha: f64,
99) -> Result<PredictionIntervalOutput> {
100 let n = fit_result.n;
101 let k = fit_result.k;
102 let p = k + 1; if alpha <= 0.0 || alpha >= 1.0 {
106 return Err(Error::InvalidInput(
107 "alpha must be between 0 and 1 (exclusive)".to_string(),
108 ));
109 }
110
111 if new_x.len() != k {
113 return Err(Error::InvalidInput(format!(
114 "new_x has {} variables but model has {} predictors",
115 new_x.len(),
116 k
117 )));
118 }
119
120 if new_x.is_empty() {
121 return Err(Error::InvalidInput("new_x is empty".to_string()));
122 }
123
124 let n_new = new_x[0].len();
126 if n_new == 0 {
127 return Err(Error::InvalidInput(
128 "new_x variables have zero observations".to_string(),
129 ));
130 }
131 for (i, var) in new_x.iter().enumerate() {
132 if var.len() != n_new {
133 return Err(Error::InvalidInput(format!(
134 "new_x variable {} has {} observations but variable 0 has {}",
135 i,
136 var.len(),
137 n_new
138 )));
139 }
140 for val in var.iter() {
141 if !val.is_finite() {
142 return Err(Error::InvalidInput(
143 "new_x contains non-finite values".to_string(),
144 ));
145 }
146 }
147 }
148
149 if x_vars.len() != k {
151 return Err(Error::InvalidInput(format!(
152 "x_vars has {} variables but model has {} predictors",
153 x_vars.len(),
154 k
155 )));
156 }
157
158 let mut x_data = Vec::with_capacity(n * p);
160 for i in 0..n {
161 x_data.push(1.0); for var in x_vars.iter() {
163 x_data.push(var[i]);
164 }
165 }
166 let x_matrix = Matrix::new(n, p, x_data);
167
168 let xtx = x_matrix.transpose().matmul(&x_matrix);
170 let xtx_inv = match xtx.invert() {
171 Some(inv) => inv,
172 None => {
173 return Err(Error::InvalidInput(
174 "X'X is singular; cannot compute prediction intervals".to_string(),
175 ))
176 }
177 };
178
179 let mut new_x_data = Vec::with_capacity(n_new * p);
181 for i in 0..n_new {
182 new_x_data.push(1.0); for var in new_x.iter() {
184 new_x_data.push(var[i]);
185 }
186 }
187 let new_x_matrix = Matrix::new(n_new, p, new_x_data);
188
189 let new_leverage = compute_leverage(&new_x_matrix, &xtx_inv);
191
192 let df_residuals = fit_result.df as f64;
194 let mse = fit_result.mse;
195 let beta = &fit_result.coefficients;
196
197 let t_critical = student_t_inverse_cdf(1.0 - alpha / 2.0, df_residuals);
199
200 let mut predicted = Vec::with_capacity(n_new);
202 let mut lower_bound = Vec::with_capacity(n_new);
203 let mut upper_bound = Vec::with_capacity(n_new);
204 let mut se_pred = Vec::with_capacity(n_new);
205
206 for i in 0..n_new {
207 let mut y_hat = 0.0;
209 for j in 0..p {
210 let x_val = new_x_matrix.get(i, j);
211 let b = beta[j];
212 if !b.is_nan() {
213 y_hat += x_val * b;
214 }
215 }
216 predicted.push(y_hat);
217
218 let h = new_leverage[i];
220 let se = (mse * (1.0 + h)).sqrt();
221 se_pred.push(se);
222
223 let margin = t_critical * se;
225 lower_bound.push(y_hat - margin);
226 upper_bound.push(y_hat + margin);
227 }
228
229 Ok(PredictionIntervalOutput {
230 predicted,
231 lower_bound,
232 upper_bound,
233 se_pred,
234 leverage: new_leverage,
235 alpha,
236 df_residuals,
237 })
238}
239
240fn compute_regularized_pi(
245 intercept: f64,
246 coefficients: &[f64],
247 mse: f64,
248 df_residual: f64,
249 x_vars: &[Vec<f64>],
250 new_x: &[&[f64]],
251 alpha: f64,
252) -> Result<PredictionIntervalOutput> {
253 let k = x_vars.len(); if alpha <= 0.0 || alpha >= 1.0 {
257 return Err(Error::InvalidInput(
258 "alpha must be between 0 and 1 (exclusive)".to_string(),
259 ));
260 }
261
262 if new_x.len() != k {
264 return Err(Error::InvalidInput(format!(
265 "new_x has {} variables but model has {} predictors",
266 new_x.len(),
267 k
268 )));
269 }
270 if k == 0 || new_x.is_empty() {
271 return Err(Error::InvalidInput("new_x is empty".to_string()));
272 }
273
274 let n_new = new_x[0].len();
275 if n_new == 0 {
276 return Err(Error::InvalidInput(
277 "new_x variables have zero observations".to_string(),
278 ));
279 }
280 for (i, var) in new_x.iter().enumerate() {
281 if var.len() != n_new {
282 return Err(Error::InvalidInput(format!(
283 "new_x variable {} has {} observations but variable 0 has {}",
284 i,
285 var.len(),
286 n_new
287 )));
288 }
289 for val in var.iter() {
290 if !val.is_finite() {
291 return Err(Error::InvalidInput(
292 "new_x contains non-finite values".to_string(),
293 ));
294 }
295 }
296 }
297
298 if coefficients.len() != k {
299 return Err(Error::InvalidInput(format!(
300 "coefficients has {} values but model has {} predictors",
301 coefficients.len(),
302 k
303 )));
304 }
305
306 if df_residual <= 0.0 {
308 return Err(Error::InvalidInput(
309 "Effective degrees of freedom must be positive".to_string(),
310 ));
311 }
312
313 let n = x_vars[0].len();
314 let p = k + 1; let mut x_data = Vec::with_capacity(n * p);
318 for i in 0..n {
319 x_data.push(1.0);
320 for var in x_vars.iter() {
321 x_data.push(var[i]);
322 }
323 }
324 let x_matrix = Matrix::new(n, p, x_data);
325
326 let xtx = x_matrix.transpose().matmul(&x_matrix);
328 let xtx_inv = match xtx.invert() {
329 Some(inv) => inv,
330 None => {
331 return Err(Error::InvalidInput(
332 "X'X is singular; cannot compute prediction intervals".to_string(),
333 ))
334 }
335 };
336
337 let mut new_x_data = Vec::with_capacity(n_new * p);
339 for i in 0..n_new {
340 new_x_data.push(1.0);
341 for var in new_x.iter() {
342 new_x_data.push(var[i]);
343 }
344 }
345 let new_x_matrix = Matrix::new(n_new, p, new_x_data);
346
347 let new_leverage = compute_leverage(&new_x_matrix, &xtx_inv);
349
350 let t_critical = student_t_inverse_cdf(1.0 - alpha / 2.0, df_residual);
352
353 let mut predicted = Vec::with_capacity(n_new);
355 let mut lower_bound = Vec::with_capacity(n_new);
356 let mut upper_bound = Vec::with_capacity(n_new);
357 let mut se_pred = Vec::with_capacity(n_new);
358
359 for i in 0..n_new {
360 let mut y_hat = intercept;
362 for (j, coef) in coefficients.iter().enumerate() {
363 y_hat += coef * new_x[j][i];
364 }
365 predicted.push(y_hat);
366
367 let h = new_leverage[i];
368 let se = (mse * (1.0 + h)).sqrt();
369 se_pred.push(se);
370
371 let margin = t_critical * se;
372 lower_bound.push(y_hat - margin);
373 upper_bound.push(y_hat + margin);
374 }
375
376 Ok(PredictionIntervalOutput {
377 predicted,
378 lower_bound,
379 upper_bound,
380 se_pred,
381 leverage: new_leverage,
382 alpha,
383 df_residuals: df_residual,
384 })
385}
386
387pub fn ridge_prediction_intervals(
399 fit: &RidgeFit,
400 x_vars: &[Vec<f64>],
401 new_x: &[&[f64]],
402 alpha: f64,
403) -> Result<PredictionIntervalOutput> {
404 let n = x_vars.get(0).map_or(0, |v| v.len()) as f64;
405 let df_residual = n - 1.0 - fit.df;
407 compute_regularized_pi(fit.intercept, &fit.coefficients, fit.mse, df_residual, x_vars, new_x, alpha)
408}
409
410pub fn lasso_prediction_intervals(
422 fit: &LassoFit,
423 x_vars: &[Vec<f64>],
424 new_x: &[&[f64]],
425 alpha: f64,
426) -> Result<PredictionIntervalOutput> {
427 let n = x_vars.get(0).map_or(0, |v| v.len()) as f64;
428 let df_residual = n - 1.0 - fit.n_nonzero as f64;
429 compute_regularized_pi(fit.intercept, &fit.coefficients, fit.mse, df_residual, x_vars, new_x, alpha)
430}
431
432pub fn elastic_net_prediction_intervals(
444 fit: &ElasticNetFit,
445 x_vars: &[Vec<f64>],
446 new_x: &[&[f64]],
447 alpha: f64,
448) -> Result<PredictionIntervalOutput> {
449 let n = x_vars.get(0).map_or(0, |v| v.len()) as f64;
450 let df_residual = n - 1.0 - fit.n_nonzero as f64;
451 compute_regularized_pi(fit.intercept, &fit.coefficients, fit.mse, df_residual, x_vars, new_x, alpha)
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_prediction_intervals_simple() {
460 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1];
462 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
463
464 let names = vec!["Intercept".to_string(), "X1".to_string()];
465 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
466
467 let new_x1 = [6.0];
468 let result = compute_from_fit(&fit, &[x1], &[&new_x1], 0.05).unwrap();
469
470 assert_eq!(result.predicted.len(), 1);
471 assert!(result.lower_bound[0] < result.predicted[0]);
473 assert!(result.upper_bound[0] > result.predicted[0]);
474 assert!(result.se_pred[0] > 0.0);
475 assert!((result.alpha - 0.05).abs() < 1e-10);
476 }
477
478 #[test]
479 fn test_prediction_intervals_multiple_observations() {
480 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1];
481 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
482
483 let names = vec!["Intercept".to_string(), "X1".to_string()];
484 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
485
486 let new_x1 = [6.0, 7.0, 3.0];
488 let result = compute_from_fit(&fit, &[x1], &[&new_x1], 0.05).unwrap();
489
490 assert_eq!(result.predicted.len(), 3);
491 assert_eq!(result.lower_bound.len(), 3);
492 assert_eq!(result.upper_bound.len(), 3);
493
494 for i in 0..3 {
495 assert!(result.lower_bound[i] < result.predicted[i]);
496 assert!(result.upper_bound[i] > result.predicted[i]);
497 }
498 }
499
500 #[test]
501 fn test_prediction_intervals_multiple_predictors() {
502 let y = vec![3.0, 5.5, 7.0, 9.5, 11.0];
503 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
504 let x2 = vec![2.0, 4.0, 5.0, 6.0, 8.0];
505
506 let names = vec![
507 "Intercept".to_string(),
508 "X1".to_string(),
509 "X2".to_string(),
510 ];
511 let fit = ols_regression(&y, &[x1.clone(), x2.clone()], &names).unwrap();
512
513 let new_x1 = [6.0];
514 let new_x2 = [9.0];
515 let result =
516 compute_from_fit(&fit, &[x1, x2], &[&new_x1, &new_x2], 0.05).unwrap();
517
518 assert_eq!(result.predicted.len(), 1);
519 assert!(result.lower_bound[0] < result.predicted[0]);
520 assert!(result.upper_bound[0] > result.predicted[0]);
521 }
522
523 #[test]
524 fn test_wider_pi_for_lower_alpha() {
525 let y = vec![1.2, 2.1, 2.8, 4.1, 4.9];
527 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
528
529 let names = vec!["Intercept".to_string(), "X1".to_string()];
530 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
531
532 let new_x1 = [3.0];
533
534 let result_95 =
535 compute_from_fit(&fit, &[x1.clone()], &[&new_x1], 0.05).unwrap();
536 let result_99 =
537 compute_from_fit(&fit, &[x1], &[&new_x1], 0.01).unwrap();
538
539 let width_95 = result_95.upper_bound[0] - result_95.lower_bound[0];
540 let width_99 = result_99.upper_bound[0] - result_99.lower_bound[0];
541
542 assert!(width_99 > width_95);
544 }
545
546 #[test]
547 fn test_extrapolation_has_higher_leverage() {
548 let y = vec![1.2, 2.1, 2.8, 4.1, 4.9];
550 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
551
552 let names = vec!["Intercept".to_string(), "X1".to_string()];
553 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
554
555 let new_center = [3.0];
557 let new_extrap = [10.0];
558
559 let result_center =
560 compute_from_fit(&fit, &[x1.clone()], &[&new_center], 0.05).unwrap();
561 let result_extrap =
562 compute_from_fit(&fit, &[x1], &[&new_extrap], 0.05).unwrap();
563
564 assert!(result_extrap.leverage[0] > result_center.leverage[0]);
566 assert!(result_extrap.se_pred[0] > result_center.se_pred[0]);
567
568 let width_center = result_center.upper_bound[0] - result_center.lower_bound[0];
569 let width_extrap = result_extrap.upper_bound[0] - result_extrap.lower_bound[0];
570 assert!(width_extrap > width_center);
571 }
572
573 #[test]
574 fn test_prediction_intervals_convenience_function() {
575 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1];
576 let x_vars = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
577
578 let new_x1 = [6.0];
579 let result = prediction_intervals(&y, &x_vars, &[&new_x1], 0.05).unwrap();
580
581 assert_eq!(result.predicted.len(), 1);
582 assert!(result.lower_bound[0] < result.predicted[0]);
583 assert!(result.upper_bound[0] > result.predicted[0]);
584 }
585
586 #[test]
587 fn test_dimension_mismatch_error() {
588 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
589 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
590
591 let names = vec!["Intercept".to_string(), "X1".to_string()];
592 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
593
594 let new_x1 = [6.0];
596 let new_x2 = [7.0];
597 let result = compute_from_fit(&fit, &[x1], &[&new_x1, &new_x2], 0.05);
598 assert!(result.is_err());
599 }
600
601 #[test]
602 fn test_invalid_alpha() {
603 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
604 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
605
606 let names = vec!["Intercept".to_string(), "X1".to_string()];
607 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
608
609 let new_x1 = [6.0];
610 assert!(compute_from_fit(&fit, &[x1.clone()], &[&new_x1], 0.0).is_err());
611 assert!(compute_from_fit(&fit, &[x1.clone()], &[&new_x1], 1.0).is_err());
612 assert!(compute_from_fit(&fit, &[x1], &[&new_x1], -0.1).is_err());
613 }
614
615 #[test]
616 fn test_se_pred_includes_residual_variance() {
617 let y = vec![1.2, 2.1, 2.8, 4.1, 4.9];
620 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
621
622 let names = vec!["Intercept".to_string(), "X1".to_string()];
623 let fit = ols_regression(&y, &[x1.clone()], &names).unwrap();
624
625 let new_x1 = [3.0];
626 let result = compute_from_fit(&fit, &[x1], &[&new_x1], 0.05).unwrap();
627
628 let sqrt_mse = fit.mse.sqrt();
629 assert!(result.se_pred[0] >= sqrt_mse);
630 }
631
632 #[test]
637 fn test_ridge_prediction_intervals_simple() {
638 use crate::regularized::ridge::{ridge_fit, RidgeFitOptions};
639
640 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
641 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
642
643 let mut x_data = Vec::new();
645 for i in 0..y.len() {
646 x_data.push(1.0);
647 x_data.push(x1[i]);
648 }
649 let x = Matrix::new(y.len(), 2, x_data);
650
651 let options = RidgeFitOptions {
652 lambda: 0.1,
653 intercept: true,
654 standardize: true,
655 ..Default::default()
656 };
657 let fit = ridge_fit(&x, &y, &options).unwrap();
658
659 let new_x1 = [8.0];
660 let result = ridge_prediction_intervals(&fit, &[x1], &[&new_x1], 0.05).unwrap();
661
662 assert_eq!(result.predicted.len(), 1);
663 assert!(result.lower_bound[0] < result.predicted[0]);
664 assert!(result.upper_bound[0] > result.predicted[0]);
665 assert!(result.se_pred[0] > 0.0);
666 assert!((result.predicted[0] - 17.0).abs() < 2.0);
668 }
669
670 #[test]
671 fn test_lasso_prediction_intervals_basic() {
672 use crate::regularized::lasso::{lasso_fit, LassoFitOptions};
673
674 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
675 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
676
677 let mut x_data = Vec::new();
678 for i in 0..y.len() {
679 x_data.push(1.0);
680 x_data.push(x1[i]);
681 }
682 let x = Matrix::new(y.len(), 2, x_data);
683
684 let options = LassoFitOptions {
685 lambda: 0.01,
686 intercept: true,
687 standardize: true,
688 ..Default::default()
689 };
690 let fit = lasso_fit(&x, &y, &options).unwrap();
691
692 let new_x1 = [8.0];
693 let result = lasso_prediction_intervals(&fit, &[x1], &[&new_x1], 0.05).unwrap();
694
695 assert_eq!(result.predicted.len(), 1);
696 assert!(result.lower_bound[0] < result.predicted[0]);
697 assert!(result.upper_bound[0] > result.predicted[0]);
698 assert!(result.se_pred[0] > 0.0);
699 }
700
701 #[test]
702 fn test_elastic_net_prediction_intervals_basic() {
703 use crate::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
704
705 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
706 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
707
708 let mut x_data = Vec::new();
709 for i in 0..y.len() {
710 x_data.push(1.0);
711 x_data.push(x1[i]);
712 }
713 let x = Matrix::new(y.len(), 2, x_data);
714
715 let options = ElasticNetOptions {
716 lambda: 0.01,
717 alpha: 0.5,
718 intercept: true,
719 standardize: true,
720 ..Default::default()
721 };
722 let fit = elastic_net_fit(&x, &y, &options).unwrap();
723
724 let new_x1 = [8.0];
725 let result = elastic_net_prediction_intervals(&fit, &[x1], &[&new_x1], 0.05).unwrap();
726
727 assert_eq!(result.predicted.len(), 1);
728 assert!(result.lower_bound[0] < result.predicted[0]);
729 assert!(result.upper_bound[0] > result.predicted[0]);
730 }
731
732 #[test]
733 fn test_regularized_pi_extrapolation_wider() {
734 use crate::regularized::ridge::{ridge_fit, RidgeFitOptions};
735
736 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
737 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
738
739 let mut x_data = Vec::new();
740 for i in 0..y.len() {
741 x_data.push(1.0);
742 x_data.push(x1[i]);
743 }
744 let x = Matrix::new(y.len(), 2, x_data);
745
746 let options = RidgeFitOptions {
747 lambda: 0.1,
748 intercept: true,
749 standardize: true,
750 ..Default::default()
751 };
752 let fit = ridge_fit(&x, &y, &options).unwrap();
753
754 let center = [4.0];
756 let extrap = [20.0];
757
758 let result_center = ridge_prediction_intervals(&fit, &[x1.clone()], &[¢er], 0.05).unwrap();
759 let result_extrap = ridge_prediction_intervals(&fit, &[x1], &[&extrap], 0.05).unwrap();
760
761 let width_center = result_center.upper_bound[0] - result_center.lower_bound[0];
762 let width_extrap = result_extrap.upper_bound[0] - result_extrap.lower_bound[0];
763
764 assert!(width_extrap > width_center, "Extrapolation PI should be wider");
765 }
766
767 #[test]
768 fn test_regularized_pi_alpha_comparison() {
769 use crate::regularized::ridge::{ridge_fit, RidgeFitOptions};
770
771 let y = vec![3.1, 4.9, 7.2, 8.8, 11.1, 12.9, 15.0];
772 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
773
774 let mut x_data = Vec::new();
775 for i in 0..y.len() {
776 x_data.push(1.0);
777 x_data.push(x1[i]);
778 }
779 let x = Matrix::new(y.len(), 2, x_data);
780
781 let options = RidgeFitOptions {
782 lambda: 0.1,
783 intercept: true,
784 standardize: true,
785 ..Default::default()
786 };
787 let fit = ridge_fit(&x, &y, &options).unwrap();
788
789 let new_x1 = [8.0];
790 let result_95 = ridge_prediction_intervals(&fit, &[x1.clone()], &[&new_x1], 0.05).unwrap();
791 let result_99 = ridge_prediction_intervals(&fit, &[x1], &[&new_x1], 0.01).unwrap();
792
793 let width_95 = result_95.upper_bound[0] - result_95.lower_bound[0];
794 let width_99 = result_99.upper_bound[0] - result_99.lower_bound[0];
795
796 assert!(width_99 > width_95, "99% PI should be wider than 95% PI");
797 }
798}