use crate::error::Result;
use crate::primitives::Matrix;
use crate::traits::UnsupervisedEstimator;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DBSCAN {
eps: f32,
min_samples: usize,
labels: Option<Vec<i32>>,
}
impl DBSCAN {
#[must_use]
pub fn new(eps: f32, min_samples: usize) -> Self {
Self {
eps,
min_samples,
labels: None,
}
}
#[must_use]
pub fn eps(&self) -> f32 {
self.eps
}
#[must_use]
pub fn min_samples(&self) -> usize {
self.min_samples
}
#[must_use]
pub fn is_fitted(&self) -> bool {
self.labels.is_some()
}
#[must_use]
pub fn labels(&self) -> &Vec<i32> {
self.labels
.as_ref()
.expect("Model not fitted. Call fit() first.")
}
fn region_query(&self, x: &Matrix<f32>, i: usize) -> Vec<usize> {
let mut neighbors = Vec::new();
let n_samples = x.shape().0;
for j in 0..n_samples {
let dist = self.euclidean_distance(x, i, j);
if dist <= self.eps {
neighbors.push(j);
}
}
neighbors
}
#[allow(clippy::unused_self)]
fn euclidean_distance(&self, x: &Matrix<f32>, i: usize, j: usize) -> f32 {
let n_features = x.shape().1;
let row_i: Vec<f32> = (0..n_features).map(|k| x.get(i, k)).collect();
let row_j: Vec<f32> = (0..n_features).map(|k| x.get(j, k)).collect();
crate::nn::functional::euclidean_distance(&row_i, &row_j)
}
fn expand_cluster(
&self,
x: &Matrix<f32>,
labels: &mut [i32],
point: usize,
neighbors: &mut Vec<usize>,
cluster_id: i32,
) {
labels[point] = cluster_id;
let mut i = 0;
while i < neighbors.len() {
let neighbor = neighbors[i];
if labels[neighbor] == -2 {
labels[neighbor] = cluster_id;
let neighbor_neighbors = self.region_query(x, neighbor);
if neighbor_neighbors.len() >= self.min_samples {
for &nn in &neighbor_neighbors {
if !neighbors.contains(&nn) {
neighbors.push(nn);
}
}
}
} else if labels[neighbor] == -1 {
labels[neighbor] = cluster_id;
}
i += 1;
}
}
}
impl UnsupervisedEstimator for DBSCAN {
type Labels = Vec<i32>;
fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
let n_samples = x.shape().0;
let mut labels = vec![-2; n_samples]; let mut cluster_id = 0;
for i in 0..n_samples {
if labels[i] != -2 {
continue;
}
let mut neighbors = self.region_query(x, i);
if neighbors.len() < self.min_samples {
labels[i] = -1;
continue;
}
self.expand_cluster(x, &mut labels, i, &mut neighbors, cluster_id);
cluster_id += 1;
}
self.labels = Some(labels);
Ok(())
}
fn predict(&self, _x: &Matrix<f32>) -> Self::Labels {
self.labels().clone()
}
}
#[cfg(test)]
#[path = "tests_dbscan_contract.rs"]
mod tests_dbscan_contract;