use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::{Distance, Distances};
use crate::neighbors::KNNWeightFunction;
use crate::numbers::basenum::Number;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNRegressorParameters<T: Number, D: Distance<Vec<T>>> {
#[cfg_attr(feature = "serde", serde(default))]
distance: D,
#[cfg_attr(feature = "serde", serde(default))]
pub algorithm: KNNAlgorithmName,
#[cfg_attr(feature = "serde", serde(default))]
pub weight: KNNWeightFunction,
#[cfg_attr(feature = "serde", serde(default))]
pub k: usize,
#[cfg_attr(feature = "serde", serde(default))]
t: PhantomData<T>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KNNRegressor<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
{
y: Option<Y>,
knn_algorithm: Option<KNNAlgorithm<TX, D>>,
weight: Option<KNNWeightFunction>,
k: Option<usize>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
KNNRegressor<TX, TY, X, Y, D>
{
fn y(&self) -> &Y {
self.y.as_ref().unwrap()
}
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
self.knn_algorithm
.as_ref()
.expect("Missing parameter: KNNAlgorithm")
}
fn weight(&self) -> &KNNWeightFunction {
self.weight.as_ref().expect("Missing parameter: weight")
}
#[allow(dead_code)]
fn k(&self) -> usize {
self.k.unwrap()
}
}
impl<T: Number, D: Distance<Vec<T>>> KNNRegressorParameters<T, D> {
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_distance<DD: Distance<Vec<T>>>(
self,
distance: DD,
) -> KNNRegressorParameters<T, DD> {
KNNRegressorParameters {
distance,
algorithm: self.algorithm,
weight: self.weight,
k: self.k,
t: PhantomData,
}
}
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
self.weight = weight;
self
}
}
impl<T: Number> Default for KNNRegressorParameters<T, Euclidian<T>> {
fn default() -> Self {
KNNRegressorParameters {
distance: Distances::euclidian(),
algorithm: KNNAlgorithmName::default(),
weight: KNNWeightFunction::default(),
k: 3,
t: PhantomData,
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> PartialEq
for KNNRegressor<TX, TY, X, Y, D>
{
fn eq(&self, other: &Self) -> bool {
if self.k != other.k || self.y().shape() != other.y().shape() {
false
} else {
for i in 0..self.y().shape() {
if self.y().get(i) != other.y().get(i) {
return false;
}
}
true
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
SupervisedEstimator<X, Y, KNNRegressorParameters<TX, D>> for KNNRegressor<TX, TY, X, Y, D>
{
fn new() -> Self {
Self {
y: Option::None,
knn_algorithm: Option::None,
weight: Option::None,
k: Option::None,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: KNNRegressorParameters<TX, D>) -> Result<Self, Failed> {
KNNRegressor::fit(x, y, parameters)
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> Predictor<X, Y>
for KNNRegressor<TX, TY, X, Y, D>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
KNNRegressor<TX, TY, X, Y, D>
{
pub fn fit(
x: &X,
y: &Y,
parameters: KNNRegressorParameters<TX, D>,
) -> Result<KNNRegressor<TX, TY, X, Y, D>, Failed> {
let y_n = y.shape();
let (x_n, _) = x.shape();
let data = x
.row_iter()
.map(|row| row.iterator(0).copied().collect())
.collect();
if x_n != y_n {
return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{x_n}], |y|=[{y_n}]"
)));
}
if parameters.k < 1 {
return Err(Failed::fit(&format!(
"k should be > 0, k=[{}]",
parameters.k
)));
}
let knn_algo = parameters.algorithm.fit(data, parameters.distance)?;
Ok(KNNRegressor {
y: Some(y.clone()),
k: Some(parameters.k),
knn_algorithm: Some(knn_algo),
weight: Some(parameters.weight),
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
})
}
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let mut row_vec = vec![TX::zero(); x.shape().1];
for (i, row) in x.row_iter().enumerate() {
row.iterator(0)
.zip(row_vec.iter_mut())
.for_each(|(&s, v)| *v = s);
result.set(i, self.predict_for_row(&row_vec)?);
}
Ok(result)
}
fn predict_for_row(&self, row: &Vec<TX>) -> Result<TY, Failed> {
let search_result = self.knn_algorithm().find(row, self.k.unwrap())?;
let mut result = TY::zero();
let weights = self
.weight()
.calc_weights(search_result.iter().map(|v| v.1).collect());
let w_sum: f64 = weights.iter().copied().sum();
for (r, w) in search_result.iter().zip(weights.iter()) {
result += *self.y().get(r.0) * TY::from_f64(*w / w_sum).unwrap();
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::distance::Distances;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn knn_fit_predict_weighted() {
let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
.unwrap();
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
let y_exp = [1., 2., 3., 4., 5.];
let knn = KNNRegressor::fit(
&x,
&y,
KNNRegressorParameters::default()
.with_k(3)
.with_distance(Distances::euclidian())
.with_algorithm(KNNAlgorithmName::LinearSearch)
.with_weight(KNNWeightFunction::Distance),
)
.unwrap();
let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat));
for i in 0..y_hat.len() {
assert!((y_hat[i] - y_exp[i]).abs() < f64::EPSILON);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn knn_fit_predict_uniform() {
let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
.unwrap();
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
let y_exp = [2., 2., 3., 4., 4.];
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat));
for i in 0..y_hat.len() {
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
.unwrap();
let y = vec![1., 2., 3., 4., 5.];
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
assert_eq!(knn, deserialized_knn);
}
}