use std::{collections::VecDeque, fmt::Display};
use thiserror::Error;
use crate::math::{
clustering::{Cluster, ClusteringAlgorithm},
neighbors::{kdtree::KDTreeSearch, neighbor::Neighbor, search::NeighborSearch},
DistanceMetric,
FloatNumber,
Point,
};
#[derive(Debug, PartialEq, Error)]
pub enum DBSCANError<T>
where
T: FloatNumber + Display,
{
#[error("Invalid minimum points: The minimum number of points must be greater than zero: {0}")]
InvalidMinPoints(usize),
#[error("Invalid epsilon: The epsilon must be greater than zero: {0}")]
InvalidEpsilon(T),
#[error("Empty points: The points must be non-empty.")]
EmptyPoints,
}
const OUTLIER: i32 = -1;
const MARKED: i32 = -2;
const UNCLASSIFIED: i32 = -3;
#[derive(Debug, PartialEq)]
#[allow(clippy::upper_case_acronyms)]
pub struct DBSCAN<T>
where
T: FloatNumber,
{
min_points: usize,
epsilon: T,
metric: DistanceMetric,
}
impl<T> DBSCAN<T>
where
T: FloatNumber,
{
pub fn new(
min_points: usize,
epsilon: T,
metric: DistanceMetric,
) -> Result<Self, DBSCANError<T>> {
if min_points == 0 {
return Err(DBSCANError::InvalidMinPoints(min_points));
}
if epsilon <= T::zero() {
return Err(DBSCANError::InvalidEpsilon(epsilon));
}
Ok(Self {
min_points,
epsilon,
metric,
})
}
#[inline]
#[must_use]
fn expand_cluster<const N: usize, NS>(
&self,
label: i32,
labels: &mut [i32],
points: &[Point<T, N>],
neighbors: Vec<Neighbor<T>>,
neighbor_search: &NS,
) -> Cluster<T, N>
where
NS: NeighborSearch<T, N>,
{
let mut cluster = Cluster::new();
let mut queue = VecDeque::from(neighbors);
while let Some(neighbor) = queue.pop_front() {
let index = neighbor.index;
if labels[index] >= 0 {
continue;
}
let point = &points[index];
if labels[index] == OUTLIER {
labels[index] = label;
cluster.add_member(index, point);
continue;
}
labels[index] = label;
cluster.add_member(index, point);
let secondary_neighbors = neighbor_search.search_radius(point, self.epsilon);
if secondary_neighbors.len() < self.min_points {
continue;
}
for secondary_neighbor in secondary_neighbors {
let secondary_index = secondary_neighbor.index;
if labels[secondary_index] == UNCLASSIFIED {
labels[secondary_index] = MARKED;
queue.push_back(secondary_neighbor);
} else if labels[secondary_index] == OUTLIER {
queue.push_back(secondary_neighbor);
}
}
}
cluster
}
}
impl<T, const N: usize> ClusteringAlgorithm<T, N> for DBSCAN<T>
where
T: FloatNumber,
{
type Err = DBSCANError<T>;
fn fit(&self, points: &[Point<T, N>]) -> Result<Vec<Cluster<T, N>>, Self::Err> {
if points.is_empty() {
return Err(DBSCANError::EmptyPoints);
}
let mut labels = vec![UNCLASSIFIED; points.len()];
let mut clusters = Vec::new();
let mut current_label = 0;
let neighbor_search = KDTreeSearch::build(points, self.metric, 16);
for (index, point) in points.iter().enumerate() {
if labels[index] != UNCLASSIFIED {
continue;
}
let neighbors = neighbor_search.search_radius(point, self.epsilon);
if neighbors.len() < self.min_points {
labels[index] = OUTLIER;
continue;
}
for neighbor in &neighbors {
if labels[neighbor.index] != UNCLASSIFIED {
continue;
}
labels[neighbor.index] = MARKED;
}
let cluster = self.expand_cluster(
current_label,
&mut labels,
points,
neighbors,
&neighbor_search,
);
if cluster.len() >= self.min_points {
clusters.push(cluster);
}
current_label += 1;
}
Ok(clusters)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::math::DistanceMetric;
#[must_use]
fn sample_points() -> Vec<Point<f32, 2>> {
vec![
[0.0, 0.0], [0.0, 1.0], [0.0, 7.0], [0.0, 8.0], [1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 7.0], [1.0, 8.0], [2.0, 1.0], [2.0, 2.0], [4.0, 3.0], [4.0, 4.0], [4.0, 5.0], [5.0, 3.0], [5.0, 4.0], [9.0, 8.0], ]
}
#[must_use]
fn empty_points() -> Vec<Point<f32, 2>> {
Vec::new()
}
#[test]
fn test_new() {
let actual = DBSCAN::new(5, 1e-3, DistanceMetric::Euclidean).unwrap();
assert_eq!(actual.min_points, 5);
assert_eq!(actual.epsilon, 1e-3);
assert_eq!(actual.metric, DistanceMetric::Euclidean);
}
#[rstest]
#[case::invalid_min_points(
0,
1e-3,
DistanceMetric::Euclidean,
DBSCANError::InvalidMinPoints(0)
)]
#[case::invalid_epsilon(5, 0.0, DistanceMetric::Euclidean, DBSCANError::InvalidEpsilon(0.0))]
fn test_new_error(
#[case] min_points: usize,
#[case] epsilon: f32,
#[case] metric: DistanceMetric,
#[case] expected: DBSCANError<f32>,
) {
let actual = DBSCAN::new(min_points, epsilon, metric);
assert!(actual.is_err());
assert_eq!(actual, Err(expected));
}
#[test]
fn test_fit() {
let points = sample_points();
let dbscan = DBSCAN::new(4, 2.0, DistanceMetric::Euclidean).unwrap();
let mut actual = dbscan.fit(&points).unwrap();
actual.sort_by(|cluster1, cluster2| cluster2.len().cmp(&cluster1.len()));
assert_eq!(actual.len(), 3);
assert_eq!(actual[0].len(), 7);
assert_eq!(actual[0].centroid(), &[1.0, 1.0]);
assert_eq!(actual[1].len(), 5);
assert_eq!(actual[1].centroid(), &[4.4, 3.8]);
assert_eq!(actual[2].len(), 4);
assert_eq!(actual[2].centroid(), &[0.5, 7.5]);
}
#[test]
fn test_fit_empty() {
let points = empty_points();
let dbscan = DBSCAN::new(4, 2.0, DistanceMetric::Euclidean).unwrap();
let actual = dbscan.fit(&points);
assert!(actual.is_err());
assert_eq!(actual.unwrap_err(), DBSCANError::EmptyPoints);
}
}