use scirs2_core::ndarray::{Array2, ArrayView2};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
pub struct CoreDistance {
pub point_idx: usize,
pub core_dist: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct ReachabilityPoint {
pub point_idx: usize,
pub reachability_dist: Option<f64>,
}
#[derive(Debug, Clone)]
struct SeedEntry {
point_idx: usize,
reachability: f64,
}
impl PartialEq for SeedEntry {
fn eq(&self, other: &Self) -> bool {
self.reachability == other.reachability && self.point_idx == other.point_idx
}
}
impl Eq for SeedEntry {}
impl PartialOrd for SeedEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SeedEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.reachability
.partial_cmp(&self.reachability)
.unwrap_or(std::cmp::Ordering::Equal)
.then(self.point_idx.cmp(&other.point_idx))
}
}
#[inline]
fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum()
}
#[inline]
fn euclid(a: &[f64], b: &[f64]) -> f64 {
sq_euclid(a, b).sqrt()
}
fn build_distance_matrix(data: ArrayView2<f64>) -> Array2<f64> {
let n = data.shape()[0];
let mut dm = Array2::<f64>::zeros((n, n));
for i in 0..n {
let ri = data.row(i).to_vec();
for j in (i + 1)..n {
let rj = data.row(j).to_vec();
let d = euclid(&ri, &rj);
dm[[i, j]] = d;
dm[[j, i]] = d;
}
}
dm
}
fn neighbours_within(point_idx: usize, dm: &Array2<f64>, max_eps: f64) -> Vec<usize> {
let n = dm.shape()[0];
(0..n)
.filter(|&j| j != point_idx && dm[[point_idx, j]] <= max_eps)
.collect()
}
fn core_distance(
point_idx: usize,
neighbours: &[usize],
dm: &Array2<f64>,
min_pts: usize,
) -> Option<f64> {
if neighbours.len() + 1 < min_pts {
return None;
}
let mut dists: Vec<f64> = neighbours
.iter()
.map(|&j| dm[[point_idx, j]])
.collect();
dists.sort_by(|a: &f64, b: &f64| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
dists.get(min_pts.saturating_sub(2)).cloned()
}
fn update_seeds(
core_pt: usize,
core_dist: f64,
neighbours: &[usize],
dm: &Array2<f64>,
processed: &[bool],
current_reach: &mut Vec<Option<f64>>,
seeds: &mut std::collections::BinaryHeap<SeedEntry>,
) {
for &nb in neighbours {
if processed[nb] {
continue;
}
let new_reach = core_dist.max(dm[[core_pt, nb]]);
let update = match current_reach[nb] {
None => true,
Some(old) => new_reach < old,
};
if update {
current_reach[nb] = Some(new_reach);
seeds.push(SeedEntry {
point_idx: nb,
reachability: new_reach,
});
}
}
}
pub fn optics(
data: ArrayView2<f64>,
min_pts: usize,
max_eps: f64,
) -> Result<Vec<ReachabilityPoint>> {
let n = data.shape()[0];
if n == 0 {
return Err(ClusteringError::InvalidInput("Empty input data".into()));
}
if min_pts < 2 {
return Err(ClusteringError::InvalidInput(
"min_pts must be >= 2".into(),
));
}
if max_eps <= 0.0 {
return Err(ClusteringError::InvalidInput(
"max_eps must be > 0".into(),
));
}
let dm = build_distance_matrix(data);
let mut processed = vec![false; n];
let mut current_reach: Vec<Option<f64>> = vec![None; n];
let mut core_dists: Vec<Option<f64>> = vec![None; n];
let mut ordering: Vec<ReachabilityPoint> = Vec::with_capacity(n);
for start in 0..n {
if processed[start] {
continue;
}
processed[start] = true;
let nbrs = neighbours_within(start, &dm, max_eps);
let cd = core_distance(start, &nbrs, &dm, min_pts);
core_dists[start] = cd;
ordering.push(ReachabilityPoint {
point_idx: start,
reachability_dist: None,
});
if let Some(cd_val) = cd {
let mut seeds = std::collections::BinaryHeap::new();
update_seeds(
start,
cd_val,
&nbrs,
&dm,
&processed,
&mut current_reach,
&mut seeds,
);
while let Some(entry) = seeds.pop() {
let pt = entry.point_idx;
if processed[pt] {
continue;
}
processed[pt] = true;
let pt_nbrs = neighbours_within(pt, &dm, max_eps);
let pt_cd = core_distance(pt, &pt_nbrs, &dm, min_pts);
core_dists[pt] = pt_cd;
ordering.push(ReachabilityPoint {
point_idx: pt,
reachability_dist: current_reach[pt],
});
if let Some(pt_cd_val) = pt_cd {
update_seeds(
pt,
pt_cd_val,
&pt_nbrs,
&dm,
&processed,
&mut current_reach,
&mut seeds,
);
}
}
}
}
Ok(ordering)
}
pub fn extract_dbscan(reachability: &[ReachabilityPoint], eps: f64) -> Vec<i32> {
let n = reachability.len();
let mut labels = vec![-1i32; n];
let mut pos_of: Vec<usize> = vec![0; n];
for (pos, rp) in reachability.iter().enumerate() {
if rp.point_idx < n {
pos_of[rp.point_idx] = pos;
}
}
let mut cluster_id: i32 = -1;
for pos in 0..n {
let rp = &reachability[pos];
let reach_exceeds = match rp.reachability_dist {
Some(r) => r > eps,
None => true, };
if reach_exceeds {
cluster_id += 1;
labels[rp.point_idx] = cluster_id;
} else {
if pos > 0 {
let prev_idx = reachability[pos - 1].point_idx;
let prev_label = if prev_idx < n { labels[prev_idx] } else { -1 };
if prev_label >= 0 {
labels[rp.point_idx] = prev_label;
} else {
cluster_id += 1;
labels[rp.point_idx] = cluster_id;
}
}
}
}
labels
}
pub fn extract_xi_clusters(reachability: &[ReachabilityPoint], xi: f64) -> Result<Vec<i32>> {
if xi <= 0.0 || xi >= 1.0 {
return Err(ClusteringError::InvalidInput(
"xi must be in (0, 1)".into(),
));
}
let n = reachability.len();
if n == 0 {
return Ok(Vec::new());
}
let reach: Vec<f64> = reachability
.iter()
.map(|rp| rp.reachability_dist.unwrap_or(f64::INFINITY))
.collect();
let max_finite = reach
.iter()
.filter(|r| r.is_finite())
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let fill = if max_finite.is_finite() {
max_finite * 1.1 + 1.0
} else {
1.0
};
let rf: Vec<f64> = reach.iter().map(|&r| if r.is_finite() { r } else { fill }).collect();
let is_steep_down = |i: usize| -> bool {
if i + 1 >= n {
return false;
}
rf[i] > 0.0 && rf[i].is_finite() && rf[i + 1].is_finite()
&& rf[i] * (1.0 - xi) >= rf[i + 1]
};
let is_steep_up = |i: usize| -> bool {
if i + 1 >= n {
return false;
}
rf[i + 1] > 0.0 && rf[i].is_finite() && rf[i + 1].is_finite()
&& rf[i] * (1.0 - xi) <= rf[i + 1]
};
let mut sd_areas: Vec<(usize, usize, f64)> = Vec::new(); let mut i = 0;
while i < n.saturating_sub(1) {
if is_steep_down(i) {
let s = i;
let mut e = i;
while e + 1 < n && is_steep_down(e) {
e += 1;
}
sd_areas.push((s, e, rf[s]));
i = e + 1;
} else {
i += 1;
}
}
let mut su_areas: Vec<(usize, usize, f64)> = Vec::new(); let mut i = 0;
while i < n.saturating_sub(1) {
if is_steep_up(i) {
let s = i;
let mut e = i;
while e + 1 < n && is_steep_up(e) {
e += 1;
}
let end_reach_idx = if e + 1 < n { e + 1 } else { e };
su_areas.push((s, e, rf[end_reach_idx]));
i = e + 1;
} else {
i += 1;
}
}
let mut cluster_ranges: Vec<(usize, usize)> = Vec::new();
for &(sd_s, sd_e, sd_r) in &sd_areas {
for &(su_s, su_e, su_r) in &su_areas {
if su_s <= sd_e {
continue;
}
let interior_lo = sd_e + 1;
let interior_hi = su_s;
if interior_lo >= interior_hi {
continue;
}
let r_high = sd_r.max(su_r);
let r_low = sd_r.min(su_r);
if r_high <= 0.0 || r_low / r_high < (1.0 - xi).powi(2) {
continue;
}
let int_min = rf[interior_lo..interior_hi]
.iter()
.cloned()
.filter(|v| v.is_finite())
.fold(f64::INFINITY, f64::min);
if int_min < r_high {
let cluster_end = (su_e + 1).min(n - 1);
cluster_ranges.push((sd_s, cluster_end));
break; }
}
}
cluster_ranges.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
let mut keep = vec![true; cluster_ranges.len()];
for outer in 0..cluster_ranges.len() {
if !keep[outer] {
continue;
}
for inner in (outer + 1)..cluster_ranges.len() {
if !keep[inner] {
continue;
}
let (os, oe) = cluster_ranges[outer];
let (is, ie) = cluster_ranges[inner];
if is >= os && ie <= oe {
keep[inner] = false;
}
}
}
let valid_clusters: Vec<(usize, usize)> = cluster_ranges
.iter()
.zip(keep.iter())
.filter_map(|(&r, &k)| if k { Some(r) } else { None })
.collect();
let mut labels = vec![-1i32; n];
for (cid, &(range_s, range_e)) in valid_clusters.iter().enumerate() {
for pos in range_s..=range_e.min(n - 1) {
let orig = reachability[pos].point_idx;
if orig < n && labels[orig] < 0 {
labels[orig] = cid as i32;
}
}
}
Ok(labels)
}
pub fn reachability_plot(optics_result: &[ReachabilityPoint]) -> (Vec<f64>, Vec<f64>) {
let x: Vec<f64> = (0..optics_result.len()).map(|i| i as f64).collect();
let y: Vec<f64> = optics_result
.iter()
.map(|rp| rp.reachability_dist.unwrap_or(f64::INFINITY))
.collect();
(x, y)
}
pub fn compute_core_distances(
data: ArrayView2<f64>,
min_pts: usize,
max_eps: f64,
) -> Result<Vec<CoreDistance>> {
let n = data.shape()[0];
if n == 0 {
return Ok(Vec::new());
}
let dm = build_distance_matrix(data);
let result = (0..n)
.map(|i| {
let nbrs = neighbours_within(i, &dm, max_eps);
let cd = core_distance(i, &nbrs, &dm, min_pts);
CoreDistance {
point_idx: i,
core_dist: cd,
}
})
.collect();
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn two_cluster_data() -> Array2<f64> {
Array2::from_shape_vec(
(14, 2),
vec![
1.0, 2.0, 1.1, 1.9, 0.9, 2.1, 1.2, 1.8, 0.8, 2.0, 1.0, 2.2, 1.15, 1.85,
8.0, 8.0, 8.1, 7.9, 7.9, 8.1, 8.2, 7.8, 7.8, 8.0, 8.0, 8.2, 8.15, 7.85,
],
)
.expect("shape ok")
}
#[test]
fn test_optics_produces_full_ordering() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
assert_eq!(ord.len(), 14, "every point must appear in ordering");
let mut seen = vec![false; 14];
for rp in &ord {
assert!(!seen[rp.point_idx], "duplicate index {}", rp.point_idx);
seen[rp.point_idx] = true;
}
assert!(seen.iter().all(|&s| s), "missing indices in ordering");
}
#[test]
fn test_optics_first_point_has_no_reachability() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
assert!(
ord[0].reachability_dist.is_none(),
"first ordering entry should have reachability = None"
);
}
#[test]
fn test_optics_within_cluster_reachabilities_small() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let mut prev_cluster: Option<usize> = None;
let mut within_reaches: Vec<f64> = Vec::new();
for rp in &ord {
let cluster = if rp.point_idx < 7 { 0 } else { 1 };
if prev_cluster == Some(cluster) {
if let Some(r) = rp.reachability_dist {
within_reaches.push(r);
}
}
prev_cluster = Some(cluster);
}
if !within_reaches.is_empty() {
let avg: f64 = within_reaches.iter().sum::<f64>() / within_reaches.len() as f64;
assert!(avg < 2.0, "expected small within-cluster reach, got {}", avg);
}
}
#[test]
fn test_optics_max_eps_restricts_reachability() {
let data = two_cluster_data();
let ord = optics(data.view(), 2, 0.01).expect("optics");
assert_eq!(ord.len(), 14);
let all_none = ord.iter().all(|rp| rp.reachability_dist.is_none());
assert!(all_none, "with tiny max_eps every point is isolated");
}
#[test]
fn test_optics_single_point() {
let data = Array2::from_shape_vec((1, 2), vec![3.0, 4.0]).expect("shape");
let ord = optics(data.view(), 2, f64::INFINITY).expect("optics");
assert_eq!(ord.len(), 1);
assert_eq!(ord[0].point_idx, 0);
assert!(ord[0].reachability_dist.is_none());
}
#[test]
fn test_optics_error_empty() {
let data = Array2::<f64>::zeros((0, 2));
assert!(optics(data.view(), 2, f64::INFINITY).is_err());
}
#[test]
fn test_optics_error_min_pts_too_small() {
let data = two_cluster_data();
assert!(optics(data.view(), 1, f64::INFINITY).is_err());
}
#[test]
fn test_optics_error_non_positive_max_eps() {
let data = two_cluster_data();
assert!(optics(data.view(), 3, 0.0).is_err());
assert!(optics(data.view(), 3, -1.0).is_err());
}
#[test]
fn test_extract_dbscan_two_clusters() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let labels = extract_dbscan(&ord, 0.5);
assert_eq!(labels.len(), 14);
let a_labels: Vec<i32> = (0..7).map(|i| labels[i]).collect();
let b_labels: Vec<i32> = (7..14).map(|i| labels[i]).collect();
assert!(a_labels.iter().all(|&l| l >= 0));
assert!(b_labels.iter().all(|&l| l >= 0));
let a_mode = *a_labels
.iter()
.max_by_key(|&&l| a_labels.iter().filter(|&&x| x == l).count())
.expect("a has labels");
let b_mode = *b_labels
.iter()
.max_by_key(|&&l| b_labels.iter().filter(|&&x| x == l).count())
.expect("b has labels");
assert_ne!(a_mode, b_mode, "clusters should receive distinct labels");
}
#[test]
fn test_extract_dbscan_all_noise_small_eps() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let labels = extract_dbscan(&ord, 1e-10);
assert_eq!(labels.len(), 14);
}
#[test]
fn test_extract_dbscan_empty_ordering() {
let labels = extract_dbscan(&[], 0.5);
assert!(labels.is_empty());
}
#[test]
fn test_extract_xi_returns_correct_length() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let labels = extract_xi_clusters(&ord, 0.05).expect("xi");
assert_eq!(labels.len(), 14);
}
#[test]
fn test_extract_xi_labels_valid_range() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let labels = extract_xi_clusters(&ord, 0.1).expect("xi");
assert!(labels.iter().all(|&l| l >= -1), "labels must be >= -1");
}
#[test]
fn test_extract_xi_error_invalid_xi() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
assert!(extract_xi_clusters(&ord, 0.0).is_err());
assert!(extract_xi_clusters(&ord, 1.0).is_err());
assert!(extract_xi_clusters(&ord, -0.1).is_err());
assert!(extract_xi_clusters(&ord, 1.5).is_err());
}
#[test]
fn test_extract_xi_empty_ordering() {
let labels = extract_xi_clusters(&[], 0.1).expect("xi empty");
assert!(labels.is_empty());
}
#[test]
fn test_reachability_plot_lengths() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let (xs, ys) = reachability_plot(&ord);
assert_eq!(xs.len(), 14);
assert_eq!(ys.len(), 14);
}
#[test]
fn test_reachability_plot_x_sequential() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let (xs, _ys) = reachability_plot(&ord);
for (i, &x) in xs.iter().enumerate() {
assert!(
(x - i as f64).abs() < 1e-12,
"x[{}] should be {}, got {}",
i,
i,
x
);
}
}
#[test]
fn test_reachability_plot_none_becomes_infinity() {
let data = two_cluster_data();
let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
let (_, ys) = reachability_plot(&ord);
assert!(ys[0].is_infinite(), "component root should be INFINITY");
}
#[test]
fn test_reachability_plot_empty() {
let (xs, ys) = reachability_plot(&[]);
assert!(xs.is_empty());
assert!(ys.is_empty());
}
#[test]
fn test_core_distances_length() {
let data = two_cluster_data();
let cds = compute_core_distances(data.view(), 3, f64::INFINITY).expect("cds");
assert_eq!(cds.len(), 14);
}
#[test]
fn test_core_distances_dense_cluster_are_core() {
let data = two_cluster_data();
let cds = compute_core_distances(data.view(), 3, f64::INFINITY).expect("cds");
let n_core = cds.iter().filter(|cd| cd.core_dist.is_some()).count();
assert!(
n_core >= 10,
"most points should be core points, got {}",
n_core
);
}
#[test]
fn test_core_distances_tiny_eps_no_cores() {
let data = two_cluster_data();
let cds = compute_core_distances(data.view(), 3, 1e-15).expect("cds");
let n_core = cds.iter().filter(|cd| cd.core_dist.is_some()).count();
assert_eq!(n_core, 0, "no cores expected with tiny eps");
}
#[test]
fn test_core_distances_empty_data() {
let data = Array2::<f64>::zeros((0, 2));
let cds = compute_core_distances(data.view(), 3, f64::INFINITY).expect("cds empty");
assert!(cds.is_empty());
}
}