pub use super::DistanceCalculationMetric;
use super::helper_function::preliminary_check;
use crate::error::ModelError;
use crate::math::{manhattan_distance_row, minkowski_distance_row, squared_euclidean_distance_row};
use crate::{Deserialize, Serialize};
use ahash::AHashMap;
use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
const PARALLEL_THRESHOLD: usize = 1000;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum WeightingStrategy {
Uniform,
Distance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KNN<T> {
k: usize,
x_train: Option<Array2<f64>>,
y_train_encoded: Option<Array1<usize>>,
#[serde(bound(
serialize = "T: Serialize + Eq + std::hash::Hash",
deserialize = "T: Deserialize<'de> + Eq + std::hash::Hash"
))]
label_map: Option<(AHashMap<T, usize>, Vec<T>)>,
weighting_strategy: WeightingStrategy,
metric: DistanceCalculationMetric,
}
impl<T: Clone + std::hash::Hash + Eq> Default for KNN<T> {
fn default() -> Self {
KNN {
k: 5,
x_train: None,
y_train_encoded: None,
label_map: None,
weighting_strategy: WeightingStrategy::Uniform,
metric: DistanceCalculationMetric::Euclidean,
}
}
}
impl<T: Clone + std::hash::Hash + Eq> KNN<T> {
pub fn new(
k: usize,
weighting_strategy: WeightingStrategy,
metric: DistanceCalculationMetric,
) -> Result<Self, ModelError> {
if k == 0 {
return Err(ModelError::InputValidationError(
"k must be greater than 0".to_string(),
));
}
Ok(KNN {
k,
x_train: None,
y_train_encoded: None,
label_map: None,
weighting_strategy,
metric,
})
}
get_field!(get_k, k, usize);
get_field!(
get_weighting_strategy,
weighting_strategy,
WeightingStrategy
);
get_field!(get_metric, metric, DistanceCalculationMetric);
get_field_as_ref!(get_x_train, x_train, Option<&Array2<f64>>);
get_field_as_ref!(get_y_train_encoded, y_train_encoded, Option<&Array1<usize>>);
get_field_as_ref!(
get_label_map,
label_map,
Option<&(AHashMap<T, usize>, Vec<T>)>
);
pub fn fit<S1, S2>(
&mut self,
x: &ArrayBase<S1, Ix2>,
y: &ArrayBase<S2, Ix1>,
) -> Result<&mut Self, ModelError>
where
S1: Data<Elem = f64>,
S2: Data<Elem = T>,
{
preliminary_check(x, None)?;
if x.nrows() < self.k {
return Err(ModelError::InputValidationError(
"The number of samples is less than k".to_string(),
));
}
let mut label_to_idx: AHashMap<T, usize> = AHashMap::new();
let mut idx_to_label: Vec<T> = Vec::new();
let mut next_idx = 0;
let mut encoded_labels = Vec::with_capacity(y.len());
for label in y.iter() {
let idx = if let Some(&existing_idx) = label_to_idx.get(label) {
existing_idx
} else {
let new_idx = next_idx;
label_to_idx.insert(label.clone(), new_idx);
idx_to_label.push(label.clone());
next_idx += 1;
new_idx
};
encoded_labels.push(idx);
}
self.x_train = Some(x.to_owned());
self.y_train_encoded = Some(Array1::from(encoded_labels));
self.label_map = Some((label_to_idx, idx_to_label));
Ok(self)
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<T>, ModelError>
where
S: Data<Elem = f64>,
{
if self.x_train.is_none() || self.y_train_encoded.is_none() || self.label_map.is_none() {
return Err(ModelError::NotFitted);
}
preliminary_check(x, None)?;
let x_train = self.x_train.as_ref().unwrap();
if x.ncols() != x_train.ncols() {
return Err(ModelError::InputValidationError(format!(
"Feature dimension mismatch: expected {}, got {}",
x_train.ncols(),
x.ncols()
)));
}
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Input array is empty".to_string(),
));
}
let y_train_encoded = self.y_train_encoded.as_ref().unwrap();
let (_, idx_to_label) = self.label_map.as_ref().unwrap();
let encoded_results: Result<Vec<usize>, ModelError> = (0..x.nrows())
.map(|i| {
let sample = x.row(i);
self.predict_one(sample, x_train.view(), y_train_encoded)
})
.collect();
encoded_results.map(|encoded_preds| {
Array1::from(
encoded_preds
.into_iter()
.map(|idx| idx_to_label[idx].clone())
.collect::<Vec<_>>(),
)
})
}
}
impl<T: Clone + std::hash::Hash + Eq + Sync + Send> KNN<T> {
pub fn predict_parallel<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<T>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
if self.x_train.is_none() || self.y_train_encoded.is_none() || self.label_map.is_none() {
return Err(ModelError::NotFitted);
}
preliminary_check(x, None)?;
let x_train = self.x_train.as_ref().unwrap();
if x.ncols() != x_train.ncols() {
return Err(ModelError::InputValidationError(format!(
"Feature dimension mismatch: expected {}, got {}",
x_train.ncols(),
x.ncols()
)));
}
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Input array is empty".to_string(),
));
}
let y_train_encoded = self.y_train_encoded.as_ref().unwrap();
let (_, idx_to_label) = self.label_map.as_ref().unwrap();
let encoded_results: Result<Vec<usize>, ModelError> = (0..x.nrows())
.into_par_iter()
.map(|i| {
let sample = x.row(i);
self.predict_one(sample, x_train.view(), y_train_encoded)
})
.collect();
encoded_results.map(|encoded_preds| {
Array1::from(
encoded_preds
.into_par_iter()
.map(|idx| idx_to_label[idx].clone())
.collect::<Vec<_>>(),
)
})
}
}
impl<T: Clone + std::hash::Hash + Eq> KNN<T> {
fn calculate_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64 {
match self.metric {
DistanceCalculationMetric::Euclidean => squared_euclidean_distance_row(&a, &b).sqrt(),
DistanceCalculationMetric::Manhattan => manhattan_distance_row(&a, &b),
DistanceCalculationMetric::Minkowski(p) => minkowski_distance_row(&a, &b, p),
}
}
fn predict_one(
&self,
x: ArrayView1<f64>,
x_train: ArrayView2<f64>,
y_train_encoded: &Array1<usize>,
) -> Result<usize, ModelError> {
let n_samples = x_train.nrows();
let k = self.k.min(n_samples);
let mut distances: Vec<(f64, usize)> = if n_samples >= PARALLEL_THRESHOLD {
(0..n_samples)
.into_iter()
.map(|i| -> Result<(f64, usize), ModelError> {
let distance = self.calculate_distance(x, x_train.row(i));
Ok((distance, i))
})
.collect::<Result<Vec<_>, _>>()?
} else {
(0..n_samples)
.map(|i| -> Result<(f64, usize), ModelError> {
let distance = self.calculate_distance(x, x_train.row(i));
Ok((distance, i))
})
.collect::<Result<Vec<_>, _>>()?
};
distances.select_nth_unstable_by(k - 1, |a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal) });
let k_neighbors = &distances[..k];
let result = match self.weighting_strategy {
WeightingStrategy::Uniform => {
const VOTING_PARALLEL_THRESHOLD: usize = 100;
if k >= VOTING_PARALLEL_THRESHOLD {
let class_counts = k_neighbors
.par_iter()
.fold(
|| AHashMap::new(),
|mut acc: AHashMap<usize, usize>, &(_, idx)| {
let class_idx = y_train_encoded[idx];
*acc.entry(class_idx).or_insert(0) += 1;
acc
},
)
.reduce(
|| AHashMap::new(),
|mut a, b| {
for (class_idx, count) in b {
*a.entry(class_idx).or_insert(0) += count;
}
a
},
);
class_counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(class_idx, _)| class_idx)
.ok_or(ModelError::ProcessingError(
"No valid neighbors found for classification".to_string(),
))?
} else {
let mut class_counts: AHashMap<usize, usize> = AHashMap::with_capacity(k);
for &(_, idx) in k_neighbors {
let class_idx = y_train_encoded[idx];
*class_counts.entry(class_idx).or_insert(0) += 1;
}
class_counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(class_idx, _)| class_idx)
.ok_or(ModelError::ProcessingError(
"No valid neighbors found for classification".to_string(),
))?
}
}
WeightingStrategy::Distance => {
if let Some(&(distance, idx)) = k_neighbors.first() {
if distance == 0.0 {
return Ok(y_train_encoded[idx]);
}
}
const WEIGHT_PARALLEL_THRESHOLD: usize = 100;
if k >= WEIGHT_PARALLEL_THRESHOLD {
let class_weights = k_neighbors
.par_iter()
.fold(
|| AHashMap::new(),
|mut acc: AHashMap<usize, f64>, &(distance, idx)| {
let weight = 1.0 / distance;
let class_idx = y_train_encoded[idx];
*acc.entry(class_idx).or_insert(0.0) += weight;
acc
},
)
.reduce(
|| AHashMap::new(),
|mut a, b| {
for (class_idx, weight) in b {
*a.entry(class_idx).or_insert(0.0) += weight;
}
a
},
);
class_weights
.into_iter()
.max_by(|(_, weight_a), (_, weight_b)| {
weight_a
.partial_cmp(weight_b)
.unwrap_or(std::cmp::Ordering::Equal) })
.map(|(class_idx, _)| class_idx)
.ok_or(ModelError::ProcessingError(
"No valid neighbors found for classification".to_string(),
))?
} else {
let mut class_weights: AHashMap<usize, f64> = AHashMap::with_capacity(k);
for &(distance, idx) in k_neighbors {
let weight = 1.0 / distance;
let class_idx = y_train_encoded[idx];
*class_weights.entry(class_idx).or_insert(0.0) += weight;
}
class_weights
.into_iter()
.max_by(|(_, weight_a), (_, weight_b)| {
weight_a
.partial_cmp(weight_b)
.unwrap_or(std::cmp::Ordering::Equal) })
.map(|(class_idx, _)| class_idx)
.ok_or(ModelError::ProcessingError(
"No valid neighbors found for classification".to_string(),
))?
}
}
};
Ok(result)
}
pub fn fit_predict<S1, S2>(
&mut self,
x_train: &ArrayBase<S1, Ix2>,
y_train: &ArrayBase<S2, Ix1>,
) -> Result<Array1<T>, ModelError>
where
S1: Data<Elem = f64>,
S2: Data<Elem = T>,
{
self.fit(x_train, y_train)?;
Ok(self.predict(x_train)?)
}
}
impl<T: Clone + std::hash::Hash + Eq + Serialize + for<'de> Deserialize<'de>> KNN<T> {
model_save_and_load_methods!(KNN<T>);
}