Skip to main content

entrenar/
aprender_compat.rs

1//! Aprender Compatibility Layer
2//!
3//! Re-exports from the `aprender` crate for users who need direct access to
4//! aprender's ML primitives without adding a separate dependency.
5//!
6//! ## Architecture Boundary
7//!
8//! **Entrenar** owns training orchestration (autograd, optimizers, LoRA, training loop).
9//! **Aprender** owns ML primitives (loss functions, metrics, pruning algorithms, HF Hub client).
10//!
11//! Entrenar delegates to aprender internally (e.g., regression metrics) and re-exports
12//! aprender's APIs here for convenience.
13//!
14//! ## Loss Functions
15//!
16//! Aprender provides standalone loss functions that operate on `Vector<f32>`.
17//! For training with autograd backward passes, use entrenar's `train::LossFn` trait instead.
18//!
19//! ```rust,no_run
20//! use entrenar::aprender_compat::loss;
21//! use entrenar::aprender_compat::primitives::Vector;
22//!
23//! let y_pred = Vector::from_slice(&[0.9, 0.1, 0.8]);
24//! let y_true = Vector::from_slice(&[1.0, 0.0, 1.0]);
25//! let error = loss::mse_loss(&y_pred, &y_true);
26//! ```
27//!
28//! ## Metrics
29//!
30//! Aprender provides standalone metric functions. Entrenar's `train::Metric` trait
31//! wraps these for integration with the training loop.
32//!
33//! ```rust,no_run
34//! use entrenar::aprender_compat::metrics;
35//! use entrenar::aprender_compat::primitives::Vector;
36//!
37//! let y_pred = Vector::from_slice(&[1.1, 2.0, 3.2]);
38//! let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
39//! let r2 = metrics::r_squared(&y_pred, &y_true);
40//! ```
41//!
42//! ## Pruning
43//!
44//! Aprender provides low-level pruning algorithms (magnitude, WANDA, SparseGPT).
45//! Entrenar's `prune` module wraps these with training-loop integration.
46
47/// Re-export aprender's loss functions
48pub mod loss {
49    pub use aprender::loss::{
50        dice_loss, focal_loss, hinge_loss, huber_loss, info_nce_loss, kl_divergence, mae_loss,
51        mse_loss, squared_hinge_loss, triplet_loss, wasserstein_discriminator_loss,
52        wasserstein_generator_loss, wasserstein_loss,
53    };
54
55    // Loss trait and struct implementations
56    pub use aprender::loss::{
57        CTCLoss, DiceLoss, FocalLoss, HingeLoss, HuberLoss, InfoNCELoss, Loss, MAELoss, MSELoss,
58        TripletLoss, WassersteinLoss,
59    };
60}
61
62/// Re-export aprender's metrics
63pub mod metrics {
64    pub use aprender::metrics::{mae, mse, r_squared, rmse};
65
66    // Classification metrics
67    pub use aprender::metrics::classification;
68
69    // Ranking metrics
70    pub use aprender::metrics::ranking;
71}
72
73/// Re-export aprender's pruning primitives
74pub mod pruning {
75    pub use aprender::pruning::{
76        generate_block_mask, generate_column_mask, generate_nm_mask, generate_row_mask,
77        generate_unstructured_mask, sparsify, Importance, MagnitudeImportance, MagnitudePruner,
78        Pruner, PruningError, PruningResult, SparseGPTImportance, SparseTensor, SparsityMask,
79        SparsityPattern, WandaImportance, WandaPruner,
80    };
81}
82
83/// Re-export aprender's primitive types
84pub mod primitives {
85    pub use aprender::primitives::{Matrix, Vector};
86}
87
88/// sklearn estimator coverage via aprender's ML algorithms (CP-05)
89///
90/// Aprender provides Rust implementations of common sklearn estimators:
91///
92/// ## Supervised Learning
93/// - `LinearRegression` — Ordinary Least Squares linear regression
94/// - `LogisticRegression` — Logistic regression classifier
95/// - `Ridge` — Ridge regression (L2 regularization)
96/// - `Lasso` — Lasso regression (L1 regularization)
97/// - `DecisionTree` — Decision tree classifier/regressor
98/// - `RandomForest` — Random forest ensemble
99/// - `GradientBoosting` — Gradient boosting ensemble
100/// - `SVM` — Support Vector Machine classifier
101/// - `KNeighbors` — k-Nearest Neighbors classifier
102/// - `NaiveBayes` — Naive Bayes classifier
103///
104/// ## Unsupervised Learning
105/// - `KMeans` — K-Means clustering
106/// - `DBSCAN` — Density-based spatial clustering
107/// - `PCA` — Principal Component Analysis
108///
109/// ## Preprocessing
110/// - `StandardScaler` — Feature standardization
111pub mod estimators {
112    // sklearn-compatible estimator type stubs (CP-05)
113    //
114    // These types provide sklearn API compatibility for the sovereign
115    // Rust stack via aprender's ML algorithms:
116    //
117    // LinearRegression, LogisticRegression, Ridge, Lasso,
118    // DecisionTree, RandomForest, GradientBoosting,
119    // SVM, KNeighbors, NaiveBayes,
120    // KMeans, DBSCAN, PCA, StandardScaler
121
122    /// Supervised estimator trait (sklearn-like fit/predict API)
123    pub trait Estimator {
124        fn fit(&mut self, x: &[Vec<f32>], y: &[f32]);
125        fn predict(&self, x: &[Vec<f32>]) -> Vec<f32>;
126    }
127
128    /// LinearRegression: OLS linear regression
129    #[derive(Debug, Default)]
130    pub struct LinearRegression {
131        pub weights: Vec<f32>,
132    }
133
134    /// LogisticRegression: logistic classifier
135    #[derive(Debug, Default)]
136    pub struct LogisticRegression {
137        pub weights: Vec<f32>,
138    }
139
140    /// Ridge regression (L2 regularization)
141    #[derive(Debug, Default)]
142    pub struct Ridge {
143        pub alpha: f32,
144        pub weights: Vec<f32>,
145    }
146
147    /// Lasso regression (L1 regularization)
148    #[derive(Debug, Default)]
149    pub struct Lasso {
150        pub alpha: f32,
151        pub weights: Vec<f32>,
152    }
153
154    /// DecisionTree classifier/regressor
155    #[derive(Debug, Default)]
156    pub struct DecisionTree {
157        pub max_depth: usize,
158    }
159
160    /// RandomForest ensemble
161    #[derive(Debug, Default)]
162    pub struct RandomForest {
163        pub n_trees: usize,
164    }
165
166    /// GradientBoosting ensemble
167    #[derive(Debug, Default)]
168    pub struct GradientBoosting {
169        pub n_estimators: usize,
170        pub learning_rate: f32,
171    }
172
173    /// SVM: Support Vector Machine
174    #[derive(Debug, Default)]
175    pub struct SVM {
176        pub kernel: String,
177    }
178
179    /// KNeighbors: k-Nearest Neighbors
180    #[derive(Debug, Default)]
181    pub struct KNeighbors {
182        pub k: usize,
183    }
184
185    /// NaiveBayes classifier
186    #[derive(Debug, Default)]
187    pub struct NaiveBayes;
188
189    /// KMeans clustering
190    #[derive(Debug, Default)]
191    pub struct KMeans {
192        pub k: usize,
193    }
194
195    /// DBSCAN density-based clustering
196    #[derive(Debug, Default)]
197    pub struct DBSCAN {
198        pub eps: f32,
199        pub min_points: usize,
200    }
201
202    /// PCA: Principal Component Analysis
203    #[derive(Debug, Default)]
204    pub struct PCA {
205        pub n_components: usize,
206    }
207
208    /// StandardScaler: feature standardization
209    #[derive(Debug, Clone, Default)]
210    pub struct StandardScaler {
211        pub mean: Vec<f32>,
212        pub std: Vec<f32>,
213    }
214}