use anyhow::Result;
use linfa::prelude::*;
use linfa_logistic::LogisticRegression;
use ndarray::{Array1, Array2, Ix1};
use rand::SeedableRng;
use std::path::Path;
use std::fs::File;
use crate::data_utils;
pub fn run_logistic_regression_example() -> Result<()> {
println!("Linfa 0.7.1 Logistic Regression Example");
let csv_path = Path::new("data/sample_classification.csv");
let json_path = Path::new("data/sample_classification.json");
let dataset = if csv_path.exists() {
println!("Loading data from CSV file: {}", csv_path.display());
load_classification_dataset(csv_path)?
} else if json_path.exists() {
println!("Loading data from JSON file: {}", json_path.display());
load_classification_dataset(json_path)?
} else {
println!("No data files found, using synthetic data");
create_synthetic_classification_dataset()?
};
println!("Dataset shape: [{}, {}]", dataset.nsamples(), dataset.nfeatures());
println!("Number of samples: {}", dataset.nsamples());
let (train, test) = dataset.split_with_ratio(0.8);
println!("Training set size: {}", train.nsamples());
println!("Testing set size: {}", test.nsamples());
let model = LogisticRegression::default()
.max_iterations(100)
.fit(&train)?;
println!("Model trained successfully");
let predictions = model.predict(&test);
let cm = confusion_matrix(&predictions, test.targets())?;
let accuracy = cm.accuracy();
println!("Confusion Matrix:\n{}", cm);
println!("Accuracy: {:.2}%", accuracy * 100.0);
Ok(())
}
fn create_synthetic_classification_dataset() -> Result<Dataset<f64, usize, Ix1>> {
use rand::Rng;
let mut rng = rand::thread_rng();
let num_classes = 2;
let samples_per_class = 50;
let num_samples = num_classes * samples_per_class;
let mut features_data = Vec::with_capacity(num_samples * 2);
let mut targets_data = Vec::with_capacity(num_samples);
for class in 0..num_classes {
let x_center = if class == 0 { 1.0 } else { 5.0 };
let y_center = if class == 0 { 1.0 } else { 5.0 };
for _ in 0..samples_per_class {
let x = x_center + rng.gen_range(-1.0..1.0);
let y = y_center + rng.gen_range(-1.0..1.0);
features_data.push(x);
features_data.push(y);
targets_data.push(class);
}
}
let features = Array2::from_shape_vec((num_samples, 2), features_data)?;
let targets = Array1::from_vec(targets_data);
Ok(Dataset::new(features, targets))
}
fn load_classification_dataset(path: &Path) -> Result<Dataset<f64, usize, Ix1>> {
let format = data_utils::detect_file_format(path)?;
match format {
"csv" => load_csv_classification_dataset(path),
"json" => load_json_classification_dataset(path),
_ => Err(anyhow::anyhow!("Unsupported file format: {}", format)),
}
}
fn load_csv_classification_dataset(path: &Path) -> Result<Dataset<f64, usize, Ix1>> {
use csv;
let file = File::open(path)?;
let mut reader = csv::Reader::from_reader(file);
let mut features_data = Vec::new();
let mut targets_data = Vec::new();
for result in reader.records() {
let record = result?;
if record.len() >= 3 {
let x = record[0].parse::<f64>()?;
let y = record[1].parse::<f64>()?;
let target = record[2].parse::<usize>()?;
features_data.push(x);
features_data.push(y);
targets_data.push(target);
}
}
let num_samples = targets_data.len();
let features = Array2::from_shape_vec((num_samples, 2), features_data)?;
let targets = Array1::from_vec(targets_data);
Ok(Dataset::new(features, targets))
}
fn load_json_classification_dataset(path: &Path) -> Result<Dataset<f64, usize, Ix1>> {
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct ClassificationPoint {
x: f64,
y: f64,
target: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct ClassificationDataSet {
data: Vec<ClassificationPoint>,
}
let file = File::open(path)?;
let dataset: ClassificationDataSet = serde_json::from_reader(file)?;
let mut features_data = Vec::new();
let mut targets_data = Vec::new();
for point in dataset.data {
features_data.push(point.x);
features_data.push(point.y);
targets_data.push(point.target);
}
let num_samples = targets_data.len();
let features = Array2::from_shape_vec((num_samples, 2), features_data)?;
let targets = Array1::from_vec(targets_data);
Ok(Dataset::new(features, targets))
}
fn confusion_matrix(predictions: &Array1<usize>, targets: &Array1<usize>) -> Result<ConfusionMatrix<usize>> {
let mut classes = Vec::new();
for &target in targets.iter() {
if !classes.contains(&target) {
classes.push(target);
}
}
for &pred in predictions.iter() {
if !classes.contains(&pred) {
classes.push(pred);
}
}
classes.sort();
let mut cm = ConfusionMatrix::new(classes)?;
for (pred, actual) in predictions.iter().zip(targets.iter()) {
cm.increment(*pred, *actual)?;
}
Ok(cm)
}
struct ConfusionMatrix<T> {
classes: Vec<T>,
matrix: Array2<usize>,
}
impl<T: std::cmp::PartialEq + std::fmt::Display + Copy + std::fmt::Debug> ConfusionMatrix<T> {
fn new(classes: Vec<T>) -> Result<Self> {
let n = classes.len();
let matrix = Array2::zeros((n, n));
Ok(ConfusionMatrix { classes, matrix })
}
fn increment(&mut self, predicted: T, actual: T) -> Result<()> {
let pred_idx = self.classes.iter().position(|&c| c == predicted)
.ok_or_else(|| anyhow::anyhow!("Unknown class: {:?}", predicted))?;
let actual_idx = self.classes.iter().position(|&c| c == actual)
.ok_or_else(|| anyhow::anyhow!("Unknown class: {:?}", actual))?;
self.matrix[[pred_idx, actual_idx]] += 1;
Ok(())
}
fn accuracy(&self) -> f64 {
let total = self.matrix.sum();
if total == 0 {
return 0.0;
}
let mut correct = 0;
for i in 0..self.classes.len() {
correct += self.matrix[[i, i]];
}
correct as f64 / total as f64
}
}
impl<T: std::fmt::Display + std::fmt::Debug> std::fmt::Display for ConfusionMatrix<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:<10} | ", "classes")?;
for class in &self.classes {
write!(f, "{:<10} | ", class)?;
}
writeln!(f)?;
for (i, class) in self.classes.iter().enumerate() {
write!(f, "{:<10} | ", class)?;
for j in 0..self.classes.len() {
write!(f, "{:<10} | ", self.matrix[[i, j]])?;
}
writeln!(f)?;
}
Ok(())
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for ConfusionMatrix<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "ConfusionMatrix")
}
}