use crate::error::{AnalyticsError, Result};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct DbscanResult {
pub labels: Array1<i32>,
pub n_clusters: usize,
pub n_noise: usize,
pub core_points: Array1<bool>,
}
pub struct DbscanClusterer {
eps: f64,
min_samples: usize,
}
impl DbscanClusterer {
pub fn new(eps: f64, min_samples: usize) -> Self {
Self { eps, min_samples }
}
pub fn fit(&self, data: &ArrayView2<f64>) -> Result<DbscanResult> {
let n_samples = data.nrows();
if n_samples == 0 {
return Err(AnalyticsError::insufficient_data("Data is empty"));
}
if self.eps <= 0.0 {
return Err(AnalyticsError::invalid_parameter("eps", "must be positive"));
}
if self.min_samples == 0 {
return Err(AnalyticsError::invalid_parameter(
"min_samples",
"must be positive",
));
}
let distances = self.compute_distances(data)?;
let neighbors = self.find_neighbors(&distances);
let mut labels = Array1::from_elem(n_samples, -1);
let mut core_points = Array1::from_elem(n_samples, false);
let mut cluster_id = 0;
for i in 0..n_samples {
if labels[i] != -1 {
continue;
}
if neighbors[i].len() < self.min_samples {
continue; }
core_points[i] = true;
self.expand_cluster(i, cluster_id, &neighbors, &mut labels, &mut core_points)?;
cluster_id += 1;
}
let n_clusters = cluster_id as usize;
let n_noise = labels.iter().filter(|&&x| x == -1).count();
Ok(DbscanResult {
labels,
n_clusters,
n_noise,
core_points,
})
}
fn compute_distances(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
let n_samples = data.nrows();
let mut distances = Array2::zeros((n_samples, n_samples));
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let dist = euclidean_distance(&data.row(i), &data.row(j))?;
distances[[i, j]] = dist;
distances[[j, i]] = dist;
}
}
Ok(distances)
}
fn find_neighbors(&self, distances: &Array2<f64>) -> Vec<Vec<usize>> {
let n_samples = distances.nrows();
let mut neighbors = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let mut point_neighbors = Vec::new();
for j in 0..n_samples {
if i != j && distances[[i, j]] <= self.eps {
point_neighbors.push(j);
}
}
neighbors.push(point_neighbors);
}
neighbors
}
fn expand_cluster(
&self,
start: usize,
cluster_id: i32,
neighbors: &[Vec<usize>],
labels: &mut Array1<i32>,
core_points: &mut Array1<bool>,
) -> Result<()> {
let mut queue = VecDeque::new();
queue.push_back(start);
labels[start] = cluster_id;
while let Some(point) = queue.pop_front() {
if neighbors[point].len() < self.min_samples {
continue;
}
core_points[point] = true;
for &neighbor in &neighbors[point] {
if labels[neighbor] == -1 {
labels[neighbor] = cluster_id;
queue.push_back(neighbor);
} else if labels[neighbor] == -2 {
labels[neighbor] = cluster_id;
}
}
}
Ok(())
}
}
fn euclidean_distance(
p1: &scirs2_core::ndarray::ArrayView1<f64>,
p2: &scirs2_core::ndarray::ArrayView1<f64>,
) -> Result<f64> {
if p1.len() != p2.len() {
return Err(AnalyticsError::dimension_mismatch(
format!("{}", p1.len()),
format!("{}", p2.len()),
));
}
let dist_sq: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
Ok(dist_sq.sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_dbscan_simple() {
let data = array![
[0.0, 0.0],
[0.1, 0.1],
[0.2, 0.0],
[10.0, 10.0],
[10.1, 10.1],
[10.0, 10.2],
[5.0, 5.0], ];
let clusterer = DbscanClusterer::new(0.5, 2);
let result = clusterer
.fit(&data.view())
.expect("DBSCAN clustering should succeed for valid data");
assert_eq!(result.n_clusters, 2);
assert!(result.n_noise > 0); assert!(result.labels.iter().any(|&x| x == -1)); }
#[test]
fn test_dbscan_all_noise() {
let data = array![[0.0, 0.0], [10.0, 10.0], [20.0, 20.0]];
let clusterer = DbscanClusterer::new(0.5, 2);
let result = clusterer
.fit(&data.view())
.expect("DBSCAN should succeed even when all points are noise");
assert_eq!(result.n_clusters, 0);
assert_eq!(result.n_noise, 3);
}
#[test]
fn test_dbscan_single_cluster() {
let data = array![[0.0, 0.0], [0.1, 0.0], [0.2, 0.0], [0.3, 0.0], [0.4, 0.0],];
let clusterer = DbscanClusterer::new(0.15, 2);
let result = clusterer
.fit(&data.view())
.expect("DBSCAN should succeed for single cluster data");
assert_eq!(result.n_clusters, 1);
assert_eq!(result.n_noise, 0);
}
#[test]
fn test_dbscan_invalid_params() {
let data = array![[1.0, 2.0]];
let clusterer = DbscanClusterer::new(-1.0, 2);
let result = clusterer.fit(&data.view());
assert!(result.is_err());
}
}