1use num_traits::{Float, NumCast};
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::fs::File;
7use std::io::{self};
8use std::path::Path;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct LinearRegression<T = f64>
13where
14 T: Float + Debug + Default + Serialize,
15{
16 pub slope: T,
18 pub intercept: T,
20 pub r_squared: T,
22 pub standard_error: T,
24 pub n: usize,
26}
27
28impl<T> Default for LinearRegression<T>
29where
30 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
31{
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<T> LinearRegression<T>
38where
39 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
40{
41 pub fn new() -> Self {
43 Self {
44 slope: T::zero(),
45 intercept: T::zero(),
46 r_squared: T::zero(),
47 standard_error: T::zero(),
48 n: 0,
49 }
50 }
51
52 pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> Result<(), String>
61 where
62 U: NumCast + Copy,
63 V: NumCast + Copy,
64 {
65 if x_values.len() != y_values.len() {
67 return Err("X and Y arrays must have the same length".to_string());
68 }
69
70 if x_values.is_empty() {
71 return Err("Cannot fit regression with empty arrays".to_string());
72 }
73
74 let n = x_values.len();
75 self.n = n;
76
77 let x_cast: Vec<T> = x_values
79 .iter()
80 .map(|&x| T::from(x).ok_or_else(|| "Failed to cast X value".to_string()))
81 .collect::<Result<Vec<T>, String>>()?;
82
83 let y_cast: Vec<T> = y_values
84 .iter()
85 .map(|&y| T::from(y).ok_or_else(|| "Failed to cast Y value".to_string()))
86 .collect::<Result<Vec<T>, String>>()?;
87
88 let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / T::from(n).unwrap();
90 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / T::from(n).unwrap();
91
92 let mut sum_xy = T::zero();
94 let mut sum_xx = T::zero();
95 let mut sum_yy = T::zero();
96
97 for i in 0..n {
98 let x_diff = x_cast[i] - x_mean;
99 let y_diff = y_cast[i] - y_mean;
100
101 sum_xy = sum_xy + (x_diff * y_diff);
102 sum_xx = sum_xx + (x_diff * x_diff);
103 sum_yy = sum_yy + (y_diff * y_diff);
104 }
105
106 if sum_xx == T::zero() {
108 return Err("No variance in X values, cannot fit regression line".to_string());
109 }
110
111 self.slope = sum_xy / sum_xx;
113 self.intercept = y_mean - (self.slope * x_mean);
114
115 self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
117
118 let mut sum_squared_residuals = T::zero();
120 for i in 0..n {
121 let predicted = self.predict_t(x_cast[i]);
122 let residual = y_cast[i] - predicted;
123 sum_squared_residuals = sum_squared_residuals + (residual * residual);
124 }
125
126 if n > 2 {
128 let two = T::from(2).unwrap();
129 self.standard_error = (sum_squared_residuals / (T::from(n).unwrap() - two)).sqrt();
130 } else {
131 self.standard_error = T::zero();
132 }
133
134 Ok(())
135 }
136
137 fn predict_t(&self, x: T) -> T {
139 self.intercept + (self.slope * x)
140 }
141
142 pub fn predict<U>(&self, x: U) -> T
150 where
151 U: NumCast + Copy,
152 {
153 let x_cast: T = match T::from(x) {
154 Some(val) => val,
155 None => return T::nan(),
156 };
157
158 self.predict_t(x_cast)
159 }
160
161 pub fn predict_many<U>(&self, x_values: &[U]) -> Vec<T>
169 where
170 U: NumCast + Copy,
171 {
172 x_values.iter().map(|&x| self.predict(x)).collect()
173 }
174
175 pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> Option<(T, T)>
184 where
185 U: NumCast + Copy,
186 {
187 if self.n < 3 {
188 return None;
189 }
190
191 let x_cast: T = T::from(x)?;
192
193 let z_score: T = match confidence_level {
196 0.90 => T::from(1.645).unwrap(),
197 0.95 => T::from(1.96).unwrap(),
198 0.99 => T::from(2.576).unwrap(),
199 _ => return None, };
201
202 let predicted = self.predict_t(x_cast);
203 let margin = z_score * self.standard_error;
204
205 Some((predicted - margin, predicted + margin))
206 }
207
208 pub fn correlation_coefficient(&self) -> T {
210 let r = self.r_squared.sqrt();
211 if self.slope >= T::zero() { r } else { -r }
212 }
213
214 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
222 let file = File::create(path)?;
223 serde_json::to_writer(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
225 }
226
227 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
235 let file = File::create(path)?;
236 bincode::serialize_into(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
238 }
239
240 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
248 let file = File::open(path)?;
249 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
251 }
252
253 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
261 let file = File::open(path)?;
262 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
264 }
265
266 pub fn to_json(&self) -> Result<String, String> {
271 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
272 }
273
274 pub fn from_json(json: &str) -> Result<Self, String> {
282 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::utils::numeric::approx_equal;
290 use tempfile::tempdir;
291
292 #[test]
293 fn test_simple_regression_f64() {
294 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
295 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
296
297 let mut model = LinearRegression::<f64>::new();
298 let result = model.fit(&x, &y);
299
300 assert!(result.is_ok());
301 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
302 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
303 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
304 }
305
306 #[test]
307 fn test_simple_regression_f32() {
308 let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
309 let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
310
311 let mut model = LinearRegression::<f32>::new();
312 let result = model.fit(&x, &y);
313
314 assert!(result.is_ok());
315 assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
316 assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
317 assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
318 }
319
320 #[test]
321 fn test_integer_data() {
322 let x = vec![1, 2, 3, 4, 5];
323 let y = vec![2, 4, 6, 8, 10];
324
325 let mut model = LinearRegression::<f64>::new();
326 let result = model.fit(&x, &y);
327
328 assert!(result.is_ok());
329 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
330 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
331 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
332 }
333
334 #[test]
335 fn test_mixed_types() {
336 let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
337 let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
338
339 let mut model = LinearRegression::<f64>::new();
340 let result = model.fit(&x, &y);
341
342 assert!(result.is_ok());
343 assert!(model.slope > 1.9 && model.slope < 2.1);
344 assert!(model.intercept > -0.1 && model.intercept < 0.1);
345 assert!(model.r_squared > 0.99);
346 }
347
348 #[test]
349 fn test_prediction() {
350 let x = vec![1, 2, 3, 4, 5];
351 let y = vec![2, 4, 6, 8, 10];
352
353 let mut model = LinearRegression::<f64>::new();
354 model.fit(&x, &y).unwrap();
355
356 assert!(approx_equal(model.predict(6u32), 12.0, Some(1e-6)));
357 assert!(approx_equal(model.predict(0i32), 0.0, Some(1e-6)));
358 }
359
360 #[test]
361 fn test_invalid_inputs() {
362 let x = vec![1, 2, 3];
363 let y = vec![2, 4];
364
365 let mut model = LinearRegression::<f64>::new();
366 let result = model.fit(&x, &y);
367
368 assert!(result.is_err());
369 }
370
371 #[test]
372 fn test_constant_x() {
373 let x = vec![1, 1, 1];
374 let y = vec![2, 3, 4];
375
376 let mut model = LinearRegression::<f64>::new();
377 let result = model.fit(&x, &y);
378
379 assert!(result.is_err());
380 }
381
382 #[test]
383 fn test_save_load_json() {
384 let dir = tempdir().unwrap();
386 let file_path = dir.path().join("model.json");
387
388 let mut model = LinearRegression::<f64>::new();
390 model
391 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
392 .unwrap();
393
394 let save_result = model.save(&file_path);
396 assert!(save_result.is_ok());
397
398 let loaded_model = LinearRegression::<f64>::load(&file_path);
400 assert!(loaded_model.is_ok());
401 let loaded = loaded_model.unwrap();
402
403 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
405 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
406 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
407 assert_eq!(loaded.n, model.n);
408 }
409
410 #[test]
411 fn test_save_load_binary() {
412 let dir = tempdir().unwrap();
414 let file_path = dir.path().join("model.bin");
415
416 let mut model = LinearRegression::<f64>::new();
418 model
419 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
420 .unwrap();
421
422 let save_result = model.save_binary(&file_path);
424 assert!(save_result.is_ok());
425
426 let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
428 assert!(loaded_model.is_ok());
429 let loaded = loaded_model.unwrap();
430
431 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
433 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
434 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
435 assert_eq!(loaded.n, model.n);
436 }
437
438 #[test]
439 fn test_json_serialization() {
440 let mut model = LinearRegression::<f64>::new();
442 model
443 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
444 .unwrap();
445
446 let json_result = model.to_json();
448 assert!(json_result.is_ok());
449 let json_str = json_result.unwrap();
450
451 let loaded_model = LinearRegression::<f64>::from_json(&json_str);
453 assert!(loaded_model.is_ok());
454 let loaded = loaded_model.unwrap();
455
456 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
458 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
459 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
460 assert_eq!(loaded.n, model.n);
461 }
462}