1use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use nabled_linalg::lu::{self as lu, LUError};
7use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use num_complex::Complex64;
9
10#[derive(Debug, Clone)]
12pub struct NdarrayRegressionResult<T = f64> {
13 pub coefficients: Array1<T>,
15 pub fitted_values: Array1<T>,
17 pub residuals: Array1<T>,
19 pub r_squared: T,
21}
22
23#[derive(Debug, Clone)]
25pub struct NdarrayComplexRegressionResult {
26 pub coefficients: Array1<Complex64>,
28 pub fitted_values: Array1<Complex64>,
30 pub residuals: Array1<Complex64>,
32 pub r_squared: f64,
34}
35
36#[derive(Debug, Clone, PartialEq)]
38pub enum RegressionError {
39 EmptyInput,
41 DimensionMismatch,
43 Singular,
45 InvalidInput(String),
47}
48
49impl fmt::Display for RegressionError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 RegressionError::EmptyInput => write!(f, "Input arrays cannot be empty"),
53 RegressionError::DimensionMismatch => write!(f, "Input dimensions are incompatible"),
54 RegressionError::Singular => write!(f, "Regression system is singular"),
55 RegressionError::InvalidInput(message) => write!(f, "Invalid input: {message}"),
56 }
57 }
58}
59
60impl std::error::Error for RegressionError {}
61
62fn usize_to_scalar<T: NabledReal>(value: usize) -> T {
63 T::from_usize(value).unwrap_or(T::max_value())
64}
65
66fn map_lu_error(error: LUError) -> RegressionError {
67 match error {
68 LUError::EmptyMatrix => RegressionError::EmptyInput,
69 LUError::NotSquare => RegressionError::InvalidInput("normal matrix was not square".into()),
70 LUError::InvalidInput(message) => RegressionError::InvalidInput(message),
71 LUError::SingularMatrix | LUError::NumericalInstability => RegressionError::Singular,
72 }
73}
74
75#[cfg(feature = "lapack-provider")]
76fn linear_regression_impl<T>(
77 x: &ArrayView2<'_, T>,
78 y: &ArrayView1<'_, T>,
79 add_intercept: bool,
80) -> Result<NdarrayRegressionResult<T>, RegressionError>
81where
82 T: NabledReal + ndarray_linalg::Lapack,
83{
84 if x.is_empty() || y.is_empty() {
85 return Err(RegressionError::EmptyInput);
86 }
87 if x.nrows() != y.len() {
88 return Err(RegressionError::DimensionMismatch);
89 }
90
91 let maybe_design = if add_intercept {
92 let mut with_intercept = Array2::<T>::zeros((x.nrows(), x.ncols() + 1));
93 for row in 0..x.nrows() {
94 with_intercept[[row, 0]] = T::one();
95 for col in 0..x.ncols() {
96 with_intercept[[row, col + 1]] = x[[row, col]];
97 }
98 }
99 Some(with_intercept)
100 } else {
101 None
102 };
103 let design = maybe_design.as_ref().map_or_else(|| x.view(), |owned| owned.view());
104
105 let xt = design.t();
106 let normal_matrix = xt.dot(&design);
107 let normal_rhs = xt.dot(y);
108 let coefficients = lu::solve(&normal_matrix, &normal_rhs).map_err(map_lu_error)?;
109
110 let fitted_values = design.dot(&coefficients);
111 let residuals = y - &fitted_values;
112
113 let y_sum = y.iter().copied().fold(T::zero(), |acc, value| acc + value);
114 let y_mean = y_sum / usize_to_scalar::<T>(y.len());
115
116 let ss_total = y
117 .iter()
118 .copied()
119 .map(|value| {
120 let centered = value - y_mean;
121 centered * centered
122 })
123 .fold(T::zero(), |acc, value| acc + value);
124
125 let ss_residual = residuals
126 .iter()
127 .copied()
128 .map(|value| value * value)
129 .fold(T::zero(), |acc, value| acc + value);
130 let r_squared =
131 if ss_total <= T::epsilon() { T::one() } else { T::one() - ss_residual / ss_total };
132
133 Ok(NdarrayRegressionResult { coefficients, fitted_values, residuals, r_squared })
134}
135
136#[cfg(not(feature = "lapack-provider"))]
137fn linear_regression_impl<T>(
138 x: &ArrayView2<'_, T>,
139 y: &ArrayView1<'_, T>,
140 add_intercept: bool,
141) -> Result<NdarrayRegressionResult<T>, RegressionError>
142where
143 T: NabledReal,
144{
145 if x.is_empty() || y.is_empty() {
146 return Err(RegressionError::EmptyInput);
147 }
148 if x.nrows() != y.len() {
149 return Err(RegressionError::DimensionMismatch);
150 }
151
152 let maybe_design = if add_intercept {
153 let mut with_intercept = Array2::<T>::zeros((x.nrows(), x.ncols() + 1));
154 for row in 0..x.nrows() {
155 with_intercept[[row, 0]] = T::one();
156 for col in 0..x.ncols() {
157 with_intercept[[row, col + 1]] = x[[row, col]];
158 }
159 }
160 Some(with_intercept)
161 } else {
162 None
163 };
164 let design = maybe_design.as_ref().map_or_else(|| x.view(), |owned| owned.view());
165
166 let xt = design.t();
167 let normal_matrix = xt.dot(&design);
168 let normal_rhs = xt.dot(y);
169 let coefficients = lu::solve(&normal_matrix, &normal_rhs).map_err(map_lu_error)?;
170
171 let fitted_values = design.dot(&coefficients);
172 let residuals = y - &fitted_values;
173
174 let y_sum = y.iter().copied().fold(T::zero(), |acc, value| acc + value);
175 let y_mean = y_sum / usize_to_scalar::<T>(y.len());
176
177 let ss_total = y
178 .iter()
179 .copied()
180 .map(|value| {
181 let centered = value - y_mean;
182 centered * centered
183 })
184 .fold(T::zero(), |acc, value| acc + value);
185
186 let ss_residual = residuals
187 .iter()
188 .copied()
189 .map(|value| value * value)
190 .fold(T::zero(), |acc, value| acc + value);
191 let r_squared =
192 if ss_total <= T::epsilon() { T::one() } else { T::one() - ss_residual / ss_total };
193
194 Ok(NdarrayRegressionResult { coefficients, fitted_values, residuals, r_squared })
195}
196
197#[cfg(not(feature = "lapack-provider"))]
202pub fn linear_regression<T>(
203 x: &Array2<T>,
204 y: &Array1<T>,
205 add_intercept: bool,
206) -> Result<NdarrayRegressionResult<T>, RegressionError>
207where
208 T: NabledReal,
209{
210 linear_regression_impl(&x.view(), &y.view(), add_intercept)
211}
212
213#[cfg(feature = "lapack-provider")]
218pub fn linear_regression<T>(
219 x: &Array2<T>,
220 y: &Array1<T>,
221 add_intercept: bool,
222) -> Result<NdarrayRegressionResult<T>, RegressionError>
223where
224 T: NabledReal + ndarray_linalg::Lapack,
225{
226 linear_regression_impl(&x.view(), &y.view(), add_intercept)
227}
228
229#[cfg(not(feature = "lapack-provider"))]
234pub fn linear_regression_view<T>(
235 x: &ArrayView2<'_, T>,
236 y: &ArrayView1<'_, T>,
237 add_intercept: bool,
238) -> Result<NdarrayRegressionResult<T>, RegressionError>
239where
240 T: NabledReal,
241{
242 linear_regression_impl(x, y, add_intercept)
243}
244
245#[cfg(feature = "lapack-provider")]
250pub fn linear_regression_view<T>(
251 x: &ArrayView2<'_, T>,
252 y: &ArrayView1<'_, T>,
253 add_intercept: bool,
254) -> Result<NdarrayRegressionResult<T>, RegressionError>
255where
256 T: NabledReal + ndarray_linalg::Lapack,
257{
258 linear_regression_impl(x, y, add_intercept)
259}
260
261fn linear_regression_complex_impl(
262 x: &ArrayView2<'_, Complex64>,
263 y: &ArrayView1<'_, Complex64>,
264 add_intercept: bool,
265) -> Result<NdarrayComplexRegressionResult, RegressionError> {
266 if x.is_empty() || y.is_empty() {
267 return Err(RegressionError::EmptyInput);
268 }
269 if x.nrows() != y.len() {
270 return Err(RegressionError::DimensionMismatch);
271 }
272
273 let maybe_design = if add_intercept {
274 let mut with_intercept = Array2::<Complex64>::zeros((x.nrows(), x.ncols() + 1));
275 for row in 0..x.nrows() {
276 with_intercept[[row, 0]] = Complex64::new(1.0, 0.0);
277 for col in 0..x.ncols() {
278 with_intercept[[row, col + 1]] = x[[row, col]];
279 }
280 }
281 Some(with_intercept)
282 } else {
283 None
284 };
285 let design = maybe_design.as_ref().map_or_else(|| x.view(), |owned| owned.view());
286
287 let xh = design.t().mapv(|value| value.conj());
288 let normal_matrix = xh.dot(&design);
289 let normal_rhs = xh.dot(y);
290 let coefficients = lu::solve_complex(&normal_matrix, &normal_rhs).map_err(map_lu_error)?;
291
292 let fitted_values = design.dot(&coefficients);
293 let residuals = y - &fitted_values;
294
295 let y_mean = y.iter().copied().sum::<Complex64>() / usize_to_scalar::<f64>(y.len());
296 let ss_total = y.iter().map(|value| (*value - y_mean).norm_sqr()).sum::<f64>();
297 let ss_residual = residuals.iter().map(Complex64::norm_sqr).sum::<f64>();
298 let r_squared = if ss_total <= f64::EPSILON { 1.0 } else { 1.0 - ss_residual / ss_total };
299
300 Ok(NdarrayComplexRegressionResult { coefficients, fitted_values, residuals, r_squared })
301}
302
303pub fn linear_regression_complex(
308 x: &Array2<Complex64>,
309 y: &Array1<Complex64>,
310 add_intercept: bool,
311) -> Result<NdarrayComplexRegressionResult, RegressionError> {
312 linear_regression_complex_impl(&x.view(), &y.view(), add_intercept)
313}
314
315pub fn linear_regression_complex_view(
320 x: &ArrayView2<'_, Complex64>,
321 y: &ArrayView1<'_, Complex64>,
322 add_intercept: bool,
323) -> Result<NdarrayComplexRegressionResult, RegressionError> {
324 linear_regression_complex_impl(x, y, add_intercept)
325}
326
327#[cfg(test)]
328mod tests {
329 use ndarray::{Array1, Array2};
330 use num_complex::Complex64;
331
332 use super::*;
333
334 #[test]
335 fn linear_regression_fits_known_line() {
336 let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
337 let y = Array1::from_vec(vec![3.0_f64, 5.0, 7.0, 9.0]);
338 let result = linear_regression(&x, &y, true).unwrap();
339 assert!((result.coefficients[0] - 1.0_f64).abs() < 1e-8);
340 assert!((result.coefficients[1] - 2.0_f64).abs() < 1e-8);
341 assert!(result.r_squared > 0.999_999);
342 }
343
344 #[test]
345 fn regression_without_intercept_fits_origin_line() {
346 let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
347 let y = Array1::from_vec(vec![2.0_f64, 4.0, 6.0, 8.0]);
348 let result = linear_regression(&x, &y, false).unwrap();
349 assert_eq!(result.coefficients.len(), 1);
350 assert!((result.coefficients[0] - 2.0_f64).abs() < 1e-8);
351 }
352
353 #[test]
354 fn regression_rejects_dimension_mismatch() {
355 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
356 let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
357 let result = linear_regression(&x, &y, true);
358 assert!(matches!(result, Err(RegressionError::DimensionMismatch)));
359 }
360
361 #[test]
362 fn regression_rejects_empty_inputs() {
363 let x = Array2::<f64>::zeros((0, 0));
364 let y = Array1::<f64>::zeros(0);
365 let result = linear_regression(&x, &y, true);
366 assert!(matches!(result, Err(RegressionError::EmptyInput)));
367 }
368
369 #[test]
370 fn regression_reports_singular_system() {
371 let x = Array2::from_shape_vec((3, 1), vec![1.0, 1.0, 1.0]).unwrap();
372 let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
373 let result = linear_regression(&x, &y, true);
374 assert!(matches!(result, Err(RegressionError::Singular)));
375 }
376
377 #[test]
378 fn regression_constant_response_has_unit_r_squared() {
379 let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
380 let y = Array1::from_vec(vec![3.0_f64, 3.0, 3.0, 3.0]);
381 let result = linear_regression(&x, &y, true).unwrap();
382 assert!((result.r_squared - 1.0_f64).abs() < 1e-12);
383 assert_eq!(result.fitted_values.len(), y.len());
384 assert_eq!(result.residuals.len(), y.len());
385 }
386
387 #[test]
388 fn regression_view_matches_owned() {
389 let x = Array2::from_shape_vec((4, 1), vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
390 let y = Array1::from_vec(vec![3.0_f64, 5.0, 7.0, 9.0]);
391 let owned = linear_regression(&x, &y, true).unwrap();
392 let viewed = linear_regression_view(&x.view(), &y.view(), true).unwrap();
393
394 assert_eq!(owned.coefficients.len(), viewed.coefficients.len());
395 for i in 0..owned.coefficients.len() {
396 assert!((owned.coefficients[i] - viewed.coefficients[i]).abs() < 1e-12);
397 }
398 assert!((owned.r_squared - viewed.r_squared).abs() < 1e-12);
399 }
400
401 #[test]
402 fn complex_regression_fits_known_line() {
403 let x = Array2::from_shape_vec((4, 1), vec![
404 Complex64::new(1.0, 0.0),
405 Complex64::new(2.0, 0.0),
406 Complex64::new(3.0, 0.0),
407 Complex64::new(4.0, 0.0),
408 ])
409 .unwrap();
410 let y = Array1::from_vec(vec![
411 Complex64::new(3.0, 1.0),
412 Complex64::new(5.0, 1.0),
413 Complex64::new(7.0, 1.0),
414 Complex64::new(9.0, 1.0),
415 ]);
416
417 let result = linear_regression_complex(&x, &y, true).unwrap();
418 assert!((result.coefficients[0] - Complex64::new(1.0, 1.0)).norm() < 1e-8);
419 assert!((result.coefficients[1] - Complex64::new(2.0, 0.0)).norm() < 1e-8);
420 assert!(result.r_squared > 0.999_999);
421 }
422
423 #[test]
424 fn complex_regression_view_matches_owned() {
425 let x = Array2::from_shape_vec((4, 1), vec![
426 Complex64::new(1.0, 0.0),
427 Complex64::new(2.0, 0.0),
428 Complex64::new(3.0, 0.0),
429 Complex64::new(4.0, 0.0),
430 ])
431 .unwrap();
432 let y = Array1::from_vec(vec![
433 Complex64::new(3.0, 1.0),
434 Complex64::new(5.0, 1.0),
435 Complex64::new(7.0, 1.0),
436 Complex64::new(9.0, 1.0),
437 ]);
438
439 let owned = linear_regression_complex(&x, &y, true).unwrap();
440 let viewed = linear_regression_complex_view(&x.view(), &y.view(), true).unwrap();
441
442 assert_eq!(owned.coefficients.len(), viewed.coefficients.len());
443 for i in 0..owned.coefficients.len() {
444 assert!((owned.coefficients[i] - viewed.coefficients[i]).norm() < 1e-12);
445 }
446 assert!((owned.r_squared - viewed.r_squared).abs() < 1e-12);
447 }
448}