use crate::cluster::traits::hierarchy::{FClusterCriterion, LinkageMatrix, LinkageMethod};
use crate::cluster::validation::validate_cluster_dtype;
use numr::dtype::DType;
use numr::error::{Error, Result};
use numr::ops::{
CompareOps, ConditionalOps, CumulativeOps, DistanceMetric, DistanceOps, IndexingOps, ReduceOps,
ScalarOps, ScatterReduceOp, ShapeOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn linkage_impl<R, C>(
client: &C,
distances: &Tensor<R>,
n: usize,
method: LinkageMethod,
) -> Result<LinkageMatrix<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
let dtype = distances.dtype();
let device = distances.device();
let sq = client.squareform(distances, n)?;
linkage_from_square(client, &sq, n, method, dtype, device)
}
pub fn linkage_from_data_impl<R, C>(
client: &C,
data: &Tensor<R>,
method: LinkageMethod,
metric: DistanceMetric,
) -> Result<LinkageMatrix<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
validate_cluster_dtype(data.dtype(), "linkage")?;
let n = data.shape()[0];
let dtype = data.dtype();
let device = data.device();
let sq = client.cdist(data, data, metric)?;
linkage_from_square(client, &sq, n, method, dtype, device)
}
fn linkage_from_square<R, C>(
client: &C,
dist_matrix: &Tensor<R>,
n: usize,
method: LinkageMethod,
dtype: DType,
device: &R::Device,
) -> Result<LinkageMatrix<R>>
where
R: Runtime<DType = DType>,
C: ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
if n < 2 {
return Err(Error::InvalidArgument {
arg: "n",
reason: "linkage requires at least 2 points".to_string(),
});
}
let mut dists = dist_matrix.clone();
let idx = client.arange(0.0, n as f64, 1.0, dtype)?;
let idx_row = idx.unsqueeze(1)?;
let idx_col = idx.unsqueeze(0)?;
let diag_mask = client.cast(&client.eq(&idx_row, &idx_col)?, DType::U8)?;
dists = client.masked_fill(&dists, &diag_mask, f64::INFINITY)?;
let mut active = Tensor::<R>::ones(&[n], dtype, device);
let mut sizes = Tensor::<R>::ones(&[n], dtype, device);
let mut z_rows: Vec<Tensor<R>> = Vec::with_capacity(n - 1);
for _step in 0..n - 1 {
let active_row = active.unsqueeze(1)?; let active_col = active.unsqueeze(0)?; let active_mask = client.mul(&active_row, &active_col)?; let inactive = client.cast(
&client.eq(&active_mask, &Tensor::<R>::zeros(&[1], dtype, device))?,
DType::U8,
)?;
let masked_dists = client.masked_fill(&dists, &inactive, f64::INFINITY)?;
let masked_dists = client.masked_fill(&masked_dists, &diag_mask, f64::INFINITY)?;
let flat = masked_dists.reshape(&[n * n])?;
let flat_argmin = client.argmin(&flat, 0, false)?;
let flat_argmin = flat_argmin.reshape(&[1])?;
let flat_f = client.cast(&flat_argmin, dtype)?;
let n_f = n as f64;
let row_f = client.div_scalar(&flat_f, n_f)?;
let row_f = client.sub(&row_f, &Tensor::<R>::full_scalar(&[1], dtype, 0.5, device))?;
let row_i64 = client.cast(&row_f, DType::I64)?;
let row_f = client.cast(&row_i64, dtype)?;
let row_times_n = client.mul_scalar(&row_f, n_f)?;
let col_f = client.sub(&flat_f, &row_times_n)?;
let col_i64 = client.cast(&col_f, DType::I64)?;
let min_dist = client.index_select(&flat, 0, &flat_argmin)?;
let size_i = client.index_select(&sizes, 0, &row_i64)?;
let size_j = client.index_select(&sizes, 0, &col_i64)?;
let new_size = client.add(&size_i, &size_j)?;
let min_ij = client.minimum(&row_f, &col_f)?;
let max_ij = client.maximum(&row_f, &col_f)?;
let z_row = client.cat(&[&min_ij, &max_ij, &min_dist, &new_size], 0)?;
z_rows.push(z_row);
let dist_row_i = client.index_select(&dists, 0, &row_i64)?.reshape(&[n])?;
let dist_row_j = client.index_select(&dists, 0, &col_i64)?.reshape(&[n])?;
let new_dists = match method {
LinkageMethod::Single => client.minimum(&dist_row_i, &dist_row_j)?,
LinkageMethod::Complete => client.maximum(&dist_row_i, &dist_row_j)?,
LinkageMethod::Average => {
let w_i = client.div(&size_i, &new_size)?;
let w_j = client.div(&size_j, &new_size)?;
let term_i = client.mul(&dist_row_i, &w_i)?;
let term_j = client.mul(&dist_row_j, &w_j)?;
client.add(&term_i, &term_j)?
}
LinkageMethod::Weighted => {
let sum = client.add(&dist_row_i, &dist_row_j)?;
client.div_scalar(&sum, 2.0)?
}
LinkageMethod::Ward => {
let all_sizes = sizes.clone();
let d_ij_sq = client.mul(&min_dist, &min_dist)?;
let d_ik_sq = client.mul(&dist_row_i, &dist_row_i)?;
let d_jk_sq = client.mul(&dist_row_j, &dist_row_j)?;
let si_plus_sk = client.add(&size_i, &all_sizes)?;
let sj_plus_sk = client.add(&size_j, &all_sizes)?;
let total = client.add(&new_size, &all_sizes)?;
let term1 = client.mul(&si_plus_sk, &d_ik_sq)?;
let term2 = client.mul(&sj_plus_sk, &d_jk_sq)?;
let term3 = client.mul(&all_sizes, &d_ij_sq)?;
let numer = client.sub(&client.add(&term1, &term2)?, &term3)?;
let result = client.div(&numer, &total)?;
let result = client.maximum(&result, &Tensor::<R>::zeros(&[1], dtype, device))?;
client.sqrt(&result)?
}
LinkageMethod::Centroid | LinkageMethod::Median => {
let w_i = client.div(&size_i, &new_size)?;
let w_j = client.div(&size_j, &new_size)?;
let term_i = client.mul(&dist_row_i, &w_i)?;
let term_j = client.mul(&dist_row_j, &w_j)?;
client.add(&term_i, &term_j)?
}
};
let new_dists_2d = new_dists.unsqueeze(0)?; let row_idx_exp = row_i64.unsqueeze(1)?.broadcast_to(&[1, n])?;
dists = client.scatter(&dists, 0, &row_idx_exp, &new_dists_2d)?;
let new_dists_col = new_dists.unsqueeze(1)?; let row_idx_col = row_i64.unsqueeze(0)?.broadcast_to(&[n, 1])?;
dists = client.scatter(&dists, 1, &row_idx_col, &new_dists_col)?;
let col_zero = Tensor::<R>::zeros(&[1], dtype, device);
active = client
.scatter(
&active.unsqueeze(0)?,
1,
&col_i64.unsqueeze(0)?,
&col_zero.unsqueeze(0)?,
)?
.squeeze(Some(0));
sizes = client
.scatter(
&sizes.unsqueeze(0)?,
1,
&row_i64.unsqueeze(0)?,
&new_size.unsqueeze(0)?,
)?
.squeeze(Some(0));
dists = client.masked_fill(&dists, &diag_mask, f64::INFINITY)?;
}
let z_refs: Vec<&Tensor<R>> = z_rows.iter().collect();
let z = client.stack(&z_refs, 0)?;
Ok(LinkageMatrix { z })
}
pub fn fcluster_impl<R, C>(
client: &C,
z: &LinkageMatrix<R>,
criterion: FClusterCriterion,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ CumulativeOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
let n_merges = z.z.shape()[0];
let n = n_merges + 1;
let dtype = z.z.dtype();
let device = z.z.device();
let threshold = match criterion {
FClusterCriterion::Distance(t) => t,
FClusterCriterion::MaxClust(k) => {
if k >= n {
return client.arange(0.0, n as f64, 1.0, DType::I64);
}
let dists_col = z.z.narrow(1, 2, 1)?.squeeze(Some(1)); let idx = n - k;
if idx == 0 {
0.0
} else {
let val: f64 = dists_col.narrow(0, idx - 1, 1)?.item()?;
if idx < n_merges {
let next: f64 = dists_col.narrow(0, idx, 1)?.item()?;
(val + next) / 2.0
} else {
val + 1.0
}
}
}
};
let id1_f = z.z.narrow(1, 0, 1)?.squeeze(Some(1)); let id2_f = z.z.narrow(1, 1, 1)?.squeeze(Some(1)); let dists_col = z.z.narrow(1, 2, 1)?.squeeze(Some(1));
let thresh_t = Tensor::<R>::full_scalar(&[n_merges], dtype, threshold, device);
let merge_mask = client.le(&dists_col, &thresh_t)?;
let id1_i64 = client.cast(&id1_f, DType::I64)?;
let id2_i64 = client.cast(&id2_f, DType::I64)?;
let mut adjacency = Tensor::<R>::zeros(&[n, n], dtype, device);
let merge_row = Tensor::<R>::zeros(&[n_merges, n], dtype, device);
let merge_row = client.scatter(
&merge_row,
1,
&id2_i64.unsqueeze(1)?,
&merge_mask.unsqueeze(1)?,
)?;
let id1_exp = id1_i64.unsqueeze(1)?.broadcast_to(&[n_merges, n])?;
adjacency = client.scatter_reduce(
&adjacency,
0,
&id1_exp,
&merge_row,
ScatterReduceOp::Max,
true,
)?;
let merge_col = Tensor::<R>::zeros(&[n_merges, n], dtype, device);
let merge_col = client.scatter(
&merge_col,
1,
&id1_i64.unsqueeze(1)?,
&merge_mask.unsqueeze(1)?,
)?;
let id2_exp = id2_i64.unsqueeze(1)?.broadcast_to(&[n_merges, n])?;
adjacency = client.scatter_reduce(
&adjacency,
0,
&id2_exp,
&merge_col,
ScatterReduceOp::Max,
true,
)?;
let eye_mask = client.cast(
&client.eq(
&client.arange(0.0, n as f64, 1.0, dtype)?.unsqueeze(1)?,
&client.arange(0.0, n as f64, 1.0, dtype)?.unsqueeze(0)?,
)?,
dtype,
)?;
adjacency = client.maximum(&adjacency, &eye_mask)?;
let ones_n = Tensor::<R>::ones(&[n], dtype, device);
let large_val = Tensor::<R>::full_scalar(&[n, n], dtype, (n + 1) as f64, device);
let mut labels = client.arange(0.0, n as f64, 1.0, dtype)?;
for _ in 0..n {
let labels_row = labels.unsqueeze(0)?.broadcast_to(&[n, n])?;
let not_adj = client.sub(&Tensor::<R>::ones(&[n, n], dtype, device), &adjacency)?;
let masked = client.add(
&client.mul(¬_adj, &large_val)?,
&client.mul(&adjacency, &labels_row)?,
)?;
let new_labels = client.min(&masked, &[1], false)?;
let own_smaller = client.le(&labels, &new_labels)?;
let not_own = client.sub(&ones_n, &own_smaller)?;
let merged = client.add(
&client.mul(&own_smaller, &labels)?,
&client.mul(¬_own, &new_labels)?,
)?;
let diff = client.sub(&merged, &labels)?;
let abs_diff = client.abs(&diff)?;
let total_diff: f64 = client.sum(&abs_diff, &[0], false)?.item()?;
labels = merged;
if total_diff == 0.0 {
break;
}
}
let labels_i64 = client.cast(&labels, DType::I64)?;
let used = Tensor::<R>::zeros(&[1, n], dtype, device);
let used = client
.scatter_reduce(
&used,
1,
&labels_i64.unsqueeze(0)?,
&ones_n.unsqueeze(0)?,
ScatterReduceOp::Max,
true,
)?
.squeeze(Some(0));
let mapping = client.sub(&client.cumsum(&used, 0)?, &ones_n)?;
let new_labels = client
.gather(&mapping.unsqueeze(0)?, 1, &labels_i64.unsqueeze(0)?)?
.squeeze(Some(0));
Ok(new_labels)
}
pub fn fclusterdata_impl<R, C>(
client: &C,
data: &Tensor<R>,
criterion: FClusterCriterion,
method: LinkageMethod,
metric: DistanceMetric,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ CumulativeOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
let z = linkage_from_data_impl(client, data, method, metric)?;
fcluster_impl(client, &z, criterion)
}
pub fn leaves_list_impl<R, C>(_client: &C, z: &LinkageMatrix<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let n_merges = z.z.shape()[0];
let n = n_merges + 1;
let device = z.z.device();
let z_data: Vec<f64> = z.z.to_vec();
let mut order = Vec::with_capacity(n);
let mut stack = vec![2 * n - 2];
while let Some(node) = stack.pop() {
if node < n {
order.push(node as i64);
} else {
let merge_idx = node - n;
let right = z_data[merge_idx * 4 + 1] as usize;
let left = z_data[merge_idx * 4] as usize;
stack.push(right);
stack.push(left);
}
}
Ok(Tensor::<R>::from_slice(&order, &[n], device))
}
pub fn cut_tree_impl<R, C>(
client: &C,
z: &LinkageMatrix<R>,
n_clusters: &[usize],
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ CumulativeOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ RuntimeClient<R>,
{
let mut label_tensors: Vec<Tensor<R>> = Vec::with_capacity(n_clusters.len());
for &k in n_clusters {
let labels = fcluster_impl(client, z, FClusterCriterion::MaxClust(k))?;
label_tensors.push(labels.unsqueeze(1)?); }
let refs: Vec<&Tensor<R>> = label_tensors.iter().collect();
client.cat(&refs, 1) }