use crate::utils::distance::{Distance, DistanceError, KNNRegressorDistance};
use smartcore::numbers::{floatnum::FloatNumber, realnum::RealNumber};
pub use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction};
#[derive(serde::Serialize, serde::Deserialize)]
pub struct KNNParameters {
pub(crate) k: usize,
pub(crate) weight: KNNWeightFunction,
pub(crate) algorithm: KNNAlgorithmName,
pub(crate) distance: Distance,
}
impl KNNParameters {
#[must_use]
pub const fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
#[must_use]
pub const fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
self.weight = weight;
self
}
#[must_use]
pub const fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
self.algorithm = algorithm;
self
}
#[must_use]
pub const fn with_distance(mut self, distance: Distance) -> Self {
self.distance = distance;
self
}
pub fn to_classifier_params<INPUT: RealNumber + FloatNumber>(
&self,
) -> Result<
smartcore::neighbors::knn_classifier::KNNClassifierParameters<
INPUT,
KNNRegressorDistance<INPUT>,
>,
DistanceError,
> {
Ok(
smartcore::neighbors::knn_classifier::KNNClassifierParameters::default()
.with_k(self.k)
.with_algorithm(self.algorithm.clone())
.with_weight(self.weight.clone())
.with_distance(KNNRegressorDistance::from(self.distance)?),
)
}
pub fn to_regressor_params<INPUT: RealNumber + FloatNumber>(
&self,
) -> Result<
smartcore::neighbors::knn_regressor::KNNRegressorParameters<
INPUT,
KNNRegressorDistance<INPUT>,
>,
DistanceError,
> {
Ok(
smartcore::neighbors::knn_regressor::KNNRegressorParameters::default()
.with_k(self.k)
.with_algorithm(self.algorithm.clone())
.with_weight(self.weight.clone())
.with_distance(KNNRegressorDistance::from(self.distance)?),
)
}
}
impl Default for KNNParameters {
fn default() -> Self {
Self {
k: 3,
weight: KNNWeightFunction::Uniform,
algorithm: KNNAlgorithmName::CoverTree,
distance: Distance::Euclidean,
}
}
}