use anyhow::Result;
use linfa::prelude::*;
use linfa_linear::LinearRegression;
use ndarray::{Array1, Array2, Ix1};
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use std::fs::File;
use std::path::Path;
use crate::data_utils;
pub fn run_regression_example() -> Result<()> {
println!("Linfa 0.7.1 Linear Regression Example");
let csv_path = Path::new("data/sample_regression.csv");
let json_path = Path::new("data/sample_regression.json");
let dataset = if csv_path.exists() {
println!("Loading data from CSV file: {}", csv_path.display());
data_utils::load_csv_dataset(csv_path)?
} else if json_path.exists() {
println!("Loading data from JSON file: {}", json_path.display());
data_utils::load_json_dataset(json_path)?
} else {
println!("No data files found, using synthetic data");
data_utils::create_synthetic_regression_dataset()?
};
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let (train, test) = dataset.shuffle(&mut rng).split_with_ratio(0.7);
println!("Training dataset: {} samples", train.nsamples());
println!("Testing dataset: {} samples", test.nsamples());
println!("Training LinearRegression model...");
let model = LinearRegression::default()
.fit(&train)?;
println!("Making predictions...");
let predictions = model.predict(test.records());
println!("Predictions vs Actual:");
for (i, pred) in predictions.iter().enumerate() {
let actual = test.targets().get(i).unwrap();
println!(" Predicted: {:.2}, Actual: {:.2}", pred, actual);
}
let mse = predictions.iter()
.zip(test.targets().iter())
.map(|(&p, &a)| (p - a) * (p - a))
.sum::<f64>() / predictions.len() as f64;
println!("Mean Squared Error: {:.4}", mse);
println!("Model parameters:");
println!(" Parameters shape: {:?}", model.params().shape());
println!(" Parameters values: {:?}", model.params());
if model.params().len() == 1 {
let m = model.params()[0]; let b = model.intercept();
println!(" Estimated coefficient (m): {:.4}", m);
println!(" Estimated intercept (b): {:.4}", b);
println!(" Estimated model equation: y = {:.4} * x + {:.4}", m, b);
println!("\nPredicting on new data:");
let new_x_values = [0.5, 7.0, 10.0];
for &x in &new_x_values {
let y = m * x + b;
println!(" x = {:.1}, predicted y = {:.2}", x, y);
}
}
Ok(())
}
fn create_synthetic_dataset() -> Result<Dataset<f64, f64, Ix1>> {
let features = Array2::from_shape_vec(
(6, 1),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
)?;
let targets = Array1::from_vec(vec![3.1, 5.2, 7.0, 8.9, 10.8, 13.1]);
Ok(Dataset::new(features, targets))
}
fn load_csv_dataset(path: &Path) -> Result<Dataset<f64, f64, Ix1>> {
let file = File::open(path)?;
let mut reader = csv::Reader::from_reader(file);
let mut features_data = Vec::new();
let mut targets_data = Vec::new();
for result in reader.records() {
let record = result?;
if record.len() >= 2 {
let x = record[0].parse::<f64>()?;
features_data.push(x);
let y = record[1].parse::<f64>()?;
targets_data.push(y);
}
}
let num_samples = targets_data.len();
let features = Array2::from_shape_vec((num_samples, 1), features_data)?;
let targets = Array1::from_vec(targets_data);
Ok(Dataset::new(features, targets))
}