pub trait Estimator<Input> {
type Estimator;
// Required method
fn fit(&self, input: &Input) -> Option<Self::Estimator>;
}Expand description
Trait for fitting classification and regression models, and transformers.
The struct on which this trait is implemented holds and validates the hyperparameters necessary to fit the estimator to the desired output. For example, a classification model may take as input a tuple with features and labels:
use ndarray::{Array1, Array2};
use rs_ml::Estimator;
struct ModelParameters {
// Hyperparameters required to fit the model
learning_rate: f64
}
struct Model {
// Internal state of model required to predict features
means: Array2<f64>
};
impl Estimator<(Array2<f64>, Array1<String>)> for ModelParameters {
type Estimator = Model;
fn fit(&self, input: &(Array2<f64>, Array1<String>)) -> Option<Self::Estimator> {
let (features, labels) = input;
// logic to fit the model
Some(Model {
means: Array2::zeros((1, 1))
})
}
}