1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//! Core traits for regression estimators.
use crate::core::{IntervalType, PredictionResult, RegressionResult};
use faer::{Col, Mat};
use thiserror::Error;
/// Errors that can occur during regression fitting.
#[derive(Debug, Error)]
pub enum RegressionError {
#[error("dimension mismatch: X has {x_rows} rows but y has {y_len} elements")]
DimensionMismatch { x_rows: usize, y_len: usize },
#[error("insufficient observations: need at least {needed}, got {got}")]
InsufficientObservations { needed: usize, got: usize },
#[error("matrix is singular or nearly singular")]
SingularMatrix,
#[error("all features are constant")]
AllFeaturesConstant,
#[error("invalid options: {0}")]
InvalidOptions(#[from] crate::core::OptionsError),
#[error("convergence failed after {iterations} iterations")]
ConvergenceFailed { iterations: usize },
#[error("invalid weights: all weights must be non-negative")]
InvalidWeights,
#[error("numerical error: {0}")]
NumericalError(String),
}
/// A regression estimator that can be fit to data.
///
/// This trait follows the sklearn pattern where fitting returns a fitted model
/// that can then make predictions.
pub trait Regressor {
/// The type of the fitted model.
type Fitted: FittedRegressor;
/// Fit the model to the data.
///
/// # Arguments
/// * `x` - Design matrix of shape (n_samples, n_features)
/// * `y` - Target vector of length n_samples
///
/// # Returns
/// A fitted model that can make predictions.
fn fit(&self, x: &Mat<f64>, y: &Col<f64>) -> Result<Self::Fitted, RegressionError>;
}
/// A fitted regression model that can make predictions.
pub trait FittedRegressor {
/// Make predictions on new data.
///
/// # Arguments
/// * `x` - Design matrix of shape (n_samples, n_features)
///
/// # Returns
/// Predicted values vector of length n_samples.
fn predict(&self, x: &Mat<f64>) -> Col<f64>;
/// Access the regression results (coefficients, statistics, etc.).
fn result(&self) -> &RegressionResult;
/// Get the coefficients (convenience method).
fn coefficients(&self) -> &Col<f64> {
&self.result().coefficients
}
/// Get the intercept (convenience method).
fn intercept(&self) -> Option<f64> {
self.result().intercept
}
/// Get R² (convenience method).
fn r_squared(&self) -> f64 {
self.result().r_squared
}
/// Calculate the score (R²) on new data.
///
/// # Arguments
/// * `x` - Design matrix
/// * `y` - True target values
fn score(&self, x: &Mat<f64>, y: &Col<f64>) -> f64 {
let predictions = self.predict(x);
let n = y.nrows();
// Calculate mean of y
let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
// Calculate TSS and RSS
let tss: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
let rss: f64 = y
.iter()
.zip(predictions.iter())
.map(|(&yi, &pi)| (yi - pi).powi(2))
.sum();
if tss == 0.0 {
// Perfect prediction of constant target
if rss == 0.0 {
1.0
} else {
0.0
}
} else {
1.0 - rss / tss
}
}
/// Make predictions with confidence or prediction intervals.
///
/// This method follows R's `predict(..., interval = "confidence" | "prediction")` API.
///
/// # Arguments
/// * `x` - Design matrix of shape (n_samples, n_features)
/// * `interval` - Type of interval: `None` for point predictions only,
/// `Some(IntervalType::Confidence)` for confidence intervals on the mean response,
/// `Some(IntervalType::Prediction)` for prediction intervals on new observations
/// * `level` - Confidence level (e.g., 0.95 for 95% intervals)
///
/// # Returns
/// A `PredictionResult` containing:
/// - `fit`: Point predictions
/// - `lower`: Lower bounds (same as fit if interval is None)
/// - `upper`: Upper bounds (same as fit if interval is None)
/// - `se`: Standard errors (zeros if interval is None)
///
/// # Example
/// ```ignore
/// use statistics::prelude::*;
///
/// let fitted = OlsRegressor::builder().build().fit(&x, &y)?;
///
/// // Point predictions only
/// let pred = fitted.predict_with_interval(&x_new, None, 0.95);
///
/// // 95% prediction intervals
/// let pred = fitted.predict_with_interval(&x_new, Some(IntervalType::Prediction), 0.95);
/// println!("Lower: {:?}, Upper: {:?}", pred.lower, pred.upper);
///
/// // 99% confidence intervals for mean response
/// let pred = fitted.predict_with_interval(&x_new, Some(IntervalType::Confidence), 0.99);
/// ```
fn predict_with_interval(
&self,
x: &Mat<f64>,
interval: Option<IntervalType>,
level: f64,
) -> PredictionResult;
}