use std::collections::HashMap;
use crate::error::{GraphalgError, GraphalgResult};
use crate::repr::weighted_graph::WeightedGraph;
pub fn chu_liu_edmonds(g: &WeightedGraph, root: usize) -> GraphalgResult<Vec<usize>> {
let n = g.n;
if root >= n {
return Err(GraphalgError::SourceOutOfRange { node: root, n });
}
let edges: Vec<(usize, usize, f64)> = g.all_edges();
cle_internal(n, root, &edges)
}
fn cle_internal(
n: usize,
root: usize,
edges: &[(usize, usize, f64)],
) -> GraphalgResult<Vec<usize>> {
let mut min_in_edge: Vec<Option<(usize, usize, f64)>> = vec![None; n];
for (idx, &(u, v, w)) in edges.iter().enumerate() {
if v == root {
continue;
}
match min_in_edge[v] {
None => min_in_edge[v] = Some((u, idx, w)),
Some((_, _, cw)) if w < cw => {
min_in_edge[v] = Some((u, idx, w));
}
_ => {}
}
}
for v in 0..n {
if v == root {
continue;
}
if min_in_edge[v].is_none() {
return Err(GraphalgError::NoSolution(format!(
"node {v} has no incoming edge"
)));
}
}
let mut comp = vec![usize::MAX; n];
let mut visited = vec![usize::MAX; n];
let mut comp_id = 0usize;
let mut in_cycle = vec![false; n];
for start in 0..n {
if visited[start] != usize::MAX {
continue;
}
let mut node = start;
while node != root && visited[node] == usize::MAX && comp[node] == usize::MAX {
visited[node] = start;
let (u, _idx, _w) = match min_in_edge[node] {
Some(t) => t,
None => break,
};
node = u;
}
if node != root && visited[node] == start {
let mut cur = node;
loop {
comp[cur] = comp_id;
in_cycle[cur] = true;
let (u, _idx, _w) = match min_in_edge[cur] {
Some(t) => t,
None => break,
};
cur = u;
if cur == node {
break;
}
}
comp_id += 1;
}
}
if comp_id == 0 {
let mut parent = vec![root; n];
for v in 0..n {
if v == root {
continue;
}
if let Some((u, _, _)) = min_in_edge[v] {
parent[v] = u;
}
}
return Ok(parent);
}
let mut new_id = vec![0usize; n];
let mut next_id = comp_id;
for v in 0..n {
if in_cycle[v] {
new_id[v] = comp[v];
} else {
new_id[v] = next_id;
next_id += 1;
}
}
let new_root = new_id[root];
let mut new_edges: Vec<(usize, usize, f64)> = Vec::new();
let mut edge_map: HashMap<(usize, usize), (usize, usize, f64)> = HashMap::new();
for &(u, v, w) in edges {
if v == root {
continue;
}
let nu = new_id[u];
let nv = new_id[v];
if nu == nv {
continue;
}
let adj_w = if in_cycle[v] {
let cw = match min_in_edge[v] {
Some((_, _, x)) => x,
None => 0.0,
};
w - cw
} else {
w
};
match edge_map.get_mut(&(nu, nv)) {
None => {
edge_map.insert((nu, nv), (u, v, adj_w));
}
Some(slot) => {
if adj_w < slot.2 {
*slot = (u, v, adj_w);
}
}
}
}
for ((nu, nv), (_, _, w)) in &edge_map {
new_edges.push((*nu, *nv, *w));
}
let new_parents_super = cle_internal(next_id, new_root, &new_edges)?;
let mut parent = vec![root; n];
for v in 0..n {
if v == root {
continue;
}
if let Some((u, _, _)) = min_in_edge[v] {
parent[v] = u;
}
}
for c in 0..comp_id {
let super_in = new_parents_super[c];
let key = (super_in, c);
if let Some(&(orig_u, orig_v, _)) = edge_map.get(&key) {
parent[orig_v] = orig_u;
}
}
Ok(parent)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cle_no_cycle() {
let mut g = WeightedGraph::new(4);
g.add_edge(0, 1, 1.0).expect("ok");
g.add_edge(0, 2, 1.0).expect("ok");
g.add_edge(1, 3, 1.0).expect("ok");
let p = chu_liu_edmonds(&g, 0).expect("ok");
assert_eq!(p[1], 0);
assert_eq!(p[2], 0);
assert_eq!(p[3], 1);
}
}