pub mod loss {
pub use aprender::loss::{
dice_loss, focal_loss, hinge_loss, huber_loss, info_nce_loss, kl_divergence, mae_loss,
mse_loss, squared_hinge_loss, triplet_loss, wasserstein_discriminator_loss,
wasserstein_generator_loss, wasserstein_loss,
};
pub use aprender::loss::{
CTCLoss, DiceLoss, FocalLoss, HingeLoss, HuberLoss, InfoNCELoss, Loss, MAELoss, MSELoss,
TripletLoss, WassersteinLoss,
};
}
pub mod metrics {
pub use aprender::metrics::{mae, mse, r_squared, rmse};
pub use aprender::metrics::classification;
pub use aprender::metrics::ranking;
}
pub mod pruning {
pub use aprender::pruning::{
generate_block_mask, generate_column_mask, generate_nm_mask, generate_row_mask,
generate_unstructured_mask, sparsify, Importance, MagnitudeImportance, MagnitudePruner,
Pruner, PruningError, PruningResult, SparseGPTImportance, SparseTensor, SparsityMask,
SparsityPattern, WandaImportance, WandaPruner,
};
}
pub mod primitives {
pub use aprender::primitives::{Matrix, Vector};
}
pub mod estimators {
pub trait Estimator {
fn fit(&mut self, x: &[Vec<f32>], y: &[f32]);
fn predict(&self, x: &[Vec<f32>]) -> Vec<f32>;
}
#[derive(Debug, Default)]
pub struct LinearRegression {
pub weights: Vec<f32>,
}
#[derive(Debug, Default)]
pub struct LogisticRegression {
pub weights: Vec<f32>,
}
#[derive(Debug, Default)]
pub struct Ridge {
pub alpha: f32,
pub weights: Vec<f32>,
}
#[derive(Debug, Default)]
pub struct Lasso {
pub alpha: f32,
pub weights: Vec<f32>,
}
#[derive(Debug, Default)]
pub struct DecisionTree {
pub max_depth: usize,
}
#[derive(Debug, Default)]
pub struct RandomForest {
pub n_trees: usize,
}
#[derive(Debug, Default)]
pub struct GradientBoosting {
pub n_estimators: usize,
pub learning_rate: f32,
}
#[derive(Debug, Default)]
pub struct SVM {
pub kernel: String,
}
#[derive(Debug, Default)]
pub struct KNeighbors {
pub k: usize,
}
#[derive(Debug, Default)]
pub struct NaiveBayes;
#[derive(Debug, Default)]
pub struct KMeans {
pub k: usize,
}
#[derive(Debug, Default)]
pub struct DBSCAN {
pub eps: f32,
pub min_points: usize,
}
#[derive(Debug, Default)]
pub struct PCA {
pub n_components: usize,
}
#[derive(Debug, Clone, Default)]
pub struct StandardScaler {
pub mean: Vec<f32>,
pub std: Vec<f32>,
}
}