use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::Fit;
use ndarray::{Array1, Array2};
use num_traits::Float;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Clone, Copy, PartialEq)]
struct SeedEntry<F: Float> {
reach_dist: F,
idx: usize,
}
impl<F: Float> Eq for SeedEntry<F> {}
impl<F: Float> Ord for SeedEntry<F> {
fn cmp(&self, other: &Self) -> Ordering {
other
.reach_dist
.partial_cmp(&self.reach_dist)
.unwrap_or(Ordering::Less)
}
}
impl<F: Float> PartialOrd for SeedEntry<F> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct OPTICS<F> {
pub min_samples: usize,
pub max_eps: F,
pub xi: F,
}
impl<F: Float> OPTICS<F> {
#[must_use]
pub fn new(min_samples: usize) -> Self {
Self {
min_samples,
max_eps: F::infinity(),
xi: F::from(0.05).unwrap_or_else(|| F::from(5e-2).unwrap()),
}
}
#[must_use]
pub fn with_max_eps(mut self, max_eps: F) -> Self {
self.max_eps = max_eps;
self
}
#[must_use]
pub fn with_xi(mut self, xi: F) -> Self {
self.xi = xi;
self
}
}
#[derive(Debug, Clone)]
pub struct FittedOPTICS<F> {
ordering_: Vec<usize>,
reachability_: Array1<F>,
core_distances_: Array1<F>,
labels_: Array1<isize>,
}
impl<F: Float> FittedOPTICS<F> {
#[must_use]
pub fn ordering(&self) -> &[usize] {
&self.ordering_
}
#[must_use]
pub fn reachability(&self) -> &Array1<F> {
&self.reachability_
}
#[must_use]
pub fn core_distances(&self) -> &Array1<F> {
&self.core_distances_
}
#[must_use]
pub fn labels(&self) -> &Array1<isize> {
&self.labels_
}
#[must_use]
pub fn n_clusters(&self) -> usize {
let max_label = self.labels_.iter().max().copied().unwrap_or(-1);
if max_label < 0 {
0
} else {
(max_label + 1) as usize
}
}
pub fn extract_clusters(&self, xi: F) -> Result<Array1<isize>, FerroError> {
if xi <= F::zero() || xi >= F::one() {
return Err(FerroError::InvalidParameter {
name: "xi".into(),
reason: "must be in (0, 1)".into(),
});
}
Ok(xi_cluster_extraction(
&self.ordering_,
&self.reachability_,
xi,
))
}
}
#[inline]
fn euclidean<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b)
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
.sqrt()
}
fn get_neighbors<F: Float>(x: &Array2<F>, idx: usize, max_eps: F) -> (Vec<usize>, Vec<F>) {
let row = x.row(idx);
let rs = row.as_slice().unwrap_or(&[]);
let mut pairs: Vec<(F, usize)> = (0..x.nrows())
.filter_map(|j| {
let other = x.row(j);
let os = other.as_slice().unwrap_or(&[]);
let d = euclidean(rs, os);
if d <= max_eps && j != idx {
Some((d, j))
} else {
None
}
})
.collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let indices = pairs.iter().map(|p| p.1).collect();
let dists = pairs.iter().map(|p| p.0).collect();
(indices, dists)
}
fn core_distance<F: Float>(x: &Array2<F>, idx: usize, max_eps: F, min_samples: usize) -> F {
let row = x.row(idx);
let rs = row.as_slice().unwrap_or(&[]);
let mut dists: Vec<F> = (0..x.nrows())
.filter_map(|j| {
if j == idx {
return None;
}
let other = x.row(j);
let os = other.as_slice().unwrap_or(&[]);
let d = euclidean(rs, os);
if d <= max_eps { Some(d) } else { None }
})
.collect();
if dists.len() < min_samples.saturating_sub(1) {
return F::infinity();
}
dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let k = min_samples.saturating_sub(1);
if k == 0 {
F::zero()
} else if k <= dists.len() {
dists[k - 1]
} else {
F::infinity()
}
}
fn update_seeds<F: Float>(
core_dist_p: F,
neighbors: &[usize],
neighbor_dists: &[F],
processed: &[bool],
reachability: &mut Array1<F>,
seeds: &mut BinaryHeap<SeedEntry<F>>,
) {
for (i, &q) in neighbors.iter().enumerate() {
if processed[q] {
continue;
}
let new_reach = if core_dist_p > neighbor_dists[i] {
core_dist_p
} else {
neighbor_dists[i]
};
if new_reach < reachability[q] {
reachability[q] = new_reach;
seeds.push(SeedEntry {
reach_dist: new_reach,
idx: q,
});
}
}
}
fn xi_cluster_extraction<F: Float>(
ordering: &[usize],
reachability: &Array1<F>,
xi: F,
) -> Array1<isize> {
let n_ordered = ordering.len();
let n_total = reachability.len();
if n_ordered == 0 {
return Array1::from_elem(n_total, -1isize);
}
let max_finite = reachability
.iter()
.filter(|v| v.is_finite())
.cloned()
.fold(F::zero(), |acc, v| if v > acc { v } else { acc });
let r_ord: Vec<F> = ordering
.iter()
.map(|&i| {
let v = reachability[i];
if v.is_finite() {
v
} else {
max_finite + F::one()
}
})
.collect();
let one_minus_xi = F::one() - xi;
let mut steep_down: Vec<usize> = Vec::new();
let mut steep_up: Vec<usize> = Vec::new();
for i in 0..(n_ordered.saturating_sub(1)) {
if r_ord[i] == F::zero() {
continue;
}
let ratio_next = r_ord[i + 1] / r_ord[i];
if ratio_next <= one_minus_xi {
steep_down.push(i);
}
}
for i in 1..n_ordered {
if r_ord[i - 1] == F::zero() {
continue;
}
let ratio_prev = r_ord[i] / r_ord[i - 1];
if ratio_prev >= F::one() / one_minus_xi {
steep_up.push(i);
}
}
let mut labels = Array1::from_elem(n_total, -1isize);
let mut cluster_id: isize = 0;
for &sd in &steep_down {
if let Some(&su) = steep_up.iter().find(|&&su| su > sd) {
let start = sd;
let end = su;
if end > start {
for &pt in ordering[start..=end].iter() {
if labels[pt] == -1 {
labels[pt] = cluster_id;
}
}
cluster_id += 1;
}
}
}
labels
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for OPTICS<F> {
type Fitted = FittedOPTICS<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedOPTICS<F>, FerroError> {
let n_samples = x.nrows();
if self.min_samples == 0 {
return Err(FerroError::InvalidParameter {
name: "min_samples".into(),
reason: "must be at least 1".into(),
});
}
if self.max_eps <= F::zero() {
return Err(FerroError::InvalidParameter {
name: "max_eps".into(),
reason: "must be positive".into(),
});
}
if self.xi <= F::zero() || self.xi >= F::one() {
return Err(FerroError::InvalidParameter {
name: "xi".into(),
reason: "must be in (0, 1)".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "OPTICS requires at least 1 sample".into(),
});
}
let mut reachability = Array1::from_elem(n_samples, F::infinity());
let mut core_distances = Array1::from_elem(n_samples, F::infinity());
let mut processed = vec![false; n_samples];
let mut ordering: Vec<usize> = Vec::with_capacity(n_samples);
for i in 0..n_samples {
core_distances[i] = core_distance(x, i, self.max_eps, self.min_samples);
}
for start in 0..n_samples {
if processed[start] {
continue;
}
processed[start] = true;
ordering.push(start);
if core_distances[start].is_infinite() {
continue;
}
let mut seeds: BinaryHeap<SeedEntry<F>> = BinaryHeap::new();
let (nbrs, nbr_dists) = get_neighbors(x, start, self.max_eps);
update_seeds(
core_distances[start],
&nbrs,
&nbr_dists,
&processed,
&mut reachability,
&mut seeds,
);
while let Some(entry) = seeds.pop() {
let p = entry.idx;
if processed[p] {
continue;
}
if entry.reach_dist > reachability[p] {
continue;
}
processed[p] = true;
ordering.push(p);
if core_distances[p].is_finite() {
let (p_nbrs, p_nbr_dists) = get_neighbors(x, p, self.max_eps);
update_seeds(
core_distances[p],
&p_nbrs,
&p_nbr_dists,
&processed,
&mut reachability,
&mut seeds,
);
}
}
}
let labels = xi_cluster_extraction(&ordering, &reachability, self.xi);
Ok(FittedOPTICS {
ordering_: ordering,
reachability_: reachability,
core_distances_: core_distances,
labels_: labels,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn three_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 5.0, 5.0, 5.1, 5.0, 5.0, 5.1, 10.0, 0.0, 10.1, 0.0,
10.0, 0.1,
],
)
.unwrap()
}
#[test]
fn test_ordering_covers_all_points() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let mut sorted = fitted.ordering().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, (0..9).collect::<Vec<_>>());
}
#[test]
fn test_reachability_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.reachability().len(), 9);
}
#[test]
fn test_core_distances_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.core_distances().len(), 9);
}
#[test]
fn test_labels_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 9);
}
#[test]
fn test_core_points_have_finite_core_distance() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
for i in 0..9 {
assert!(
fitted.core_distances()[i].is_finite(),
"expected finite core distance for point {i}"
);
}
}
#[test]
fn test_isolated_point_infinite_core_distance() {
let mut data = three_blobs().into_raw_vec_and_offset().0;
data.extend_from_slice(&[100.0, 100.0]);
let x = Array2::from_shape_vec((10, 2), data).unwrap();
let fitted = OPTICS::<f64>::new(3)
.with_max_eps(2.0)
.fit(&x, &())
.unwrap();
assert!(
fitted.core_distances()[9].is_infinite(),
"isolated point should have infinite core distance"
);
}
#[test]
fn test_reachability_first_point_infinite() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let first = fitted.ordering()[0];
assert!(
fitted.reachability()[first].is_infinite(),
"first point in ordering should have infinite reachability"
);
}
#[test]
fn test_extract_clusters_valid_xi() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let labels = fitted.extract_clusters(0.05).unwrap();
assert_eq!(labels.len(), 9);
}
#[test]
fn test_extract_clusters_invalid_xi_zero() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert!(fitted.extract_clusters(0.0).is_err());
}
#[test]
fn test_extract_clusters_invalid_xi_one() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert!(fitted.extract_clusters(1.0).is_err());
}
#[test]
fn test_invalid_min_samples_zero() {
let x = three_blobs();
let result = OPTICS::<f64>::new(0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_max_eps_zero() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_max_eps(0.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_max_eps_negative() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_max_eps(-1.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_xi_zero() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_xi(0.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_xi_one() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_xi(1.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
let result = OPTICS::<f64>::new(2).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_single_sample() {
let x = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
let fitted = OPTICS::<f64>::new(1).fit(&x, &()).unwrap();
assert_eq!(fitted.ordering().len(), 1);
assert_eq!(fitted.ordering()[0], 0);
}
#[test]
fn test_f32_support() {
let x = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.0, 0.0, 0.1, 10.0, 10.0, 10.1, 10.0, 10.0, 10.1,
],
)
.unwrap();
let fitted = OPTICS::<f32>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.ordering().len(), 6);
}
#[test]
fn test_n_clusters_non_negative() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let _ = fitted.n_clusters(); }
#[test]
fn test_ordering_unique_indices() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let ordering = fitted.ordering();
let mut seen = std::collections::HashSet::new();
for &idx in ordering {
assert!(seen.insert(idx), "duplicate index {idx} in ordering");
}
}
#[test]
fn test_with_max_eps_limits_reachability() {
let x = three_blobs();
let max_eps = 0.5;
let fitted = OPTICS::<f64>::new(2)
.with_max_eps(max_eps)
.fit(&x, &())
.unwrap();
for &r in fitted.reachability().iter() {
if r.is_finite() {
assert!(r <= max_eps + 1e-10);
}
}
}
}