1use std::collections::{HashMap, HashSet};
4use std::f32;
5
6use ndarray::{ArrayView2, ArrayViewMut2, Axis};
7use ordered_float::OrderedFloat;
8
9pub fn chu_liu_edmonds(scores: ArrayView2<f32>, root_vertex: usize) -> Vec<Option<usize>> {
41 assert_eq!(
42 scores.nrows(),
43 scores.ncols(),
44 "Score matrix must be a square matrix, has shape: ({}, {})",
45 scores.nrows(),
46 scores.ncols()
47 );
48
49 let mut active_vertices = vec![true; scores.nrows()];
52
53 chu_liu_edmonds_(
54 scores.to_owned().view_mut(),
55 root_vertex,
56 &mut active_vertices,
57 )
58}
59
60fn chu_liu_edmonds_(
61 mut scores: ArrayViewMut2<f32>,
62 root_vertex: usize,
63 active_vertices: &mut [bool],
64) -> Vec<Option<usize>> {
65 let max_parents = find_max_parents(scores.view(), root_vertex, active_vertices);
68
69 let cycle = match find_cycle(&max_parents) {
72 Some(cycle) => cycle,
73 None => return max_parents,
74 };
75
76 let (incoming_replacements, outgoing_replacements) =
79 contract_cycle(scores.view_mut(), &max_parents, active_vertices, &cycle);
80
81 let contracted_mst = chu_liu_edmonds_(scores, root_vertex, active_vertices);
84
85 expand_cycle(
87 max_parents,
88 contracted_mst,
89 cycle,
90 incoming_replacements,
91 outgoing_replacements,
92 )
93}
94
95#[allow(clippy::type_complexity)]
101fn contract_cycle(
102 mut scores: ArrayViewMut2<f32>,
103 max_parents: &[Option<usize>],
104 active_vertices: &mut [bool],
105 cycle: &[usize],
106) -> (
107 HashMap<(usize, usize), usize>,
108 HashMap<(usize, usize), usize>,
109) {
110 let first_in_cycle = cycle[0];
113
114 let cycle_sum = cycle
117 .iter()
118 .map(|&vertex| {
119 let parent = max_parents[vertex].unwrap();
120 scores[(parent, vertex)]
121 })
122 .sum::<f32>();
123
124 for &vertex in &cycle[1..] {
126 active_vertices[vertex] = false;
127 }
128
129 let cycle = cycle.iter().map(ToOwned::to_owned).collect::<HashSet<_>>();
134
135 let mut incoming_replacements = HashMap::new();
136 let mut outgoing_replacements = HashMap::new();
137 for vertex in 0..scores.nrows() {
138 if !active_vertices[vertex] || cycle.contains(&vertex) {
140 continue;
141 }
142
143 let mut best_incoming = -f32::INFINITY;
144 let mut best_outgoing = -f32::INFINITY;
145
146 let mut best_incoming_vertex = None;
147 let mut best_outgoing_vertex = None;
148
149 for &cycle_vertex in &cycle {
150 if scores[(cycle_vertex, vertex)] > best_outgoing {
152 best_outgoing = scores[(cycle_vertex, vertex)];
153 best_outgoing_vertex = Some(cycle_vertex);
154 }
155
156 let best_parent = max_parents[cycle_vertex].unwrap();
157 let best_weight = scores[(best_parent, cycle_vertex)];
158 let incoming_score = cycle_sum + scores[(vertex, cycle_vertex)] - best_weight;
159
160 if incoming_score > best_incoming {
162 best_incoming = incoming_score;
163 best_incoming_vertex = Some(cycle_vertex);
164 }
165 }
166
167 scores[(vertex, first_in_cycle)] = best_incoming;
170 scores[(first_in_cycle, vertex)] = best_outgoing;
171
172 incoming_replacements.insert(
173 (vertex, first_in_cycle),
174 best_incoming_vertex.expect("No edge improves over -INF"),
175 );
176 outgoing_replacements.insert(
177 (first_in_cycle, vertex),
178 best_outgoing_vertex.expect("No edge improves over -INF"),
179 );
180 }
181
182 (incoming_replacements, outgoing_replacements)
183}
184
185fn expand_cycle(
187 max_parents: Vec<Option<usize>>,
188 mut mst: Vec<Option<usize>>,
189 cycle: Vec<usize>,
190 incoming_replacements: HashMap<(usize, usize), usize>,
191 outgoing_replacements: HashMap<(usize, usize), usize>,
192) -> Vec<Option<usize>> {
193 let cycle_vertex = cycle[0];
194
195 let kicked_out = incoming_replacements[&(mst[cycle_vertex].unwrap(), cycle_vertex)];
198
199 mst[kicked_out] = mst[cycle_vertex];
202
203 for cycle_vertex in cycle {
205 if cycle_vertex == kicked_out {
206 continue;
207 }
208
209 mst[cycle_vertex] = max_parents[cycle_vertex];
210 }
211
212 for (contracted_edge, orig_edge) in outgoing_replacements {
215 if mst[contracted_edge.1] == Some(contracted_edge.0) {
216 mst[contracted_edge.1] = Some(orig_edge);
217 }
218 }
219
220 mst
221}
222
223fn find_max_parents(
226 scores: ArrayView2<f32>,
227 root_vertex: usize,
228 active_vertices: &[bool],
229) -> Vec<Option<usize>> {
230 let mut max_parents = vec![None; active_vertices.len()];
231
232 for child in 0..scores.ncols() {
233 if child == root_vertex {
235 continue;
236 }
237
238 if !active_vertices[child] {
240 continue;
241 }
242
243 let parent = scores
245 .index_axis(Axis(1), child)
246 .iter()
247 .enumerate()
248 .filter(|v| v.0 != child && active_vertices[v.0])
250 .max_by_key(|v| OrderedFloat(*v.1))
252 .map(|v| v.0);
254
255 max_parents[child] = parent;
256 }
257
258 max_parents
259}
260
261fn find_cycle(parents: &[Option<usize>]) -> Option<Vec<usize>> {
262 let mut visited = vec![false; parents.len()];
263 let mut on_stack = vec![false; parents.len()];
264 let mut edge_to = vec![0; parents.len()];
265
266 for start in 0..parents.len() {
267 if let cycle @ Some(_) =
268 find_cycle_(parents, &mut visited, &mut edge_to, &mut on_stack, start)
269 {
270 return cycle;
271 }
272 }
273
274 None
275}
276
277fn find_cycle_(
278 parents: &[Option<usize>],
279 visited: &mut [bool],
280 edge_to: &mut [usize],
281 on_stack: &mut [bool],
282 vertex: usize,
283) -> Option<Vec<usize>> {
284 visited[vertex] = true;
285 on_stack[vertex] = true;
286
287 if let Some(parent) = parents[vertex] {
289 if !visited[parent] {
291 edge_to[parent] = vertex;
292 if let cycle @ Some(_) = find_cycle_(parents, visited, edge_to, on_stack, parent) {
293 return cycle;
294 }
295 } else if on_stack[parent] {
296 let mut cycle = Vec::new();
297 let mut cycle_vertex = vertex;
298
299 while cycle_vertex != parent {
300 cycle.push(cycle_vertex);
301 cycle_vertex = edge_to[cycle_vertex];
302 }
303 cycle.push(parent);
304
305 return Some(cycle);
306 }
307 }
308
309 on_stack[vertex] = false;
310 visited[vertex] = true;
311
312 None
313}
314
315#[cfg(test)]
316mod tests {
317 use ndarray::{array, Array};
318 use ndarray_rand::rand_distr::Uniform;
319 use ndarray_rand::RandomExt;
320
321 use super::{chu_liu_edmonds, find_cycle, find_max_parents};
322
323 fn assert_tree(parents: &[Option<usize>], root: usize) {
324 for (vertex, &parent) in parents.iter().enumerate() {
325 if vertex == root {
326 assert_eq!(
327 parent, None,
328 "Root vertex {} has a parent in graph {:?}",
329 root, parents
330 )
331 } else {
332 assert!(
333 parent.is_some(),
334 "Non-root vertex {} does not have a parent in the graph {:?}",
335 vertex,
336 parents
337 )
338 }
339 }
340
341 let cycle = find_cycle(parents);
342 assert_eq!(
343 find_cycle(parents),
344 None,
345 "Graph {:?} contains a cycle: {:?}",
346 parents,
347 cycle.unwrap()
348 );
349 }
350
351 #[test]
352 pub fn finds_max_parents() {
353 let distances = Array::range(0f32, 25f32, 1f32).into_shape((5, 5)).unwrap();
354 let max_parents = find_max_parents(distances.view(), 0, &[true; 5]);
355 assert_eq!(max_parents, vec![None, Some(4), Some(4), Some(4), Some(3)]);
356 }
357
358 #[test]
359 pub fn finds_max_parents_with_inactive_vertices() {
360 let distances = Array::range(0f32, 25f32, 1f32).into_shape((5, 5)).unwrap();
361 let max_parents = find_max_parents(distances.view(), 0, &[true, false, true, false, true]);
362 assert_eq!(max_parents, vec![None, None, Some(4), None, Some(2)]);
363 }
364
365 #[test]
366 pub fn finds_trees_in_random_graphs() {
367 const NUM_TEST_ITERATIONS: usize = 1000;
373
374 for _ in 0..NUM_TEST_ITERATIONS {
376 let scores = Array::random((10, 10), Uniform::new(0., 1.));
377 let mst = chu_liu_edmonds(scores.view(), 0);
378 assert_tree(&mst, 0);
379 }
380 }
381
382 #[test]
383 pub fn finds_cycle() {
384 assert_eq!(
386 find_cycle(&[None, Some(0), Some(1), Some(2), Some(3)]),
387 None,
388 );
389
390 assert_eq!(
392 find_cycle(&[None, Some(0), Some(0), Some(0), Some(0)]),
393 None,
394 );
395
396 assert_eq!(
398 find_cycle(&[None, Some(4), Some(4), Some(4), Some(3)]),
399 Some(vec![3, 4])
400 );
401
402 assert_eq!(
404 find_cycle(&[None, Some(4), Some(1), Some(2), Some(3)]),
405 Some(vec![2, 3, 4, 1])
406 );
407
408 assert_eq!(find_cycle(&[Some(0)]), Some(vec![0]));
410 }
411
412 #[test]
413 fn correctly_decodes_toy_matrices() {
414 let scores = Array::zeros((1, 1));
415 let parents = chu_liu_edmonds(scores.view(), 0);
416 assert_eq!(parents, vec![None]);
417
418 let scores = Array::range(1f32, 10f32, 1f32).into_shape((3, 3)).unwrap();
419 let parents = chu_liu_edmonds(scores.view(), 0);
420 assert_eq!(parents, vec![None, Some(2), Some(0)]);
421
422 let scores = Array::range(1f32, 17f32, 1f32).into_shape((4, 4)).unwrap();
423 let parents = chu_liu_edmonds(scores.view(), 0);
424 assert_eq!(parents, vec![None, Some(3), Some(3), Some(0)]);
425 }
426
427 #[test]
428 fn correctly_decodes_random_large_matrices() {
429 let check1 = array![
434 [
435 0.15154335, 0.21364425, 0.02926004, 0.24640401, 0.05929783, 0.98366485, 0.53015432,
436 0.07778964, 0.00989446, 0.17998191
437 ],
438 [
439 0.68921352, 0.33551225, 0.91974265, 0.08476561, 0.48800752, 0.87661821, 0.31723634,
440 0.51386131, 0.97963044, 0.36960274
441 ],
442 [
443 0.13969799, 0.46092784, 0.75821582, 0.78823102, 0.63945137, 0.42556879, 0.81997744,
444 0.12978648, 0.40536874, 0.4744205
445 ],
446 [
447 0.40688978, 0.25514681, 0.59851297, 0.82950985, 0.46627791, 0.05888491, 0.97450763,
448 0.90287058, 0.35996474, 0.6448661
449 ],
450 [
451 0.30530523, 0.76566773, 0.64714425, 0.1424588, 0.14283951, 0.00153444, 0.9688441,
452 0.87582559, 0.63371798, 0.67004456
453 ],
454 [
455 0.88822529, 0.26780501, 0.61901697, 0.35049028, 0.06430303, 0.44334551, 0.15308377,
456 0.42145127, 0.87420229, 0.3309963
457 ],
458 [
459 0.31808055, 0.35399265, 0.31438455, 0.63534316, 0.36917357, 0.7707749, 0.1686939,
460 0.66622048, 0.67872444, 0.28663183
461 ],
462 [
463 0.82167446, 0.15910145, 0.6654594, 0.54279563, 0.19068867, 0.17368633, 0.07199292,
464 0.29239669, 0.60002772, 0.75121407
465 ],
466 [
467 0.74016819, 0.28619099, 0.71608573, 0.64490596, 0.05975497, 0.8792097, 0.85888953,
468 0.90590799, 0.62783992, 0.12660846
469 ],
470 [
471 0.80810707, 0.10910174, 0.11777376, 0.36885688, 0.88732921, 0.82053854, 0.84096041,
472 0.53546477, 0.49554398, 0.21705035
473 ]
474 ];
475
476 assert_eq!(
477 chu_liu_edmonds(check1.view(), 0),
478 [
479 None,
480 Some(4),
481 Some(1),
482 Some(2),
483 Some(9),
484 Some(0),
485 Some(3),
486 Some(8),
487 Some(5),
488 Some(7)
489 ]
490 );
491
492 let check2 = array![
493 [
494 0.63699522, 0.87615555, 0.45236657, 0.5188734, 0.13080447, 0.30954603, 0.70385654,
495 0.00940039, 0.99012901, 0.91048303
496 ],
497 [
498 0.6110081, 0.11629512, 0.91845679, 0.55938488, 0.45709085, 0.16727591, 0.3338458,
499 0.87262039, 0.26543677, 0.78429413
500 ],
501 [
502 0.06226577, 0.3509711, 0.8738929, 0.77723445, 0.83439156, 0.72800083, 0.70465176,
503 0.9323746, 0.01803918, 0.50092784
504 ],
505 [
506 0.30294811, 0.65599656, 0.23342294, 0.01840916, 0.78500845, 0.78103093, 0.82584077,
507 0.72756822, 0.60326683, 0.44574654
508 ],
509 [
510 0.75513096, 0.06980882, 0.72330091, 0.94334981, 0.262673, 0.84566782, 0.6318016,
511 0.0442728, 0.2669838, 0.59781991
512 ],
513 [
514 0.27443631, 0.33890352, 0.83353679, 0.88552379, 0.89789705, 0.00165288, 0.17836232,
515 0.59181986, 0.426987, 0.91632828
516 ],
517 [
518 0.55585136, 0.87230681, 0.10995064, 0.65543565, 0.96603594, 0.34425304, 0.07438735,
519 0.21991817, 0.53278602, 0.46460502
520 ],
521 [
522 0.78368679, 0.55949995, 0.42268737, 0.1681499, 0.62903574, 0.75765237, 0.07484798,
523 0.37319298, 0.62900207, 0.26623339
524 ],
525 [
526 0.66636035, 0.19227743, 0.48126272, 0.14611228, 0.6107612, 0.30056951, 0.77329224,
527 0.93780084, 0.12710157, 0.96506847
528 ],
529 [
530 0.76441608, 0.25583239, 0.14817458, 0.68389535, 0.85748418, 0.81745151, 0.71656758,
531 0.11733889, 0.98476048, 0.26556185
532 ]
533 ];
534
535 assert_eq!(
536 chu_liu_edmonds(check2.view(), 0),
537 [
538 None,
539 Some(0),
540 Some(1),
541 Some(4),
542 Some(6),
543 Some(4),
544 Some(8),
545 Some(8),
546 Some(0),
547 Some(8)
548 ]
549 );
550
551 let check3 = array![
552 [
553 0.32226934, 0.03494655, 0.13943128, 0.77627796, 0.32289177, 0.20728151, 0.79354934,
554 0.44277001, 0.70666543, 0.76361263
555 ],
556 [
557 0.89787456, 0.19412729, 0.2769623, 0.42547065, 0.78306101, 0.99639906, 0.44910723,
558 0.69166559, 0.5974235, 0.6019087
559 ],
560 [
561 0.01936413, 0.77783413, 0.2635923, 0.24239049, 0.15320177, 0.58810727, 0.93770173,
562 0.97238493, 0.40536974, 0.28189387
563 ],
564 [
565 0.21176774, 0.90580752, 0.48167285, 0.17517493, 0.35126148, 0.09566258, 0.77651317,
566 0.844114, 0.32902123, 0.93356815
567 ],
568 [
569 0.68965019, 0.98577739, 0.06460552, 0.103729, 0.59807881, 0.82418659, 0.20288672,
570 0.55119795, 0.01953631, 0.75208802
571 ],
572 [
573 0.49706455, 0.52543525, 0.16288358, 0.72442708, 0.57151594, 0.68195141, 0.47521668,
574 0.56127222, 0.6673682, 0.93037853
575 ],
576 [
577 0.12841745, 0.89183647, 0.21585613, 0.73852511, 0.09812739, 0.06616884, 0.12730214,
578 0.8322976, 0.93773286, 0.23950978
579 ],
580 [
581 0.73496813, 0.52910843, 0.94925765, 0.77135859, 0.85716859, 0.47158383, 0.88753378,
582 0.00141653, 0.47463287, 0.33777619
583 ],
584 [
585 0.76116294, 0.77581507, 0.99508616, 0.24001213, 0.13688175, 0.57771731, 0.1435426,
586 0.18420174, 0.07373099, 0.15492254
587 ],
588 [
589 0.88146862, 0.27868822, 0.41427004, 0.989063, 0.08847578, 0.31721111, 0.13694788,
590 0.99730908, 0.8523681, 0.81020978
591 ]
592 ];
593
594 assert_eq!(
595 chu_liu_edmonds(check3.view(), 0),
596 [
597 None,
598 Some(4),
599 Some(8),
600 Some(9),
601 Some(7),
602 Some(1),
603 Some(0),
604 Some(2),
605 Some(6),
606 Some(5)
607 ]
608 );
609
610 let check4 = array![
611 [
612 0.94146094, 0.08429249, 0.11658879, 0.7209569, 0.04588338, 0.41361274, 0.00335799,
613 0.58725318, 0.37633847, 0.50978681
614 ],
615 [
616 0.50163181, 0.96919669, 0.16614751, 0.15533209, 0.15054694, 0.08811524, 0.13978445,
617 0.65591973, 0.95264964, 0.17669406
618 ],
619 [
620 0.36864862, 0.95739286, 0.65356991, 0.71690581, 0.29263559, 0.98409776, 0.61308834,
621 0.50921288, 0.49160935, 0.53610581
622 ],
623 [
624 0.23275999, 0.60587704, 0.55893549, 0.69733286, 0.30008536, 0.13133368, 0.90196987,
625 0.52283165, 0.96302483, 0.44467621
626 ],
627 [
628 0.15057842, 0.58499236, 0.11330645, 0.57510935, 0.39645653, 0.53736407, 0.08391498,
629 0.06004636, 0.88086527, 0.25429321
630 ],
631 [
632 0.40042428, 0.08725659, 0.87216523, 0.18444633, 0.61547065, 0.8032823, 0.16163181,
633 0.81884952, 0.51741822, 0.73005934
634 ],
635 [
636 0.08460523, 0.01342742, 0.70127922, 0.45693109, 0.40153192, 0.07611445, 0.74831201,
637 0.3385515, 0.24000027, 0.33290993
638 ],
639 [
640 0.01990056, 0.28629396, 0.85476794, 0.68330081, 0.93204836, 0.14587584, 0.06681271,
641 0.50342723, 0.30878763, 0.51632671
642 ],
643 [
644 0.22297607, 0.99004514, 0.02590417, 0.61425698, 0.16932825, 0.06197453, 0.58227628,
645 0.46317503, 0.21611736, 0.88426682
646 ],
647 [
648 0.21695749, 0.52528143, 0.9569687, 0.70641648, 0.45516634, 0.59951297, 0.82591367,
649 0.6038499, 0.14423517, 0.12984568
650 ]
651 ];
652
653 assert_eq!(
654 chu_liu_edmonds(check4.view(), 0),
655 [
656 None,
657 Some(8),
658 Some(9),
659 Some(0),
660 Some(7),
661 Some(2),
662 Some(3),
663 Some(5),
664 Some(3),
665 Some(8)
666 ]
667 );
668
669 let check5 = array![
670 [
671 0.19181828, 0.07215655, 0.49029481, 0.40338361, 0.77464947, 0.15287357, 0.33550702,
672 0.9075557, 0.16816009, 0.12815985
673 ],
674 [
675 0.39814249, 0.83951939, 0.6197687, 0.10285881, 0.35754604, 0.03372432, 0.26903616,
676 0.39758852, 0.27831648, 0.8626124
677 ],
678 [
679 0.32651809, 0.36621293, 0.55139869, 0.48841691, 0.86105511, 0.95220918, 0.99901665,
680 0.43452191, 0.51957831, 0.12977951
681 ],
682 [
683 0.24777433, 0.20835293, 0.35423981, 0.8647926, 0.54734269, 0.19705202, 0.20262791,
684 0.29885766, 0.89558149, 0.48529723
685 ],
686 [
687 0.99486246, 0.02998787, 0.94388915, 0.16682153, 0.04621821, 0.78283825, 0.32711021,
688 0.11668783, 0.54230828, 0.01990573
689 ],
690 [
691 0.81816179, 0.77223827, 0.3778254, 0.14590591, 0.53032985, 0.12751733, 0.80951733,
692 0.94590486, 0.14917576, 0.0905699
693 ],
694 [
695 0.56977204, 0.6759112, 0.86349563, 0.30270709, 0.03673155, 0.8814458, 0.52538187,
696 0.97650872, 0.9278274, 0.73412665
697 ],
698 [
699 0.96577082, 0.17352435, 0.71417166, 0.57713058, 0.99690502, 0.5856659, 0.87223811,
700 0.8265802, 0.07539461, 0.28718492
701 ],
702 [
703 0.64135636, 0.53712009, 0.98343642, 0.68861079, 0.33153221, 0.86677607, 0.65411023,
704 0.97146557, 0.78007143, 0.24988737
705 ],
706 [
707 0.52704545, 0.39384584, 0.99308, 0.03148114, 0.43305557, 0.11551732, 0.13331425,
708 0.17881437, 0.05076005, 0.20889167
709 ]
710 ];
711
712 assert_eq!(
713 chu_liu_edmonds(check5.view(), 0),
714 [
715 None,
716 Some(5),
717 Some(4),
718 Some(8),
719 Some(7),
720 Some(2),
721 Some(2),
722 Some(0),
723 Some(6),
724 Some(1)
725 ]
726 );
727 }
728
729 #[test]
730 #[should_panic]
731 fn panics_on_incorrect_shape_score_matrix() {
732 let scores = Array::range(0f32, 16f32, 1f32).into_shape((2, 8)).unwrap();
733 let _ = chu_liu_edmonds(scores.view(), 0);
734 }
735}