aprender/classification/
mod.rs

1//! Classification algorithms.
2//!
3//! This module implements classification algorithms including:
4//! - Logistic Regression for binary classification
5//! - K-Nearest Neighbors (kNN) for instance-based classification
6//! - Gaussian Naive Bayes for probabilistic classification
7//! - Linear Support Vector Machine (SVM) for maximum-margin classification
8//! - Softmax Regression for multi-class classification (planned)
9//!
10//! # Example
11//!
12//! ```
13//! use aprender::classification::LogisticRegression;
14//! use aprender::prelude::*;
15//!
16//! // Binary classification data
17//! let x = Matrix::from_vec(4, 2, vec![
18//!     0.0, 0.0,
19//!     0.0, 1.0,
20//!     1.0, 0.0,
21//!     1.0, 1.0,
22//! ]).expect("Matrix dimensions match data length");
23//! let y = vec![0, 0, 0, 1];
24//!
25//! let mut model = LogisticRegression::new()
26//!     .with_learning_rate(0.1)
27//!     .with_max_iter(1000);
28//! model.fit(&x, &y).expect("Training data is valid with 4 samples");
29//! let predictions = model.predict(&x);
30//!
31//! assert_eq!(predictions.len(), 4);
32//! for pred in predictions {
33//!     assert!(pred == 0 || pred == 1);
34//! }
35//! ```
36
37use crate::error::Result;
38use crate::primitives::{Matrix, Vector};
39use serde::{Deserialize, Serialize};
40use std::path::Path;
41
42/// Logistic Regression classifier for binary classification.
43///
44/// Uses sigmoid activation and binary cross-entropy loss with
45/// gradient descent optimization.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct LogisticRegression {
48    /// Model coefficients (weights)
49    coefficients: Option<Vector<f32>>,
50    /// Intercept (bias) term
51    intercept: f32,
52    /// Learning rate for gradient descent
53    learning_rate: f32,
54    /// Maximum number of iterations
55    max_iter: usize,
56    /// Convergence tolerance
57    tol: f32,
58}
59
60impl LogisticRegression {
61    /// Creates a new logistic regression classifier with default parameters.
62    ///
63    /// # Example
64    ///
65    /// ```
66    /// use aprender::classification::LogisticRegression;
67    ///
68    /// let model = LogisticRegression::new();
69    /// ```
70    pub fn new() -> Self {
71        Self {
72            coefficients: None,
73            intercept: 0.0,
74            learning_rate: 0.01,
75            max_iter: 1000,
76            tol: 1e-4,
77        }
78    }
79
80    /// Sets the learning rate.
81    pub fn with_learning_rate(mut self, lr: f32) -> Self {
82        self.learning_rate = lr;
83        self
84    }
85
86    /// Sets the maximum number of iterations.
87    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
88        self.max_iter = max_iter;
89        self
90    }
91
92    /// Sets the convergence tolerance.
93    pub fn with_tolerance(mut self, tol: f32) -> Self {
94        self.tol = tol;
95        self
96    }
97
98    /// Sigmoid activation function: σ(z) = 1 / (1 + e^(-z))
99    fn sigmoid(z: f32) -> f32 {
100        1.0 / (1.0 + (-z).exp())
101    }
102
103    /// Predicts probabilities for samples.
104    ///
105    /// Returns probability of class 1 for each sample.
106    pub fn predict_proba(&self, x: &Matrix<f32>) -> Vector<f32> {
107        let coef = self.coefficients.as_ref().expect("Model not fitted yet");
108        let (n_samples, _) = x.shape();
109
110        let mut probas = Vec::with_capacity(n_samples);
111        for row in 0..n_samples {
112            let mut z = self.intercept;
113            for col in 0..coef.len() {
114                z += coef[col] * x.get(row, col);
115            }
116            probas.push(Self::sigmoid(z));
117        }
118
119        Vector::from_vec(probas)
120    }
121
122    /// Fits the logistic regression model to training data.
123    ///
124    /// # Arguments
125    ///
126    /// * `x` - Feature matrix (n_samples × n_features)
127    /// * `y` - Binary labels (n_samples), must be 0 or 1
128    ///
129    /// # Returns
130    ///
131    /// `Ok(())` on success, `Err` with message on failure
132    pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
133        let (n_samples, n_features) = x.shape();
134
135        if n_samples != y.len() {
136            return Err("Number of samples in X and y must match".into());
137        }
138        if n_samples == 0 {
139            return Err("Cannot fit with zero samples".into());
140        }
141
142        // Validate labels are binary (0 or 1)
143        for &label in y {
144            if label != 0 && label != 1 {
145                return Err("Labels must be 0 or 1 for binary classification".into());
146            }
147        }
148
149        // Initialize coefficients and intercept
150        self.coefficients = Some(Vector::from_vec(vec![0.0; n_features]));
151        self.intercept = 0.0;
152
153        // Gradient descent optimization
154        for _ in 0..self.max_iter {
155            // Compute predictions (probabilities)
156            let probas = self.predict_proba(x);
157
158            // Compute gradients
159            let mut coef_grad = vec![0.0; n_features];
160            let mut intercept_grad = 0.0;
161
162            for i in 0..n_samples {
163                let error = probas[i] - y[i] as f32;
164                intercept_grad += error;
165                for (j, grad) in coef_grad.iter_mut().enumerate() {
166                    *grad += error * x.get(i, j);
167                }
168            }
169
170            // Average gradients
171            let n = n_samples as f32;
172            intercept_grad /= n;
173            for grad in &mut coef_grad {
174                *grad /= n;
175            }
176
177            // Update parameters
178            self.intercept -= self.learning_rate * intercept_grad;
179            if let Some(ref mut coef) = self.coefficients {
180                for j in 0..n_features {
181                    coef[j] -= self.learning_rate * coef_grad[j];
182                }
183            }
184
185            // Check convergence (simplified - could check gradient norm)
186            if intercept_grad.abs() < self.tol && coef_grad.iter().all(|&g| g.abs() < self.tol) {
187                break;
188            }
189        }
190
191        Ok(())
192    }
193
194    /// Predicts class labels for samples.
195    ///
196    /// Returns 0 or 1 for each sample based on probability threshold of 0.5.
197    pub fn predict(&self, x: &Matrix<f32>) -> Vec<usize> {
198        let probas = self.predict_proba(x);
199        probas
200            .as_slice()
201            .iter()
202            .map(|&p| usize::from(p >= 0.5))
203            .collect()
204    }
205
206    /// Computes accuracy score on test data.
207    ///
208    /// Returns fraction of correctly classified samples.
209    pub fn score(&self, x: &Matrix<f32>, y: &[usize]) -> f32 {
210        let predictions = self.predict(x);
211        let correct = predictions
212            .iter()
213            .zip(y.iter())
214            .filter(|(pred, true_label)| pred == true_label)
215            .count();
216        correct as f32 / y.len() as f32
217    }
218
219    /// Get model coefficients (weights).
220    ///
221    /// # Panics
222    ///
223    /// Panics if the model is not fitted.
224    pub fn coefficients(&self) -> &Vector<f32> {
225        self.coefficients.as_ref().expect("Model not fitted")
226    }
227
228    /// Get intercept (bias) term.
229    pub fn intercept(&self) -> f32 {
230        self.intercept
231    }
232
233    /// Saves the trained model to SafeTensors format.
234    ///
235    /// SafeTensors is an industry-standard model serialization format
236    /// compatible with HuggingFace, Ollama, PyTorch, TensorFlow, and realizar.
237    ///
238    /// # Arguments
239    ///
240    /// * `path` - File path to save the model
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if:
245    /// - Model is not fitted (call `fit()` first)
246    /// - File writing fails
247    /// - Serialization fails
248    ///
249    /// # Example
250    ///
251    /// ```
252    /// use aprender::classification::LogisticRegression;
253    /// use aprender::prelude::*;
254    ///
255    /// let mut model = LogisticRegression::new();
256    /// let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0]).expect("4x2 matrix with 8 values");
257    /// let y = vec![0, 0, 1, 1];
258    /// model.fit(&x, &y).expect("Valid training data");
259    ///
260    /// model.save_safetensors("model.safetensors").expect("Model is fitted and path is writable");
261    /// ```
262    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
263        use crate::serialization::safetensors;
264        use std::collections::BTreeMap;
265
266        // Verify model is fitted
267        let coefficients = self
268            .coefficients
269            .as_ref()
270            .ok_or("Cannot save unfitted model. Call fit() first.")?;
271
272        // Prepare tensors (BTreeMap ensures deterministic ordering)
273        let mut tensors = BTreeMap::new();
274
275        // Coefficients tensor
276        let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
277        let coef_shape = vec![coefficients.len()];
278        tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
279
280        // Intercept tensor
281        let intercept_data = vec![self.intercept];
282        let intercept_shape = vec![1];
283        tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
284
285        // Save to SafeTensors format
286        safetensors::save_safetensors(path, &tensors)?;
287        Ok(())
288    }
289
290    /// Loads a model from SafeTensors format.
291    ///
292    /// # Arguments
293    ///
294    /// * `path` - File path to load the model from
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if:
299    /// - File reading fails
300    /// - SafeTensors format is invalid
301    /// - Required tensors are missing
302    ///
303    /// # Example
304    ///
305    /// ```
306    /// use aprender::classification::LogisticRegression;
307    ///
308    /// # use aprender::prelude::*;
309    /// # let mut model = LogisticRegression::new();
310    /// # let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0]).expect("4x2 matrix with 8 values");
311    /// # let y = vec![0, 0, 1, 1];
312    /// # model.fit(&x, &y).expect("Valid training data");
313    /// # model.save_safetensors("/tmp/doctest_logistic_model.safetensors").expect("Can save to /tmp");
314    /// let loaded_model = LogisticRegression::load_safetensors("/tmp/doctest_logistic_model.safetensors").expect("File exists and is valid SafeTensors format");
315    /// # std::fs::remove_file("/tmp/doctest_logistic_model.safetensors").ok();
316    /// ```
317    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
318        use crate::serialization::safetensors;
319
320        // Load SafeTensors file
321        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
322
323        // Extract coefficients tensor
324        let coef_meta = metadata
325            .get("coefficients")
326            .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
327        let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
328
329        // Extract intercept tensor
330        let intercept_meta = metadata
331            .get("intercept")
332            .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
333        let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
334
335        // Validate intercept shape
336        if intercept_data.len() != 1 {
337            return Err(format!(
338                "Invalid intercept tensor: expected 1 value, got {}",
339                intercept_data.len()
340            ));
341        }
342
343        // Construct model with default hyperparameters
344        // Note: Hyperparameters are not serialized as they're only needed during training
345        Ok(Self {
346            coefficients: Some(Vector::from_vec(coef_data)),
347            intercept: intercept_data[0],
348            learning_rate: 0.01, // Default value
349            max_iter: 1000,      // Default value
350            tol: 1e-4,           // Default value
351        })
352    }
353}
354
355impl Default for LogisticRegression {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361/// Distance metric for K-Nearest Neighbors.
362#[derive(Debug, Clone, Copy, PartialEq)]
363pub enum DistanceMetric {
364    /// Euclidean distance: sqrt(sum((x_i - y_i)^2))
365    Euclidean,
366    /// Manhattan distance: sum(|x_i - y_i|)
367    Manhattan,
368    /// Minkowski distance with parameter p
369    Minkowski(f32),
370}
371
372/// K-Nearest Neighbors classifier.
373///
374/// Instance-based learning algorithm that classifies new samples based on
375/// the k closest training examples in the feature space.
376///
377/// # Example
378///
379/// ```
380/// use aprender::classification::{KNearestNeighbors, DistanceMetric};
381/// use aprender::primitives::Matrix;
382///
383/// let x = Matrix::from_vec(6, 2, vec![
384///     0.0, 0.0,  // class 0
385///     0.0, 1.0,  // class 0
386///     1.0, 0.0,  // class 0
387///     5.0, 5.0,  // class 1
388///     5.0, 6.0,  // class 1
389///     6.0, 5.0,  // class 1
390/// ]).expect("6x2 matrix with 12 values");
391/// let y = vec![0, 0, 0, 1, 1, 1];
392///
393/// let mut knn = KNearestNeighbors::new(3);
394/// knn.fit(&x, &y).expect("Valid training data with 6 samples");
395///
396/// let test = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
397/// let predictions = knn.predict(&test).expect("Predict should succeed");
398/// assert_eq!(predictions[0], 0);  // Closer to class 0
399/// ```
400#[derive(Debug, Clone)]
401pub struct KNearestNeighbors {
402    /// Number of neighbors to use
403    k: usize,
404    /// Distance metric
405    metric: DistanceMetric,
406    /// Whether to use weighted voting (inverse distance)
407    weights: bool,
408    /// Training feature matrix (stored during fit)
409    x_train: Option<Matrix<f32>>,
410    /// Training labels (stored during fit)
411    y_train: Option<Vec<usize>>,
412}
413
414impl KNearestNeighbors {
415    /// Creates a new K-Nearest Neighbors classifier.
416    ///
417    /// # Arguments
418    ///
419    /// * `k` - Number of neighbors to use for voting
420    ///
421    /// # Example
422    ///
423    /// ```
424    /// use aprender::classification::KNearestNeighbors;
425    ///
426    /// let knn = KNearestNeighbors::new(5);
427    /// ```
428    #[must_use]
429    pub fn new(k: usize) -> Self {
430        Self {
431            k,
432            metric: DistanceMetric::Euclidean,
433            weights: false,
434            x_train: None,
435            y_train: None,
436        }
437    }
438
439    /// Sets the distance metric.
440    #[must_use]
441    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
442        self.metric = metric;
443        self
444    }
445
446    /// Enables weighted voting (inverse distance weighting).
447    #[must_use]
448    pub fn with_weights(mut self, weights: bool) -> Self {
449        self.weights = weights;
450        self
451    }
452
453    /// Fits the model by storing the training data.
454    ///
455    /// kNN is a lazy learner - it simply stores the training data
456    /// and defers computation until prediction time.
457    ///
458    /// # Errors
459    ///
460    /// Returns error if data dimensions are invalid.
461    pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
462        let (n_samples, _n_features) = x.shape();
463
464        if n_samples == 0 {
465            return Err("Cannot fit with zero samples".into());
466        }
467
468        if y.len() != n_samples {
469            return Err("Number of samples in X and y must match".into());
470        }
471
472        if self.k > n_samples {
473            return Err("k cannot be larger than number of training samples".into());
474        }
475
476        // Store training data
477        self.x_train = Some(x.clone());
478        self.y_train = Some(y.to_vec());
479
480        Ok(())
481    }
482
483    /// Predicts class labels for samples.
484    ///
485    /// For each test sample, finds the k nearest training samples
486    /// and returns the majority class.
487    ///
488    /// # Errors
489    ///
490    /// Returns error if model is not fitted or dimensions mismatch.
491    pub fn predict(&self, x: &Matrix<f32>) -> Result<Vec<usize>> {
492        let x_train = self.x_train.as_ref().ok_or("Model not fitted")?;
493        let y_train = self.y_train.as_ref().ok_or("Model not fitted")?;
494
495        let (n_samples, n_features) = x.shape();
496        let (_n_train, n_train_features) = x_train.shape();
497
498        if n_features != n_train_features {
499            return Err("Feature dimension mismatch".into());
500        }
501
502        let mut predictions = Vec::with_capacity(n_samples);
503
504        for i in 0..n_samples {
505            // Compute distances to all training samples
506            let mut distances: Vec<(f32, usize)> = Vec::with_capacity(y_train.len());
507
508            for (j, &label) in y_train.iter().enumerate() {
509                let dist = self.compute_distance(x, i, x_train, j, n_features);
510                distances.push((dist, label));
511            }
512
513            // Sort by distance and take k nearest
514            distances.sort_by(|a, b| {
515                a.0.partial_cmp(&b.0)
516                    .expect("Distance values are valid f32 (not NaN)")
517            });
518            let k_nearest = &distances[..self.k];
519
520            // Vote for class
521            let predicted_class = if self.weights {
522                self.weighted_vote(k_nearest)
523            } else {
524                self.majority_vote(k_nearest)
525            };
526
527            predictions.push(predicted_class);
528        }
529
530        Ok(predictions)
531    }
532
533    /// Returns probability estimates for each class.
534    ///
535    /// Probabilities are computed as the proportion of neighbors belonging
536    /// to each class (optionally weighted by inverse distance).
537    ///
538    /// # Errors
539    ///
540    /// Returns error if model is not fitted or dimensions mismatch.
541    pub fn predict_proba(&self, x: &Matrix<f32>) -> Result<Vec<Vec<f32>>> {
542        let x_train = self.x_train.as_ref().ok_or("Model not fitted")?;
543        let y_train = self.y_train.as_ref().ok_or("Model not fitted")?;
544
545        let (n_samples, n_features) = x.shape();
546        let (_n_train, n_train_features) = x_train.shape();
547
548        if n_features != n_train_features {
549            return Err("Feature dimension mismatch".into());
550        }
551
552        // Find number of classes
553        let n_classes = *y_train
554            .iter()
555            .max()
556            .expect("Training labels are non-empty (verified in fit())")
557            + 1;
558
559        let mut probabilities = Vec::with_capacity(n_samples);
560
561        for i in 0..n_samples {
562            // Compute distances to all training samples
563            let mut distances: Vec<(f32, usize)> = Vec::with_capacity(y_train.len());
564
565            for (j, &label) in y_train.iter().enumerate() {
566                let dist = self.compute_distance(x, i, x_train, j, n_features);
567                distances.push((dist, label));
568            }
569
570            // Sort by distance and take k nearest
571            distances.sort_by(|a, b| {
572                a.0.partial_cmp(&b.0)
573                    .expect("Distance values are valid f32 (not NaN)")
574            });
575            let k_nearest = &distances[..self.k];
576
577            // Compute class probabilities
578            let mut class_counts = vec![0.0; n_classes];
579
580            if self.weights {
581                // Weighted by inverse distance
582                for (dist, label) in k_nearest {
583                    let weight = if *dist < 1e-10 { 1.0 } else { 1.0 / dist };
584                    class_counts[*label] += weight;
585                }
586            } else {
587                // Uniform weights
588                for (_dist, label) in k_nearest {
589                    class_counts[*label] += 1.0;
590                }
591            }
592
593            // Normalize to probabilities
594            let total: f32 = class_counts.iter().sum();
595            for count in &mut class_counts {
596                *count /= total;
597            }
598
599            probabilities.push(class_counts);
600        }
601
602        Ok(probabilities)
603    }
604
605    /// Computes distance between two samples.
606    fn compute_distance(
607        &self,
608        x1: &Matrix<f32>,
609        i1: usize,
610        x2: &Matrix<f32>,
611        i2: usize,
612        n_features: usize,
613    ) -> f32 {
614        match self.metric {
615            DistanceMetric::Euclidean => {
616                let mut sum = 0.0;
617                for k in 0..n_features {
618                    let diff = x1.get(i1, k) - x2.get(i2, k);
619                    sum += diff * diff;
620                }
621                sum.sqrt()
622            }
623            DistanceMetric::Manhattan => {
624                let mut sum = 0.0;
625                for k in 0..n_features {
626                    sum += (x1.get(i1, k) - x2.get(i2, k)).abs();
627                }
628                sum
629            }
630            DistanceMetric::Minkowski(p) => {
631                let mut sum = 0.0;
632                for k in 0..n_features {
633                    sum += (x1.get(i1, k) - x2.get(i2, k)).abs().powf(p);
634                }
635                sum.powf(1.0 / p)
636            }
637        }
638    }
639
640    /// Performs majority voting among k nearest neighbors.
641    #[allow(clippy::unused_self)]
642    fn majority_vote(&self, neighbors: &[(f32, usize)]) -> usize {
643        let mut class_counts = std::collections::HashMap::new();
644
645        for (_dist, label) in neighbors {
646            *class_counts.entry(*label).or_insert(0) += 1;
647        }
648
649        *class_counts
650            .iter()
651            .max_by_key(|(_, count)| *count)
652            .map(|(label, _)| label)
653            .expect("Neighbors slice is non-empty (k >= 1)")
654    }
655
656    /// Performs weighted voting (inverse distance weighting).
657    #[allow(clippy::unused_self)]
658    fn weighted_vote(&self, neighbors: &[(f32, usize)]) -> usize {
659        let mut class_weights = std::collections::HashMap::new();
660
661        for (dist, label) in neighbors {
662            let weight = if *dist < 1e-10 { 1.0 } else { 1.0 / dist };
663            *class_weights.entry(*label).or_insert(0.0) += weight;
664        }
665
666        *class_weights
667            .iter()
668            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Weights are valid f32 (not NaN)"))
669            .map(|(label, _)| label)
670            .expect("Neighbors slice is non-empty (k >= 1)")
671    }
672}
673
674/// Gaussian Naive Bayes classifier.
675///
676/// Assumes features follow a Gaussian (normal) distribution within each class.
677/// Uses Bayes' theorem with independence assumption between features.
678///
679/// # Example
680///
681/// ```
682/// use aprender::classification::GaussianNB;
683/// use aprender::primitives::Matrix;
684///
685/// let x = Matrix::from_vec(4, 2, vec![
686///     0.0, 0.0,
687///     0.0, 1.0,
688///     1.0, 0.0,
689///     1.0, 1.0,
690/// ]).expect("4x2 matrix with 8 values");
691/// let y = vec![0, 0, 1, 1];
692///
693/// let mut model = GaussianNB::new();
694/// model.fit(&x, &y).expect("Valid training data");
695/// let predictions = model.predict(&x).expect("Model is fitted");
696/// ```
697#[derive(Debug, Clone)]
698pub struct GaussianNB {
699    /// Class prior probabilities P(y=c)
700    class_priors: Option<Vec<f32>>,
701    /// Feature means per class: means[class][feature]
702    means: Option<Vec<Vec<f32>>>,
703    /// Feature variances per class: variances[class][feature]
704    variances: Option<Vec<Vec<f32>>>,
705    /// Class labels
706    classes: Option<Vec<usize>>,
707    /// Laplace smoothing parameter (var_smoothing)
708    var_smoothing: f32,
709}
710
711impl GaussianNB {
712    /// Creates a new Gaussian Naive Bayes classifier.
713    ///
714    /// # Example
715    ///
716    /// ```
717    /// use aprender::classification::GaussianNB;
718    ///
719    /// let model = GaussianNB::new();
720    /// ```
721    pub fn new() -> Self {
722        Self {
723            class_priors: None,
724            means: None,
725            variances: None,
726            classes: None,
727            var_smoothing: 1e-9,
728        }
729    }
730
731    /// Sets the variance smoothing parameter.
732    ///
733    /// Adds this value to variances to avoid numerical instability.
734    ///
735    /// # Example
736    ///
737    /// ```
738    /// use aprender::classification::GaussianNB;
739    ///
740    /// let model = GaussianNB::new().with_var_smoothing(1e-8);
741    /// ```
742    pub fn with_var_smoothing(mut self, var_smoothing: f32) -> Self {
743        self.var_smoothing = var_smoothing;
744        self
745    }
746
747    /// Trains the Gaussian Naive Bayes classifier.
748    ///
749    /// Computes class priors, feature means, and variances for each class.
750    ///
751    /// # Errors
752    ///
753    /// Returns error if:
754    /// - Sample count mismatch between X and y
755    /// - Empty data
756    /// - Less than 2 classes
757    pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
758        let (n_samples, n_features) = x.shape();
759
760        if n_samples == 0 {
761            return Err("Cannot fit with empty data".into());
762        }
763
764        if y.len() != n_samples {
765            return Err("Number of samples in X and y must match".into());
766        }
767
768        // Find unique classes
769        let mut classes: Vec<usize> = y.to_vec();
770        classes.sort_unstable();
771        classes.dedup();
772
773        if classes.len() < 2 {
774            return Err("Need at least 2 classes".into());
775        }
776
777        let n_classes = classes.len();
778
779        // Initialize storage
780        let mut class_priors = vec![0.0; n_classes];
781        let mut means = vec![vec![0.0; n_features]; n_classes];
782        let mut variances = vec![vec![0.0; n_features]; n_classes];
783
784        // Compute class priors and feature statistics
785        for (class_idx, &class_label) in classes.iter().enumerate() {
786            // Find samples belonging to this class
787            let class_samples: Vec<usize> = y
788                .iter()
789                .enumerate()
790                .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
791                .collect();
792
793            let n_class_samples = class_samples.len() as f32;
794            class_priors[class_idx] = n_class_samples / n_samples as f32;
795
796            // Compute mean for each feature
797            for (feature_idx, mean_val) in means[class_idx].iter_mut().enumerate() {
798                let sum: f32 = class_samples
799                    .iter()
800                    .map(|&sample_idx| x.get(sample_idx, feature_idx))
801                    .sum();
802                *mean_val = sum / n_class_samples;
803            }
804
805            // Compute variance for each feature
806            for (feature_idx, variance_val) in variances[class_idx].iter_mut().enumerate() {
807                let mean = means[class_idx][feature_idx];
808                let sum_sq_diff: f32 = class_samples
809                    .iter()
810                    .map(|&sample_idx| {
811                        let diff = x.get(sample_idx, feature_idx) - mean;
812                        diff * diff
813                    })
814                    .sum();
815                *variance_val = sum_sq_diff / n_class_samples + self.var_smoothing;
816            }
817        }
818
819        self.class_priors = Some(class_priors);
820        self.means = Some(means);
821        self.variances = Some(variances);
822        self.classes = Some(classes);
823
824        Ok(())
825    }
826
827    /// Predicts class labels for samples.
828    ///
829    /// Returns the class with highest posterior probability for each sample.
830    ///
831    /// # Errors
832    ///
833    /// Returns error if model is not fitted or dimension mismatch.
834    pub fn predict(&self, x: &Matrix<f32>) -> Result<Vec<usize>> {
835        let probabilities = self.predict_proba(x)?;
836        let classes = self.classes.as_ref().ok_or("Model not fitted")?;
837
838        let predictions: Vec<usize> = probabilities
839            .iter()
840            .map(|probs| {
841                let max_idx = probs
842                    .iter()
843                    .enumerate()
844                    .max_by(|(_, a), (_, b)| {
845                        a.partial_cmp(b)
846                            .expect("Probabilities are valid f32 (not NaN)")
847                    })
848                    .map(|(idx, _)| idx)
849                    .expect("Probabilities vector is non-empty (n_classes >= 2)");
850                classes[max_idx]
851            })
852            .collect();
853
854        Ok(predictions)
855    }
856
857    /// Returns probability estimates for each class.
858    ///
859    /// Uses Bayes' theorem with Gaussian likelihood:
860    /// P(y=c|X) ∝ P(y=c) * ∏ P(x_i|y=c)
861    ///
862    /// # Errors
863    ///
864    /// Returns error if model is not fitted or dimension mismatch.
865    pub fn predict_proba(&self, x: &Matrix<f32>) -> Result<Vec<Vec<f32>>> {
866        let means = self.means.as_ref().ok_or("Model not fitted")?;
867        let variances = self.variances.as_ref().ok_or("Model not fitted")?;
868        let class_priors = self.class_priors.as_ref().ok_or("Model not fitted")?;
869
870        let (n_samples, n_features) = x.shape();
871        let n_classes = means.len();
872
873        if n_features != means[0].len() {
874            return Err("Feature dimension mismatch".into());
875        }
876
877        let mut probabilities = Vec::with_capacity(n_samples);
878
879        for sample_idx in 0..n_samples {
880            let mut log_probs = vec![0.0; n_classes];
881
882            // Compute log posterior for each class
883            for class_idx in 0..n_classes {
884                // Start with log prior
885                log_probs[class_idx] = class_priors[class_idx].ln();
886
887                // Add log likelihood for each feature (Gaussian PDF)
888                for feature_idx in 0..n_features {
889                    let x_val = x.get(sample_idx, feature_idx);
890                    let mean = means[class_idx][feature_idx];
891                    let variance = variances[class_idx][feature_idx];
892
893                    // Log of Gaussian PDF: -0.5 * log(2π*σ²) - (x-μ)² / (2σ²)
894                    let diff = x_val - mean;
895                    let log_likelihood = -0.5 * (2.0 * std::f32::consts::PI * variance).ln()
896                        - (diff * diff) / (2.0 * variance);
897
898                    log_probs[class_idx] += log_likelihood;
899                }
900            }
901
902            // Convert log probabilities to probabilities using log-sum-exp trick
903            let max_log_prob = log_probs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
904            let exp_probs: Vec<f32> = log_probs
905                .iter()
906                .map(|&log_p| (log_p - max_log_prob).exp())
907                .collect();
908            let sum: f32 = exp_probs.iter().sum();
909            let normalized: Vec<f32> = exp_probs.iter().map(|p| p / sum).collect();
910
911            probabilities.push(normalized);
912        }
913
914        Ok(probabilities)
915    }
916}
917
918impl Default for GaussianNB {
919    fn default() -> Self {
920        Self::new()
921    }
922}
923
924/// Linear Support Vector Machine (SVM) classifier.
925///
926/// Implements binary classification using hinge loss and subgradient descent.
927/// For multi-class problems, use One-vs-Rest strategy.
928///
929/// # Algorithm
930///
931/// Minimizes the objective:
932/// ```text
933/// min  λ||w||² + (1/n) Σᵢ max(0, 1 - yᵢ(w·xᵢ + b))
934/// ```
935///
936/// Where λ = 1/(2nC) controls regularization strength.
937///
938/// # Example
939///
940/// ```ignore
941/// use aprender::classification::LinearSVM;
942/// use aprender::primitives::Matrix;
943///
944/// let x = Matrix::from_vec(4, 2, vec![
945///     0.0, 0.0,
946///     0.0, 1.0,
947///     1.0, 0.0,
948///     1.0, 1.0,
949/// ])?;
950/// let y = vec![0, 0, 1, 1];
951///
952/// let mut svm = LinearSVM::new();
953/// svm.fit(&x, &y)?;
954/// let predictions = svm.predict(&x)?;
955/// ```
956#[derive(Debug, Clone)]
957pub struct LinearSVM {
958    /// Weights for each feature
959    weights: Option<Vec<f32>>,
960    /// Bias term
961    bias: f32,
962    /// Regularization parameter (default: 1.0)
963    /// Larger C means less regularization
964    c: f32,
965    /// Learning rate for subgradient descent (default: 0.01)
966    learning_rate: f32,
967    /// Maximum iterations (default: 1000)
968    max_iter: usize,
969    /// Convergence tolerance (default: 1e-4)
970    tol: f32,
971}
972
973impl LinearSVM {
974    /// Creates a new Linear SVM with default parameters.
975    ///
976    /// # Default Parameters
977    ///
978    /// - C: 1.0 (moderate regularization)
979    /// - learning_rate: 0.01
980    /// - max_iter: 1000
981    /// - tol: 1e-4
982    pub fn new() -> Self {
983        Self {
984            weights: None,
985            bias: 0.0,
986            c: 1.0,
987            learning_rate: 0.01,
988            max_iter: 1000,
989            tol: 1e-4,
990        }
991    }
992
993    /// Sets the regularization parameter C.
994    ///
995    /// Larger C means less regularization (fit data more closely).
996    /// Smaller C means more regularization (simpler model).
997    pub fn with_c(mut self, c: f32) -> Self {
998        self.c = c;
999        self
1000    }
1001
1002    /// Sets the learning rate for subgradient descent.
1003    pub fn with_learning_rate(mut self, learning_rate: f32) -> Self {
1004        self.learning_rate = learning_rate;
1005        self
1006    }
1007
1008    /// Sets the maximum number of iterations.
1009    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1010        self.max_iter = max_iter;
1011        self
1012    }
1013
1014    /// Sets the convergence tolerance.
1015    pub fn with_tolerance(mut self, tol: f32) -> Self {
1016        self.tol = tol;
1017        self
1018    }
1019
1020    /// Trains the Linear SVM on the given data.
1021    ///
1022    /// # Arguments
1023    ///
1024    /// - `x`: Feature matrix (n_samples × n_features)
1025    /// - `y`: Binary labels (0 or 1)
1026    ///
1027    /// # Returns
1028    ///
1029    /// Ok(()) on success, Err with message on failure.
1030    pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
1031        if x.n_rows() != y.len() {
1032            return Err("x and y must have the same number of samples".into());
1033        }
1034
1035        if x.n_rows() == 0 {
1036            return Err("Cannot fit with 0 samples".into());
1037        }
1038
1039        // Convert labels to {-1, +1}
1040        let y_signed: Vec<f32> = y
1041            .iter()
1042            .map(|&label| if label == 0 { -1.0 } else { 1.0 })
1043            .collect();
1044
1045        let n_samples = x.n_rows();
1046        let n_features = x.n_cols();
1047
1048        // Initialize weights and bias
1049        let mut w = vec![0.0; n_features];
1050        let mut b = 0.0;
1051
1052        let lambda = 1.0 / (2.0 * n_samples as f32 * self.c);
1053
1054        // Subgradient descent with learning rate decay
1055        for epoch in 0..self.max_iter {
1056            let eta = self.learning_rate / (1.0 + epoch as f32 * 0.01);
1057            let prev_w = w.clone();
1058            let prev_b = b;
1059
1060            // Iterate over all samples (batch update)
1061            for (i, &y_i) in y_signed.iter().enumerate() {
1062                // Compute decision value: w·x + b
1063                let mut decision = b;
1064                for (j, &w_j) in w.iter().enumerate() {
1065                    decision += w_j * x.get(i, j);
1066                }
1067
1068                // Compute margin: y * (w·x + b)
1069                let margin = y_i * decision;
1070
1071                // Subgradient update
1072                if margin < 1.0 {
1073                    // Misclassified or within margin: update with hinge loss gradient
1074                    for (j, w_j) in w.iter_mut().enumerate() {
1075                        let gradient = 2.0 * lambda * *w_j - y_i * x.get(i, j);
1076                        *w_j -= eta * gradient;
1077                    }
1078                    b += eta * y_i;
1079                } else {
1080                    // Correctly classified outside margin: only regularization gradient
1081                    for w_j in &mut w {
1082                        let gradient = 2.0 * lambda * *w_j;
1083                        *w_j -= eta * gradient;
1084                    }
1085                }
1086            }
1087
1088            // Check convergence (weight change between iterations)
1089            let mut weight_change = 0.0;
1090            for j in 0..n_features {
1091                weight_change += (w[j] - prev_w[j]).powi(2);
1092            }
1093            weight_change += (b - prev_b).powi(2);
1094            weight_change = weight_change.sqrt();
1095
1096            if weight_change < self.tol {
1097                break;
1098            }
1099        }
1100
1101        self.weights = Some(w);
1102        self.bias = b;
1103
1104        Ok(())
1105    }
1106
1107    /// Computes the decision function for the given samples.
1108    ///
1109    /// Returns w·x + b for each sample. Positive values indicate class 1,
1110    /// negative values indicate class 0.
1111    ///
1112    /// # Arguments
1113    ///
1114    /// - `x`: Feature matrix (n_samples × n_features)
1115    ///
1116    /// # Returns
1117    ///
1118    /// Vector of decision values, one per sample.
1119    pub fn decision_function(&self, x: &Matrix<f32>) -> Result<Vec<f32>> {
1120        let weights = self.weights.as_ref().ok_or("Model not trained yet")?;
1121
1122        if x.n_cols() != weights.len() {
1123            return Err("Feature dimension mismatch".into());
1124        }
1125
1126        let mut decisions = Vec::with_capacity(x.n_rows());
1127
1128        for i in 0..x.n_rows() {
1129            let mut decision = self.bias;
1130            for (j, &w_j) in weights.iter().enumerate() {
1131                decision += w_j * x.get(i, j);
1132            }
1133            decisions.push(decision);
1134        }
1135
1136        Ok(decisions)
1137    }
1138
1139    /// Predicts class labels for the given samples.
1140    ///
1141    /// # Arguments
1142    ///
1143    /// - `x`: Feature matrix (n_samples × n_features)
1144    ///
1145    /// # Returns
1146    ///
1147    /// Vector of predicted labels (0 or 1).
1148    pub fn predict(&self, x: &Matrix<f32>) -> Result<Vec<usize>> {
1149        let decisions = self.decision_function(x)?;
1150
1151        Ok(decisions.iter().map(|&d| usize::from(d >= 0.0)).collect())
1152    }
1153}
1154
1155impl Default for LinearSVM {
1156    fn default() -> Self {
1157        Self::new()
1158    }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164
1165    #[test]
1166    fn test_sigmoid() {
1167        assert!((LogisticRegression::sigmoid(0.0) - 0.5).abs() < 1e-6);
1168        assert!(LogisticRegression::sigmoid(10.0) > 0.99);
1169        assert!(LogisticRegression::sigmoid(-10.0) < 0.01);
1170    }
1171
1172    #[test]
1173    fn test_logistic_regression_new() {
1174        let model = LogisticRegression::new();
1175        assert!(model.coefficients.is_none());
1176        assert_eq!(model.intercept, 0.0);
1177    }
1178
1179    #[test]
1180    fn test_logistic_regression_builder() {
1181        let model = LogisticRegression::new()
1182            .with_learning_rate(0.1)
1183            .with_max_iter(500)
1184            .with_tolerance(1e-3);
1185
1186        assert_eq!(model.learning_rate, 0.1);
1187        assert_eq!(model.max_iter, 500);
1188        assert_eq!(model.tol, 1e-3);
1189    }
1190
1191    #[test]
1192    fn test_logistic_regression_fit_simple() {
1193        // Simple linearly separable data
1194        let x = Matrix::from_vec(
1195            4,
1196            2,
1197            vec![
1198                0.0, 0.0, // class 0
1199                0.0, 1.0, // class 0
1200                1.0, 0.0, // class 1
1201                1.0, 1.0, // class 1
1202            ],
1203        )
1204        .expect("4x2 matrix with 8 values");
1205        let y = vec![0, 0, 1, 1];
1206
1207        let mut model = LogisticRegression::new()
1208            .with_learning_rate(0.1)
1209            .with_max_iter(1000);
1210
1211        let result = model.fit(&x, &y);
1212        assert!(result.is_ok());
1213        assert!(model.coefficients.is_some());
1214    }
1215
1216    #[test]
1217    fn test_logistic_regression_predict() {
1218        let x = Matrix::from_vec(
1219            4,
1220            2,
1221            vec![
1222                0.0, 0.0, // class 0
1223                0.0, 1.0, // class 0
1224                1.0, 0.0, // class 1
1225                1.0, 1.0, // class 1
1226            ],
1227        )
1228        .expect("4x2 matrix with 8 values");
1229        let y = vec![0, 0, 1, 1];
1230
1231        let mut model = LogisticRegression::new()
1232            .with_learning_rate(0.1)
1233            .with_max_iter(1000);
1234
1235        model
1236            .fit(&x, &y)
1237            .expect("Training should succeed with valid data");
1238        let predictions = model.predict(&x);
1239
1240        // Should correctly classify training data
1241        assert_eq!(predictions.len(), 4);
1242        for pred in predictions {
1243            assert!(pred == 0 || pred == 1);
1244        }
1245    }
1246
1247    #[test]
1248    fn test_logistic_regression_score() {
1249        let x = Matrix::from_vec(
1250            4,
1251            2,
1252            vec![
1253                0.0, 0.0, // class 0
1254                0.0, 1.0, // class 0
1255                1.0, 0.0, // class 1
1256                1.0, 1.0, // class 1
1257            ],
1258        )
1259        .expect("4x2 matrix with 8 values");
1260        let y = vec![0, 0, 1, 1];
1261
1262        let mut model = LogisticRegression::new()
1263            .with_learning_rate(0.1)
1264            .with_max_iter(1000);
1265
1266        model
1267            .fit(&x, &y)
1268            .expect("Training should succeed with valid data");
1269        let accuracy = model.score(&x, &y);
1270
1271        // Should achieve high accuracy on linearly separable data
1272        assert!(accuracy >= 0.75); // At least 75% accuracy
1273    }
1274
1275    #[test]
1276    fn test_logistic_regression_invalid_labels() {
1277        let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
1278        let y = vec![0, 2]; // Invalid label 2
1279
1280        let mut model = LogisticRegression::new();
1281        let result = model.fit(&x, &y);
1282
1283        assert!(result.is_err());
1284        assert_eq!(
1285            result.expect_err("Should fail with invalid label value"),
1286            "Labels must be 0 or 1 for binary classification"
1287        );
1288    }
1289
1290    #[test]
1291    fn test_logistic_regression_mismatched_samples() {
1292        let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
1293        let y = vec![0]; // Only 1 label for 2 samples
1294
1295        let mut model = LogisticRegression::new();
1296        let result = model.fit(&x, &y);
1297
1298        assert!(result.is_err());
1299        assert_eq!(
1300            result.expect_err("Should fail with mismatched sample counts"),
1301            "Number of samples in X and y must match"
1302        );
1303    }
1304
1305    #[test]
1306    fn test_logistic_regression_zero_samples() {
1307        let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
1308        let y = vec![];
1309
1310        let mut model = LogisticRegression::new();
1311        let result = model.fit(&x, &y);
1312
1313        assert!(result.is_err());
1314        assert_eq!(
1315            result.expect_err("Should fail with zero samples"),
1316            "Cannot fit with zero samples"
1317        );
1318    }
1319
1320    #[test]
1321    fn test_predict_proba() {
1322        let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
1323        let y = vec![0, 1];
1324
1325        let mut model = LogisticRegression::new()
1326            .with_learning_rate(0.1)
1327            .with_max_iter(1000);
1328
1329        model
1330            .fit(&x, &y)
1331            .expect("Training should succeed with valid data");
1332        let probas = model.predict_proba(&x);
1333
1334        assert_eq!(probas.len(), 2);
1335        for &p in probas.as_slice() {
1336            assert!((0.0..=1.0).contains(&p));
1337        }
1338    }
1339
1340    // SafeTensors Serialization Tests
1341    // RED PHASE: These tests will fail until we implement save_safetensors() and load_safetensors()
1342
1343    #[test]
1344    fn test_save_safetensors_unfitted_model() {
1345        // Test 1: Cannot save unfitted model
1346        let model = LogisticRegression::new();
1347        let result = model.save_safetensors("/tmp/test_unfitted_logistic.safetensors");
1348
1349        assert!(result.is_err());
1350        assert!(result
1351            .expect_err("Should fail when saving unfitted model")
1352            .contains("unfitted"));
1353    }
1354
1355    #[test]
1356    fn test_save_load_safetensors_roundtrip() {
1357        // Test 2: Save and load preserves model state
1358        let x = Matrix::from_vec(
1359            4,
1360            2,
1361            vec![
1362                0.0, 0.0, // class 0
1363                0.0, 1.0, // class 0
1364                1.0, 0.0, // class 1
1365                1.0, 1.0, // class 1
1366            ],
1367        )
1368        .expect("4x2 matrix with 8 values");
1369        let y = vec![0, 0, 1, 1];
1370
1371        // Train model
1372        let mut model = LogisticRegression::new()
1373            .with_learning_rate(0.1)
1374            .with_max_iter(1000);
1375        model
1376            .fit(&x, &y)
1377            .expect("Training should succeed with valid data");
1378
1379        // Save model
1380        let path = "/tmp/test_logistic_roundtrip.safetensors";
1381        model
1382            .save_safetensors(path)
1383            .expect("Should save fitted model to valid path");
1384
1385        // Load model
1386        let loaded =
1387            LogisticRegression::load_safetensors(path).expect("Should load valid SafeTensors file");
1388
1389        // Verify coefficients match
1390        assert_eq!(
1391            model
1392                .coefficients
1393                .as_ref()
1394                .expect("Model is fitted and has coefficients")
1395                .len(),
1396            loaded
1397                .coefficients
1398                .as_ref()
1399                .expect("Loaded model has coefficients")
1400                .len()
1401        );
1402        for i in 0..model
1403            .coefficients
1404            .as_ref()
1405            .expect("Model has coefficients")
1406            .len()
1407        {
1408            assert_eq!(
1409                model.coefficients.as_ref().expect("Model has coefficients")[i],
1410                loaded
1411                    .coefficients
1412                    .as_ref()
1413                    .expect("Loaded model has coefficients")[i]
1414            );
1415        }
1416        assert_eq!(model.intercept, loaded.intercept);
1417
1418        // Verify predictions match
1419        let predictions_original = model.predict(&x);
1420        let predictions_loaded = loaded.predict(&x);
1421        assert_eq!(predictions_original, predictions_loaded);
1422
1423        // Cleanup
1424        std::fs::remove_file(path).ok();
1425    }
1426
1427    #[test]
1428    fn test_load_safetensors_corrupted_file() {
1429        // Test 3: Loading corrupted file fails gracefully
1430        let path = "/tmp/test_corrupted_logistic.safetensors";
1431        std::fs::write(path, b"CORRUPTED DATA").expect("Should write test file");
1432
1433        let result = LogisticRegression::load_safetensors(path);
1434        assert!(result.is_err());
1435
1436        std::fs::remove_file(path).ok();
1437    }
1438
1439    #[test]
1440    fn test_load_safetensors_missing_file() {
1441        // Test 4: Loading missing file fails with clear error
1442        let result =
1443            LogisticRegression::load_safetensors("/tmp/nonexistent_logistic_xyz.safetensors");
1444        assert!(result.is_err());
1445        let err = result.expect_err("Should fail when loading nonexistent file");
1446        assert!(
1447            err.contains("No such file") || err.contains("not found"),
1448            "Error should mention file not found: {err}"
1449        );
1450    }
1451
1452    #[test]
1453    fn test_safetensors_preserves_probabilities() {
1454        // Test 5: Probabilities are identical after save/load
1455        let x = Matrix::from_vec(
1456            4,
1457            2,
1458            vec![
1459                0.0, 0.0, // class 0
1460                0.0, 1.0, // class 0
1461                1.0, 0.0, // class 1
1462                1.0, 1.0, // class 1
1463            ],
1464        )
1465        .expect("4x2 matrix with 8 values");
1466        let y = vec![0, 0, 1, 1];
1467
1468        let mut model = LogisticRegression::new()
1469            .with_learning_rate(0.1)
1470            .with_max_iter(1000);
1471        model
1472            .fit(&x, &y)
1473            .expect("Training should succeed with valid data");
1474
1475        let probas_before = model.predict_proba(&x);
1476
1477        // Save and load
1478        let path = "/tmp/test_logistic_probas.safetensors";
1479        model
1480            .save_safetensors(path)
1481            .expect("Should save fitted model to valid path");
1482        let loaded =
1483            LogisticRegression::load_safetensors(path).expect("Should load valid SafeTensors file");
1484
1485        let probas_after = loaded.predict_proba(&x);
1486
1487        // Verify probabilities match exactly
1488        assert_eq!(probas_before.len(), probas_after.len());
1489        for i in 0..probas_before.len() {
1490            assert_eq!(probas_before[i], probas_after[i]);
1491        }
1492
1493        std::fs::remove_file(path).ok();
1494    }
1495
1496    // K-Nearest Neighbors tests
1497    #[test]
1498    fn test_knn_new() {
1499        let knn = KNearestNeighbors::new(3);
1500        assert_eq!(knn.k, 3);
1501        assert_eq!(knn.metric, DistanceMetric::Euclidean);
1502        assert!(!knn.weights);
1503    }
1504
1505    #[test]
1506    fn test_knn_basic_fit_predict() {
1507        // Simple 2-class problem
1508        let x = Matrix::from_vec(
1509            6,
1510            2,
1511            vec![
1512                0.0, 0.0, // class 0
1513                0.0, 1.0, // class 0
1514                1.0, 0.0, // class 0
1515                5.0, 5.0, // class 1
1516                5.0, 6.0, // class 1
1517                6.0, 5.0, // class 1
1518            ],
1519        )
1520        .expect("6x2 matrix with 12 values");
1521        let y = vec![0, 0, 0, 1, 1, 1];
1522
1523        let mut knn = KNearestNeighbors::new(3);
1524        knn.fit(&x, &y)
1525            .expect("Training should succeed with valid data");
1526
1527        // Test point close to class 0
1528        let test1 = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1529        let pred1 = knn.predict(&test1).expect("Prediction should succeed");
1530        assert_eq!(pred1[0], 0);
1531
1532        // Test point close to class 1
1533        let test2 = Matrix::from_vec(1, 2, vec![5.5, 5.5]).expect("1x2 test matrix");
1534        let pred2 = knn.predict(&test2).expect("Prediction should succeed");
1535        assert_eq!(pred2[0], 1);
1536    }
1537
1538    #[test]
1539    fn test_knn_k_equals_one() {
1540        // With k=1, should predict nearest neighbor exactly
1541        let x = Matrix::from_vec(
1542            4,
1543            2,
1544            vec![
1545                0.0, 0.0, // class 0
1546                1.0, 1.0, // class 1
1547                2.0, 2.0, // class 0
1548                3.0, 3.0, // class 1
1549            ],
1550        )
1551        .expect("4x2 matrix with 8 values");
1552        let y = vec![0, 1, 0, 1];
1553
1554        let mut knn = KNearestNeighbors::new(1);
1555        knn.fit(&x, &y)
1556            .expect("Training should succeed with valid data");
1557
1558        // Predict on training data - should be perfect
1559        let predictions = knn.predict(&x).expect("Prediction should succeed");
1560        assert_eq!(predictions, y);
1561    }
1562
1563    #[test]
1564    fn test_knn_euclidean_distance() {
1565        let x = Matrix::from_vec(
1566            3,
1567            2,
1568            vec![
1569                0.0, 0.0, // class 0
1570                3.0, 4.0, // class 1 (distance 5.0 from origin)
1571                1.0, 1.0, // class 0
1572            ],
1573        )
1574        .expect("3x2 matrix with 6 values");
1575        let y = vec![0, 1, 0];
1576
1577        let mut knn = KNearestNeighbors::new(1).with_metric(DistanceMetric::Euclidean);
1578        knn.fit(&x, &y)
1579            .expect("Training should succeed with valid data");
1580
1581        // Test point at (1.5, 2.0) - closer to (1, 1) than (3, 4)
1582        let test = Matrix::from_vec(1, 2, vec![1.5, 2.0]).expect("1x2 test matrix");
1583        let pred = knn.predict(&test).expect("Prediction should succeed");
1584        assert_eq!(pred[0], 0);
1585    }
1586
1587    #[test]
1588    fn test_knn_manhattan_distance() {
1589        let x = Matrix::from_vec(
1590            3,
1591            2,
1592            vec![
1593                0.0, 0.0, // class 0
1594                2.0, 2.0, // class 1 (Manhattan distance 4.0 from origin)
1595                1.0, 0.0, // class 0
1596            ],
1597        )
1598        .expect("3x2 matrix with 6 values");
1599        let y = vec![0, 1, 0];
1600
1601        let mut knn = KNearestNeighbors::new(1).with_metric(DistanceMetric::Manhattan);
1602        knn.fit(&x, &y)
1603            .expect("Training should succeed with valid data");
1604
1605        let test = Matrix::from_vec(1, 2, vec![0.5, 0.0]).expect("1x2 test matrix");
1606        let pred = knn.predict(&test).expect("Prediction should succeed");
1607        assert_eq!(pred[0], 0); // Closer to (1, 0)
1608    }
1609
1610    #[test]
1611    fn test_knn_minkowski_distance() {
1612        let x = Matrix::from_vec(
1613            3,
1614            2,
1615            vec![
1616                0.0, 0.0, // class 0
1617                3.0, 4.0, // class 1
1618                1.0, 1.0, // class 0
1619            ],
1620        )
1621        .expect("3x2 matrix with 6 values");
1622        let y = vec![0, 1, 0];
1623
1624        // Minkowski with p=3
1625        let mut knn = KNearestNeighbors::new(1).with_metric(DistanceMetric::Minkowski(3.0));
1626        knn.fit(&x, &y)
1627            .expect("Training should succeed with valid data");
1628
1629        let test = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1630        let pred = knn.predict(&test).expect("Prediction should succeed");
1631        assert_eq!(pred[0], 0);
1632    }
1633
1634    #[test]
1635    fn test_knn_weighted_voting() {
1636        // Set up data where uniform voting gives different result than weighted
1637        let x = Matrix::from_vec(
1638            5,
1639            1,
1640            vec![
1641                0.0, // class 0
1642                0.1, // class 0
1643                5.0, // class 1
1644                5.5, // class 1
1645                6.0, // class 1
1646            ],
1647        )
1648        .expect("5x1 matrix with 5 values");
1649        let y = vec![0, 0, 1, 1, 1];
1650
1651        let mut knn_weighted = KNearestNeighbors::new(3).with_weights(true);
1652        knn_weighted
1653            .fit(&x, &y)
1654            .expect("Training should succeed with valid data");
1655
1656        // Test point at 0.05 - very close to class 0
1657        let test = Matrix::from_vec(1, 1, vec![0.05]).expect("1x1 test matrix");
1658        let pred = knn_weighted
1659            .predict(&test)
1660            .expect("Prediction should succeed");
1661        assert_eq!(pred[0], 0); // Should be class 0 due to proximity weighting
1662    }
1663
1664    #[test]
1665    fn test_knn_predict_proba() {
1666        let x = Matrix::from_vec(
1667            6,
1668            2,
1669            vec![
1670                0.0, 0.0, // class 0
1671                0.0, 1.0, // class 0
1672                1.0, 0.0, // class 0
1673                5.0, 5.0, // class 1
1674                5.0, 6.0, // class 1
1675                6.0, 5.0, // class 1
1676            ],
1677        )
1678        .expect("6x2 matrix with 12 values");
1679        let y = vec![0, 0, 0, 1, 1, 1];
1680
1681        let mut knn = KNearestNeighbors::new(3);
1682        knn.fit(&x, &y)
1683            .expect("Training should succeed with valid data");
1684
1685        let test = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1686        let probas = knn
1687            .predict_proba(&test)
1688            .expect("Probability prediction should succeed");
1689
1690        assert_eq!(probas.len(), 1);
1691        assert_eq!(probas[0].len(), 2); // 2 classes
1692
1693        // Probabilities should sum to 1.0
1694        let sum: f32 = probas[0].iter().sum();
1695        assert!((sum - 1.0).abs() < 1e-6);
1696
1697        // Point closer to class 0 should have higher probability for class 0
1698        assert!(probas[0][0] > probas[0][1]);
1699    }
1700
1701    #[test]
1702    fn test_knn_multiclass() {
1703        // 3-class problem
1704        let x = Matrix::from_vec(
1705            9,
1706            2,
1707            vec![
1708                0.0, 0.0, // class 0
1709                0.0, 1.0, // class 0
1710                1.0, 0.0, // class 0
1711                5.0, 5.0, // class 1
1712                5.0, 6.0, // class 1
1713                6.0, 5.0, // class 1
1714                10.0, 10.0, // class 2
1715                10.0, 11.0, // class 2
1716                11.0, 10.0, // class 2
1717            ],
1718        )
1719        .expect("9x2 matrix with 18 values");
1720        let y = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1721
1722        let mut knn = KNearestNeighbors::new(3);
1723        knn.fit(&x, &y)
1724            .expect("Training should succeed with valid data");
1725
1726        // Test each cluster
1727        let test1 = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1728        assert_eq!(
1729            knn.predict(&test1).expect("Prediction should succeed")[0],
1730            0
1731        );
1732
1733        let test2 = Matrix::from_vec(1, 2, vec![5.5, 5.5]).expect("1x2 test matrix");
1734        assert_eq!(
1735            knn.predict(&test2).expect("Prediction should succeed")[0],
1736            1
1737        );
1738
1739        let test3 = Matrix::from_vec(1, 2, vec![10.5, 10.5]).expect("1x2 test matrix");
1740        assert_eq!(
1741            knn.predict(&test3).expect("Prediction should succeed")[0],
1742            2
1743        );
1744    }
1745
1746    #[test]
1747    fn test_knn_not_fitted_error() {
1748        let knn = KNearestNeighbors::new(3);
1749        let test = Matrix::from_vec(1, 2, vec![0.0, 0.0]).expect("1x2 test matrix");
1750
1751        let result = knn.predict(&test);
1752        assert!(result.is_err());
1753        assert_eq!(
1754            result.expect_err("Should fail when predicting with unfitted model"),
1755            "Model not fitted"
1756        );
1757    }
1758
1759    #[test]
1760    fn test_knn_dimension_mismatch() {
1761        let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
1762            .expect("3x2 matrix with 6 values");
1763        let y = vec![0, 1, 0];
1764
1765        let mut knn = KNearestNeighbors::new(1);
1766        knn.fit(&x, &y)
1767            .expect("Training should succeed with valid data");
1768
1769        // Test with wrong number of features
1770        let test = Matrix::from_vec(1, 3, vec![0.0, 0.0, 0.0]).expect("1x3 test matrix");
1771        let result = knn.predict(&test);
1772
1773        assert!(result.is_err());
1774        assert_eq!(
1775            result.expect_err("Should fail with dimension mismatch"),
1776            "Feature dimension mismatch"
1777        );
1778    }
1779
1780    #[test]
1781    fn test_knn_sample_mismatch() {
1782        let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
1783            .expect("3x2 matrix with 6 values");
1784        let y = vec![0, 1]; // Wrong length
1785
1786        let mut knn = KNearestNeighbors::new(1);
1787        let result = knn.fit(&x, &y);
1788
1789        assert!(result.is_err());
1790        assert_eq!(
1791            result.expect_err("Should fail with sample mismatch"),
1792            "Number of samples in X and y must match"
1793        );
1794    }
1795
1796    #[test]
1797    fn test_knn_k_too_large() {
1798        let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
1799            .expect("3x2 matrix with 6 values");
1800        let y = vec![0, 1, 0];
1801
1802        let mut knn = KNearestNeighbors::new(5); // k > n_samples
1803        let result = knn.fit(&x, &y);
1804
1805        assert!(result.is_err());
1806        assert_eq!(
1807            result.expect_err("Should fail when k exceeds sample count"),
1808            "k cannot be larger than number of training samples"
1809        );
1810    }
1811
1812    #[test]
1813    fn test_knn_empty_data() {
1814        let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
1815        let y = vec![];
1816
1817        let mut knn = KNearestNeighbors::new(1);
1818        let result = knn.fit(&x, &y);
1819
1820        assert!(result.is_err());
1821        assert_eq!(
1822            result.expect_err("Should fail with empty data"),
1823            "Cannot fit with zero samples"
1824        );
1825    }
1826
1827    #[test]
1828    fn test_knn_builder_pattern() {
1829        let knn = KNearestNeighbors::new(5)
1830            .with_metric(DistanceMetric::Manhattan)
1831            .with_weights(true);
1832
1833        assert_eq!(knn.k, 5);
1834        assert_eq!(knn.metric, DistanceMetric::Manhattan);
1835        assert!(knn.weights);
1836    }
1837
1838    #[test]
1839    fn test_knn_distance_symmetry() {
1840        // Property test: distance(a, b) == distance(b, a)
1841        let x = Matrix::from_vec(
1842            2,
1843            2,
1844            vec![
1845                1.0, 2.0, // point a
1846                3.0, 4.0, // point b
1847            ],
1848        )
1849        .expect("2x2 matrix with 4 values");
1850        let y = vec![0, 1];
1851
1852        let mut knn = KNearestNeighbors::new(1);
1853        knn.fit(&x, &y)
1854            .expect("Training should succeed with valid data");
1855
1856        // Compute both directions
1857        let dist_ab = knn.compute_distance(&x, 0, &x, 1, 2);
1858        let dist_ba = knn.compute_distance(&x, 1, &x, 0, 2);
1859
1860        assert!((dist_ab - dist_ba).abs() < 1e-6);
1861    }
1862
1863    #[test]
1864    fn test_knn_perfect_fit_with_k1() {
1865        // Property test: k=1 on training data gives perfect predictions
1866        let x = Matrix::from_vec(
1867            10,
1868            3,
1869            vec![
1870                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
1871                7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0,
1872            ],
1873        )
1874        .expect("10x3 matrix with 30 values");
1875        let y = vec![0, 0, 1, 1, 0, 1, 0, 1, 0, 1];
1876
1877        let mut knn = KNearestNeighbors::new(1);
1878        knn.fit(&x, &y)
1879            .expect("Training should succeed with valid data");
1880
1881        let predictions = knn.predict(&x).expect("Prediction should succeed");
1882        assert_eq!(predictions, y);
1883    }
1884
1885    // ========== Gaussian Naive Bayes Tests ==========
1886
1887    #[test]
1888    fn test_gaussian_nb_new() {
1889        let model = GaussianNB::new();
1890        assert!(model.class_priors.is_none());
1891        assert!(model.means.is_none());
1892        assert!(model.variances.is_none());
1893        assert_eq!(model.var_smoothing, 1e-9);
1894    }
1895
1896    #[test]
1897    fn test_gaussian_nb_builder() {
1898        let model = GaussianNB::new().with_var_smoothing(1e-8);
1899        assert_eq!(model.var_smoothing, 1e-8);
1900    }
1901
1902    #[test]
1903    fn test_gaussian_nb_basic_fit_predict() {
1904        // Simple 2-class problem: class 0 at (0,0), class 1 at (1,1)
1905        let x = Matrix::from_vec(
1906            4,
1907            2,
1908            vec![
1909                0.0, 0.0, // class 0
1910                0.1, 0.1, // class 0
1911                1.0, 1.0, // class 1
1912                0.9, 0.9, // class 1
1913            ],
1914        )
1915        .expect("4x2 matrix with 8 values");
1916        let y = vec![0, 0, 1, 1];
1917
1918        let mut model = GaussianNB::new();
1919        model
1920            .fit(&x, &y)
1921            .expect("Training should succeed with valid data");
1922
1923        let predictions = model.predict(&x).expect("Prediction should succeed");
1924        assert_eq!(predictions, y);
1925    }
1926
1927    #[test]
1928    fn test_gaussian_nb_multiclass() {
1929        // 3-class problem
1930        let x = Matrix::from_vec(
1931            9,
1932            2,
1933            vec![
1934                0.0, 0.0, // class 0
1935                0.1, 0.1, // class 0
1936                0.0, 0.1, // class 0
1937                5.0, 5.0, // class 1
1938                5.1, 5.1, // class 1
1939                5.0, 5.1, // class 1
1940                -5.0, -5.0, // class 2
1941                -5.1, -5.1, // class 2
1942                -5.0, -5.1, // class 2
1943            ],
1944        )
1945        .expect("9x2 matrix with 18 values");
1946        let y = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1947
1948        let mut model = GaussianNB::new();
1949        model
1950            .fit(&x, &y)
1951            .expect("Training should succeed with valid data");
1952
1953        let predictions = model.predict(&x).expect("Prediction should succeed");
1954        assert_eq!(predictions, y);
1955    }
1956
1957    #[test]
1958    fn test_gaussian_nb_predict_proba() {
1959        let x = Matrix::from_vec(
1960            4,
1961            2,
1962            vec![
1963                0.0, 0.0, // class 0
1964                0.1, 0.1, // class 0
1965                1.0, 1.0, // class 1
1966                0.9, 0.9, // class 1
1967            ],
1968        )
1969        .expect("4x2 matrix with 8 values");
1970        let y = vec![0, 0, 1, 1];
1971
1972        let mut model = GaussianNB::new();
1973        model
1974            .fit(&x, &y)
1975            .expect("Training should succeed with valid data");
1976
1977        let probabilities = model
1978            .predict_proba(&x)
1979            .expect("Probability prediction should succeed");
1980
1981        // Check all samples have probabilities
1982        assert_eq!(probabilities.len(), 4);
1983
1984        // Check probabilities sum to 1
1985        for probs in &probabilities {
1986            assert_eq!(probs.len(), 2);
1987            let sum: f32 = probs.iter().sum();
1988            assert!((sum - 1.0).abs() < 1e-5);
1989        }
1990
1991        // Check first sample (class 0) has high probability for class 0
1992        assert!(probabilities[0][0] > 0.5);
1993
1994        // Check last sample (class 1) has high probability for class 1
1995        assert!(probabilities[3][1] > 0.5);
1996    }
1997
1998    #[test]
1999    fn test_gaussian_nb_not_fitted_error() {
2000        let model = GaussianNB::new();
2001        let x_test = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 test matrix");
2002
2003        let result = model.predict(&x_test);
2004        assert!(result.is_err());
2005        assert_eq!(
2006            result.expect_err("Should fail when predicting with unfitted model"),
2007            "Model not fitted"
2008        );
2009    }
2010
2011    #[test]
2012    fn test_gaussian_nb_empty_data() {
2013        let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
2014        let y: Vec<usize> = vec![];
2015
2016        let mut model = GaussianNB::new();
2017        let result = model.fit(&x, &y);
2018
2019        assert!(result.is_err());
2020        assert_eq!(
2021            result.expect_err("Should fail with empty data"),
2022            "Cannot fit with empty data"
2023        );
2024    }
2025
2026    #[test]
2027    fn test_gaussian_nb_sample_mismatch() {
2028        let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
2029            .expect("3x2 matrix with 6 values");
2030        let y = vec![0, 1]; // Wrong length
2031
2032        let mut model = GaussianNB::new();
2033        let result = model.fit(&x, &y);
2034
2035        assert!(result.is_err());
2036        assert_eq!(
2037            result.expect_err("Should fail with sample mismatch"),
2038            "Number of samples in X and y must match"
2039        );
2040    }
2041
2042    #[test]
2043    fn test_gaussian_nb_single_class() {
2044        let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
2045            .expect("3x2 matrix with 6 values");
2046        let y = vec![0, 0, 0]; // All same class
2047
2048        let mut model = GaussianNB::new();
2049        let result = model.fit(&x, &y);
2050
2051        assert!(result.is_err());
2052        assert_eq!(
2053            result.expect_err("Should fail with single class"),
2054            "Need at least 2 classes"
2055        );
2056    }
2057
2058    #[test]
2059    fn test_gaussian_nb_dimension_mismatch() {
2060        let x_train = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.1, 0.1, 1.0, 1.0, 0.9, 0.9])
2061            .expect("4x2 training matrix");
2062        let y_train = vec![0, 0, 1, 1];
2063
2064        let mut model = GaussianNB::new();
2065        model
2066            .fit(&x_train, &y_train)
2067            .expect("Training should succeed with valid data");
2068
2069        let x_test =
2070            Matrix::from_vec(2, 3, vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]).expect("2x3 test matrix");
2071        let result = model.predict(&x_test);
2072
2073        assert!(result.is_err());
2074        assert_eq!(
2075            result.expect_err("Should fail with dimension mismatch"),
2076            "Feature dimension mismatch"
2077        );
2078    }
2079
2080    #[test]
2081    fn test_gaussian_nb_balanced_classes() {
2082        // Equal number of samples per class
2083        let x = Matrix::from_vec(
2084            6,
2085            2,
2086            vec![
2087                0.0, 0.0, // class 0
2088                0.1, 0.1, // class 0
2089                0.2, 0.2, // class 0
2090                1.0, 1.0, // class 1
2091                1.1, 1.1, // class 1
2092                1.2, 1.2, // class 1
2093            ],
2094        )
2095        .expect("6x2 matrix with 12 values");
2096        let y = vec![0, 0, 0, 1, 1, 1];
2097
2098        let mut model = GaussianNB::new();
2099        model
2100            .fit(&x, &y)
2101            .expect("Training should succeed with valid data");
2102
2103        // Check class priors are equal
2104        let priors = model
2105            .class_priors
2106            .expect("Model is fitted and has class priors");
2107        assert!((priors[0] - 0.5).abs() < 1e-5);
2108        assert!((priors[1] - 0.5).abs() < 1e-5);
2109    }
2110
2111    #[test]
2112    fn test_gaussian_nb_imbalanced_classes() {
2113        // Imbalanced: 1 sample class 0, 3 samples class 1
2114        let x = Matrix::from_vec(
2115            4,
2116            2,
2117            vec![
2118                0.0, 0.0, // class 0
2119                1.0, 1.0, // class 1
2120                1.1, 1.1, // class 1
2121                1.2, 1.2, // class 1
2122            ],
2123        )
2124        .expect("4x2 matrix with 8 values");
2125        let y = vec![0, 1, 1, 1];
2126
2127        let mut model = GaussianNB::new();
2128        model
2129            .fit(&x, &y)
2130            .expect("Training should succeed with valid data");
2131
2132        // Check class priors reflect imbalance
2133        let priors = model
2134            .class_priors
2135            .expect("Model is fitted and has class priors");
2136        assert!((priors[0] - 0.25).abs() < 1e-5); // 1/4
2137        assert!((priors[1] - 0.75).abs() < 1e-5); // 3/4
2138    }
2139
2140    #[test]
2141    fn test_gaussian_nb_var_smoothing() {
2142        // Test that variance smoothing prevents division by zero
2143        let x = Matrix::from_vec(
2144            4,
2145            2,
2146            vec![
2147                0.0, 0.0, // class 0 - identical points
2148                0.0, 0.0, // class 0 - identical points
2149                1.0, 1.0, // class 1 - identical points
2150                1.0, 1.0, // class 1 - identical points
2151            ],
2152        )
2153        .expect("4x2 matrix with 8 values");
2154        let y = vec![0, 0, 1, 1];
2155
2156        let mut model = GaussianNB::new().with_var_smoothing(1e-8);
2157        model
2158            .fit(&x, &y)
2159            .expect("Training should succeed with valid data");
2160
2161        // Should not panic or produce NaN/Inf
2162        let predictions = model.predict(&x).expect("Prediction should succeed");
2163        assert_eq!(predictions, y);
2164
2165        let probabilities = model
2166            .predict_proba(&x)
2167            .expect("Probability prediction should succeed");
2168        for probs in &probabilities {
2169            for &p in probs {
2170                assert!(p.is_finite());
2171                assert!((0.0..=1.0).contains(&p));
2172            }
2173        }
2174    }
2175
2176    #[test]
2177    fn test_gaussian_nb_probabilities_sum_to_one() {
2178        // Property test: probabilities must sum to 1
2179        let x = Matrix::from_vec(
2180            10,
2181            3,
2182            vec![
2183                0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 1.0, 1.0, 1.0, 1.1,
2184                1.1, 1.1, 1.2, 1.2, 1.2, 1.3, 1.3, 1.3, 2.0, 2.0, 2.0, 2.1, 2.1, 2.1,
2185            ],
2186        )
2187        .expect("10x3 matrix with 30 values");
2188        let y = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
2189
2190        let mut model = GaussianNB::new();
2191        model
2192            .fit(&x, &y)
2193            .expect("Training should succeed with valid data");
2194
2195        let probabilities = model
2196            .predict_proba(&x)
2197            .expect("Probability prediction should succeed");
2198
2199        for probs in &probabilities {
2200            let sum: f32 = probs.iter().sum();
2201            assert!((sum - 1.0).abs() < 1e-5);
2202        }
2203    }
2204
2205    #[test]
2206    fn test_gaussian_nb_default() {
2207        let model1 = GaussianNB::new();
2208        let model2 = GaussianNB::default();
2209
2210        assert_eq!(model1.var_smoothing, model2.var_smoothing);
2211    }
2212
2213    #[test]
2214    fn test_gaussian_nb_class_separation() {
2215        // Well-separated classes should have high confidence
2216        let x = Matrix::from_vec(
2217            4,
2218            2,
2219            vec![
2220                0.0, 0.0, // class 0
2221                0.1, 0.1, // class 0
2222                10.0, 10.0, // class 1 (far away)
2223                10.1, 10.1, // class 1 (far away)
2224            ],
2225        )
2226        .expect("4x2 matrix with 8 values");
2227        let y = vec![0, 0, 1, 1];
2228
2229        let mut model = GaussianNB::new();
2230        model
2231            .fit(&x, &y)
2232            .expect("Training should succeed with valid data");
2233
2234        let probabilities = model
2235            .predict_proba(&x)
2236            .expect("Probability prediction should succeed");
2237
2238        // First sample should have very high confidence for class 0
2239        assert!(probabilities[0][0] > 0.99);
2240
2241        // Last sample should have very high confidence for class 1
2242        assert!(probabilities[3][1] > 0.99);
2243    }
2244
2245    // ===== LinearSVM Tests =====
2246
2247    #[test]
2248    fn test_linear_svm_new() {
2249        let svm = LinearSVM::new();
2250        assert!(svm.weights.is_none());
2251        assert_eq!(svm.bias, 0.0);
2252        assert_eq!(svm.c, 1.0);
2253        assert_eq!(svm.learning_rate, 0.01);
2254        assert_eq!(svm.max_iter, 1000);
2255        assert_eq!(svm.tol, 1e-4);
2256    }
2257
2258    #[test]
2259    fn test_linear_svm_builder() {
2260        let svm = LinearSVM::new()
2261            .with_c(0.5)
2262            .with_learning_rate(0.001)
2263            .with_max_iter(500)
2264            .with_tolerance(1e-5);
2265
2266        assert_eq!(svm.c, 0.5);
2267        assert_eq!(svm.learning_rate, 0.001);
2268        assert_eq!(svm.max_iter, 500);
2269        assert_eq!(svm.tol, 1e-5);
2270    }
2271
2272    #[test]
2273    fn test_linear_svm_fit_simple() {
2274        // Simple linearly separable data
2275        let x = Matrix::from_vec(
2276            4,
2277            2,
2278            vec![
2279                0.0, 0.0, // class 0
2280                0.0, 1.0, // class 0
2281                1.0, 0.0, // class 1
2282                1.0, 1.0, // class 1
2283            ],
2284        )
2285        .expect("4x2 matrix with 8 values");
2286        let y = vec![0, 0, 1, 1];
2287
2288        let mut svm = LinearSVM::new().with_max_iter(1000).with_learning_rate(0.1);
2289
2290        let result = svm.fit(&x, &y);
2291        assert!(result.is_ok());
2292        assert!(svm.weights.is_some());
2293    }
2294
2295    #[test]
2296    fn test_linear_svm_predict_simple() {
2297        // Simple linearly separable data
2298        let x = Matrix::from_vec(
2299            4,
2300            2,
2301            vec![
2302                0.0, 0.0, // class 0
2303                0.0, 1.0, // class 0
2304                1.0, 0.0, // class 1
2305                1.0, 1.0, // class 1
2306            ],
2307        )
2308        .expect("4x2 matrix with 8 values");
2309        let y = vec![0, 0, 1, 1];
2310
2311        let mut svm = LinearSVM::new().with_max_iter(1000).with_learning_rate(0.1);
2312        svm.fit(&x, &y)
2313            .expect("Training should succeed with valid data");
2314
2315        let predictions = svm.predict(&x).expect("Prediction should succeed");
2316        assert_eq!(predictions.len(), 4);
2317
2318        // Should classify correctly (or close to it)
2319        let correct = predictions
2320            .iter()
2321            .zip(y.iter())
2322            .filter(|(pred, true_label)| *pred == *true_label)
2323            .count();
2324
2325        // Should get at least 3 out of 4 correct for simple case
2326        assert!(correct >= 3);
2327    }
2328
2329    #[test]
2330    fn test_linear_svm_decision_function() {
2331        let x = Matrix::from_vec(
2332            4,
2333            2,
2334            vec![
2335                0.0, 0.0, // class 0
2336                0.0, 1.0, // class 0
2337                1.0, 0.0, // class 1
2338                1.0, 1.0, // class 1
2339            ],
2340        )
2341        .expect("4x2 matrix with 8 values");
2342        let y = vec![0, 0, 1, 1];
2343
2344        let mut svm = LinearSVM::new().with_max_iter(1000).with_learning_rate(0.1);
2345        svm.fit(&x, &y)
2346            .expect("Training should succeed with valid data");
2347
2348        let decisions = svm
2349            .decision_function(&x)
2350            .expect("Decision function should succeed");
2351        assert_eq!(decisions.len(), 4);
2352
2353        // Class 0 samples should have negative decisions
2354        // Class 1 samples should have positive decisions
2355        // (may not be perfect for simple gradient descent)
2356    }
2357
2358    #[test]
2359    fn test_linear_svm_predict_untrained() {
2360        let svm = LinearSVM::new();
2361        let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
2362
2363        let result = svm.predict(&x);
2364        assert!(result.is_err());
2365        assert_eq!(
2366            result.expect_err("Should fail when predicting with untrained model"),
2367            "Model not trained yet"
2368        );
2369    }
2370
2371    #[test]
2372    fn test_linear_svm_dimension_mismatch() {
2373        let x_train = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2374            .expect("4x2 training matrix");
2375        let y = vec![0, 0, 1, 1];
2376
2377        let mut svm = LinearSVM::new();
2378        svm.fit(&x_train, &y)
2379            .expect("Training should succeed with valid data");
2380
2381        // Try to predict with wrong number of features
2382        let x_test =
2383            Matrix::from_vec(2, 3, vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]).expect("2x3 test matrix");
2384        let result = svm.predict(&x_test);
2385        assert!(result.is_err());
2386        assert_eq!(
2387            result.expect_err("Should fail with dimension mismatch"),
2388            "Feature dimension mismatch"
2389        );
2390    }
2391
2392    #[test]
2393    fn test_linear_svm_empty_data() {
2394        let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
2395        let y = vec![];
2396
2397        let mut svm = LinearSVM::new();
2398        let result = svm.fit(&x, &y);
2399        assert!(result.is_err());
2400        assert_eq!(
2401            result.expect_err("Should fail with empty data"),
2402            "Cannot fit with 0 samples"
2403        );
2404    }
2405
2406    #[test]
2407    fn test_linear_svm_mismatched_samples() {
2408        let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2409            .expect("4x2 matrix with 8 values");
2410        let y = vec![0, 0, 1]; // Wrong length
2411
2412        let mut svm = LinearSVM::new();
2413        let result = svm.fit(&x, &y);
2414        assert!(result.is_err());
2415        assert_eq!(
2416            result.expect_err("Should fail with mismatched sample counts"),
2417            "x and y must have the same number of samples"
2418        );
2419    }
2420
2421    #[test]
2422    fn test_linear_svm_regularization_c() {
2423        let x = Matrix::from_vec(
2424            6,
2425            2,
2426            vec![
2427                0.0, 0.0, // class 0
2428                0.1, 0.1, // class 0
2429                0.0, 0.2, // class 0
2430                1.0, 1.0, // class 1
2431                0.9, 0.9, // class 1
2432                1.0, 0.8, // class 1
2433            ],
2434        )
2435        .expect("6x2 matrix with 12 values");
2436        let y = vec![0, 0, 0, 1, 1, 1];
2437
2438        // High C (less regularization) - should fit data more closely
2439        let mut svm_high_c = LinearSVM::new()
2440            .with_c(10.0)
2441            .with_max_iter(1000)
2442            .with_learning_rate(0.1);
2443        svm_high_c
2444            .fit(&x, &y)
2445            .expect("Training should succeed with valid data");
2446        let pred_high_c = svm_high_c.predict(&x).expect("Prediction should succeed");
2447
2448        // Low C (more regularization) - should prefer simpler model
2449        let mut svm_low_c = LinearSVM::new()
2450            .with_c(0.1)
2451            .with_max_iter(1000)
2452            .with_learning_rate(0.1);
2453        svm_low_c
2454            .fit(&x, &y)
2455            .expect("Training should succeed with valid data");
2456        let pred_low_c = svm_low_c.predict(&x).expect("Prediction should succeed");
2457
2458        // Both should make predictions
2459        assert_eq!(pred_high_c.len(), 6);
2460        assert_eq!(pred_low_c.len(), 6);
2461    }
2462
2463    #[test]
2464    fn test_linear_svm_binary_classification() {
2465        // More realistic binary classification problem
2466        let x = Matrix::from_vec(
2467            10,
2468            2,
2469            vec![
2470                // Class 0 (bottom-left cluster)
2471                0.0, 0.0, 0.1, 0.1, 0.0, 0.2, 0.2, 0.0, 0.1,
2472                0.2, // Class 1 (top-right cluster)
2473                1.0, 1.0, 0.9, 0.9, 1.0, 0.8, 0.8, 1.0, 0.9, 1.1,
2474            ],
2475        )
2476        .expect("10x2 matrix with 20 values");
2477        let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
2478
2479        let mut svm = LinearSVM::new()
2480            .with_c(1.0)
2481            .with_max_iter(2000)
2482            .with_learning_rate(0.1);
2483
2484        svm.fit(&x, &y)
2485            .expect("Training should succeed with valid data");
2486        let predictions = svm.predict(&x).expect("Prediction should succeed");
2487
2488        // Should achieve reasonable accuracy
2489        let correct = predictions
2490            .iter()
2491            .zip(y.iter())
2492            .filter(|(pred, true_label)| *pred == *true_label)
2493            .count();
2494
2495        // Should get at least 8 out of 10 correct for well-separated clusters
2496        assert!(
2497            correct >= 8,
2498            "Expected at least 8/10 correct, got {correct}/10"
2499        );
2500    }
2501
2502    #[test]
2503    fn test_linear_svm_convergence() {
2504        let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2505            .expect("4x2 matrix with 8 values");
2506        let y = vec![0, 0, 1, 1];
2507
2508        // With very few iterations, might not converge
2509        let mut svm_few_iter = LinearSVM::new().with_max_iter(10).with_learning_rate(0.01);
2510        svm_few_iter
2511            .fit(&x, &y)
2512            .expect("Training should succeed with valid data");
2513
2514        // With many iterations, should converge better
2515        let mut svm_many_iter = LinearSVM::new().with_max_iter(2000).with_learning_rate(0.1);
2516        svm_many_iter
2517            .fit(&x, &y)
2518            .expect("Training should succeed with valid data");
2519
2520        // Both should train successfully
2521        assert!(svm_few_iter.weights.is_some());
2522        assert!(svm_many_iter.weights.is_some());
2523    }
2524
2525    #[test]
2526    fn test_linear_svm_default() {
2527        let svm1 = LinearSVM::new();
2528        let svm2 = LinearSVM::default();
2529
2530        assert_eq!(svm1.c, svm2.c);
2531        assert_eq!(svm1.learning_rate, svm2.learning_rate);
2532        assert_eq!(svm1.max_iter, svm2.max_iter);
2533    }
2534}