use std::collections::VecDeque;
use crate::{GeoFloat, MultiPoint, Point};
use rstar::RTree;
use rstar::primitives::GeomWithData;
pub trait Dbscan<T>
where
T: GeoFloat,
{
fn dbscan(&self, epsilson: T, min_samples: usize) -> Vec<Option<usize>>;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PointState {
Unvisited,
Noise,
Queued,
Clustered(usize),
}
fn dbscan_impl<T>(points: &[Point<T>], epsilon: T, min_samples: usize) -> Vec<Option<usize>>
where
T: GeoFloat,
{
let n = points.len();
if n == 0 {
return Vec::new();
}
if min_samples == 0 || min_samples > n {
return vec![None; n];
}
let tree = RTree::bulk_load(
points
.iter()
.enumerate()
.map(|(idx, &point)| GeomWithData::new(point, idx))
.collect(),
);
let mut states = vec![PointState::Unvisited; n];
let mut cluster_id = 0;
let mut neighbours_buf = Vec::with_capacity(min_samples);
let mut queue = VecDeque::new();
for point_idx in 0..n {
if states[point_idx] != PointState::Unvisited {
continue;
}
queue.clear();
queue.extend(
tree.locate_within_distance(points[point_idx], epsilon * epsilon)
.map(|geom_with_data| geom_with_data.data),
);
if queue.len() < min_samples {
states[point_idx] = PointState::Noise;
continue;
}
states[point_idx] = PointState::Clustered(cluster_id);
for &neighbour_idx in &queue {
if matches!(
states[neighbour_idx],
PointState::Unvisited | PointState::Noise
) {
states[neighbour_idx] = PointState::Queued;
}
}
while let Some(current_idx) = queue.pop_front() {
match states[current_idx] {
PointState::Queued => {
states[current_idx] = PointState::Clustered(cluster_id);
neighbours_buf.clear();
neighbours_buf.extend(
tree.locate_within_distance(points[current_idx], epsilon * epsilon)
.map(|geom_with_data| geom_with_data.data),
);
if neighbours_buf.len() >= min_samples {
for &neighbour_idx in &neighbours_buf {
if matches!(
states[neighbour_idx],
PointState::Unvisited | PointState::Noise
) {
queue.push_back(neighbour_idx);
states[neighbour_idx] = PointState::Queued;
}
}
}
}
_ => {
continue;
}
}
}
cluster_id += 1;
}
states
.into_iter()
.map(|state| match state {
PointState::Clustered(id) => Some(id),
_ => None,
})
.collect()
}
impl<T> Dbscan<T> for MultiPoint<T>
where
T: GeoFloat,
{
fn dbscan(&self, epsilon: T, min_samples: usize) -> Vec<Option<usize>> {
dbscan_impl(&self.0, epsilon, min_samples)
}
}
impl<T> Dbscan<T> for &MultiPoint<T>
where
T: GeoFloat,
{
fn dbscan(&self, epsilon: T, min_samples: usize) -> Vec<Option<usize>> {
dbscan_impl(&self.0, epsilon, min_samples)
}
}
impl<T> Dbscan<T> for [Point<T>]
where
T: GeoFloat,
{
fn dbscan(&self, epsilon: T, min_samples: usize) -> Vec<Option<usize>> {
dbscan_impl(self, epsilon, min_samples)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::point;
#[test]
fn test_dbscan_empty() {
let points: Vec<Point<f64>> = vec![];
let labels = points.dbscan(1.0, 2);
assert_eq!(labels.len(), 0);
}
#[test]
fn test_dbscan_single_point() {
let points = [point!(x: 0.0, y: 0.0)];
let labels = points.dbscan(1.0, 2);
assert_eq!(labels, vec![None]); }
#[test]
fn test_dbscan_all_noise() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 10.0, y: 10.0),
point!(x: 20.0, y: 20.0),
];
let labels = points.dbscan(1.0, 2);
assert_eq!(labels, vec![None, None, None]);
}
#[test]
fn test_dbscan_single_cluster() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 1.0, y: 1.0),
];
let labels = points.dbscan(1.5, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
assert_eq!(labels[3], Some(0));
}
#[test]
fn test_dbscan_two_clusters() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
];
let labels = points.dbscan(2.0, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
assert_eq!(labels[3], Some(1));
assert_eq!(labels[4], Some(1));
assert_eq!(labels[5], Some(1));
}
#[test]
fn test_dbscan_with_noise() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 100.0, y: 100.0),
];
let labels = points.dbscan(2.0, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
assert_eq!(labels[3], None);
}
#[test]
fn test_dbscan_border_points() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.5, y: 0.5),
point!(x: 2.0, y: 0.0),
];
let labels = points.dbscan(1.5, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
assert_eq!(labels[3], Some(0));
}
#[test]
fn test_dbscan_multipoint() {
let points = MultiPoint::new(vec![
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
point!(x: 10.0, y: 10.0),
point!(x: 11.0, y: 10.0),
point!(x: 10.0, y: 11.0),
]);
let labels = points.dbscan(2.0, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
assert_eq!(labels[3], Some(1));
assert_eq!(labels[4], Some(1));
assert_eq!(labels[5], Some(1));
}
#[test]
fn test_dbscan_min_samples_includes_self() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 10.0, y: 10.0),
point!(x: 20.0, y: 20.0),
];
let labels = points.dbscan(1.0, 1);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(1));
assert_eq!(labels[2], Some(2));
}
#[test]
fn test_dbscan_varying_density() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 0.5, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.5, y: 0.5),
point!(x: 10.0, y: 10.0),
point!(x: 12.0, y: 10.0),
point!(x: 11.0, y: 12.0),
];
let labels = points.dbscan(2.5, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
assert_eq!(labels[3], Some(0));
assert_eq!(labels[4], Some(1));
assert_eq!(labels[5], Some(1));
assert_eq!(labels[6], Some(1));
}
#[test]
fn test_dbscan_min_samples_too_large() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 0.0, y: 1.0),
];
let labels = points.dbscan(2.0, 10);
assert_eq!(labels, vec![None, None, None]);
}
#[test]
fn test_dbscan_min_samples_zero() {
let points = [point!(x: 0.0, y: 0.0), point!(x: 1.0, y: 0.0)];
let labels = points.dbscan(2.0, 0);
assert_eq!(labels, vec![None, None]);
}
#[test]
fn test_dbscan_identical_points() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 0.0, y: 0.0),
point!(x: 0.0, y: 0.0),
];
let labels = points.dbscan(0.1, 2);
assert_eq!(labels[0], Some(0));
assert_eq!(labels[1], Some(0));
assert_eq!(labels[2], Some(0));
}
#[test]
fn test_dbscan_linear_cluster() {
let points = [
point!(x: 0.0, y: 0.0),
point!(x: 1.0, y: 0.0),
point!(x: 2.0, y: 0.0),
point!(x: 3.0, y: 0.0),
point!(x: 4.0, y: 0.0),
];
let labels = points.dbscan(1.5, 2);
assert!(labels.iter().all(|&label| label == Some(0)));
}
}