#![allow(clippy::disallowed_methods)]
use aprender::classification::{GaussianNB, KNearestNeighbors, LinearSVM};
use aprender::primitives::Matrix;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Linear SVM: Iris Binary Classification ===\n");
let (x_train, y_train, x_test, y_test) = load_binary_iris_data()?;
println!(
"Binary Dataset: {} training samples, {} test samples",
x_train.n_rows(),
x_test.n_rows()
);
println!("Classes: 0=Setosa, 1=Versicolor\n");
println!("=== Part 1: Basic Linear SVM ===\n");
let mut svm = LinearSVM::new()
.with_c(1.0)
.with_max_iter(1000)
.with_learning_rate(0.1);
svm.fit(&x_train, &y_train)?;
let predictions = svm.predict(&x_test)?;
let accuracy = compute_accuracy(&predictions, &y_test);
println!("Test Accuracy: {:.1}%\n", accuracy * 100.0);
println!("=== Part 2: Decision Function & Margins ===\n");
let decisions = svm.decision_function(&x_test)?;
println!("Sample predictions with decision values:");
println!("Sample True Predicted Decision Margin");
println!("───────────────────────────────────────────");
for i in 0..5.min(x_test.n_rows()) {
let true_label = y_test[i];
let pred = predictions[i];
let decision = decisions[i];
let margin = if true_label == 1 { decision } else { -decision };
println!(" {i} {true_label} {pred} {decision:.3} {margin:.3}");
}
println!();
println!("=== Part 3: Effect of Regularization (C) ===\n");
for &c_value in &[0.01, 0.1, 1.0, 10.0, 100.0] {
let mut svm_c = LinearSVM::new()
.with_c(c_value)
.with_max_iter(1000)
.with_learning_rate(0.1);
svm_c.fit(&x_train, &y_train)?;
let preds = svm_c.predict(&x_test)?;
let acc = compute_accuracy(&preds, &y_test);
println!("C={:6.2}: Accuracy = {:.1}%", c_value, acc * 100.0);
}
println!();
println!("=== Part 4: Comparison with Other Classifiers ===\n");
let mut nb = GaussianNB::new();
nb.fit(&x_train, &y_train)?;
let nb_predictions = nb.predict(&x_test)?;
let nb_accuracy = compute_accuracy(&nb_predictions, &y_test);
let mut knn = KNearestNeighbors::new(5).with_weights(true);
knn.fit(&x_train, &y_train)?;
let knn_predictions = knn.predict(&x_test)?;
let knn_accuracy = compute_accuracy(&knn_predictions, &y_test);
println!("Classifier Accuracy");
println!("─────────────────────────────");
println!("Linear SVM {:.1}%", accuracy * 100.0);
println!("Naive Bayes {:.1}%", nb_accuracy * 100.0);
println!("k-NN (k=5) {:.1}%\n", knn_accuracy * 100.0);
println!("=== Part 5: Understanding the Model ===\n");
println!("Linear SVM Characteristics:");
println!("✓ Maximizes margin between classes");
println!("✓ Robust to outliers (with appropriate C)");
println!("✓ Fast prediction (linear decision function)");
println!("✓ Convex optimization (guaranteed convergence)");
println!("✓ Effective in high-dimensional spaces\n");
println!("Regularization Parameter C:");
println!("- Small C (0.01-0.1): Large margin, simpler model, more regularization");
println!("- Large C (10-100): Small margin, complex model, less regularization");
println!("- Default C=1.0: Balanced trade-off\n");
println!("=== Part 6: Per-Class Performance ===\n");
let mut class_correct = [0; 2];
let mut class_total = [0; 2];
for (&pred, &true_label) in predictions.iter().zip(y_test.iter()) {
class_total[true_label] += 1;
if pred == true_label {
class_correct[true_label] += 1;
}
}
let species = ["Setosa", "Versicolor"];
println!("Species Correct Total Accuracy");
println!("──────────────────────────────────────");
for i in 0..2 {
let acc = class_correct[i] as f32 / class_total[i] as f32 * 100.0;
println!(
"{:12} {}/{} {:2} {:.1}%",
species[i], class_correct[i], class_total[i], class_total[i], acc
);
}
println!();
println!("=== Summary ===\n");
println!("Linear SVM vs Naive Bayes vs k-NN:");
println!("- Training: SVM iterative vs NB instant vs kNN instant (lazy)");
println!("- Prediction: SVM O(p) vs NB O(p·c) vs kNN O(n·p)");
println!("- Decision: SVM margin-based vs NB probabilistic vs kNN similarity");
println!("- Regularization: SVM C parameter vs NB variance smoothing vs kNN k");
println!(
"- Accuracy: SVM {:.1}% vs NB {:.1}% vs kNN {:.1}%",
accuracy * 100.0,
nb_accuracy * 100.0,
knn_accuracy * 100.0
);
Ok(())
}
#[allow(clippy::type_complexity)]
fn load_binary_iris_data(
) -> Result<(Matrix<f32>, Vec<usize>, Matrix<f32>, Vec<usize>), &'static str> {
let x_train = Matrix::from_vec(
14,
4,
vec![
5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0,
3.6, 1.4, 0.2, 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4,
0.3, 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3, 6.5,
2.8, 4.6, 1.5, 5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6,
],
)?;
let y_train = vec![
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ];
let x_test = Matrix::from_vec(
6,
4,
vec![
5.0, 3.3, 1.4, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, 3.9, 1.4,
],
)?;
let y_test = vec![
0, 0, 0, 1, 1, 1, ];
Ok((x_train, y_train, x_test, y_test))
}
fn compute_accuracy(predictions: &[usize], true_labels: &[usize]) -> f32 {
let correct = predictions
.iter()
.zip(true_labels.iter())
.filter(|(pred, true_label)| pred == true_label)
.count();
correct as f32 / true_labels.len() as f32
}