use crate::{
api::{Predictor, SupervisedEstimator},
error::{Failed, FailedError},
linalg::basic::arrays::{Array1, Array2},
numbers::basenum::Number,
numbers::realnum::RealNumber,
};
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
#[derive(Debug)]
pub struct GridSearchCVParameters<
T: Number,
M: Array2<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> {
_phantom: std::marker::PhantomData<(T, M)>,
parameters_search: I,
estimator: F,
score: S,
cv: K,
}
impl<
T: RealNumber,
M: Array2<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> GridSearchCVParameters<T, M, C, I, E, F, K, S>
{
pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self {
GridSearchCVParameters {
_phantom: std::marker::PhantomData,
parameters_search,
estimator,
score,
cv,
}
}
}
#[derive(Debug)]
pub struct GridSearchCV<T: RealNumber, M: Array2<T>, C: Clone, E: Predictor<M, M::RowVector>> {
_phantom: std::marker::PhantomData<(T, M)>,
predictor: E,
pub cross_validation_result: CrossValidationResult<T>,
pub best_parameter: C,
}
impl<T: RealNumber, M: Array2<T>, E: Predictor<M, M::RowVector>, C: Clone>
GridSearchCV<T, M, C, E>
{
pub fn fit<
I: Iterator<Item = C>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
S: Fn(&M::RowVector, &M::RowVector) -> T,
>(
x: &M,
y: &M::RowVector,
gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
) -> Result<Self, Failed> {
let mut best_result: Option<CrossValidationResult<T>> = None;
let mut best_parameters = None;
let parameters_search = gs_parameters.parameters_search;
let estimator = gs_parameters.estimator;
let cv = gs_parameters.cv;
let score = gs_parameters.score;
for parameters in parameters_search {
let result = cross_validate(&estimator, x, y, ¶meters, &cv, &score)?;
if best_result.is_none()
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
{
best_parameters = Some(parameters);
best_result = Some(result);
}
}
if let (Some(best_parameter), Some(cross_validation_result)) =
(best_parameters, best_result)
{
let predictor = estimator(x, y, best_parameter.clone())?;
Ok(Self {
_phantom: gs_parameters._phantom,
predictor,
cross_validation_result,
best_parameter,
})
} else {
Err(Failed::because(
FailedError::FindFailed,
"there were no parameter sets found",
))
}
}
pub fn cv_results(&self) -> &CrossValidationResult<T> {
&self.cross_validation_result
}
pub fn best_parameters(&self) -> &C {
&self.best_parameter
}
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predictor.predict(x)
}
}
impl<
T: RealNumber,
M: Array2<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>>
for GridSearchCV<T, M, C, E>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
) -> Result<Self, Failed> {
GridSearchCV::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Array2<T>, C: Clone, E: Predictor<M, M::RowVector>>
Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E>
{
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
#[cfg(test)]
mod tests {
use crate::{
linalg::naive::dense_matrix::DenseMatrix,
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{
hyper_tuning::grid_search::{self, GridSearchCVParameters},
KFold,
},
};
use grid_search::GridSearchCV;
#[test]
fn test_grid_search() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let cv = KFold {
n_splits: 5,
..KFold::default()
};
let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let grid_search = GridSearchCV::fit(
&x,
&y,
GridSearchCVParameters {
estimator: LogisticRegression::fit,
score: accuracy,
cv,
parameters_search: parameters.into_iter(),
_phantom: Default::default(),
},
)
.unwrap();
let best_parameters = grid_search.best_parameters();
assert!([1.].contains(&best_parameters.alpha));
let cv_results = grid_search.cv_results();
assert_eq!(cv_results.mean_test_score(), 0.9);
let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]);
let result = grid_search.predict(&x).unwrap();
assert_eq!(result, vec![0.]);
}
}