use super::traits::Clustering;
use crate::error::{Error, Result};
use clump::DistanceMetric;
pub use clump::NOISE;
#[derive(Debug, Clone)]
pub struct Dbscan<D: DistanceMetric = clump::Euclidean> {
inner: clump::Dbscan<D>,
}
impl Dbscan<clump::Euclidean> {
pub fn new(epsilon: f32, min_pts: usize) -> Self {
Self {
inner: clump::Dbscan::new(epsilon, min_pts),
}
}
}
impl<D: DistanceMetric> Dbscan<D> {
pub fn with_metric(epsilon: f32, min_pts: usize, metric: D) -> Self {
Self {
inner: clump::Dbscan::with_metric(epsilon, min_pts, metric),
}
}
pub fn with_epsilon(mut self, epsilon: f32) -> Self {
self.inner = self.inner.with_epsilon(epsilon);
self
}
pub fn with_min_pts(mut self, min_pts: usize) -> Self {
self.inner = self.inner.with_min_pts(min_pts);
self
}
pub fn is_noise(label: usize) -> bool {
label == NOISE
}
}
impl Default for Dbscan<clump::Euclidean> {
fn default() -> Self {
Self {
inner: clump::Dbscan::default(),
}
}
}
impl<D: DistanceMetric> Clustering for Dbscan<D> {
fn fit_predict(&self, data: &[Vec<f32>]) -> Result<Vec<usize>> {
self.inner.fit_predict(data).map_err(Error::from)
}
fn n_clusters(&self) -> usize {
0
}
}
pub trait DbscanExt {
fn fit_predict_with_noise(&self, data: &[Vec<f32>]) -> Result<Vec<Option<usize>>>;
fn is_noise(label: usize) -> bool {
label == NOISE
}
}
impl<D: DistanceMetric> DbscanExt for Dbscan<D> {
fn fit_predict_with_noise(&self, data: &[Vec<f32>]) -> Result<Vec<Option<usize>>> {
self.inner.fit_predict_with_noise(data).map_err(Error::from)
}
}
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests {
use super::*;
#[test]
fn test_dbscan_two_clusters() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![0.0, 0.1],
vec![0.1, 0.1],
vec![0.05, 0.05],
vec![5.0, 5.0],
vec![5.1, 5.0],
vec![5.0, 5.1],
vec![5.1, 5.1],
vec![5.05, 5.05],
];
let dbscan = Dbscan::new(0.3, 3);
let labels = dbscan.fit_predict(&data).unwrap();
assert_eq!(labels.len(), 10);
let cluster1 = labels[0];
for label in &labels[1..5] {
assert_eq!(*label, cluster1);
}
let cluster2 = labels[5];
for label in &labels[6..10] {
assert_eq!(*label, cluster2);
}
assert_ne!(cluster1, cluster2);
}
#[test]
fn test_dbscan_with_noise() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![0.0, 0.1],
vec![0.1, 0.1],
vec![100.0, 100.0],
vec![5.0, 5.0],
vec![5.1, 5.0],
vec![5.0, 5.1],
vec![5.1, 5.1],
];
let dbscan = Dbscan::new(0.3, 3);
let labels = dbscan.fit_predict_with_noise(&data).unwrap();
assert_eq!(labels.len(), 9);
assert!(labels[4].is_none());
for (i, label) in labels.iter().enumerate() {
if i != 4 {
assert!(label.is_some());
}
}
}
#[test]
fn test_dbscan_fit_predict_uses_noise_sentinel() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![0.0, 0.1],
vec![0.1, 0.1],
vec![100.0, 100.0],
vec![5.0, 5.0],
vec![5.1, 5.0],
vec![5.0, 5.1],
vec![5.1, 5.1],
];
let dbscan = Dbscan::new(0.3, 3);
let labels = dbscan.fit_predict(&data).unwrap();
assert_eq!(labels.len(), 9);
assert_eq!(labels[4], NOISE);
assert!(Dbscan::<clump::Euclidean>::is_noise(labels[4]));
}
#[test]
fn test_dbscan_all_noise() {
let data = vec![
vec![0.0, 0.0],
vec![10.0, 0.0],
vec![0.0, 10.0],
vec![10.0, 10.0],
];
let dbscan = Dbscan::new(0.5, 3);
let labels = dbscan.fit_predict_with_noise(&data).unwrap();
for label in labels {
assert!(label.is_none());
}
}
#[test]
fn test_dbscan_all_one_cluster() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![0.0, 0.1],
vec![0.1, 0.1],
];
let dbscan = Dbscan::new(0.5, 2);
let labels = dbscan.fit_predict(&data).unwrap();
let cluster = labels[0];
for label in labels {
assert_eq!(label, cluster);
}
}
#[test]
fn test_dbscan_empty() {
let data: Vec<Vec<f32>> = vec![];
let dbscan = Dbscan::new(0.5, 3);
let result = dbscan.fit_predict(&data);
assert!(result.is_err());
}
#[test]
fn test_dbscan_invalid_params() {
assert!(std::panic::catch_unwind(|| Dbscan::new(0.0, 3)).is_err());
assert!(std::panic::catch_unwind(|| Dbscan::new(-1.0, 3)).is_err());
assert!(std::panic::catch_unwind(|| Dbscan::new(0.5, 0)).is_err());
}
#[test]
fn test_dbscan_chain() {
let data: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.3, 0.0]).collect();
let dbscan = Dbscan::new(0.5, 2);
let labels = dbscan.fit_predict(&data).unwrap();
let cluster = labels[0];
for label in labels {
assert_eq!(label, cluster);
}
}
}