use crate::cluster::traits::dbscan::{DbscanOptions, DbscanResult};
use crate::cluster::validation::{validate_cluster_dtype, validate_data_2d};
use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{
CompareOps, ConditionalOps, CumulativeOps, DistanceOps, IndexingOps, ReduceOps, ScalarOps,
ShapeOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn dbscan_impl<R, C>(
client: &C,
data: &Tensor<R>,
options: &DbscanOptions,
) -> Result<DbscanResult<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ CumulativeOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
validate_cluster_dtype(data.dtype(), "dbscan")?;
validate_data_2d(data.shape(), "dbscan")?;
let n = data.shape()[0];
let dtype = data.dtype();
let device = data.device();
let zeros_n = Tensor::<R>::zeros(&[n], dtype, device);
let ones_n = Tensor::<R>::ones(&[n], dtype, device);
let dists = client.cdist(data, data, options.metric)?;
let eps_t = Tensor::<R>::full_scalar(&[n, n], dtype, options.eps, device);
let neighbor_mask = client.le(&dists, &eps_t)?;
let neighbor_counts = client.sum(&neighbor_mask, &[1], false)?; let min_samples_t = Tensor::<R>::full_scalar(&[n], dtype, options.min_samples as f64, device);
let is_core = client.ge(&neighbor_counts, &min_samples_t)?;
let is_core_row = is_core.unsqueeze(1)?.broadcast_to(&[n, n])?; let is_core_col = is_core.unsqueeze(0)?.broadcast_to(&[n, n])?; let either_core = client.maximum(&is_core_row, &is_core_col)?; let adjacency = client.mul(&neighbor_mask, &either_core)?;
let mut labels = client.arange(0.0, n as f64, 1.0, dtype)?; let large_val_nn = Tensor::<R>::full_scalar(&[n, n], dtype, (n + 1) as f64, device);
for _ in 0..n {
let labels_row = labels.unsqueeze(0)?.broadcast_to(&[n, n])?;
let adj_nn = adjacency.broadcast_to(&[n, n])?;
let ones_nn = Tensor::<R>::ones(&[n, n], dtype, device);
let not_adj = client.sub(&ones_nn, &adj_nn)?;
let masked_large = client.mul(¬_adj, &large_val_nn)?;
let masked_labels_part = client.mul(&adj_nn, &labels_row)?;
let masked_labels = client.add(&masked_large, &masked_labels_part)?;
let new_labels = client.min(&masked_labels, &[1], false)?;
let own_smaller = client.le(&labels, &new_labels)?; let not_own = client.sub(&ones_n, &own_smaller)?;
let new_labels_merged = client.add(
&client.mul(&own_smaller, &labels)?,
&client.mul(¬_own, &new_labels)?,
)?;
let diff = client.sub(&new_labels_merged, &labels)?;
let abs_diff = client.abs(&diff)?;
let total_diff: f64 = client.sum(&abs_diff, &[0], false)?.item()?;
labels = new_labels_merged;
if total_diff == 0.0 {
break;
}
}
let core_neighbor = client.mul(&neighbor_mask, &is_core_row)?;
let reachable_from_core = client.sum(&core_neighbor, &[0], false)?; let is_reachable = client.gt(&reachable_from_core, &zeros_n)?;
let not_core = client.sub(&ones_n, &is_core)?;
let not_reachable = client.sub(&ones_n, &is_reachable)?;
let is_noise = client.mul(¬_core, ¬_reachable)?;
let neg_one = Tensor::<R>::full_scalar(&[n], dtype, -1.0, device);
let not_noise = client.sub(&ones_n, &is_noise)?;
let labels_after_noise = client.add(
&client.mul(&is_noise, &neg_one)?,
&client.mul(¬_noise, &labels)?,
)?;
let safe_labels = client.maximum(&labels_after_noise, &zeros_n)?;
let safe_labels_i64 = client.cast(&safe_labels, DType::I64)?;
let used = Tensor::<R>::zeros(&[1, n], dtype, device);
let non_noise_2d = client.sub(&ones_n, &is_noise)?.unsqueeze(0)?; let used = client
.scatter_reduce(
&used,
1,
&safe_labels_i64.unsqueeze(0)?,
&non_noise_2d,
numr::ops::ScatterReduceOp::Max,
true,
)?
.squeeze(Some(0));
let mapping = client.sub(&client.cumsum(&used, 0)?, &ones_n)?;
let new_labels_f = client
.gather(&mapping.unsqueeze(0)?, 1, &safe_labels_i64.unsqueeze(0)?)?
.squeeze(Some(0));
let final_labels = client.where_cond(&is_noise, &neg_one, &new_labels_f)?;
let n_clusters: f64 = client.sum(&used, &[0], false)?.item()?;
let is_core_u8 = client.cast(&is_core, DType::U8)?;
let all_indices = client.arange(0.0, n as f64, 1.0, DType::I64)?;
let core_sample_indices = client.masked_select(&all_indices, &is_core_u8)?;
Ok(DbscanResult {
labels: final_labels,
core_sample_indices,
n_clusters: n_clusters as usize,
})
}