use anyhow::{anyhow, Result};
use clap::{Parser, Subcommand};
use linfa::prelude::*;
use ndarray::{Array1, Array2, Axis};
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use rand_xoshiro::Xoshiro256Plus;
use std::collections::HashMap;
use std::fs::File;
use std::path::PathBuf;
mod datasets;
mod evaluation;
mod models;
#[derive(Parser)]
#[command(author, version, about = "Linfa Machine Learning Examples")]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Classify,
Tree,
Regress,
Cluster,
All,
Custom {
#[arg(short, long)]
file: PathBuf,
#[arg(short, long, default_value = "classify")]
analysis: String,
},
Help,
}
fn main() -> Result<()> {
let cli = Cli::parse();
match &cli.command {
Commands::Classify => run_logistic_regression_example()?,
Commands::Tree => run_decision_tree_example()?,
Commands::Regress => run_regression_example()?,
Commands::Cluster => run_dbscan_example()?,
Commands::All => {
println!("\n=== Running LogisticRegression Classification Example ===\n");
run_logistic_regression_example()?;
println!("\n=== Running DecisionTree Classification Example ===\n");
run_decision_tree_example()?;
println!("\n=== Running LinearRegression Example ===\n");
run_regression_example()?;
println!("\n=== Running DBSCAN Clustering Example ===\n");
run_dbscan_example()?;
},
Commands::Custom { file, analysis } => {
println!("Loading custom dataset from {:?}", file);
println!("Custom dataset analysis is not implemented in this example.");
println!("Please use one of the built-in examples.");
},
Commands::Help => {
println!("Linfa Machine Learning Examples");
println!("==============================");
println!("Available commands:");
println!(" classify - Run logistic regression classification example");
println!(" tree - Run decision tree classification example");
println!(" regress - Run linear regression example");
println!(" cluster - Run DBSCAN clustering example");
println!(" all - Run all examples");
println!(" custom - Load a custom dataset and run analysis");
println!(" help - Show this help information");
},
}
Ok(())
}
fn run_logistic_regression_example() -> Result<()> {
println!("Linfa 0.7.1 LogisticRegression Classification Example");
let features = Array2::from_shape_vec(
(6, 2),
vec![
1.0, 2.0, 1.0, 3.0, 2.0, 2.0, 3.0, 1.0, 3.0, 3.0, 4.0, 2.0, ],
)?;
let targets = Array1::from_vec(vec![0, 0, 0, 1, 1, 1]);
let dataset = Dataset::new(features, targets);
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let (train, test) = dataset.shuffle(&mut rng).split_with_ratio(0.5);
println!("Training dataset: {} samples", train.nsamples());
println!("Testing dataset: {} samples", test.nsamples());
println!("Training LogisticRegression model...");
let model = linfa_logistic::LogisticRegression::default()
.max_iterations(100)
.fit(&train)?;
println!("Making predictions...");
let predictions = model.predict(test.records());
println!("Predictions: {:?}", predictions);
let cm = predictions.confusion_matrix(&test)?;
println!("Confusion Matrix:");
println!("{:?}", cm);
println!("Accuracy: {:.2}", cm.accuracy());
println!("\nThis example demonstrates a complete classification workflow using LogisticRegression in Linfa 0.7.1.");
println!("It shows how to create a dataset, split it into training and testing sets, train a model, make predictions, and evaluate the results.");
Ok(())
}
fn run_decision_tree_example() -> Result<()> {
println!("Linfa 0.7.1 Decision Tree Classification Example");
let features = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 2.0, 1.5, 2.5, 2.0, 3.0, 1.0, 3.0, 4.0, 1.0, 4.5, 1.5, 5.0, 1.0, 5.5, 0.5, ],
)?;
let targets = Array1::from_vec(vec![0, 0, 0, 0, 1, 1, 1, 1]);
let dataset = Dataset::new(features, targets);
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let (train, test) = dataset.shuffle(&mut rng).split_with_ratio(0.75);
println!("Training dataset: {} samples", train.nsamples());
println!("Testing dataset: {} samples", test.nsamples());
println!("Training Decision Tree model...");
let model = linfa_trees::DecisionTree::params()
.max_depth(Some(3))
.fit(&train)?;
println!("Making predictions...");
let predictions = model.predict(test.records());
println!("Predictions vs Actual:");
for (i, &pred) in predictions.iter().enumerate() {
let actual = test.targets().get(i).unwrap();
println!(" Predicted: {}, Actual: {}", pred, actual);
}
let cm = predictions.confusion_matrix(test.targets())?;
println!("Confusion Matrix:\n");
println!("{:?}", cm);
println!("Accuracy: {:.2}", cm.accuracy());
println!("\nThis example demonstrates a complete Decision Tree classification workflow using Linfa 0.7.1.");
println!("It shows how to create a dataset, split it into training and testing sets, train a Decision Tree model,");
println!("make predictions, and evaluate the results.");
Ok(())
}
fn run_regression_example() -> Result<()> {
println!("Linfa 0.7.1 Linear Regression Example");
let features = Array2::from_shape_vec(
(6, 1),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
)?;
let targets = Array1::from_vec(vec![3.1, 5.2, 7.0, 8.9, 10.8, 13.1]);
let dataset = Dataset::new(features.clone(), targets.clone());
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let (train, test) = dataset.shuffle(&mut rng).split_with_ratio(0.7);
println!("Training dataset: {} samples", train.nsamples());
println!("Testing dataset: {} samples", test.nsamples());
println!("Training LinearRegression model...");
let model = linfa_linear::LinearRegression::default()
.fit(&train)?;
println!("Making predictions...");
let predictions = model.predict(test.records());
println!("Predictions vs Actual:");
for (i, pred) in predictions.iter().enumerate() {
let actual = test.targets().get(i).unwrap();
println!(" Predicted: {:.2}, Actual: {:.2}", pred, actual);
}
let mse = predictions.iter()
.zip(test.targets().iter())
.map(|(&p, &a)| (p - a) * (p - a))
.sum::<f64>() / predictions.len() as f64;
println!("Mean Squared Error: {:.4}", mse);
println!("Model parameters:");
let params = model.params();
println!(" Parameters shape: {:?}", params.shape());
println!(" Parameters values: {:?}", params);
let x_mean = train.records().column(0).mean().unwrap_or(0.0);
let y_mean = train.targets().mean().unwrap_or(0.0);
let x_test = test.records().column(0).to_owned();
let y_pred = predictions;
if x_test.len() > 0 && y_pred.len() > 0 {
let coefficient = (y_pred[0] - y_mean) / (x_test[0] - x_mean);
let intercept = y_mean - coefficient * x_mean;
println!(" Estimated coefficient (m): {:.4}", coefficient);
println!(" Estimated intercept (b): {:.4}", intercept);
println!(" Estimated model equation: y = {:.4} * x + {:.4}", coefficient, intercept);
}
println!("\nPredicting on new data:");
let new_data = Array2::from_shape_vec(
(3, 1),
vec![0.5, 7.0, 10.0],
)?;
let new_predictions = model.predict(&new_data);
for (i, &x) in new_data.column(0).iter().enumerate() {
println!(" x = {:.1}, predicted y = {:.2}", x, new_predictions[i]);
}
Ok(())
}
fn run_dbscan_example() -> Result<()> {
println!("Linfa 0.7.1 DBSCAN Clustering Example");
let n_samples_per_cluster = 100;
let n_clusters = 3;
let n_samples = n_samples_per_cluster * n_clusters;
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let cluster_centers = Array2::random_using(
(n_clusters, 2),
Uniform::new(-20.0, 20.0),
&mut rng,
);
let mut samples = Array2::zeros((n_samples, 2));
for i in 0..n_clusters {
let cluster_center = cluster_centers.row(i);
let start_idx = i * n_samples_per_cluster;
let end_idx = start_idx + n_samples_per_cluster;
let cluster_samples = Array2::random_using(
(n_samples_per_cluster, 2),
Uniform::new(-5.0, 5.0),
&mut rng,
);
for j in start_idx..end_idx {
for k in 0..2 {
samples[[j, k]] = cluster_center[k] + cluster_samples[[j - start_idx, k]];
}
}
}
println!("Dataset shape: {:?}", samples.shape());
println!("Number of samples: {}", n_samples);
let dataset = Dataset::from(samples);
let min_points = 3; let tolerance = 2.0;
println!("Running DBSCAN clustering with min_points = {}, tolerance = {}", min_points, tolerance);
let model = linfa_clustering::Dbscan::params(min_points)
.tolerance(tolerance)
.transform(&dataset)?;
let cluster_memberships = model.cluster_memberships();
println!("First 10 cluster assignments: {:?}", &cluster_memberships.slice(Axis(0), 0..10.min(cluster_memberships.len())));
let mut cluster_counts: HashMap<usize, usize> = HashMap::new();
for membership in cluster_memberships.iter() {
if let Some(cluster_idx) = membership {
*cluster_counts.entry(*cluster_idx).or_insert(0) += 1;
}
}
println!("Cluster counts:");
for (cluster_idx, count) in cluster_counts.iter() {
println!(" Cluster {}: {} points", cluster_idx, count);
}
Ok(())
}