aprender/
traits.rs

1//! Core traits for ML estimators and transformers.
2//!
3//! These traits define the API contracts for all ML algorithms.
4
5use crate::error::Result;
6use crate::primitives::{Matrix, Vector};
7
8/// Primary trait for supervised learning estimators.
9///
10/// Estimators implement fit/predict/score following sklearn conventions.
11///
12/// # Examples
13///
14/// ```
15/// use aprender::prelude::*;
16///
17/// // Create training data: y = 2x + 1
18/// let x_train = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
19/// let y_train = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
20///
21/// // Test data
22/// let x_test = Matrix::from_vec(2, 1, vec![5.0, 6.0]).unwrap();
23/// let y_test = Vector::from_slice(&[11.0, 13.0]);
24///
25/// let mut model = LinearRegression::new();
26/// model.fit(&x_train, &y_train).unwrap();
27/// let predictions = model.predict(&x_test);
28/// let score = model.score(&x_test, &y_test);
29/// assert!(score > 0.99);
30/// ```
31pub trait Estimator {
32    /// Fits the model to training data.
33    ///
34    /// # Errors
35    ///
36    /// Returns an error if fitting fails (dimension mismatch, singular matrix, etc.).
37    fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()>;
38
39    /// Predicts target values for input data.
40    fn predict(&self, x: &Matrix<f32>) -> Vector<f32>;
41
42    /// Computes the score (R² for regression, accuracy for classification).
43    fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32;
44}
45
46/// Trait for unsupervised learning models.
47///
48/// # Examples
49///
50/// ```
51/// use aprender::prelude::*;
52///
53/// // Create data with 2 clear clusters
54/// let data = Matrix::from_vec(6, 2, vec![
55///     0.0, 0.0, 0.1, 0.1, 0.2, 0.0,  // Cluster 1
56///     10.0, 10.0, 10.1, 10.1, 10.0, 10.2,  // Cluster 2
57/// ]).unwrap();
58///
59/// let mut kmeans = KMeans::new(2).with_random_state(42);
60/// kmeans.fit(&data).unwrap();
61/// let labels = kmeans.predict(&data);
62/// assert_eq!(labels.len(), 6);
63/// ```
64pub trait UnsupervisedEstimator {
65    /// The type of labels/clusters produced.
66    type Labels;
67
68    /// Fits the model to data.
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if fitting fails (empty data, invalid parameters, etc.).
73    fn fit(&mut self, x: &Matrix<f32>) -> Result<()>;
74
75    /// Predicts cluster assignments or transforms data.
76    fn predict(&self, x: &Matrix<f32>) -> Self::Labels;
77}
78
79/// Trait for data transformers (scalers, encoders, etc.).
80///
81/// This trait defines the interface for preprocessing transformers.
82/// Implementations include scalers, encoders, and feature transformers.
83///
84/// # Future Usage
85///
86/// ```text
87/// let mut scaler = StandardScaler::new();
88/// let x_scaled = scaler.fit_transform(&x)?;
89/// let x_test_scaled = scaler.transform(&x_test)?;
90/// ```
91pub trait Transformer {
92    /// Fits the transformer to data.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if fitting fails.
97    fn fit(&mut self, x: &Matrix<f32>) -> Result<()>;
98
99    /// Transforms data using fitted parameters.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if transformer is not fitted.
104    fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>>;
105
106    /// Fits and transforms in one step.
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if fitting fails.
111    fn fit_transform(&mut self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
112        self.fit(x)?;
113        self.transform(x)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    // Traits are tested via their implementations
120}