use crate::dataset::Dataset;
use crate::distance::{cosine_distance, euclidean_sq, manhattan};
use crate::error::{Result, ScryLearnError};
use crate::neighbors::kdtree::KdTree;
use crate::neighbors::DistanceMetric;
const KDTREE_MAX_DIM: usize = 20;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Dbscan {
eps: f64,
min_samples: usize,
metric: DistanceMetric,
labels: Vec<i32>, n_clusters: usize,
core_features: Vec<Vec<f64>>,
core_labels: Vec<i32>,
fitted: bool,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl Dbscan {
pub fn new(eps: f64, min_samples: usize) -> Self {
Self {
eps,
min_samples,
metric: DistanceMetric::Euclidean,
labels: Vec::new(),
n_clusters: 0,
core_features: Vec::new(),
core_labels: Vec::new(),
fitted: false,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn metric(mut self, m: DistanceMetric) -> Self {
self.metric = m;
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let n = data.n_samples();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
let rows = data.feature_matrix();
let n_features = data.n_features();
let threshold = self.eps_threshold();
let use_kdtree =
matches!(self.metric, DistanceMetric::Euclidean) && n_features <= KDTREE_MAX_DIM;
let kdtree = if use_kdtree {
Some(KdTree::build(&rows))
} else {
None
};
let mut labels = vec![-1i32; n];
let mut cluster_id = 0i32;
for i in 0..n {
if labels[i] != -1 {
continue;
}
let neighbors = self.find_neighbors(i, &rows, threshold, kdtree.as_ref());
if neighbors.len() < self.min_samples {
continue; }
labels[i] = cluster_id;
let mut queue: Vec<usize> = neighbors.into_iter().filter(|&j| j != i).collect();
let mut qi = 0;
while qi < queue.len() {
let j = queue[qi];
qi += 1;
if labels[j] == -1 {
labels[j] = cluster_id;
}
if labels[j] != cluster_id {
continue;
}
let j_neighbors = self.find_neighbors(j, &rows, threshold, kdtree.as_ref());
if j_neighbors.len() >= self.min_samples {
for k in j_neighbors {
if labels[k] == -1 {
labels[k] = cluster_id;
queue.push(k);
}
}
}
}
cluster_id += 1;
}
let mut core_features = Vec::new();
let mut core_labels = Vec::new();
for i in 0..n {
if labels[i] >= 0 {
let neighbors = self.find_neighbors(i, &rows, threshold, kdtree.as_ref());
if neighbors.len() >= self.min_samples {
core_features.push(rows[i].clone());
core_labels.push(labels[i]);
}
}
}
self.labels = labels;
self.n_clusters = cluster_id as usize;
self.core_features = core_features;
self.core_labels = core_labels;
self.fitted = true;
Ok(())
}
fn find_neighbors(
&self,
idx: usize,
rows: &[Vec<f64>],
threshold: f64,
kdtree: Option<&KdTree>,
) -> Vec<usize> {
kdtree.map_or_else(
|| {
let n = rows.len();
(0..n)
.filter(|&j| self.distance(&rows[idx], &rows[j]) <= threshold)
.collect()
},
|tree| {
tree.query_radius(&rows[idx], threshold, rows)
},
)
}
#[inline]
fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
match self.metric {
DistanceMetric::Euclidean => euclidean_sq(a, b),
DistanceMetric::Manhattan => manhattan(a, b),
DistanceMetric::Cosine => cosine_distance(a, b),
}
}
#[inline]
fn eps_threshold(&self) -> f64 {
match self.metric {
DistanceMetric::Euclidean => self.eps * self.eps,
DistanceMetric::Manhattan | DistanceMetric::Cosine => self.eps,
}
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<i32>> {
crate::version::check_schema_version(self._schema_version)?;
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let threshold = self.eps_threshold();
Ok(features
.iter()
.map(|query| {
let mut best_dist = f64::INFINITY;
let mut best_label = -1i32;
for (i, core_pt) in self.core_features.iter().enumerate() {
let d = self.distance(query, core_pt);
if d <= threshold && d < best_dist {
best_dist = d;
best_label = self.core_labels[i];
}
}
best_label
})
.collect())
}
pub fn labels(&self) -> &[i32] {
&self.labels
}
pub fn n_clusters(&self) -> usize {
self.n_clusters
}
pub fn n_noise(&self) -> usize {
self.labels.iter().filter(|&&l| l == -1).count()
}
pub fn n_core_points(&self) -> usize {
self.core_features.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dbscan_two_clusters() {
let mut rng = crate::rng::FastRng::new(0);
let mut f1 = Vec::new();
let mut f2 = Vec::new();
for _ in 0..10 {
f1.push(rng.f64() * 2.0);
f2.push(rng.f64() * 2.0);
}
for _ in 0..10 {
f1.push(50.0 + rng.f64() * 2.0);
f2.push(50.0 + rng.f64() * 2.0);
}
let data = Dataset::new(
vec![f1, f2],
vec![0.0; 20],
vec!["x".into(), "y".into()],
"label",
);
let mut db = Dbscan::new(5.0, 3);
db.fit(&data).unwrap();
assert_eq!(db.n_clusters(), 2, "should find 2 clusters");
}
#[test]
fn test_dbscan_noise() {
let data = Dataset::new(
vec![vec![0.0, 100.0, 200.0], vec![0.0, 100.0, 200.0]],
vec![0.0; 3],
vec!["x".into(), "y".into()],
"label",
);
let mut db = Dbscan::new(1.0, 2);
db.fit(&data).unwrap();
assert_eq!(db.n_noise(), 3, "all points should be noise");
}
#[test]
fn test_dbscan_kdtree_parity() {
let mut rng = crate::rng::FastRng::new(42);
let n = 100;
let mut f1 = Vec::with_capacity(n);
let mut f2 = Vec::with_capacity(n);
for _ in 0..40 {
f1.push(rng.f64() * 3.0);
f2.push(rng.f64() * 3.0);
}
for _ in 0..40 {
f1.push(20.0 + rng.f64() * 3.0);
f2.push(20.0 + rng.f64() * 3.0);
}
for _ in 0..20 {
f1.push(rng.f64() * 100.0);
f2.push(rng.f64() * 100.0);
}
let data = Dataset::new(
vec![f1, f2],
vec![0.0; n],
vec!["x".into(), "y".into()],
"label",
);
let mut db_kd = Dbscan::new(4.0, 3);
db_kd.fit(&data).unwrap();
let labels_kd = db_kd.labels().to_vec();
let mut db_kd2 = Dbscan::new(4.0, 3);
db_kd2.fit(&data).unwrap();
let labels_kd2 = db_kd2.labels().to_vec();
assert_eq!(labels_kd, labels_kd2, "DBSCAN should be deterministic");
assert!(db_kd.n_clusters() >= 2, "should find at least 2 clusters");
}
#[test]
fn test_dbscan_predict() {
let data = Dataset::new(
vec![
vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
],
vec![0.0; 6],
vec!["x".into(), "y".into()],
"label",
);
let mut db = Dbscan::new(5.0, 2);
db.fit(&data).unwrap();
assert_eq!(db.n_clusters(), 2);
let near_a = db.predict(&[vec![0.5, 0.5]]).unwrap();
assert!(near_a[0] >= 0, "Should be assigned to cluster A");
let near_b = db.predict(&[vec![10.5, 10.5]]).unwrap();
assert!(near_b[0] >= 0, "Should be assigned to cluster B");
assert_ne!(near_a[0], near_b[0], "Different clusters");
let far = db.predict(&[vec![500.0, 500.0]]).unwrap();
assert_eq!(far[0], -1, "Far point should be noise");
}
#[test]
fn test_dbscan_manhattan() {
let mut rng = crate::rng::FastRng::new(0);
let mut f1 = Vec::new();
let mut f2 = Vec::new();
for _ in 0..10 {
f1.push(rng.f64() * 2.0);
f2.push(rng.f64() * 2.0);
}
for _ in 0..10 {
f1.push(50.0 + rng.f64() * 2.0);
f2.push(50.0 + rng.f64() * 2.0);
}
let data = Dataset::new(
vec![f1, f2],
vec![0.0; 20],
vec!["x".into(), "y".into()],
"label",
);
let mut db = Dbscan::new(5.0, 3).metric(DistanceMetric::Manhattan);
db.fit(&data).unwrap();
assert_eq!(
db.n_clusters(),
2,
"Manhattan DBSCAN should find 2 clusters"
);
}
}