pub use super::DistanceCalculationMetric;
use super::helper_function::preliminary_check;
use crate::error::ModelError;
use crate::math::{manhattan_distance_row, minkowski_distance_row, squared_euclidean_distance_row};
use crate::{Deserialize, Serialize};
use ahash::AHashSet;
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, ArrayBase, ArrayView1, Data, Ix2};
use rayon::prelude::{IntoParallelIterator, ParallelBridge, ParallelIterator};
use std::collections::VecDeque;
const DBSCAN_PARALLEL_THRESHOLD: usize = 1000;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DBSCAN {
eps: f64,
min_samples: usize,
metric: DistanceCalculationMetric,
labels_: Option<Array1<i32>>,
core_sample_indices: Option<Array1<usize>>,
}
impl Default for DBSCAN {
fn default() -> Self {
DBSCAN {
eps: 0.5,
min_samples: 5,
metric: DistanceCalculationMetric::Euclidean,
labels_: None,
core_sample_indices: None,
}
}
}
impl DBSCAN {
pub fn new(
eps: f64,
min_samples: usize,
metric: DistanceCalculationMetric,
) -> Result<Self, ModelError> {
if eps <= 0.0 || !eps.is_finite() {
return Err(ModelError::InputValidationError(format!(
"eps must be positive and finite, got {}",
eps
)));
}
if min_samples == 0 {
return Err(ModelError::InputValidationError(
"min_samples must be greater than 0".to_string(),
));
}
match metric {
DistanceCalculationMetric::Minkowski(p) => {
if p <= 0.0 || !p.is_finite() {
return Err(ModelError::InputValidationError(format!(
"Minkowski p must be positive and finite, got {}",
p
)));
}
}
_ => {} }
Ok(DBSCAN {
eps,
min_samples,
metric,
labels_: None,
core_sample_indices: None,
})
}
get_field!(get_epsilon, eps, f64);
get_field!(get_min_samples, min_samples, usize);
get_field!(get_metric, metric, DistanceCalculationMetric);
get_field_as_ref!(get_labels, labels_, Option<&Array1<i32>>);
get_field_as_ref!(
get_core_sample_indices,
core_sample_indices,
Option<&Array1<usize>>
);
fn compute_distance(&self, p_row: ArrayView1<f64>, q_row: ArrayView1<f64>) -> f64 {
match self.metric {
DistanceCalculationMetric::Euclidean => {
squared_euclidean_distance_row(&p_row, &q_row).sqrt()
}
DistanceCalculationMetric::Manhattan => manhattan_distance_row(&p_row, &q_row),
DistanceCalculationMetric::Minkowski(p) => minkowski_distance_row(&p_row, &q_row, p),
}
}
fn region_query<S>(&self, data: &ArrayBase<S, Ix2>, p: usize) -> Result<Vec<usize>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
if p >= data.nrows() {
return Err(ModelError::InputValidationError(format!(
"Point index {} is out of bounds (max: {})",
p,
data.nrows() - 1
)));
}
let p_row = data.row(p);
let eps = self.eps;
let n_samples = data.nrows();
let neighbors: Vec<usize> = if n_samples >= DBSCAN_PARALLEL_THRESHOLD {
(0..n_samples)
.into_par_iter()
.filter_map(|q| {
let q_row = data.row(q);
let dist = self.compute_distance(p_row, q_row);
if dist <= eps { Some(q) } else { None }
})
.collect()
} else {
(0..n_samples)
.filter_map(|q| {
let q_row = data.row(q);
let dist = self.compute_distance(p_row, q_row);
if dist <= eps { Some(q) } else { None }
})
.collect()
};
Ok(neighbors)
}
pub fn fit<S>(&mut self, data: &ArrayBase<S, Ix2>) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
preliminary_check(&data, None)?;
let n_samples = data.nrows();
if n_samples > i32::MAX as usize {
return Err(ModelError::InputValidationError(
"Dataset too large: exceeds maximum number of samples".to_string(),
));
}
let mut labels = Array1::from(vec![-1; n_samples]); let mut core_samples = AHashSet::with_capacity(n_samples / 4); let mut cluster_id = 0i32;
let pb = ProgressBar::new(n_samples as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Clusters: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
pb.set_message("0 | Core points: 0");
for p in 0..n_samples {
pb.inc(1);
if labels[p] != -1 {
continue;
}
let neighbors = self.region_query(&data, p).map_err(|e| {
ModelError::ProcessingError(format!("Region query failed: {:?}", e))
})?;
if neighbors.len() < self.min_samples {
labels[p] = -1; continue;
}
labels[p] = cluster_id;
core_samples.insert(p);
let mut seeds: VecDeque<usize> = neighbors.into_iter().collect();
while let Some(q) = seeds.pop_front() {
if labels[q] == cluster_id {
continue;
}
labels[q] = cluster_id;
let q_neighbors = self.region_query(&data, q).map_err(|e| {
ModelError::ProcessingError(format!(
"Region query failed for point {}: {:?}",
q, e
))
})?;
if q_neighbors.len() >= self.min_samples {
core_samples.insert(q);
for r in q_neighbors {
if labels[r] != cluster_id {
seeds.push_back(r);
}
}
}
}
cluster_id += 1;
pb.set_message(format!(
"{} | Core points: {}",
cluster_id,
core_samples.len()
));
if cluster_id >= i32::MAX {
pb.finish_with_message("Error: cluster ID overflow");
return Err(ModelError::ProcessingError(
"Too many clusters: cluster ID overflow".to_string(),
));
}
}
pb.finish_with_message(format!(
"{} | Core points: {} | Noise points: {}",
cluster_id,
core_samples.len(),
labels.iter().filter(|&&x| x == -1).count()
));
self.labels_ = Some(labels);
let mut core_indices: Vec<usize> = core_samples.into_iter().collect();
core_indices.sort_unstable();
self.core_sample_indices = Some(Array1::from(core_indices));
Ok(self)
}
pub fn predict<S>(
&self,
trained_data: &ArrayBase<S, Ix2>,
new_data: &ArrayBase<S, Ix2>,
) -> Result<Array1<i32>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
let labels = self.labels_.as_ref().ok_or(ModelError::NotFitted)?;
let core_samples = self
.core_sample_indices
.as_ref()
.ok_or(ModelError::NotFitted)?;
if trained_data.nrows() == 0 {
return Err(ModelError::InputValidationError(
"Trained data is empty".to_string(),
));
}
if new_data.nrows() == 0 {
return Ok(Array1::from(vec![]));
}
if trained_data.ncols() != new_data.ncols() {
return Err(ModelError::InputValidationError(format!(
"Feature dimension mismatch: trained data has {} features, new data has {} features",
trained_data.ncols(),
new_data.ncols()
)));
}
if trained_data.nrows() != labels.len() {
return Err(ModelError::InputValidationError(format!(
"Trained data rows ({}) don't match labels length ({})",
trained_data.nrows(),
labels.len()
)));
}
if trained_data.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Trained data contains NaN or infinite values".to_string(),
));
}
if new_data.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"New data contains NaN or infinite values".to_string(),
));
}
let core_set: AHashSet<usize> = core_samples.iter().copied().collect();
let predictions: Result<Vec<i32>, ModelError> = new_data
.rows()
.into_iter()
.par_bridge() .map(|row| -> Result<i32, ModelError> {
let mut min_dist = f64::MAX;
let mut closest_label = -1;
for (j, orig_row) in trained_data.rows().into_iter().enumerate() {
if labels[j] == -1 {
continue; }
let dist = self.compute_distance(row, orig_row);
if dist.is_nan() || dist.is_infinite() {
continue;
}
if dist <= self.eps && core_set.contains(&j) {
return Ok(labels[j]);
}
if dist < min_dist {
min_dist = dist;
closest_label = labels[j];
}
}
if min_dist <= self.eps {
Ok(closest_label)
} else {
Ok(-1)
}
})
.collect();
Ok(Array1::from(predictions?))
}
pub fn fit_predict<S>(&mut self, data: &ArrayBase<S, Ix2>) -> Result<Array1<i32>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
self.fit(data)?;
Ok(self.labels_.as_ref().unwrap().clone())
}
model_save_and_load_methods!(DBSCAN);
}