Skip to main content

anofox_ml_core/
lib.rs

1//! Core traits and types for the anofox-ml machine learning library.
2//!
3//! This crate defines the foundational traits that all anofox-ml estimators and
4//! transformers implement. It uses a **type-state pattern**: calling [`Fit::fit`]
5//! or [`FitUnsupervised::fit`] on an unfitted configuration struct returns a
6//! distinct *fitted* type that implements [`Predict`] or [`Transform`]. This
7//! makes it a compile-time error to call `predict` on an unfitted model.
8//!
9//! The crate also provides the [`Float`] trait (a unified bound for `f32`/`f64`),
10//! error types, train/test splitting utilities, and a [`Pipeline`] for chaining
11//! transformers with an estimator.
12//!
13//! # Examples
14//!
15//! ```
16//! use anofox_ml_core::{Fit, Predict, FitUnsupervised, Transform, Float};
17//!
18//! // The type-state pattern in action:
19//! // 1. `Fit` takes an unfitted config and returns a `Fitted` type.
20//! // 2. Only the `Fitted` type implements `Predict`.
21//! fn example_trait_bounds<F, M, FM>(model: &M, x: &ndarray::Array2<F>, y: &ndarray::Array1<F>)
22//! where
23//!     F: Float,
24//!     M: Fit<F, Fitted = FM>,
25//!     FM: Predict<F>,
26//! {
27//!     let fitted = model.fit(x, y).unwrap();
28//!     let _predictions = fitted.predict(x).unwrap();
29//! }
30//! ```
31
32pub mod column_transformer;
33pub mod error;
34pub mod feature_union;
35pub mod float;
36pub mod function_transformer;
37pub mod halving;
38pub mod inspection;
39pub mod multi_output;
40pub mod persistence;
41pub mod pipeline;
42pub mod sparse;
43pub mod traits;
44pub mod utils;
45
46pub use column_transformer::{ColumnSelector, ColumnTransformer, Remainder};
47pub use error::{Result, RustMlError};
48pub use feature_union::{FeatureUnion, FittedFeatureUnion};
49pub use float::Float;
50pub use function_transformer::FunctionTransformer;
51pub use halving::{halving_grid_search_cv, halving_random_search_cv, HalvingResult};
52pub use inspection::{permutation_importance, PermutationImportance};
53pub use multi_output::{
54    ClassifierChain, FittedClassifierChain, FittedMultiOutputClassifier,
55    FittedMultiOutputRegressor, FittedRegressorChain, MultiOutputClassifier, MultiOutputRegressor,
56    RegressorChain,
57};
58pub use pipeline::{
59    FitPredict, FitTransform, FittedPipeline, Pipeline, PredictStep, TransformStep,
60};
61pub use sparse::CsrMatrix;
62pub use traits::{
63    ClassifierScore, DecisionFunction, Fit, FitUnsupervised, FitUnsupervisedWeighted, FitWeighted,
64    InverseTransform, PartialFit, Predict, PredictLogProba, PredictProba, RegressorScore,
65    Transform,
66};
67pub use utils::{
68    cross_val_predict, cross_val_score, cross_val_score_stratified, cross_validate, grid_search_cv,
69    group_k_fold, k_fold, learning_curve, leave_one_out, leave_p_out, randomized_search_cv,
70    repeated_k_fold, repeated_stratified_k_fold, shuffle_split, stratified_k_fold,
71    stratified_shuffle_split, time_series_split, train_test_split, validation_curve,
72    CrossValidateResult, GridSearchResult,
73};