Skip to main content

nabled_ml/
regression.rs

1//! Linear regression over ndarray matrices.
2
3use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use nabled_linalg::lu::{self as lu, LUError};
7use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use num_complex::Complex64;
9
10/// Regression result for ndarray inputs.
11#[derive(Debug, Clone)]
12pub struct NdarrayRegressionResult<T = f64> {
13    /// Regression coefficients.
14    pub coefficients:  Array1<T>,
15    /// Model fitted values.
16    pub fitted_values: Array1<T>,
17    /// Residuals (`y - y_hat`).
18    pub residuals:     Array1<T>,
19    /// Coefficient of determination.
20    pub r_squared:     T,
21}
22
23/// Complex regression result for ndarray inputs.
24#[derive(Debug, Clone)]
25pub struct NdarrayComplexRegressionResult {
26    /// Regression coefficients.
27    pub coefficients:  Array1<Complex64>,
28    /// Model fitted values.
29    pub fitted_values: Array1<Complex64>,
30    /// Residuals (`y - y_hat`).
31    pub residuals:     Array1<Complex64>,
32    /// Coefficient of determination (real-valued).
33    pub r_squared:     f64,
34}
35
36/// Error type for regression operations.
37#[derive(Debug, Clone, PartialEq)]
38pub enum RegressionError {
39    /// Input arrays are empty.
40    EmptyInput,
41    /// Input dimensions are incompatible.
42    DimensionMismatch,
43    /// Regression problem is singular.
44    Singular,
45    /// Invalid user input.
46    InvalidInput(String),
47}
48
49impl fmt::Display for RegressionError {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            RegressionError::EmptyInput => write!(f, "Input arrays cannot be empty"),
53            RegressionError::DimensionMismatch => write!(f, "Input dimensions are incompatible"),
54            RegressionError::Singular => write!(f, "Regression system is singular"),
55            RegressionError::InvalidInput(message) => write!(f, "Invalid input: {message}"),
56        }
57    }
58}
59
60impl std::error::Error for RegressionError {}
61
62fn usize_to_scalar<T: NabledReal>(value: usize) -> T {
63    T::from_usize(value).unwrap_or(T::max_value())
64}
65
66fn map_lu_error(error: LUError) -> RegressionError {
67    match error {
68        LUError::EmptyMatrix => RegressionError::EmptyInput,
69        LUError::NotSquare => RegressionError::InvalidInput("normal matrix was not square".into()),
70        LUError::InvalidInput(message) => RegressionError::InvalidInput(message),
71        LUError::SingularMatrix | LUError::NumericalInstability => RegressionError::Singular,
72    }
73}
74
75#[cfg(feature = "lapack-provider")]
76fn linear_regression_impl<T>(
77    x: &ArrayView2<'_, T>,
78    y: &ArrayView1<'_, T>,
79    add_intercept: bool,
80) -> Result<NdarrayRegressionResult<T>, RegressionError>
81where
82    T: NabledReal + ndarray_linalg::Lapack,
83{
84    if x.is_empty() || y.is_empty() {
85        return Err(RegressionError::EmptyInput);
86    }
87    if x.nrows() != y.len() {
88        return Err(RegressionError::DimensionMismatch);
89    }
90
91    let maybe_design = if add_intercept {
92        let mut with_intercept = Array2::<T>::zeros((x.nrows(), x.ncols() + 1));
93        for row in 0..x.nrows() {
94            with_intercept[[row, 0]] = T::one();
95            for col in 0..x.ncols() {
96                with_intercept[[row, col + 1]] = x[[row, col]];
97            }
98        }
99        Some(with_intercept)
100    } else {
101        None
102    };
103    let design = maybe_design.as_ref().map_or_else(|| x.view(), |owned| owned.view());
104
105    let xt = design.t();
106    let normal_matrix = xt.dot(&design);
107    let normal_rhs = xt.dot(y);
108    let coefficients = lu::solve(&normal_matrix, &normal_rhs).map_err(map_lu_error)?;
109
110    let fitted_values = design.dot(&coefficients);
111    let residuals = y - &fitted_values;
112
113    let y_sum = y.iter().copied().fold(T::zero(), |acc, value| acc + value);
114    let y_mean = y_sum / usize_to_scalar::<T>(y.len());
115
116    let ss_total = y
117        .iter()
118        .copied()
119        .map(|value| {
120            let centered = value - y_mean;
121            centered * centered
122        })
123        .fold(T::zero(), |acc, value| acc + value);
124
125    let ss_residual = residuals
126        .iter()
127        .copied()
128        .map(|value| value * value)
129        .fold(T::zero(), |acc, value| acc + value);
130    let r_squared =
131        if ss_total <= T::epsilon() { T::one() } else { T::one() - ss_residual / ss_total };
132
133    Ok(NdarrayRegressionResult { coefficients, fitted_values, residuals, r_squared })
134}
135
136#[cfg(not(feature = "lapack-provider"))]
137fn linear_regression_impl<T>(
138    x: &ArrayView2<'_, T>,
139    y: &ArrayView1<'_, T>,
140    add_intercept: bool,
141) -> Result<NdarrayRegressionResult<T>, RegressionError>
142where
143    T: NabledReal,
144{
145    if x.is_empty() || y.is_empty() {
146        return Err(RegressionError::EmptyInput);
147    }
148    if x.nrows() != y.len() {
149        return Err(RegressionError::DimensionMismatch);
150    }
151
152    let maybe_design = if add_intercept {
153        let mut with_intercept = Array2::<T>::zeros((x.nrows(), x.ncols() + 1));
154        for row in 0..x.nrows() {
155            with_intercept[[row, 0]] = T::one();
156            for col in 0..x.ncols() {
157                with_intercept[[row, col + 1]] = x[[row, col]];
158            }
159        }
160        Some(with_intercept)
161    } else {
162        None
163    };
164    let design = maybe_design.as_ref().map_or_else(|| x.view(), |owned| owned.view());
165
166    let xt = design.t();
167    let normal_matrix = xt.dot(&design);
168    let normal_rhs = xt.dot(y);
169    let coefficients = lu::solve(&normal_matrix, &normal_rhs).map_err(map_lu_error)?;
170
171    let fitted_values = design.dot(&coefficients);
172    let residuals = y - &fitted_values;
173
174    let y_sum = y.iter().copied().fold(T::zero(), |acc, value| acc + value);
175    let y_mean = y_sum / usize_to_scalar::<T>(y.len());
176
177    let ss_total = y
178        .iter()
179        .copied()
180        .map(|value| {
181            let centered = value - y_mean;
182            centered * centered
183        })
184        .fold(T::zero(), |acc, value| acc + value);
185
186    let ss_residual = residuals
187        .iter()
188        .copied()
189        .map(|value| value * value)
190        .fold(T::zero(), |acc, value| acc + value);
191    let r_squared =
192        if ss_total <= T::epsilon() { T::one() } else { T::one() - ss_residual / ss_total };
193
194    Ok(NdarrayRegressionResult { coefficients, fitted_values, residuals, r_squared })
195}
196
197/// Solve linear regression with optional intercept.
198///
199/// # Errors
200/// Returns an error for invalid dimensions or singular design matrix.
201#[cfg(not(feature = "lapack-provider"))]
202pub fn linear_regression<T>(
203    x: &Array2<T>,
204    y: &Array1<T>,
205    add_intercept: bool,
206) -> Result<NdarrayRegressionResult<T>, RegressionError>
207where
208    T: NabledReal,
209{
210    linear_regression_impl(&x.view(), &y.view(), add_intercept)
211}
212
213/// Solve linear regression with optional intercept.
214///
215/// # Errors
216/// Returns an error for invalid dimensions or singular design matrix.
217#[cfg(feature = "lapack-provider")]
218pub fn linear_regression<T>(
219    x: &Array2<T>,
220    y: &Array1<T>,
221    add_intercept: bool,
222) -> Result<NdarrayRegressionResult<T>, RegressionError>
223where
224    T: NabledReal + ndarray_linalg::Lapack,
225{
226    linear_regression_impl(&x.view(), &y.view(), add_intercept)
227}
228
229/// Solve linear regression with optional intercept from matrix/vector views.
230///
231/// # Errors
232/// Returns an error for invalid dimensions or singular design matrix.
233#[cfg(not(feature = "lapack-provider"))]
234pub fn linear_regression_view<T>(
235    x: &ArrayView2<'_, T>,
236    y: &ArrayView1<'_, T>,
237    add_intercept: bool,
238) -> Result<NdarrayRegressionResult<T>, RegressionError>
239where
240    T: NabledReal,
241{
242    linear_regression_impl(x, y, add_intercept)
243}
244
245/// Solve linear regression with optional intercept from matrix/vector views.
246///
247/// # Errors
248/// Returns an error for invalid dimensions or singular design matrix.
249#[cfg(feature = "lapack-provider")]
250pub fn linear_regression_view<T>(
251    x: &ArrayView2<'_, T>,
252    y: &ArrayView1<'_, T>,
253    add_intercept: bool,
254) -> Result<NdarrayRegressionResult<T>, RegressionError>
255where
256    T: NabledReal + ndarray_linalg::Lapack,
257{
258    linear_regression_impl(x, y, add_intercept)
259}
260
261fn linear_regression_complex_impl(
262    x: &ArrayView2<'_, Complex64>,
263    y: &ArrayView1<'_, Complex64>,
264    add_intercept: bool,
265) -> Result<NdarrayComplexRegressionResult, RegressionError> {
266    if x.is_empty() || y.is_empty() {
267        return Err(RegressionError::EmptyInput);
268    }
269    if x.nrows() != y.len() {
270        return Err(RegressionError::DimensionMismatch);
271    }
272
273    let maybe_design = if add_intercept {
274        let mut with_intercept = Array2::<Complex64>::zeros((x.nrows(), x.ncols() + 1));
275        for row in 0..x.nrows() {
276            with_intercept[[row, 0]] = Complex64::new(1.0, 0.0);
277            for col in 0..x.ncols() {
278                with_intercept[[row, col + 1]] = x[[row, col]];
279            }
280        }
281        Some(with_intercept)
282    } else {
283        None
284    };
285    let design = maybe_design.as_ref().map_or_else(|| x.view(), |owned| owned.view());
286
287    let xh = design.t().mapv(|value| value.conj());
288    let normal_matrix = xh.dot(&design);
289    let normal_rhs = xh.dot(y);
290    let coefficients = lu::solve_complex(&normal_matrix, &normal_rhs).map_err(map_lu_error)?;
291
292    let fitted_values = design.dot(&coefficients);
293    let residuals = y - &fitted_values;
294
295    let y_mean = y.iter().copied().sum::<Complex64>() / usize_to_scalar::<f64>(y.len());
296    let ss_total = y.iter().map(|value| (*value - y_mean).norm_sqr()).sum::<f64>();
297    let ss_residual = residuals.iter().map(Complex64::norm_sqr).sum::<f64>();
298    let r_squared = if ss_total <= f64::EPSILON { 1.0 } else { 1.0 - ss_residual / ss_total };
299
300    Ok(NdarrayComplexRegressionResult { coefficients, fitted_values, residuals, r_squared })
301}
302
303/// Solve complex linear regression with optional intercept.
304///
305/// # Errors
306/// Returns an error for invalid dimensions or singular design matrix.
307pub fn linear_regression_complex(
308    x: &Array2<Complex64>,
309    y: &Array1<Complex64>,
310    add_intercept: bool,
311) -> Result<NdarrayComplexRegressionResult, RegressionError> {
312    linear_regression_complex_impl(&x.view(), &y.view(), add_intercept)
313}
314
315/// Solve complex linear regression with optional intercept from matrix/vector views.
316///
317/// # Errors
318/// Returns an error for invalid dimensions or singular design matrix.
319pub fn linear_regression_complex_view(
320    x: &ArrayView2<'_, Complex64>,
321    y: &ArrayView1<'_, Complex64>,
322    add_intercept: bool,
323) -> Result<NdarrayComplexRegressionResult, RegressionError> {
324    linear_regression_complex_impl(x, y, add_intercept)
325}
326
327#[cfg(test)]
328mod tests {
329    use ndarray::{Array1, Array2};
330    use num_complex::Complex64;
331
332    use super::*;
333
334    #[test]
335    fn linear_regression_fits_known_line() {
336        let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
337        let y = Array1::from_vec(vec![3.0_f64, 5.0, 7.0, 9.0]);
338        let result = linear_regression(&x, &y, true).unwrap();
339        assert!((result.coefficients[0] - 1.0_f64).abs() < 1e-8);
340        assert!((result.coefficients[1] - 2.0_f64).abs() < 1e-8);
341        assert!(result.r_squared > 0.999_999);
342    }
343
344    #[test]
345    fn regression_without_intercept_fits_origin_line() {
346        let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
347        let y = Array1::from_vec(vec![2.0_f64, 4.0, 6.0, 8.0]);
348        let result = linear_regression(&x, &y, false).unwrap();
349        assert_eq!(result.coefficients.len(), 1);
350        assert!((result.coefficients[0] - 2.0_f64).abs() < 1e-8);
351    }
352
353    #[test]
354    fn regression_rejects_dimension_mismatch() {
355        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
356        let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
357        let result = linear_regression(&x, &y, true);
358        assert!(matches!(result, Err(RegressionError::DimensionMismatch)));
359    }
360
361    #[test]
362    fn regression_rejects_empty_inputs() {
363        let x = Array2::<f64>::zeros((0, 0));
364        let y = Array1::<f64>::zeros(0);
365        let result = linear_regression(&x, &y, true);
366        assert!(matches!(result, Err(RegressionError::EmptyInput)));
367    }
368
369    #[test]
370    fn regression_reports_singular_system() {
371        let x = Array2::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
372        let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
373        let result = linear_regression(&x, &y, true);
374        assert!(matches!(result, Err(RegressionError::Singular)));
375    }
376
377    #[test]
378    fn regression_constant_response_has_unit_r_squared() {
379        let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
380        let y = Array1::from_vec(vec![3.0_f64, 3.0, 3.0, 3.0]);
381        let result = linear_regression(&x, &y, true).unwrap();
382        assert!((result.r_squared - 1.0_f64).abs() < 1e-12);
383        assert_eq!(result.fitted_values.len(), y.len());
384        assert_eq!(result.residuals.len(), y.len());
385    }
386
387    #[test]
388    fn regression_view_matches_owned() {
389        let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
390        let y = Array1::from_vec(vec![3.0_f64, 5.0, 7.0, 9.0]);
391        let owned = linear_regression(&x, &y, true).unwrap();
392        let viewed = linear_regression_view(&x.view(), &y.view(), true).unwrap();
393
394        assert_eq!(owned.coefficients.len(), viewed.coefficients.len());
395        for i in 0..owned.coefficients.len() {
396            assert!((owned.coefficients[i] - viewed.coefficients[i]).abs() < 1e-12);
397        }
398        assert!((owned.r_squared - viewed.r_squared).abs() < 1e-12);
399    }
400
401    #[test]
402    fn complex_regression_fits_known_line() {
403        let x = Array2::from_shape_vec((4, 1), vec![
404            Complex64::new(1.0, 0.0),
405            Complex64::new(2.0, 0.0),
406            Complex64::new(3.0, 0.0),
407            Complex64::new(4.0, 0.0),
408        ])
409        .unwrap();
410        let y = Array1::from_vec(vec![
411            Complex64::new(3.0, 1.0),
412            Complex64::new(5.0, 1.0),
413            Complex64::new(7.0, 1.0),
414            Complex64::new(9.0, 1.0),
415        ]);
416
417        let result = linear_regression_complex(&x, &y, true).unwrap();
418        assert!((result.coefficients[0] - Complex64::new(1.0, 1.0)).norm() < 1e-8);
419        assert!((result.coefficients[1] - Complex64::new(2.0, 0.0)).norm() < 1e-8);
420        assert!(result.r_squared > 0.999_999);
421    }
422
423    #[test]
424    fn complex_regression_view_matches_owned() {
425        let x = Array2::from_shape_vec((4, 1), vec![
426            Complex64::new(1.0, 0.0),
427            Complex64::new(2.0, 0.0),
428            Complex64::new(3.0, 0.0),
429            Complex64::new(4.0, 0.0),
430        ])
431        .unwrap();
432        let y = Array1::from_vec(vec![
433            Complex64::new(3.0, 1.0),
434            Complex64::new(5.0, 1.0),
435            Complex64::new(7.0, 1.0),
436            Complex64::new(9.0, 1.0),
437        ]);
438
439        let owned = linear_regression_complex(&x, &y, true).unwrap();
440        let viewed = linear_regression_complex_view(&x.view(), &y.view(), true).unwrap();
441
442        assert_eq!(owned.coefficients.len(), viewed.coefficients.len());
443        for i in 0..owned.coefficients.len() {
444            assert!((owned.coefficients[i] - viewed.coefficients[i]).norm() < 1e-12);
445        }
446        assert!((owned.r_squared - viewed.r_squared).abs() < 1e-12);
447    }
448}