#[derive(Debug, Clone)]
pub struct LinearRegression {
pub coefficients: Vec<f64>,
}
impl LinearRegression {
pub fn fit(features: &[Vec<f64>], target: &[f64], fit_intercept: bool) -> Self {
let n = features.len();
if n == 0 {
panic!("No training samples provided.");
}
if n != target.len() {
panic!("features and target must have the same number of rows.");
}
let d = if n > 0 { features[0].len() } else { 0 };
if d == 0 && !fit_intercept {
panic!("No features provided and fit_intercept = false.");
}
let x_cols = if fit_intercept { d + 1 } else { d };
let mut x_matrix = vec![0.0; n * x_cols];
for i in 0..n {
let row = &features[i];
if row.len() != d {
panic!("Inconsistent feature dimension on row {}.", i);
}
if fit_intercept {
x_matrix[i * x_cols] = 1.0; for (j, val) in row.iter().enumerate() {
x_matrix[i * x_cols + (j + 1)] = *val;
}
} else {
for (j, val) in row.iter().enumerate() {
x_matrix[i * x_cols + j] = *val;
}
}
}
let mut y_vec = vec![0.0; n];
y_vec.copy_from_slice(target);
let xtx = matmul_transpose_a(&x_matrix, n, x_cols); let xtx_inv = invert_matrix(xtx, x_cols)
.unwrap_or_else(|| panic!("Matrix inversion failed (X^T X might be singular)."));
let xty = matvec_transpose_a(&x_matrix, n, x_cols, &y_vec);
let mut beta = vec![0.0; x_cols];
for i in 0..x_cols {
let mut sum = 0.0;
for j in 0..x_cols {
sum += xtx_inv[i * x_cols + j] * xty[j];
}
beta[i] = sum;
}
Self { coefficients: beta }
}
pub fn predict(&self, x: &[f64]) -> f64 {
let dim = self.coefficients.len();
let has_intercept = x.len() + 1 == dim;
if has_intercept {
let mut result = self.coefficients[0];
for (j, &val) in x.iter().enumerate() {
result += self.coefficients[j + 1] * val;
}
result
} else {
if x.len() != dim {
panic!(
"Input feature length mismatch. Expected {}, got {}",
dim,
x.len()
);
}
let mut result = 0.0;
for (j, &val) in x.iter().enumerate() {
result += self.coefficients[j] * val;
}
result
}
}
pub fn predict_batch(&self, xs: &[Vec<f64>]) -> Vec<f64> {
xs.iter().map(|row| self.predict(row)).collect()
}
}
fn matmul_transpose_a(x: &[f64], n: usize, x_cols: usize) -> Vec<f64> {
let mut result = vec![0.0; x_cols * x_cols];
for i in 0..x_cols {
for j in 0..x_cols {
let mut sum = 0.0;
for k in 0..n {
let xi = x[k * x_cols + i];
let xj = x[k * x_cols + j];
sum += xi * xj;
}
result[i * x_cols + j] = sum;
}
}
result
}
fn matvec_transpose_a(x: &[f64], n: usize, x_cols: usize, y: &[f64]) -> Vec<f64> {
let mut result = vec![0.0; x_cols];
for i in 0..x_cols {
let mut sum = 0.0;
for k in 0..n {
sum += x[k * x_cols + i] * y[k];
}
result[i] = sum;
}
result
}
fn invert_matrix(mut mat: Vec<f64>, dim: usize) -> Option<Vec<f64>> {
let mut inv = vec![0.0; dim * dim];
for i in 0..dim {
inv[i * dim + i] = 1.0;
}
for i in 0..dim {
let mut pivot_row = i;
let mut pivot_val = mat[i * dim + i].abs();
for r in (i + 1)..dim {
let val = mat[r * dim + i].abs();
if val > pivot_val {
pivot_row = r;
pivot_val = val;
}
}
if pivot_val < 1e-15 {
return None;
}
if pivot_row != i {
for c in 0..dim {
mat.swap(i * dim + c, pivot_row * dim + c);
inv.swap(i * dim + c, pivot_row * dim + c);
}
}
let pivot = mat[i * dim + i];
for c in 0..dim {
mat[i * dim + c] /= pivot;
inv[i * dim + c] /= pivot;
}
for r in 0..dim {
if r != i {
let factor = mat[r * dim + i];
for c in 0..dim {
mat[r * dim + c] -= factor * mat[i * dim + c];
inv[r * dim + c] -= factor * inv[i * dim + c];
}
}
}
}
Some(inv)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_regression() {
let x = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0]];
let y = vec![2.0, 5.0, 8.0, 11.0];
let model = LinearRegression::fit(&x, &y, true);
let intercept = model.coefficients[0];
let slope = model.coefficients[1];
assert!((intercept - 2.0).abs() < 1e-7);
assert!((slope - 3.0).abs() < 1e-7);
let pred = model.predict(&[4.0]); assert!((pred - 14.0).abs() < 1e-7);
}
#[test]
fn test_no_intercept() {
let x = vec![vec![1.0], vec![2.0], vec![3.0]];
let y = vec![2.0, 4.0, 6.0];
let model = LinearRegression::fit(&x, &y, false);
assert_eq!(model.coefficients.len(), 1);
let slope = model.coefficients[0];
assert!((slope - 2.0).abs() < 1e-7);
let pred = model.predict(&[4.0]);
assert!((pred - 8.0).abs() < 1e-7);
}
}