use crate::error::{ClusterError, ClusterResult};
use scirs2_core::ndarray::{Array2, ArrayView1};
use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
use std::cmp::Ordering;
use torsh_tensor::Tensor;
pub fn suggest_epsilon(
data: &Tensor,
k: usize,
method: &str,
percentile: Option<f64>,
) -> ClusterResult<f64> {
if k == 0 {
return Err(ClusterError::InvalidInput(
"k must be greater than 0".to_string(),
));
}
let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
let shape = data.shape();
let data_shape = shape.dims();
if data_shape.len() != 2 {
return Err(ClusterError::InvalidInput(
"Data tensor must be 2-dimensional".to_string(),
));
}
let n_samples = data_shape[0];
let n_features = data_shape[1];
if k >= n_samples {
return Err(ClusterError::InvalidInput(format!(
"k ({}) must be less than number of samples ({})",
k, n_samples
)));
}
let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
.map_err(|e| ClusterError::InvalidInput(format!("Failed to reshape data array: {}", e)))?;
let k_distances = compute_k_distances(&data_array, k)?;
match method {
"elbow" => find_elbow_point(&k_distances),
"knee" => find_knee_point(&k_distances),
"percentile" => {
let p = percentile.ok_or_else(|| {
ClusterError::ConfigError(
"percentile parameter required for percentile method".to_string(),
)
})?;
find_percentile(&k_distances, p)
}
_ => Err(ClusterError::ConfigError(format!(
"Unknown epsilon selection method: {}. Use 'elbow', 'knee', or 'percentile'",
method
))),
}
}
fn compute_k_distances(data: &Array2<f32>, k: usize) -> ClusterResult<Vec<f64>> {
let n_samples = data.nrows();
let mut k_distances = Vec::with_capacity(n_samples);
if n_samples >= 500 {
let distances: Vec<f64> = (0..n_samples)
.into_par_iter()
.map(|i| {
let mut dists = Vec::with_capacity(n_samples - 1);
for j in 0..n_samples {
if i != j {
let dist = euclidean_distance(&data.row(i), &data.row(j));
dists.push(dist);
}
}
dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
if dists.len() > k - 1 {
dists[k - 1]
} else {
*dists.last().unwrap_or(&f64::MAX)
}
})
.collect();
k_distances = distances;
} else {
for i in 0..n_samples {
let mut dists = Vec::with_capacity(n_samples - 1);
for j in 0..n_samples {
if i != j {
let dist = euclidean_distance(&data.row(i), &data.row(j));
dists.push(dist);
}
}
dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
if dists.len() > k - 1 {
k_distances.push(dists[k - 1]);
} else {
k_distances.push(*dists.last().unwrap_or(&f64::MAX));
}
}
}
k_distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
Ok(k_distances)
}
fn find_elbow_point(sorted_distances: &[f64]) -> ClusterResult<f64> {
if sorted_distances.len() < 3 {
return Err(ClusterError::InvalidInput(
"Need at least 3 points to find elbow".to_string(),
));
}
let n = sorted_distances.len();
let max_dist = sorted_distances[n - 1];
let min_dist = sorted_distances[0];
let dist_range = max_dist - min_dist;
if dist_range < 1e-10 {
return Ok(sorted_distances[n / 2]);
}
let mut max_curvature = 0.0;
let mut elbow_idx = n / 2;
for i in 1..(n - 1) {
let x_prev = (i - 1) as f64 / (n - 1) as f64;
let x_curr = i as f64 / (n - 1) as f64;
let x_next = (i + 1) as f64 / (n - 1) as f64;
let y_prev = (sorted_distances[i - 1] - min_dist) / dist_range;
let y_curr = (sorted_distances[i] - min_dist) / dist_range;
let y_next = (sorted_distances[i + 1] - min_dist) / dist_range;
let dx1 = x_curr - x_prev;
let dx2 = x_next - x_curr;
let dy1 = y_curr - y_prev;
let dy2 = y_next - y_curr;
let slope1 = dy1 / dx1;
let slope2 = dy2 / dx2;
let curvature = ((slope2 - slope1) / ((dx1 + dx2) / 2.0)).abs();
if curvature > max_curvature {
max_curvature = curvature;
elbow_idx = i;
}
}
Ok(sorted_distances[elbow_idx])
}
fn find_knee_point(sorted_distances: &[f64]) -> ClusterResult<f64> {
if sorted_distances.len() < 2 {
return Err(ClusterError::InvalidInput(
"Need at least 2 points to find knee".to_string(),
));
}
let n = sorted_distances.len();
let x1 = 0.0;
let y1 = sorted_distances[0];
let x2 = (n - 1) as f64;
let y2 = sorted_distances[n - 1];
let m = (y2 - y1) / (x2 - x1);
let b = y1 - m * x1;
let mut max_distance = 0.0;
let mut knee_idx = n / 2;
for (i, &y) in sorted_distances.iter().enumerate() {
let x = i as f64;
let distance = (m * x - y + b).abs() / (m * m + 1.0).sqrt();
if distance > max_distance {
max_distance = distance;
knee_idx = i;
}
}
Ok(sorted_distances[knee_idx])
}
fn find_percentile(sorted_distances: &[f64], percentile: f64) -> ClusterResult<f64> {
if !(0.0..=100.0).contains(&percentile) {
return Err(ClusterError::InvalidInput(format!(
"Percentile must be between 0 and 100, got {}",
percentile
)));
}
if sorted_distances.is_empty() {
return Err(ClusterError::InvalidInput(
"Cannot compute percentile of empty array".to_string(),
));
}
let n = sorted_distances.len();
let idx = ((percentile / 100.0) * (n - 1) as f64).round() as usize;
let idx = idx.min(n - 1);
Ok(sorted_distances[idx])
}
#[inline]
fn euclidean_distance(point1: &ArrayView1<f32>, point2: &ArrayView1<f32>) -> f64 {
let mut sum_sq = 0.0_f64;
for (a, b) in point1.iter().zip(point2.iter()) {
let diff = (*a as f64) - (*b as f64);
sum_sq += diff * diff;
}
sum_sq.sqrt()
}
pub fn suggest_dbscan_params(
data: &Tensor,
method: &str,
percentile: Option<f64>,
) -> ClusterResult<(f64, usize)> {
let shape = data.shape();
let data_shape = shape.dims();
if data_shape.len() != 2 {
return Err(ClusterError::InvalidInput(
"Data tensor must be 2-dimensional".to_string(),
));
}
let n_features = data_shape[1];
let min_samples = (2 * n_features).max(4);
let k = min_samples - 1;
let actual_method = if method == "auto" { "elbow" } else { method };
let epsilon = suggest_epsilon(data, k, actual_method, percentile)?;
Ok((epsilon, min_samples))
}
#[allow(dead_code)]
pub fn optimize_epsilon(
_data: &Tensor,
min_eps: f64,
max_eps: f64,
n_values: usize,
_min_samples: usize,
) -> ClusterResult<(f64, usize, f64)> {
if n_values < 2 {
return Err(ClusterError::InvalidInput(
"Need at least 2 epsilon values to optimize".to_string(),
));
}
if min_eps >= max_eps {
return Err(ClusterError::InvalidInput(
"min_eps must be less than max_eps".to_string(),
));
}
let _eps_values: Vec<f64> = (0..n_values)
.map(|i| {
let log_min = min_eps.ln();
let log_max = max_eps.ln();
let log_eps = log_min + (log_max - log_min) * (i as f64 / (n_values - 1) as f64);
log_eps.exp()
})
.collect();
Err(ClusterError::NotImplemented(
"optimize_epsilon requires external DBSCAN evaluation - use suggest_epsilon instead or implement manually".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_suggest_epsilon_basic() -> ClusterResult<()> {
let data = Tensor::from_vec(
vec![
0.0, 0.0, 0.1, 0.1, 0.2, 0.0, 0.0,
0.2, 5.0, 5.0, 10.0, 10.0,
],
&[6, 2],
)?;
let eps = suggest_epsilon(&data, 3, "elbow", None)?;
assert!(eps > 0.0);
assert!(eps < 20.0);
Ok(())
}
#[test]
fn test_suggest_epsilon_methods() -> ClusterResult<()> {
let data = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 10.0, 10.0], &[4, 2])?;
let eps_elbow = suggest_epsilon(&data, 2, "elbow", None)?;
let eps_knee = suggest_epsilon(&data, 2, "knee", None)?;
let eps_percentile = suggest_epsilon(&data, 2, "percentile", Some(75.0))?;
assert!(eps_elbow > 0.0);
assert!(eps_knee > 0.0);
assert!(eps_percentile > 0.0);
assert!(eps_elbow < 20.0);
assert!(eps_knee < 20.0);
assert!(eps_percentile < 20.0);
Ok(())
}
#[test]
fn test_suggest_epsilon_percentile() -> ClusterResult<()> {
let data = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0], &[3, 2])?;
let eps_50 = suggest_epsilon(&data, 2, "percentile", Some(50.0))?;
let eps_90 = suggest_epsilon(&data, 2, "percentile", Some(90.0))?;
assert!(eps_90 >= eps_50);
Ok(())
}
#[test]
fn test_suggest_dbscan_params() -> ClusterResult<()> {
let data = Tensor::from_vec(
vec![
0.0, 0.0, 0.1, 0.1, 0.2, 0.0, 5.0, 5.0, 5.1, 5.1, 5.2, 5.0, ],
&[6, 2],
)?;
let (eps, min_samples) = suggest_dbscan_params(&data, "auto", None)?;
assert!(min_samples >= 4);
assert!(eps > 0.0);
assert!(eps < 10.0);
Ok(())
}
#[test]
fn test_compute_k_distances() -> ClusterResult<()> {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("test data shape should be valid");
let k_dists = compute_k_distances(&data, 2)?;
assert_eq!(k_dists.len(), 4);
for i in 1..k_dists.len() {
assert!(k_dists[i] >= k_dists[i - 1]);
}
for &dist in &k_dists {
assert!(dist > 0.0);
}
Ok(())
}
#[test]
fn test_find_elbow_point() -> ClusterResult<()> {
let distances = vec![0.1, 0.2, 0.3, 0.4, 1.0, 2.0, 5.0, 10.0, 20.0];
let elbow = find_elbow_point(&distances)?;
assert!(elbow >= 0.1);
assert!(elbow <= 20.0);
Ok(())
}
#[test]
fn test_find_knee_point() -> ClusterResult<()> {
let distances = vec![0.1, 0.2, 0.3, 0.5, 1.0, 3.0, 7.0, 15.0];
let knee = find_knee_point(&distances)?;
assert!(knee >= 0.1);
assert!(knee <= 15.0);
Ok(())
}
#[test]
fn test_find_percentile() -> ClusterResult<()> {
let distances = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let p50 = find_percentile(&distances, 50.0)?;
let p90 = find_percentile(&distances, 90.0)?;
assert_relative_eq!(p50, 5.5, epsilon = 1.0);
assert!(p90 >= 9.0);
assert!(p90 <= 10.0);
Ok(())
}
#[test]
fn test_invalid_inputs() {
let data = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[2, 2])
.expect("test tensor creation should succeed");
assert!(suggest_epsilon(&data, 10, "elbow", None).is_err());
assert!(suggest_epsilon(&data, 0, "elbow", None).is_err());
assert!(suggest_epsilon(&data, 1, "invalid_method", None).is_err());
assert!(suggest_epsilon(&data, 1, "percentile", None).is_err());
assert!(suggest_epsilon(&data, 1, "percentile", Some(150.0)).is_err());
}
#[test]
fn test_euclidean_distance() {
let p1 = Array1::from_vec(vec![0.0, 0.0]);
let p2 = Array1::from_vec(vec![3.0, 4.0]);
let dist = euclidean_distance(&p1.view(), &p2.view());
assert_relative_eq!(dist, 5.0, epsilon = 1e-6);
}
#[test]
fn test_parallel_k_distances() -> ClusterResult<()> {
let mut data_vec = Vec::new();
for i in 0..1000 {
data_vec.push((i % 10) as f32);
data_vec.push((i / 10) as f32);
}
let data = Tensor::from_vec(data_vec, &[1000, 2])?;
let eps = suggest_epsilon(&data, 4, "elbow", None)?;
assert!(eps > 0.0);
assert!(eps.is_finite());
Ok(())
}
}