1use crate::accel;
5use crate::dataset::Dataset;
6use crate::error::{Result, ScryLearnError};
7use crate::sparse::{CscMatrix, CsrMatrix};
8
9#[derive(Clone, Debug, Default)]
11#[non_exhaustive]
12pub enum LinRegSolver {
13 Normal,
15 Qr,
17 Svd,
19 #[default]
21 Auto,
22}
23
24#[derive(Clone)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[non_exhaustive]
37pub struct LinearRegression {
38 coefficients: Vec<f64>,
40 intercept: f64,
42 alpha: f64,
44 #[cfg_attr(feature = "serde", serde(skip))]
46 solver: LinRegSolver,
47 fitted: bool,
48 #[cfg_attr(feature = "serde", serde(default))]
49 _schema_version: u32,
50}
51
52impl LinearRegression {
53 pub fn new() -> Self {
55 Self {
56 coefficients: Vec::new(),
57 intercept: 0.0,
58 alpha: 0.0,
59 solver: LinRegSolver::Auto,
60 fitted: false,
61 _schema_version: crate::version::SCHEMA_VERSION,
62 }
63 }
64
65 pub fn alpha(mut self, a: f64) -> Self {
67 self.alpha = a;
68 self
69 }
70
71 pub fn solver(mut self, s: LinRegSolver) -> Self {
73 self.solver = s;
74 self
75 }
76
77 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
79 data.validate_finite()?;
80 if let Some(csc) = data.sparse_csc() {
81 return self.fit_sparse(csc, &data.target);
82 }
83 let n = data.n_samples();
84 let m = data.n_features();
85 if n == 0 {
86 return Err(ScryLearnError::EmptyDataset);
87 }
88
89 match &self.solver {
90 LinRegSolver::Normal => self.fit_normal(data),
91 LinRegSolver::Qr => self.fit_qr(data),
92 LinRegSolver::Svd => self.fit_svd(data),
93 LinRegSolver::Auto => {
94 if m >= n {
97 return self.fit_svd(data);
98 }
99 match self.fit_normal(data) {
100 Ok(()) => Ok(()),
101 Err(_) => self.fit_svd(data),
102 }
103 }
104 }
105 }
106
107 fn fit_normal(&mut self, data: &Dataset) -> Result<()> {
109 let n = data.n_samples();
110 let m = data.n_features();
111 let dim = m + 1;
112
113 let backend = accel::auto();
114 let mat = data.matrix();
115 let (mut xtx, mut xty) = backend.xtx_xty_contiguous(mat.as_slice(), &data.target, n, m);
116
117 for j in 1..dim {
118 xtx[j * dim + j] += self.alpha;
119 }
120
121 let beta = solve_linear(dim, &mut xtx, &mut xty)?;
122
123 self.intercept = beta[0];
124 self.coefficients = beta[1..].to_vec();
125 self.fitted = true;
126 Ok(())
127 }
128
129 fn build_augmented(data: &Dataset) -> (Vec<f64>, usize, usize) {
131 let n = data.n_samples();
132 let m = data.n_features();
133 let dim = m + 1;
134 let mat = data.matrix();
135 let mut x = vec![0.0; n * dim];
136 for i in 0..n {
137 x[i] = 1.0;
138 }
139 for j in 0..m {
140 let offset = (j + 1) * n;
141 x[offset..offset + n].copy_from_slice(mat.col(j));
142 }
143 (x, n, dim)
144 }
145
146 fn build_regularized(data: &Dataset, alpha: f64) -> (Vec<f64>, Vec<f64>, usize, usize) {
148 let n = data.n_samples();
149 let m = data.n_features();
150 let dim = m + 1;
151 let mat = data.matrix();
152 let sqrt_a = alpha.sqrt();
153 let aug_rows = n + m;
154 let mut x_aug = vec![0.0; aug_rows * dim];
155 let mut y_aug = vec![0.0; aug_rows];
156
157 for i in 0..n {
158 x_aug[i] = 1.0;
159 }
160 for j in 0..m {
161 let offset = (j + 1) * aug_rows;
162 x_aug[offset..offset + n].copy_from_slice(mat.col(j));
163 }
164 y_aug[..n].copy_from_slice(&data.target);
165
166 for j in 0..m {
167 x_aug[(j + 1) * aug_rows + n + j] = sqrt_a;
168 }
169
170 (x_aug, y_aug, aug_rows, dim)
171 }
172
173 fn fit_qr(&mut self, data: &Dataset) -> Result<()> {
175 if self.alpha > 0.0 {
176 let (x_aug, y_aug, aug_rows, dim) = Self::build_regularized(data, self.alpha);
177 let beta = super::qr::qr_solve(&x_aug, &y_aug, aug_rows, dim)?;
178 self.intercept = beta[0];
179 self.coefficients = beta[1..].to_vec();
180 } else {
181 let (x, n, dim) = Self::build_augmented(data);
182 let beta = super::qr::qr_solve(&x, &data.target, n, dim)?;
183 self.intercept = beta[0];
184 self.coefficients = beta[1..].to_vec();
185 }
186
187 self.fitted = true;
188 Ok(())
189 }
190
191 fn fit_svd(&mut self, data: &Dataset) -> Result<()> {
193 if self.alpha > 0.0 {
194 let (x_aug, y_aug, aug_rows, dim) = Self::build_regularized(data, self.alpha);
195 let result = super::svd::svd_solve(&x_aug, &y_aug, aug_rows, dim)?;
196 self.intercept = result.coefficients[0];
197 self.coefficients = result.coefficients[1..].to_vec();
198 } else {
199 let (x, n, dim) = Self::build_augmented(data);
200 let result = super::svd::svd_solve(&x, &data.target, n, dim)?;
201 self.intercept = result.coefficients[0];
202 self.coefficients = result.coefficients[1..].to_vec();
203 }
204
205 self.fitted = true;
206 Ok(())
207 }
208
209 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
211 crate::version::check_schema_version(self._schema_version)?;
212 if !self.fitted {
213 return Err(ScryLearnError::NotFitted);
214 }
215 Ok(features
216 .iter()
217 .map(|row| {
218 let mut y = self.intercept;
219 for (j, &coeff) in self.coefficients.iter().enumerate() {
220 if j < row.len() {
221 y += coeff * row[j];
222 }
223 }
224 y
225 })
226 .collect())
227 }
228
229 pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
233 let n = features.n_rows();
234 let m = features.n_cols();
235 if n == 0 {
236 return Err(ScryLearnError::EmptyDataset);
237 }
238 if target.len() != n {
239 return Err(ScryLearnError::InvalidParameter(format!(
240 "target length {} != n_rows {}",
241 target.len(),
242 n
243 )));
244 }
245
246 let dim = m + 1; let mut xtx = vec![0.0; dim * dim];
250 let mut xty = vec![0.0; dim];
251
252 xtx[0] = n as f64;
254
255 xty[0] = target.iter().sum();
257
258 for j in 0..m {
260 let col = features.col(j);
261 let sum: f64 = col.iter().map(|(_, v)| v).sum();
262 xtx[j + 1] = sum;
263 xtx[(j + 1) * dim] = sum;
264
265 let mut dot = 0.0;
267 for (row_idx, val) in col.iter() {
268 dot += val * target[row_idx];
269 }
270 xty[j + 1] = dot;
271 }
272
273 let mut dense_col = vec![0.0; n];
277 for j in 0..m {
278 for (row_idx, val) in features.col(j).iter() {
280 dense_col[row_idx] = val;
281 }
282
283 let mut diag = 0.0;
285 for (row_idx, val) in features.col(j).iter() {
286 diag += val * dense_col[row_idx];
287 }
288 xtx[(j + 1) * dim + j + 1] = diag;
289
290 for i in 0..j {
292 let mut dot = 0.0;
293 for (row_idx, val) in features.col(i).iter() {
294 dot += val * dense_col[row_idx];
295 }
296 xtx[(i + 1) * dim + j + 1] = dot;
297 xtx[(j + 1) * dim + i + 1] = dot;
298 }
299
300 for (row_idx, _) in features.col(j).iter() {
302 dense_col[row_idx] = 0.0;
303 }
304 }
305
306 for j in 1..dim {
308 xtx[j * dim + j] += self.alpha;
309 }
310
311 let beta = solve_linear(dim, &mut xtx, &mut xty)?;
312 self.intercept = beta[0];
313 self.coefficients = beta[1..].to_vec();
314 self.fitted = true;
315 Ok(())
316 }
317
318 pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
320 if !self.fitted {
321 return Err(ScryLearnError::NotFitted);
322 }
323 Ok((0..features.n_rows())
324 .map(|i| {
325 let mut y = self.intercept;
326 for (col, val) in features.row(i).iter() {
327 if col < self.coefficients.len() {
328 y += self.coefficients[col] * val;
329 }
330 }
331 y
332 })
333 .collect())
334 }
335
336 pub fn coefficients(&self) -> &[f64] {
338 &self.coefficients
339 }
340
341 pub fn intercept(&self) -> f64 {
343 self.intercept
344 }
345}
346
347impl Default for LinearRegression {
348 fn default() -> Self {
349 Self::new()
350 }
351}
352
353fn solve_linear(n: usize, a: &mut [f64], b: &mut [f64]) -> Result<Vec<f64>> {
355 for col in 0..n {
356 let mut max_row = col;
357 let mut max_val = a[col * n + col].abs();
358 for row in (col + 1)..n {
359 let val = a[row * n + col].abs();
360 if val > max_val {
361 max_val = val;
362 max_row = row;
363 }
364 }
365 if max_val < crate::constants::SINGULAR_THRESHOLD {
366 return Err(ScryLearnError::InvalidParameter(
367 "singular matrix — features may be linearly dependent".into(),
368 ));
369 }
370
371 if max_row != col {
372 for k in 0..n {
373 a.swap(col * n + k, max_row * n + k);
374 }
375 b.swap(col, max_row);
376 }
377
378 let pivot = a[col * n + col];
379 for k in col..n {
380 a[col * n + k] /= pivot;
381 }
382 b[col] /= pivot;
383
384 for row in 0..n {
385 if row == col {
386 continue;
387 }
388 let factor = a[row * n + col];
389 for k in col..n {
390 a[row * n + k] -= factor * a[col * n + k];
391 }
392 b[row] -= factor * b[col];
393 }
394 }
395
396 Ok(b.to_vec())
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_linear_regression_y_equals_x() {
405 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
406 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
407 let data = Dataset::new(features, target, vec!["x".into()], "y");
408
409 let mut lr = LinearRegression::new();
410 lr.fit(&data).unwrap();
411
412 assert!(
413 (lr.coefficients()[0] - 2.0).abs() < 1e-6,
414 "coefficient should be ~2.0, got {}",
415 lr.coefficients()[0]
416 );
417 assert!(
418 (lr.intercept() - 3.0).abs() < 1e-6,
419 "intercept should be ~3.0, got {}",
420 lr.intercept()
421 );
422 }
423
424 #[test]
425 fn test_ridge_regression() {
426 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
427 let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
428 let data = Dataset::new(features, target, vec!["x".into()], "y");
429
430 let mut lr = LinearRegression::new().alpha(1.0);
431 lr.fit(&data).unwrap();
432
433 assert!(lr.coefficients()[0] < 2.0);
434 assert!(lr.coefficients()[0] > 1.0);
435 }
436
437 #[test]
438 fn test_svd_solver_matches_normal() {
439 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
440 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
441 let data = Dataset::new(features, target, vec!["x".into()], "y");
442
443 let mut lr_normal = LinearRegression::new();
444 lr_normal.fit(&data).unwrap();
445
446 let mut lr_svd = LinearRegression::new().solver(LinRegSolver::Svd);
447 lr_svd.fit(&data).unwrap();
448
449 assert!(
450 (lr_normal.coefficients()[0] - lr_svd.coefficients()[0]).abs() < 1e-6,
451 "Normal={} vs SVD={}",
452 lr_normal.coefficients()[0],
453 lr_svd.coefficients()[0]
454 );
455 assert!(
456 (lr_normal.intercept() - lr_svd.intercept()).abs() < 1e-6,
457 "Normal intercept={} vs SVD={}",
458 lr_normal.intercept(),
459 lr_svd.intercept()
460 );
461 }
462
463 #[test]
464 fn test_qr_solver_matches_normal() {
465 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
466 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
467 let data = Dataset::new(features, target, vec!["x".into()], "y");
468
469 let mut lr_normal = LinearRegression::new();
470 lr_normal.fit(&data).unwrap();
471
472 let mut lr_qr = LinearRegression::new().solver(LinRegSolver::Qr);
473 lr_qr.fit(&data).unwrap();
474
475 assert!(
476 (lr_normal.coefficients()[0] - lr_qr.coefficients()[0]).abs() < 1e-6,
477 "Normal={} vs QR={}",
478 lr_normal.coefficients()[0],
479 lr_qr.coefficients()[0]
480 );
481 assert!(
482 (lr_normal.intercept() - lr_qr.intercept()).abs() < 1e-6,
483 "Normal intercept={} vs QR={}",
484 lr_normal.intercept(),
485 lr_qr.intercept()
486 );
487 }
488
489 #[test]
490 fn test_svd_handles_ill_conditioned() {
491 let n = 5;
492 let mut features = vec![vec![0.0; n]; n];
493 for j in 0..n {
494 for i in 0..n {
495 features[j][i] = 1.0 / (i + j + 1) as f64;
496 }
497 }
498 let true_beta = vec![1.0; n];
499 let target: Vec<f64> = (0..n)
500 .map(|i| (0..n).map(|j| features[j][i] * true_beta[j]).sum())
501 .collect();
502 let names: Vec<String> = (0..n).map(|j| format!("f{j}")).collect();
503 let data = Dataset::new(features, target, names, "y");
504
505 let mut lr = LinearRegression::new().solver(LinRegSolver::Svd);
506 lr.fit(&data).unwrap();
507
508 for (i, &c) in lr.coefficients().iter().enumerate() {
509 assert!(
510 (c - 1.0).abs() < 0.5,
511 "SVD Hilbert coeff[{}] = {}, expected ~1.0",
512 i,
513 c
514 );
515 }
516 }
517
518 #[test]
519 fn test_ridge_with_svd() {
520 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
521 let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
522 let data = Dataset::new(features, target, vec!["x".into()], "y");
523
524 let mut lr_normal = LinearRegression::new().alpha(1.0);
525 lr_normal.fit(&data).unwrap();
526
527 let mut lr_svd = LinearRegression::new().alpha(1.0).solver(LinRegSolver::Svd);
528 lr_svd.fit(&data).unwrap();
529
530 assert!(
531 (lr_normal.coefficients()[0] - lr_svd.coefficients()[0]).abs() < 0.1,
532 "Ridge Normal={} vs SVD={}",
533 lr_normal.coefficients()[0],
534 lr_svd.coefficients()[0]
535 );
536 }
537
538 #[test]
539 fn test_auto_solver() {
540 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
541 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
542 let data = Dataset::new(features, target, vec!["x".into()], "y");
543
544 let mut lr = LinearRegression::new().solver(LinRegSolver::Auto);
545 lr.fit(&data).unwrap();
546
547 assert!(
548 (lr.coefficients()[0] - 2.0).abs() < 1e-6,
549 "Auto solver coefficient should be ~2.0, got {}",
550 lr.coefficients()[0]
551 );
552 }
553
554 #[test]
555 fn test_sparse_fit_matches_dense() {
556 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
557 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
558 let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "y");
559
560 let mut lr_dense = LinearRegression::new();
561 lr_dense.fit(&data).unwrap();
562
563 let csc = CscMatrix::from_dense(&features);
564 let mut lr_sparse = LinearRegression::new();
565 lr_sparse.fit_sparse(&csc, &target).unwrap();
566
567 assert!(
568 (lr_dense.coefficients()[0] - lr_sparse.coefficients()[0]).abs() < 1e-6,
569 "Dense={} vs Sparse={}",
570 lr_dense.coefficients()[0],
571 lr_sparse.coefficients()[0]
572 );
573 assert!(
574 (lr_dense.intercept() - lr_sparse.intercept()).abs() < 1e-6,
575 "Dense intercept={} vs Sparse={}",
576 lr_dense.intercept(),
577 lr_sparse.intercept()
578 );
579 }
580
581 #[test]
582 fn test_sparse_predict_matches_dense() {
583 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
584 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
585 let data = Dataset::new(features, target, vec!["x".into()], "y");
586
587 let mut lr = LinearRegression::new();
588 lr.fit(&data).unwrap();
589
590 let test_rows = vec![vec![3.0], vec![10.0], vec![15.0]];
591 let preds_dense = lr.predict(&test_rows).unwrap();
592
593 let csr = CsrMatrix::from_dense(&test_rows);
594 let preds_sparse = lr.predict_sparse(&csr).unwrap();
595
596 for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
597 assert!((d - s).abs() < 1e-6, "Dense pred={d} vs Sparse pred={s}");
598 }
599 }
600
601 #[test]
602 fn test_auto_dispatch_sparse_fit() {
603 let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
605 let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
606 let csc = CscMatrix::from_dense(&features);
607 let data = crate::dataset::Dataset::from_sparse(csc, target, vec!["x".into()], "y");
608
609 let mut lr = LinearRegression::new();
610 lr.fit(&data).unwrap();
611
612 assert!(
613 (lr.coefficients()[0] - 2.0).abs() < 1e-4,
614 "Auto-dispatched sparse fit: coefficient should be ~2.0, got {}",
615 lr.coefficients()[0]
616 );
617 }
618}
619
620#[cfg(all(test, feature = "scry-gpu"))]
621mod gpu_tests {
622 use super::*;
623
624 #[test]
625 fn gpu_linear_regression_matches_cpu() {
626 let n = 500;
627 let m = 50;
628 let mut features = Vec::with_capacity(m);
629 for j in 0..m {
630 let col: Vec<f64> = (0..n).map(|i| ((i * (j + 1)) % 97) as f64 * 0.1).collect();
631 features.push(col);
632 }
633 let target: Vec<f64> = (0..n)
634 .map(|i| features[0][i] * 2.0 + features[1][i] * 0.5 + features[2][i] + 3.0)
635 .collect();
636 let names: Vec<String> = (0..m).map(|j| format!("f{j}")).collect();
637 let data = Dataset::new(features, target, names, "y");
638
639 let mut lr = LinearRegression::new().alpha(0.1);
640 lr.fit(&data).unwrap();
641
642 assert!(lr.coefficients().len() == m);
643 let preds = lr.predict(&[vec![1.0; m]]).unwrap();
644 assert!(
645 preds[0].is_finite(),
646 "prediction must be finite, got {}",
647 preds[0]
648 );
649 }
650}