use crate::clustering::mst::MstEdge;
use crate::prelude::*;
#[derive(Clone, Debug)]
pub struct LinkageRow<T> {
pub left: usize,
pub right: usize,
pub distance: T,
pub size: usize,
}
pub fn mst_to_linkage_tree<T>(mst: &mut [MstEdge<T>], n_samples: usize) -> Vec<LinkageRow<T>>
where
T: EvocFloat,
{
mst.sort_by(|a, b| a.weight.partial_cmp(&b.weight).unwrap());
let n_merges = mst.len();
let total_nodes = 2 * n_samples - 1;
let mut parent = vec![usize::MAX; total_nodes];
let mut size = vec![0usize; total_nodes];
for i in 0..n_samples {
size[i] = 1;
}
let mut linkage = Vec::with_capacity(n_merges);
let mut next_label = n_samples;
for edge in mst.iter() {
let mut left = edge.u;
let mut right = edge.v;
while parent[left] != usize::MAX {
left = parent[left];
}
while parent[right] != usize::MAX {
right = parent[right];
}
let new_size = size[left] + size[right];
if left < right {
std::mem::swap(&mut left, &mut right);
}
linkage.push(LinkageRow {
left,
right,
distance: edge.weight,
size: new_size,
});
parent[left] = next_label;
parent[right] = next_label;
size[next_label] = new_size;
next_label += 1;
}
linkage
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clustering::mst::MstEdge;
#[test]
fn test_linkage_basic() {
let mut mst = vec![
MstEdge {
u: 0,
v: 1,
weight: 1.0,
},
MstEdge {
u: 1,
v: 2,
weight: 2.0,
},
];
let linkage = mst_to_linkage_tree(&mut mst, 3);
assert_eq!(linkage.len(), 2);
assert_eq!(linkage[0].distance, 1.0);
assert_eq!(linkage[0].size, 2);
assert_eq!(linkage[1].distance, 2.0);
assert_eq!(linkage[1].size, 3);
}
#[test]
fn test_linkage_unsorted_input() {
let mut mst = vec![
MstEdge {
u: 1,
v: 2,
weight: 5.0,
},
MstEdge {
u: 0,
v: 1,
weight: 1.0,
},
];
let linkage = mst_to_linkage_tree(&mut mst, 3);
assert_eq!(linkage[0].distance, 1.0);
assert_eq!(linkage[1].distance, 5.0);
}
#[test]
fn test_linkage_sizes() {
let mut mst = vec![
MstEdge {
u: 0,
v: 1,
weight: 1.0,
},
MstEdge {
u: 1,
v: 2,
weight: 2.0,
},
MstEdge {
u: 2,
v: 3,
weight: 3.0,
},
];
let linkage = mst_to_linkage_tree(&mut mst, 4);
assert_eq!(linkage.len(), 3);
assert_eq!(linkage[0].size, 2);
assert_eq!(linkage[1].size, 3);
assert_eq!(linkage[2].size, 4);
}
#[test]
fn test_linkage_monotonic_distances() {
let mut mst = vec![
MstEdge {
u: 0,
v: 1,
weight: 3.0,
},
MstEdge {
u: 2,
v: 3,
weight: 1.0,
},
MstEdge {
u: 1,
v: 2,
weight: 5.0,
},
];
let linkage = mst_to_linkage_tree(&mut mst, 4);
for i in 1..linkage.len() {
assert!(linkage[i].distance >= linkage[i - 1].distance);
}
}
#[test]
fn test_linkage_two_simultaneous_merges() {
let mut mst = vec![
MstEdge {
u: 0,
v: 1,
weight: 1.0,
},
MstEdge {
u: 2,
v: 3,
weight: 1.0,
},
MstEdge {
u: 0,
v: 2,
weight: 10.0,
},
];
let linkage = mst_to_linkage_tree(&mut mst, 4);
assert_eq!(linkage.len(), 3);
assert_eq!(linkage[0].size, 2);
assert_eq!(linkage[1].size, 2);
assert_eq!(linkage[2].size, 4);
}
}