use crate::cluster::traits::optics::{OpticsOptions, OpticsResult};
use crate::cluster::validation::{validate_cluster_dtype, validate_data_2d};
use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{
CompareOps, ConditionalOps, DistanceOps, IndexingOps, ReduceOps, ScalarOps, ShapeOps,
SortingOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn optics_impl<R, C>(
client: &C,
data: &Tensor<R>,
options: &OpticsOptions,
) -> Result<OpticsResult<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ SortingOps<R>
+ RuntimeClient<R>,
{
validate_cluster_dtype(data.dtype(), "optics")?;
validate_data_2d(data.shape(), "optics")?;
let n = data.shape()[0];
let dtype = data.dtype();
let device = data.device();
let inf = f64::INFINITY;
let dists = client.cdist(data, data, options.metric)?;
let sorted_dists = client.sort(&dists, 1, false)?; let ms = options.min_samples;
let core_distances = if ms <= n {
sorted_dists.narrow(1, ms, 1)?.contiguous()?.reshape(&[n])? } else {
Tensor::<R>::full_scalar(&[n], dtype, inf, device)
};
let max_eps_t = Tensor::<R>::full_scalar(&[n], dtype, options.max_eps, device);
let exceeds_max = client.gt(&core_distances, &max_eps_t)?;
let inf_t = Tensor::<R>::full_scalar(&[n], dtype, inf, device);
let core_distances = client.where_cond(&exceeds_max, &inf_t, &core_distances)?;
let mut reachability = Tensor::<R>::full_scalar(&[n], dtype, inf, device);
let mut processed = Tensor::<R>::zeros(&[n], dtype, device); let ones = Tensor::<R>::ones(&[n], dtype, device);
let mut ordering_vec = Vec::with_capacity(n);
for _step in 0..n {
let large = Tensor::<R>::full_scalar(&[n], dtype, inf + 1.0, device);
let proc_bool = client.gt(&processed, &Tensor::<R>::zeros(&[n], dtype, device))?;
let masked_reach = client.where_cond(&proc_bool, &large, &reachability)?;
let current_idx: i64 = client
.argmin(&masked_reach, 0, false)?
.reshape(&[1])?
.item()?;
ordering_vec.push(current_idx);
let idx_t = Tensor::<R>::from_slice(&[current_idx], &[1], device);
let one_val = Tensor::<R>::ones(&[1], dtype, device);
let proc_2d = processed.unsqueeze(0)?;
let idx_2d = idx_t.unsqueeze(0)?;
let one_2d = one_val.unsqueeze(0)?;
processed = client
.scatter(&proc_2d, 1, &idx_2d, &one_2d)?
.squeeze(Some(0));
let current_dists = client.index_select(&dists, 0, &idx_t)?.reshape(&[n])?;
let current_core = client
.index_select(&core_distances, 0, &idx_t)?
.reshape(&[1])?;
let current_core_broadcast = current_core.broadcast_to(&[n])?;
let new_reach = client.maximum(¤t_core_broadcast, ¤t_dists)?;
let within_eps = client.le(¤t_dists, &max_eps_t)?;
let not_processed = client.sub(&ones, &processed)?;
let update_mask = client.mul(&within_eps, ¬_processed)?;
let better = client.lt(&new_reach, &reachability)?;
let should_update = client.mul(&update_mask, &better)?;
reachability = client.where_cond(&should_update, &new_reach, &reachability)?;
}
let ordering = Tensor::<R>::from_slice(&ordering_vec, &[n], device);
let reachability_ordered = client.index_select(&reachability, 0, &ordering)?;
let core_distances_ordered = client.index_select(&core_distances, 0, &ordering)?;
let labels = if let Some(xi) = options.xi {
xi_cluster_extraction(
client,
&reachability_ordered,
&ordering,
n,
xi,
dtype,
device,
)?
} else {
Tensor::<R>::full_scalar(&[n], DType::I64, -1.0, device)
};
Ok(OpticsResult {
ordering,
reachability: reachability_ordered,
core_distances: core_distances_ordered,
labels,
})
}
fn xi_cluster_extraction<R, C>(
_client: &C,
reachability: &Tensor<R>,
ordering: &Tensor<R>,
n: usize,
xi: f64,
_dtype: numr::dtype::DType,
device: &R::Device,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: CompareOps<R>
+ ConditionalOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ IndexingOps<R>
+ ReduceOps<R>
+ TypeConversionOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let reach_vec: Vec<f64> = reachability.to_vec();
let order_vec: Vec<f64> = ordering.to_vec();
let mut labels_vec = vec![-1i64; n];
let mut cluster_id = 0i64;
let factor = 1.0 - xi;
let mut steep_down_start: Option<usize> = None;
for i in 0..n.saturating_sub(1) {
let r_curr = reach_vec[i];
let r_next = reach_vec[i + 1];
if r_curr.is_infinite() || r_next.is_infinite() {
if let Some(start) = steep_down_start.take() {
for &ov in &order_vec[start..=i] {
let orig_idx = ov as usize;
if orig_idx < n {
labels_vec[orig_idx] = cluster_id;
}
}
cluster_id += 1;
}
continue;
}
if r_curr * factor >= r_next {
if steep_down_start.is_none() {
steep_down_start = Some(i);
}
} else if r_curr <= r_next * factor {
if let Some(start) = steep_down_start.take() {
for &ov in &order_vec[start..=i] {
let orig_idx = ov as usize;
if orig_idx < n {
labels_vec[orig_idx] = cluster_id;
}
}
cluster_id += 1;
}
}
}
if let Some(start) = steep_down_start {
for &ov in &order_vec[start..n] {
let orig_idx = ov as usize;
if orig_idx < n {
labels_vec[orig_idx] = cluster_id;
}
}
}
Ok(Tensor::<R>::from_slice(&labels_vec, &[n], device))
}