use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::Fit;
use ndarray::{Array1, Array2};
use num_traits::Float;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
#[derive(Clone, Copy, PartialEq)]
struct SeedEntry<F: Float> {
reach_dist: F,
idx: usize,
}
impl<F: Float> Eq for SeedEntry<F> {}
impl<F: Float> Ord for SeedEntry<F> {
fn cmp(&self, other: &Self) -> Ordering {
other
.reach_dist
.partial_cmp(&self.reach_dist)
.unwrap_or(Ordering::Less)
}
}
impl<F: Float> PartialOrd for SeedEntry<F> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct OPTICS<F> {
pub min_samples: usize,
pub max_eps: F,
pub xi: F,
pub min_cluster_size: Option<usize>,
}
impl<F: Float> OPTICS<F> {
#[must_use]
pub fn new(min_samples: usize) -> Self {
Self {
min_samples,
max_eps: F::infinity(),
xi: F::from(0.05).unwrap_or_else(|| F::from(5e-2).unwrap()),
min_cluster_size: None,
}
}
#[must_use]
pub fn with_max_eps(mut self, max_eps: F) -> Self {
self.max_eps = max_eps;
self
}
#[must_use]
pub fn with_xi(mut self, xi: F) -> Self {
self.xi = xi;
self
}
#[must_use]
pub fn with_min_cluster_size(mut self, size: usize) -> Self {
self.min_cluster_size = Some(size);
self
}
}
#[derive(Debug, Clone)]
pub struct FittedOPTICS<F> {
ordering_: Vec<usize>,
reachability_: Array1<F>,
core_distances_: Array1<F>,
labels_: Array1<isize>,
predecessors_: Vec<Option<usize>>,
min_samples_: usize,
}
impl<F: Float> FittedOPTICS<F> {
#[must_use]
pub fn ordering(&self) -> &[usize] {
&self.ordering_
}
#[must_use]
pub fn reachability(&self) -> &Array1<F> {
&self.reachability_
}
#[must_use]
pub fn core_distances(&self) -> &Array1<F> {
&self.core_distances_
}
#[must_use]
pub fn labels(&self) -> &Array1<isize> {
&self.labels_
}
#[must_use]
pub fn predecessors(&self) -> &[Option<usize>] {
&self.predecessors_
}
#[must_use]
pub fn n_clusters(&self) -> usize {
let max_label = self.labels_.iter().max().copied().unwrap_or(-1);
if max_label < 0 {
0
} else {
(max_label + 1) as usize
}
}
pub fn extract_clusters(&self, xi: F) -> Result<Array1<isize>, FerroError> {
if xi <= F::zero() || xi >= F::one() {
return Err(FerroError::InvalidParameter {
name: "xi".into(),
reason: "must be in (0, 1)".into(),
});
}
Ok(xi_cluster_extraction(
&self.ordering_,
&self.reachability_,
&self.predecessors_,
xi,
self.min_samples_,
))
}
}
#[inline]
fn euclidean<F: Float>(a: &[F], b: &[F]) -> F {
a.iter()
.zip(b)
.fold(F::zero(), |acc, (&ai, &bi)| acc + (ai - bi) * (ai - bi))
.sqrt()
}
fn get_neighbors<F: Float>(x: &Array2<F>, idx: usize, max_eps: F) -> (Vec<usize>, Vec<F>) {
let row = x.row(idx);
let rs = row.as_slice().unwrap_or(&[]);
let mut pairs: Vec<(F, usize)> = (0..x.nrows())
.filter_map(|j| {
let other = x.row(j);
let os = other.as_slice().unwrap_or(&[]);
let d = euclidean(rs, os);
if d <= max_eps && j != idx {
Some((d, j))
} else {
None
}
})
.collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let indices = pairs.iter().map(|p| p.1).collect();
let dists = pairs.iter().map(|p| p.0).collect();
(indices, dists)
}
fn core_distance<F: Float>(x: &Array2<F>, idx: usize, max_eps: F, min_samples: usize) -> F {
let row = x.row(idx);
let rs = row.as_slice().unwrap_or(&[]);
let mut dists: Vec<F> = (0..x.nrows())
.filter_map(|j| {
if j == idx {
return None;
}
let other = x.row(j);
let os = other.as_slice().unwrap_or(&[]);
let d = euclidean(rs, os);
if d <= max_eps { Some(d) } else { None }
})
.collect();
if dists.len() < min_samples.saturating_sub(1) {
return F::infinity();
}
dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let k = min_samples.saturating_sub(1);
if k == 0 {
F::zero()
} else if k <= dists.len() {
dists[k - 1]
} else {
F::infinity()
}
}
#[allow(clippy::too_many_arguments)]
fn update_seeds<F: Float>(
core_dist_p: F,
current_point: usize,
neighbors: &[usize],
neighbor_dists: &[F],
processed: &[bool],
reachability: &mut Array1<F>,
predecessors: &mut [Option<usize>],
seeds: &mut BinaryHeap<SeedEntry<F>>,
) {
for (i, &q) in neighbors.iter().enumerate() {
if processed[q] {
continue;
}
let new_reach = if core_dist_p > neighbor_dists[i] {
core_dist_p
} else {
neighbor_dists[i]
};
if new_reach < reachability[q] {
reachability[q] = new_reach;
predecessors[q] = Some(current_point);
seeds.push(SeedEntry {
reach_dist: new_reach,
idx: q,
});
}
}
}
#[derive(Debug, Clone)]
struct SteepDownArea {
start: usize,
end: usize,
mib: f64,
}
fn extend_region(
steep_point: &[bool],
xward_point: &[bool],
start: usize,
min_samples: usize,
) -> usize {
let n = steep_point.len();
let mut non_xward_points = 0usize;
let mut index = start;
let mut end = start;
while index < n {
if steep_point[index] {
non_xward_points = 0;
end = index;
} else if !xward_point[index] {
non_xward_points += 1;
if non_xward_points > min_samples {
break;
}
} else {
return end;
}
index += 1;
}
end
}
fn correct_predecessor(
r_plot: &[f64],
pred_plot: &[Option<usize>],
ordering: &[usize],
s: usize,
mut e: usize,
) -> Option<(usize, usize)> {
while s < e {
if r_plot[s] > r_plot[e] {
return Some((s, e));
}
let p_e = pred_plot[ordering[e]];
for item in ordering.iter().take(e).skip(s) {
if p_e == Some(*item) {
return Some((s, e));
}
}
e -= 1;
}
None
}
fn xi_cluster_extraction<F: Float>(
ordering: &[usize],
reachability: &Array1<F>,
predecessors: &[Option<usize>],
xi: F,
min_samples: usize,
) -> Array1<isize> {
let n_ordered = ordering.len();
let n_total = reachability.len();
if n_ordered == 0 {
return Array1::from_elem(n_total, -1isize);
}
let mut r_plot: Vec<f64> = ordering
.iter()
.map(|&i| {
let v = reachability[i];
if v.is_finite() {
v.to_f64().unwrap_or(f64::INFINITY)
} else {
f64::INFINITY
}
})
.collect();
r_plot.push(f64::INFINITY);
let pred_plot: Vec<Option<usize>> = ordering.iter().map(|&i| predecessors[i]).collect();
let xi_f64 = xi.to_f64().unwrap_or(0.05);
let xi_complement = 1.0 - xi_f64;
let min_samples = min_samples.max(1);
let n_plot = r_plot.len() - 1; let mut steep_upward = vec![false; n_plot];
let mut steep_downward = vec![false; n_plot];
let mut upward = vec![false; n_plot];
let mut downward = vec![false; n_plot];
for i in 0..n_plot {
if r_plot[i + 1] == 0.0 {
if r_plot[i] > 0.0 {
steep_downward[i] = true;
downward[i] = true;
}
continue;
}
let ratio = r_plot[i] / r_plot[i + 1];
if ratio <= xi_complement {
steep_upward[i] = true;
}
if ratio >= 1.0 / xi_complement {
steep_downward[i] = true;
}
if ratio > 1.0 {
downward[i] = true;
}
if ratio < 1.0 {
upward[i] = true;
}
}
let mut sdas: Vec<SteepDownArea> = Vec::new();
let mut clusters: Vec<(usize, usize)> = Vec::new();
let mut index = 0usize;
let mut mib = 0.0_f64;
let steep_indices: Vec<usize> = (0..n_plot)
.filter(|&i| steep_upward[i] || steep_downward[i])
.collect();
for &steep_index in &steep_indices {
if steep_index < index {
continue;
}
for item in r_plot.iter().take(steep_index + 1).skip(index) {
if *item > mib {
mib = *item;
}
}
if steep_downward[steep_index] {
sdas = update_filter_sdas(sdas, mib, xi_complement, &r_plot);
let d_start = steep_index;
let d_end = extend_region(&steep_downward, &upward, d_start, min_samples);
sdas.push(SteepDownArea {
start: d_start,
end: d_end,
mib: 0.0,
});
index = d_end + 1;
if index < r_plot.len() {
mib = r_plot[index];
}
} else {
sdas = update_filter_sdas(sdas, mib, xi_complement, &r_plot);
let u_start = steep_index;
let u_end = extend_region(&steep_upward, &downward, u_start, min_samples);
index = u_end + 1;
if index < r_plot.len() {
mib = r_plot[index];
}
let mut u_clusters: Vec<(usize, usize)> = Vec::new();
for sda in &sdas {
let mut c_start = sda.start;
let c_end_initial = u_end;
let r_after = if c_end_initial + 1 < r_plot.len() {
r_plot[c_end_initial + 1]
} else {
f64::INFINITY
};
if r_after * xi_complement < sda.mib {
continue;
}
let d_max = r_plot[sda.start];
let mut c_end = c_end_initial;
if d_max * xi_complement >= r_after {
while c_start < sda.end
&& c_start + 1 < r_plot.len()
&& r_plot[c_start + 1] > r_after
{
c_start += 1;
}
} else if r_after * xi_complement >= d_max {
while c_end > u_start && c_end > 0 && r_plot[c_end - 1] > d_max {
c_end -= 1;
}
}
if let Some((cs, ce)) =
correct_predecessor(&r_plot, &pred_plot, ordering, c_start, c_end)
{
c_start = cs;
c_end = ce;
} else {
continue;
}
if c_end < c_start + 1 {
continue;
}
if c_start > sda.end {
continue;
}
if c_end < u_start {
continue;
}
u_clusters.push((c_start, c_end));
}
u_clusters.reverse();
clusters.extend(u_clusters);
}
}
let mut ord_labels = vec![-1isize; n_ordered];
let mut label = 0isize;
for &(c_start, c_end) in &clusters {
let end = c_end.min(n_ordered - 1);
let all_unassigned = (c_start..=end).all(|pos| ord_labels[pos] == -1);
if all_unassigned {
for item in ord_labels.iter_mut().take(end + 1).skip(c_start) {
*item = label;
}
label += 1;
}
}
let mut labels = Array1::from_elem(n_total, -1isize);
for (ord_pos, &pt) in ordering.iter().enumerate() {
labels[pt] = ord_labels[ord_pos];
}
labels
}
fn update_filter_sdas(
sdas: Vec<SteepDownArea>,
mib: f64,
xi_complement: f64,
r_plot: &[f64],
) -> Vec<SteepDownArea> {
if mib.is_infinite() {
return Vec::new();
}
let mut result: Vec<SteepDownArea> = sdas
.into_iter()
.filter(|sda| mib <= r_plot[sda.start] * xi_complement)
.collect();
for sda in &mut result {
if mib > sda.mib {
sda.mib = mib;
}
}
result
}
fn filter_small_clusters(labels: &mut Array1<isize>, min_cluster_size: usize) {
let mut cluster_sizes: HashMap<isize, usize> = HashMap::new();
for &l in labels.iter() {
if l >= 0 {
*cluster_sizes.entry(l).or_insert(0) += 1;
}
}
for label in labels.iter_mut() {
if *label >= 0 {
if let Some(&size) = cluster_sizes.get(label) {
if size < min_cluster_size {
*label = -1;
}
}
}
}
let mut unique_labels: Vec<isize> = cluster_sizes
.keys()
.filter(|&&k| {
cluster_sizes
.get(&k)
.is_some_and(|&sz| sz >= min_cluster_size)
})
.copied()
.collect();
unique_labels.sort_unstable();
let mut remap: HashMap<isize, isize> = HashMap::new();
for (new_id, &old_id) in unique_labels.iter().enumerate() {
remap.insert(old_id, new_id as isize);
}
for label in labels.iter_mut() {
if *label >= 0 {
if let Some(&new_id) = remap.get(label) {
*label = new_id;
}
}
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for OPTICS<F> {
type Fitted = FittedOPTICS<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedOPTICS<F>, FerroError> {
let n_samples = x.nrows();
if self.min_samples == 0 {
return Err(FerroError::InvalidParameter {
name: "min_samples".into(),
reason: "must be at least 1".into(),
});
}
if self.max_eps <= F::zero() {
return Err(FerroError::InvalidParameter {
name: "max_eps".into(),
reason: "must be positive".into(),
});
}
if self.xi <= F::zero() || self.xi >= F::one() {
return Err(FerroError::InvalidParameter {
name: "xi".into(),
reason: "must be in (0, 1)".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "OPTICS requires at least 1 sample".into(),
});
}
let mut reachability = Array1::from_elem(n_samples, F::infinity());
let mut core_distances = Array1::from_elem(n_samples, F::infinity());
let mut processed = vec![false; n_samples];
let mut ordering: Vec<usize> = Vec::with_capacity(n_samples);
let mut predecessors: Vec<Option<usize>> = vec![None; n_samples];
for i in 0..n_samples {
core_distances[i] = core_distance(x, i, self.max_eps, self.min_samples);
}
for start in 0..n_samples {
if processed[start] {
continue;
}
processed[start] = true;
ordering.push(start);
if core_distances[start].is_infinite() {
continue;
}
let mut seeds: BinaryHeap<SeedEntry<F>> = BinaryHeap::new();
let (nbrs, nbr_dists) = get_neighbors(x, start, self.max_eps);
update_seeds(
core_distances[start],
start,
&nbrs,
&nbr_dists,
&processed,
&mut reachability,
&mut predecessors,
&mut seeds,
);
while let Some(entry) = seeds.pop() {
let p = entry.idx;
if processed[p] {
continue;
}
if entry.reach_dist > reachability[p] {
continue;
}
processed[p] = true;
ordering.push(p);
if core_distances[p].is_finite() {
let (p_nbrs, p_nbr_dists) = get_neighbors(x, p, self.max_eps);
update_seeds(
core_distances[p],
p,
&p_nbrs,
&p_nbr_dists,
&processed,
&mut reachability,
&mut predecessors,
&mut seeds,
);
}
}
}
let mut labels = xi_cluster_extraction(
&ordering,
&reachability,
&predecessors,
self.xi,
self.min_samples,
);
let min_size = self.min_cluster_size.unwrap_or(self.min_samples);
filter_small_clusters(&mut labels, min_size);
Ok(FittedOPTICS {
ordering_: ordering,
reachability_: reachability,
core_distances_: core_distances,
labels_: labels,
predecessors_: predecessors,
min_samples_: self.min_samples,
})
}
}
impl<F: Float + Send + Sync + 'static> OPTICS<F> {
pub fn fit_predict(&self, x: &Array2<F>) -> Result<Array1<isize>, FerroError> {
let fitted = self.fit(x, &())?;
Ok(fitted.labels().clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn three_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 5.0, 5.0, 5.1, 5.0, 5.0, 5.1, 10.0, 0.0, 10.1, 0.0,
10.0, 0.1,
],
)
.unwrap()
}
#[test]
fn test_ordering_covers_all_points() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let mut sorted = fitted.ordering().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, (0..9).collect::<Vec<_>>());
}
#[test]
fn test_reachability_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.reachability().len(), 9);
}
#[test]
fn test_core_distances_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.core_distances().len(), 9);
}
#[test]
fn test_labels_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.labels().len(), 9);
}
#[test]
fn test_core_points_have_finite_core_distance() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
for i in 0..9 {
assert!(
fitted.core_distances()[i].is_finite(),
"expected finite core distance for point {i}"
);
}
}
#[test]
fn test_isolated_point_infinite_core_distance() {
let mut data = three_blobs().into_raw_vec_and_offset().0;
data.extend_from_slice(&[100.0, 100.0]);
let x = Array2::from_shape_vec((10, 2), data).unwrap();
let fitted = OPTICS::<f64>::new(3)
.with_max_eps(2.0)
.fit(&x, &())
.unwrap();
assert!(
fitted.core_distances()[9].is_infinite(),
"isolated point should have infinite core distance"
);
}
#[test]
fn test_reachability_first_point_infinite() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let first = fitted.ordering()[0];
assert!(
fitted.reachability()[first].is_infinite(),
"first point in ordering should have infinite reachability"
);
}
#[test]
fn test_extract_clusters_valid_xi() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let labels = fitted.extract_clusters(0.05).unwrap();
assert_eq!(labels.len(), 9);
}
#[test]
fn test_extract_clusters_invalid_xi_zero() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert!(fitted.extract_clusters(0.0).is_err());
}
#[test]
fn test_extract_clusters_invalid_xi_one() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert!(fitted.extract_clusters(1.0).is_err());
}
#[test]
fn test_invalid_min_samples_zero() {
let x = three_blobs();
let result = OPTICS::<f64>::new(0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_max_eps_zero() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_max_eps(0.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_max_eps_negative() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_max_eps(-1.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_xi_zero() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_xi(0.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_xi_one() {
let x = three_blobs();
let result = OPTICS::<f64>::new(2).with_xi(1.0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
let result = OPTICS::<f64>::new(2).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_single_sample() {
let x = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
let fitted = OPTICS::<f64>::new(1).fit(&x, &()).unwrap();
assert_eq!(fitted.ordering().len(), 1);
assert_eq!(fitted.ordering()[0], 0);
}
#[test]
fn test_f32_support() {
let x = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.0, 0.0, 0.1, 10.0, 10.0, 10.1, 10.0, 10.0, 10.1,
],
)
.unwrap();
let fitted = OPTICS::<f32>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.ordering().len(), 6);
}
#[test]
fn test_n_clusters_non_negative() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let _ = fitted.n_clusters(); }
#[test]
fn test_ordering_unique_indices() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let ordering = fitted.ordering();
let mut seen = std::collections::HashSet::new();
for &idx in ordering {
assert!(seen.insert(idx), "duplicate index {idx} in ordering");
}
}
#[test]
fn test_with_max_eps_limits_reachability() {
let x = three_blobs();
let max_eps = 0.5;
let fitted = OPTICS::<f64>::new(2)
.with_max_eps(max_eps)
.fit(&x, &())
.unwrap();
for &r in fitted.reachability() {
if r.is_finite() {
assert!(r <= max_eps + 1e-10);
}
}
}
#[test]
fn test_predecessors_length() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
assert_eq!(fitted.predecessors().len(), 9);
}
#[test]
fn test_first_point_has_no_predecessor() {
let x = three_blobs();
let fitted = OPTICS::<f64>::new(2).fit(&x, &()).unwrap();
let first = fitted.ordering()[0];
assert!(
fitted.predecessors()[first].is_none(),
"first point in ordering should have no predecessor"
);
}
#[test]
fn test_min_cluster_size_filters_small_clusters() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 0.05, 0.05, 10.0, 10.0, 10.05, 10.0, 20.0, 20.0, 20.05, 20.0, ],
)
.unwrap();
let fitted = OPTICS::<f64>::new(2)
.with_min_cluster_size(3)
.fit(&x, &())
.unwrap();
for &l in fitted.labels() {
if l >= 0 {
let count = fitted.labels().iter().filter(|&&c| c == l).count();
assert!(
count >= 3,
"cluster with label {l} has only {count} points, expected >= 3"
);
}
}
}
#[test]
fn test_with_min_cluster_size_builder() {
let optics = OPTICS::<f64>::new(5).with_min_cluster_size(10);
assert_eq!(optics.min_cluster_size, Some(10));
}
}