use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use super::TopicId;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DendrogramNode {
pub id: u32,
pub left: Option<u32>,
pub right: Option<u32>,
pub distance: f32,
pub count: usize,
}
impl DendrogramNode {
pub fn leaf(id: u32) -> Self {
Self {
id,
left: None,
right: None,
distance: 0.0,
count: 1,
}
}
pub fn internal(id: u32, left: u32, right: u32, distance: f32, count: usize) -> Self {
Self {
id,
left: Some(left),
right: Some(right),
distance,
count,
}
}
#[inline]
pub fn is_leaf(&self) -> bool {
self.left.is_none() && self.right.is_none()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Dendrogram {
nodes: HashMap<u32, DendrogramNode>,
num_leaves: usize,
root: Option<u32>,
}
impl Dendrogram {
pub fn new(num_leaves: usize) -> Self {
let mut nodes = HashMap::with_capacity(2 * num_leaves - 1);
for i in 0..num_leaves {
nodes.insert(i as u32, DendrogramNode::leaf(i as u32));
}
Self {
nodes,
num_leaves,
root: None,
}
}
pub fn from_linkage(linkage: &[(u32, u32, f32, u32)], num_leaves: usize) -> Self {
let mut dendro = Self::new(num_leaves);
for (i, &(c1, c2, dist, count)) in linkage.iter().enumerate() {
let new_id = (num_leaves + i) as u32;
dendro.nodes.insert(
new_id,
DendrogramNode::internal(new_id, c1, c2, dist, count as usize),
);
}
if !linkage.is_empty() {
dendro.root = Some((num_leaves + linkage.len() - 1) as u32);
} else if num_leaves == 1 {
dendro.root = Some(0);
}
dendro
}
pub fn root(&self) -> Option<&DendrogramNode> {
self.root.and_then(|id| self.nodes.get(&id))
}
pub fn get(&self, id: u32) -> Option<&DendrogramNode> {
self.nodes.get(&id)
}
pub fn num_leaves(&self) -> usize {
self.num_leaves
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn cut_at_distance(&self, threshold: f32) -> Vec<u32> {
let mut assignments = vec![0u32; self.num_leaves];
let mut next_cluster = 0u32;
if let Some(root_id) = self.root {
self.cut_recursive(root_id, threshold, &mut assignments, &mut next_cluster);
} else {
for (i, assignment) in assignments.iter_mut().enumerate() {
*assignment = i as u32;
}
}
assignments
}
fn cut_recursive(
&self,
node_id: u32,
threshold: f32,
assignments: &mut [u32],
next_cluster: &mut u32,
) {
let Some(node) = self.nodes.get(&node_id) else {
return;
};
if node.is_leaf() {
assignments[node_id as usize] = *next_cluster;
} else if node.distance > threshold {
if let Some(left) = node.left {
self.cut_recursive(left, threshold, assignments, next_cluster);
}
*next_cluster += 1;
if let Some(right) = node.right {
self.cut_recursive(right, threshold, assignments, next_cluster);
}
} else {
self.assign_all_leaves(node_id, *next_cluster, assignments);
}
}
fn assign_all_leaves(&self, node_id: u32, cluster: u32, assignments: &mut [u32]) {
let Some(node) = self.nodes.get(&node_id) else {
return;
};
if node.is_leaf() {
assignments[node_id as usize] = cluster;
} else {
if let Some(left) = node.left {
self.assign_all_leaves(left, cluster, assignments);
}
if let Some(right) = node.right {
self.assign_all_leaves(right, cluster, assignments);
}
}
}
pub fn cut_to_k_clusters(&self, k: usize) -> Vec<u32> {
if k >= self.num_leaves {
return (0..self.num_leaves as u32).collect();
}
if k <= 1 {
return vec![0; self.num_leaves];
}
let mut distances: Vec<f32> = self
.nodes
.values()
.filter(|n| !n.is_leaf())
.map(|n| n.distance)
.collect();
distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
distances.dedup();
let m = distances.len();
if k > m + 1 {
return (0..self.num_leaves as u32).collect();
}
let threshold = if k == 1 {
distances.last().copied().unwrap_or(f32::MAX) + 0.0001
} else if k > m {
distances.first().copied().unwrap_or(0.0) - 0.0001
} else {
let lower_idx = m - k;
let upper_idx = m - k + 1;
(distances[lower_idx] + distances[upper_idx]) / 2.0
};
self.cut_at_distance(threshold)
}
pub fn nodes_at_level(&self, min_distance: f32, max_distance: f32) -> Vec<&DendrogramNode> {
self.nodes
.values()
.filter(|n| n.distance >= min_distance && n.distance < max_distance)
.collect()
}
pub fn leaves_under(&self, node_id: u32) -> Vec<u32> {
let mut leaves = Vec::new();
self.collect_leaves(node_id, &mut leaves);
leaves
}
fn collect_leaves(&self, node_id: u32, leaves: &mut Vec<u32>) {
let Some(node) = self.nodes.get(&node_id) else {
return;
};
if node.is_leaf() {
leaves.push(node_id);
} else {
if let Some(left) = node.left {
self.collect_leaves(left, leaves);
}
if let Some(right) = node.right {
self.collect_leaves(right, leaves);
}
}
}
pub fn depth(&self, node_id: u32) -> Option<usize> {
self.root
.map(|root_id| self.depth_from(root_id, node_id, 0))
}
fn depth_from(&self, current: u32, target: u32, current_depth: usize) -> usize {
if current == target {
return current_depth;
}
let Some(node) = self.nodes.get(¤t) else {
return usize::MAX;
};
let mut min_depth = usize::MAX;
if let Some(left) = node.left {
let d = self.depth_from(left, target, current_depth + 1);
min_depth = min_depth.min(d);
}
if let Some(right) = node.right {
let d = self.depth_from(right, target, current_depth + 1);
min_depth = min_depth.min(d);
}
min_depth
}
pub fn assignments_to_topic_ids(assignments: &[u32]) -> Vec<TopicId> {
assignments.iter().map(|&a| TopicId::new(a)).collect()
}
pub fn unique_clusters(assignments: &[u32]) -> Vec<u32> {
let unique: HashSet<u32> = assignments.iter().copied().collect();
let mut sorted: Vec<u32> = unique.into_iter().collect();
sorted.sort();
sorted
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_dendrogram() {
let dendro = Dendrogram::new(5);
assert_eq!(dendro.num_leaves(), 5);
assert_eq!(dendro.num_nodes(), 5);
assert!(dendro.root().is_none());
}
#[test]
fn test_from_linkage() {
let linkage = vec![(0, 1, 1.0, 2), (2, 3, 1.5, 2), (4, 5, 2.0, 4)];
let dendro = Dendrogram::from_linkage(&linkage, 4);
assert_eq!(dendro.num_leaves(), 4);
assert_eq!(dendro.num_nodes(), 7);
let root = dendro.root().expect("should have root");
assert_eq!(root.id, 6);
assert_eq!(root.distance, 2.0);
assert_eq!(root.count, 4);
}
#[test]
fn test_cut_at_distance() {
let linkage = vec![(0, 1, 1.0, 2), (2, 3, 1.5, 2), (4, 5, 2.0, 4)];
let dendro = Dendrogram::from_linkage(&linkage, 4);
let assignments = dendro.cut_at_distance(3.0);
let unique = Dendrogram::unique_clusters(&assignments);
assert_eq!(unique.len(), 1);
let assignments = dendro.cut_at_distance(1.8);
let unique = Dendrogram::unique_clusters(&assignments);
assert_eq!(unique.len(), 2);
let assignments = dendro.cut_at_distance(0.5);
let unique = Dendrogram::unique_clusters(&assignments);
assert_eq!(unique.len(), 4);
}
#[test]
fn test_cut_to_k_clusters() {
let linkage = vec![(0, 1, 1.0, 2), (2, 3, 1.5, 2), (4, 5, 2.0, 4)];
let dendro = Dendrogram::from_linkage(&linkage, 4);
let assignments = dendro.cut_to_k_clusters(2);
let unique = Dendrogram::unique_clusters(&assignments);
assert_eq!(unique.len(), 2);
assert_eq!(assignments[0], assignments[1]);
assert_eq!(assignments[2], assignments[3]);
assert_ne!(assignments[0], assignments[2]);
let assignments = dendro.cut_to_k_clusters(1);
let unique = Dendrogram::unique_clusters(&assignments);
assert_eq!(unique.len(), 1);
let assignments = dendro.cut_to_k_clusters(4);
let unique = Dendrogram::unique_clusters(&assignments);
assert_eq!(unique.len(), 4);
}
#[test]
fn test_leaves_under() {
let linkage = vec![(0, 1, 1.0, 2), (2, 3, 1.5, 2), (4, 5, 2.0, 4)];
let dendro = Dendrogram::from_linkage(&linkage, 4);
let leaves = dendro.leaves_under(4);
assert_eq!(leaves.len(), 2);
assert!(leaves.contains(&0));
assert!(leaves.contains(&1));
let leaves = dendro.leaves_under(6);
assert_eq!(leaves.len(), 4);
}
#[test]
fn test_nodes_at_level() {
let linkage = vec![(0, 1, 1.0, 2), (2, 3, 1.5, 2), (4, 5, 2.0, 4)];
let dendro = Dendrogram::from_linkage(&linkage, 4);
let nodes = dendro.nodes_at_level(1.0, 1.6);
assert_eq!(nodes.len(), 2);
let nodes = dendro.nodes_at_level(2.0, f32::MAX);
assert_eq!(nodes.len(), 1); }
#[test]
fn test_assignments_to_topic_ids() {
let assignments = vec![0, 0, 1, 1, 2];
let topic_ids = Dendrogram::assignments_to_topic_ids(&assignments);
assert_eq!(topic_ids.len(), 5);
assert_eq!(topic_ids[0], TopicId::new(0));
assert_eq!(topic_ids[4], TopicId::new(2));
}
#[test]
fn test_single_node() {
let dendro = Dendrogram::new(1);
assert_eq!(dendro.num_leaves(), 1);
let linkage: Vec<(u32, u32, f32, u32)> = vec![];
let dendro = Dendrogram::from_linkage(&linkage, 1);
assert_eq!(dendro.num_leaves(), 1);
}
#[test]
fn test_depth() {
let linkage = vec![(0, 1, 1.0, 2), (2, 3, 1.5, 2), (4, 5, 2.0, 4)];
let dendro = Dendrogram::from_linkage(&linkage, 4);
assert_eq!(dendro.depth(6), Some(0));
assert_eq!(dendro.depth(4), Some(1));
assert_eq!(dendro.depth(0), Some(2));
}
}