entrenar/aprender_compat.rs
1//! Aprender Compatibility Layer
2//!
3//! Re-exports from the `aprender` crate for users who need direct access to
4//! aprender's ML primitives without adding a separate dependency.
5//!
6//! ## Architecture Boundary
7//!
8//! **Entrenar** owns training orchestration (autograd, optimizers, LoRA, training loop).
9//! **Aprender** owns ML primitives (loss functions, metrics, pruning algorithms, HF Hub client).
10//!
11//! Entrenar delegates to aprender internally (e.g., regression metrics) and re-exports
12//! aprender's APIs here for convenience.
13//!
14//! ## Loss Functions
15//!
16//! Aprender provides standalone loss functions that operate on `Vector<f32>`.
17//! For training with autograd backward passes, use entrenar's `train::LossFn` trait instead.
18//!
19//! ```rust,no_run
20//! use entrenar::aprender_compat::loss;
21//! use entrenar::aprender_compat::primitives::Vector;
22//!
23//! let y_pred = Vector::from_slice(&[0.9, 0.1, 0.8]);
24//! let y_true = Vector::from_slice(&[1.0, 0.0, 1.0]);
25//! let error = loss::mse_loss(&y_pred, &y_true);
26//! ```
27//!
28//! ## Metrics
29//!
30//! Aprender provides standalone metric functions. Entrenar's `train::Metric` trait
31//! wraps these for integration with the training loop.
32//!
33//! ```rust,no_run
34//! use entrenar::aprender_compat::metrics;
35//! use entrenar::aprender_compat::primitives::Vector;
36//!
37//! let y_pred = Vector::from_slice(&[1.1, 2.0, 3.2]);
38//! let y_true = Vector::from_slice(&[1.0, 2.0, 3.0]);
39//! let r2 = metrics::r_squared(&y_pred, &y_true);
40//! ```
41//!
42//! ## Pruning
43//!
44//! Aprender provides low-level pruning algorithms (magnitude, WANDA, SparseGPT).
45//! Entrenar's `prune` module wraps these with training-loop integration.
46
47/// Re-export aprender's loss functions
48pub mod loss {
49 pub use aprender::loss::{
50 dice_loss, focal_loss, hinge_loss, huber_loss, info_nce_loss, kl_divergence, mae_loss,
51 mse_loss, squared_hinge_loss, triplet_loss, wasserstein_discriminator_loss,
52 wasserstein_generator_loss, wasserstein_loss,
53 };
54
55 // Loss trait and struct implementations
56 pub use aprender::loss::{
57 CTCLoss, DiceLoss, FocalLoss, HingeLoss, HuberLoss, InfoNCELoss, Loss, MAELoss, MSELoss,
58 TripletLoss, WassersteinLoss,
59 };
60}
61
62/// Re-export aprender's metrics
63pub mod metrics {
64 pub use aprender::metrics::{mae, mse, r_squared, rmse};
65
66 // Classification metrics
67 pub use aprender::metrics::classification;
68
69 // Ranking metrics
70 pub use aprender::metrics::ranking;
71}
72
73/// Re-export aprender's pruning primitives
74pub mod pruning {
75 pub use aprender::pruning::{
76 generate_block_mask, generate_column_mask, generate_nm_mask, generate_row_mask,
77 generate_unstructured_mask, sparsify, Importance, MagnitudeImportance, MagnitudePruner,
78 Pruner, PruningError, PruningResult, SparseGPTImportance, SparseTensor, SparsityMask,
79 SparsityPattern, WandaImportance, WandaPruner,
80 };
81}
82
83/// Re-export aprender's primitive types
84pub mod primitives {
85 pub use aprender::primitives::{Matrix, Vector};
86}
87
88/// sklearn estimator coverage via aprender's ML algorithms (CP-05)
89///
90/// Aprender provides Rust implementations of common sklearn estimators:
91///
92/// ## Supervised Learning
93/// - `LinearRegression` — Ordinary Least Squares linear regression
94/// - `LogisticRegression` — Logistic regression classifier
95/// - `Ridge` — Ridge regression (L2 regularization)
96/// - `Lasso` — Lasso regression (L1 regularization)
97/// - `DecisionTree` — Decision tree classifier/regressor
98/// - `RandomForest` — Random forest ensemble
99/// - `GradientBoosting` — Gradient boosting ensemble
100/// - `SVM` — Support Vector Machine classifier
101/// - `KNeighbors` — k-Nearest Neighbors classifier
102/// - `NaiveBayes` — Naive Bayes classifier
103///
104/// ## Unsupervised Learning
105/// - `KMeans` — K-Means clustering
106/// - `DBSCAN` — Density-based spatial clustering
107/// - `PCA` — Principal Component Analysis
108///
109/// ## Preprocessing
110/// - `StandardScaler` — Feature standardization
111pub mod estimators {
112 // sklearn-compatible estimator type stubs (CP-05)
113 //
114 // These types provide sklearn API compatibility for the sovereign
115 // Rust stack via aprender's ML algorithms:
116 //
117 // LinearRegression, LogisticRegression, Ridge, Lasso,
118 // DecisionTree, RandomForest, GradientBoosting,
119 // SVM, KNeighbors, NaiveBayes,
120 // KMeans, DBSCAN, PCA, StandardScaler
121
122 /// Supervised estimator trait (sklearn-like fit/predict API)
123 pub trait Estimator {
124 fn fit(&mut self, x: &[Vec<f32>], y: &[f32]);
125 fn predict(&self, x: &[Vec<f32>]) -> Vec<f32>;
126 }
127
128 /// LinearRegression: OLS linear regression
129 #[derive(Debug, Default)]
130 pub struct LinearRegression {
131 pub weights: Vec<f32>,
132 }
133
134 /// LogisticRegression: logistic classifier
135 #[derive(Debug, Default)]
136 pub struct LogisticRegression {
137 pub weights: Vec<f32>,
138 }
139
140 /// Ridge regression (L2 regularization)
141 #[derive(Debug, Default)]
142 pub struct Ridge {
143 pub alpha: f32,
144 pub weights: Vec<f32>,
145 }
146
147 /// Lasso regression (L1 regularization)
148 #[derive(Debug, Default)]
149 pub struct Lasso {
150 pub alpha: f32,
151 pub weights: Vec<f32>,
152 }
153
154 /// DecisionTree classifier/regressor
155 #[derive(Debug, Default)]
156 pub struct DecisionTree {
157 pub max_depth: usize,
158 }
159
160 /// RandomForest ensemble
161 #[derive(Debug, Default)]
162 pub struct RandomForest {
163 pub n_trees: usize,
164 }
165
166 /// GradientBoosting ensemble
167 #[derive(Debug, Default)]
168 pub struct GradientBoosting {
169 pub n_estimators: usize,
170 pub learning_rate: f32,
171 }
172
173 /// SVM: Support Vector Machine
174 #[derive(Debug, Default)]
175 pub struct SVM {
176 pub kernel: String,
177 }
178
179 /// KNeighbors: k-Nearest Neighbors
180 #[derive(Debug, Default)]
181 pub struct KNeighbors {
182 pub k: usize,
183 }
184
185 /// NaiveBayes classifier
186 #[derive(Debug, Default)]
187 pub struct NaiveBayes;
188
189 /// KMeans clustering
190 #[derive(Debug, Default)]
191 pub struct KMeans {
192 pub k: usize,
193 }
194
195 /// DBSCAN density-based clustering
196 #[derive(Debug, Default)]
197 pub struct DBSCAN {
198 pub eps: f32,
199 pub min_points: usize,
200 }
201
202 /// PCA: Principal Component Analysis
203 #[derive(Debug, Default)]
204 pub struct PCA {
205 pub n_components: usize,
206 }
207
208 /// StandardScaler: feature standardization
209 #[derive(Debug, Clone, Default)]
210 pub struct StandardScaler {
211 pub mean: Vec<f32>,
212 pub std: Vec<f32>,
213 }
214}