use crate::error::Result;
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct Dendrogram {
merges: Vec<Merge>,
n_items: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct Merge {
pub cluster_a: usize,
pub cluster_b: usize,
pub distance: f64,
pub size: usize,
}
impl Dendrogram {
pub fn new(n_items: usize) -> Self {
Self {
merges: Vec::with_capacity(n_items.saturating_sub(1)),
n_items,
}
}
pub fn add_merge(&mut self, cluster_a: usize, cluster_b: usize, distance: f64, size: usize) {
self.merges.push(Merge {
cluster_a,
cluster_b,
distance,
size,
});
}
pub fn cut_at_distance(&self, threshold: f64) -> Vec<usize> {
let mut cluster_map: Vec<usize> = (0..self.n_items).collect();
let mut cluster_id_map: Vec<usize> = (0..(2 * self.n_items)).collect();
for (i, merge) in self.merges.iter().enumerate() {
if merge.distance > threshold {
break;
}
let new_cluster_id = self.n_items + i;
let mut id_a = merge.cluster_a;
while cluster_id_map[id_a] != id_a && id_a < cluster_id_map.len() {
id_a = cluster_id_map[id_a];
}
let mut id_b = merge.cluster_b;
while cluster_id_map[id_b] != id_b && id_b < cluster_id_map.len() {
id_b = cluster_id_map[id_b];
}
while cluster_id_map.len() <= new_cluster_id {
cluster_id_map.push(cluster_id_map.len());
}
cluster_id_map[id_a] = new_cluster_id;
cluster_id_map[id_b] = new_cluster_id;
cluster_id_map[merge.cluster_a] = new_cluster_id;
cluster_id_map[merge.cluster_b] = new_cluster_id;
}
for slot in cluster_map.iter_mut().take(self.n_items) {
let mut cid = *slot;
while cid < cluster_id_map.len() && cluster_id_map[cid] != cid {
cid = cluster_id_map[cid];
}
*slot = cid;
}
let mut unique: Vec<usize> = cluster_map.to_vec();
unique.sort_unstable();
unique.dedup();
cluster_map
.iter()
.map(|&l| unique.iter().position(|&u| u == l).unwrap_or(0))
.collect()
}
pub fn cut_to_k(&self, k: usize) -> Result<Vec<usize>> {
if k == 0 || k > self.n_items {
return Ok((0..self.n_items).collect());
}
let n_merges = self.n_items.saturating_sub(k);
if n_merges >= self.merges.len() {
return Ok(self.cut_at_distance(f64::MAX));
}
let threshold = if n_merges > 0 {
self.merges[n_merges - 1].distance
} else {
0.0
};
Ok(self.cut_at_distance(threshold))
}
pub fn n_items(&self) -> usize {
self.n_items
}
pub fn n_merges(&self) -> usize {
self.merges.len()
}
pub fn merges(&self) -> impl Iterator<Item = &Merge> {
self.merges.iter()
}
pub fn distances(&self) -> Vec<f64> {
self.merges.iter().map(|m| m.distance).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dendrogram_creation() {
let dendro = Dendrogram::new(5);
assert_eq!(dendro.n_items(), 5);
assert_eq!(dendro.n_merges(), 0);
}
#[test]
fn test_dendrogram_merge() {
let mut dendro = Dendrogram::new(4);
dendro.add_merge(0, 1, 0.5, 2);
dendro.add_merge(2, 3, 0.7, 2);
dendro.add_merge(4, 5, 1.0, 4);
assert_eq!(dendro.n_merges(), 3);
}
}