aprender/traits.rs
1//! Core traits for ML estimators and transformers.
2//!
3//! These traits define the API contracts for all ML algorithms.
4
5use crate::error::Result;
6use crate::primitives::{Matrix, Vector};
7
8/// Primary trait for supervised learning estimators.
9///
10/// Estimators implement fit/predict/score following sklearn conventions.
11///
12/// # Examples
13///
14/// ```
15/// use aprender::prelude::*;
16///
17/// // Create training data: y = 2x + 1
18/// let x_train = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
19/// let y_train = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
20///
21/// // Test data
22/// let x_test = Matrix::from_vec(2, 1, vec![5.0, 6.0]).unwrap();
23/// let y_test = Vector::from_slice(&[11.0, 13.0]);
24///
25/// let mut model = LinearRegression::new();
26/// model.fit(&x_train, &y_train).unwrap();
27/// let predictions = model.predict(&x_test);
28/// let score = model.score(&x_test, &y_test);
29/// assert!(score > 0.99);
30/// ```
31pub trait Estimator {
32 /// Fits the model to training data.
33 ///
34 /// # Errors
35 ///
36 /// Returns an error if fitting fails (dimension mismatch, singular matrix, etc.).
37 fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()>;
38
39 /// Predicts target values for input data.
40 fn predict(&self, x: &Matrix<f32>) -> Vector<f32>;
41
42 /// Computes the score (R² for regression, accuracy for classification).
43 fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32;
44}
45
46/// Trait for unsupervised learning models.
47///
48/// # Examples
49///
50/// ```
51/// use aprender::prelude::*;
52///
53/// // Create data with 2 clear clusters
54/// let data = Matrix::from_vec(6, 2, vec![
55/// 0.0, 0.0, 0.1, 0.1, 0.2, 0.0, // Cluster 1
56/// 10.0, 10.0, 10.1, 10.1, 10.0, 10.2, // Cluster 2
57/// ]).unwrap();
58///
59/// let mut kmeans = KMeans::new(2).with_random_state(42);
60/// kmeans.fit(&data).unwrap();
61/// let labels = kmeans.predict(&data);
62/// assert_eq!(labels.len(), 6);
63/// ```
64pub trait UnsupervisedEstimator {
65 /// The type of labels/clusters produced.
66 type Labels;
67
68 /// Fits the model to data.
69 ///
70 /// # Errors
71 ///
72 /// Returns an error if fitting fails (empty data, invalid parameters, etc.).
73 fn fit(&mut self, x: &Matrix<f32>) -> Result<()>;
74
75 /// Predicts cluster assignments or transforms data.
76 fn predict(&self, x: &Matrix<f32>) -> Self::Labels;
77}
78
79/// Trait for data transformers (scalers, encoders, etc.).
80///
81/// This trait defines the interface for preprocessing transformers.
82/// Implementations include scalers, encoders, and feature transformers.
83///
84/// # Future Usage
85///
86/// ```text
87/// let mut scaler = StandardScaler::new();
88/// let x_scaled = scaler.fit_transform(&x)?;
89/// let x_test_scaled = scaler.transform(&x_test)?;
90/// ```
91pub trait Transformer {
92 /// Fits the transformer to data.
93 ///
94 /// # Errors
95 ///
96 /// Returns an error if fitting fails.
97 fn fit(&mut self, x: &Matrix<f32>) -> Result<()>;
98
99 /// Transforms data using fitted parameters.
100 ///
101 /// # Errors
102 ///
103 /// Returns an error if transformer is not fitted.
104 fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>>;
105
106 /// Fits and transforms in one step.
107 ///
108 /// # Errors
109 ///
110 /// Returns an error if fitting fails.
111 fn fit_transform(&mut self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
112 self.fit(x)?;
113 self.transform(x)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 // Traits are tested via their implementations
120}