use crate::cluster::traits::affinity_propagation::{
AffinityPropagationOptions, AffinityPropagationResult,
};
use crate::cluster::validation::validate_cluster_dtype;
use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{
CompareOps, ConditionalOps, DistanceOps, IndexingOps, ReduceOps, ScalarOps, ShapeOps,
SortingOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn affinity_propagation_impl<R, C>(
client: &C,
similarities: &Tensor<R>,
options: &AffinityPropagationOptions,
) -> Result<AffinityPropagationResult<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ShapeOps<R>
+ IndexingOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>
+ SortingOps<R>
+ RuntimeClient<R>,
{
validate_cluster_dtype(similarities.dtype(), "affinity_propagation")?;
let n = similarities.shape()[0];
let dtype = similarities.dtype();
let device = similarities.device();
let damping = options.damping;
let pref = match options.preference {
Some(p) => p,
None => {
let flat = similarities.reshape(&[n * n])?;
let sorted = client.sort(&flat, 0, false)?;
let mid = n * n / 2;
let median: f64 = sorted.narrow(0, mid, 1)?.item()?;
median
}
};
let indices = client.arange(0.0, n as f64, 1.0, DType::I64)?;
let eye_idx = indices.unsqueeze(1)?.broadcast_to(&[n, n])?;
let col_idx = client.arange(0.0, n as f64, 1.0, DType::I64)?;
let col_idx_exp = col_idx.unsqueeze(0)?.broadcast_to(&[n, n])?;
let is_diag = client.eq(&eye_idx, &col_idx_exp)?;
let pref_matrix = Tensor::<R>::full_scalar(&[n, n], dtype, pref, device);
let s = client.where_cond(&is_diag, &pref_matrix, similarities)?;
let mut r = Tensor::<R>::zeros(&[n, n], dtype, device);
let mut a = Tensor::<R>::zeros(&[n, n], dtype, device);
let mut n_iter = 0;
let mut no_change_count = 0;
let mut prev_exemplar_f: Option<Tensor<R>> = None;
let damp_t = Tensor::<R>::full_scalar(&[n, n], dtype, damping, device);
let one_minus_damp_t = Tensor::<R>::full_scalar(&[n, n], dtype, 1.0 - damping, device);
let zeros_nn = Tensor::<R>::zeros(&[n, n], dtype, device);
for iter in 0..options.max_iter {
n_iter = iter + 1;
let as_sum = client.add(&a, &s)?;
let sorted_row = client.sort(&as_sum, 1, true)?; let max1 = sorted_row.narrow(1, 0, 1)?; let max2 = sorted_row.narrow(1, 1, 1)?;
let is_max = client.eq(&as_sum, &max1.broadcast_to(&[n, n])?)?;
let exclude_max = client.where_cond(
&is_max,
&max2.broadcast_to(&[n, n])?,
&max1.broadcast_to(&[n, n])?,
)?;
let r_new = client.sub(&s, &exclude_max)?;
r = client.add(
&client.mul(&damp_t, &r)?,
&client.mul(&one_minus_damp_t, &r_new)?,
)?;
let r_pos = client.maximum(&r, &zeros_nn)?;
let sum_r_pos = client.sum(&r_pos, &[0], false)?;
let idx_gather = indices.unsqueeze(1)?; let r_diag = client.gather(&r, 1, &idx_gather)?.reshape(&[n])?;
let r_pos_diag = client.gather(&r_pos, 1, &idx_gather)?.reshape(&[n])?;
let sum_r_pos_exp = sum_r_pos.unsqueeze(0)?.broadcast_to(&[n, n])?;
let r_pos_diag_exp = r_pos_diag.unsqueeze(0)?.broadcast_to(&[n, n])?;
let r_diag_exp = r_diag.unsqueeze(0)?.broadcast_to(&[n, n])?;
let a_raw = client.add(
&r_diag_exp,
&client.sub(&client.sub(&sum_r_pos_exp, &r_pos)?, &r_pos_diag_exp)?,
)?;
let a_non_diag = client.minimum(&a_raw, &zeros_nn)?;
let a_diag_vals = client.sub(&sum_r_pos, &r_pos_diag)?; let a_diag_exp = a_diag_vals.unsqueeze(0)?.broadcast_to(&[n, n])?;
let a_new = client.where_cond(&is_diag, &a_diag_exp, &a_non_diag)?;
a = client.add(
&client.mul(&damp_t, &a)?,
&client.mul(&one_minus_damp_t, &a_new)?,
)?;
let ar_diag = client.add(&client.gather(&a, 1, &idx_gather)?.reshape(&[n])?, &r_diag)?;
let exemplar_mask = client.gt(&ar_diag, &Tensor::<R>::zeros(&[n], dtype, device))?;
let exemplar_f = client.cast(&exemplar_mask, dtype)?;
if let Some(ref prev) = prev_exemplar_f {
let diff = client.sub(&exemplar_f, prev)?;
let abs_diff = client.abs(&diff)?;
let total_diff: f64 = client.sum(&abs_diff, &[0], false)?.item()?;
if total_diff == 0.0 {
no_change_count += 1;
if no_change_count >= options.convergence_iter {
break;
}
} else {
no_change_count = 0;
}
}
prev_exemplar_f = Some(exemplar_f);
}
let ar = client.add(&a, &r)?;
let idx_gather = indices.unsqueeze(1)?;
let ar_diag = client.add(
&client.gather(&ar, 1, &idx_gather)?.reshape(&[n])?,
&Tensor::<R>::zeros(&[n], dtype, device), )?;
let exemplar_mask = client.gt(&ar_diag, &Tensor::<R>::zeros(&[n], dtype, device))?;
let all_indices = client.arange(0.0, n as f64, 1.0, DType::I64)?;
let exemplar_mask_u8 = client.cast(&exemplar_mask, DType::U8)?;
let cluster_centers_indices = client.masked_select(&all_indices, &exemplar_mask_u8)?;
let n_clusters = cluster_centers_indices.shape()[0];
if n_clusters == 0 {
let labels = Tensor::<R>::full_scalar(&[n], DType::I64, -1.0, device);
return Ok(AffinityPropagationResult {
labels,
cluster_centers_indices,
n_iter,
});
}
let exemplar_rows = client.index_select(&s, 0, &cluster_centers_indices)?; let exemplar_cols = exemplar_rows.transpose(0, 1)?; let labels = client.argmax(&exemplar_cols, 1, false)?;
Ok(AffinityPropagationResult {
labels,
cluster_centers_indices,
n_iter,
})
}