chu_liu_edmonds/
lib.rs

1//! Find the maximum spanning tree using Chu Liu Edmonds
2
3use std::collections::{HashMap, HashSet};
4use std::f32;
5
6use ndarray::{ArrayView2, ArrayViewMut2, Axis};
7use ordered_float::OrderedFloat;
8
9// The implementation follows the description of:
10//
11// https://en.wikipedia.org/wiki/Edmonds%27_algorithm
12//
13// There are several differences compared to the above
14// description:
15//
16// - We want to compute the maximum spanning tree. So, we find
17//   incoming edges with maximum scores. This means that we have to
18//   change the calculation scores of incoming edges of contracted
19//   cycles. Here we follow Kübler et al., 2009, pp. 47.
20// - Since the input is a score matrix, there are no parallel edges
21//   in the input graph.
22// - Since we use (a copy of) the score matrix to store weights of
23//   incoming/outgoing contraction edges, we cannot store parallel
24//   edges. So, we only store the highest scoring parallel edge
25//   when computing the edges of the contraction. This does not change
26//   the main algorithm, since the next recursion of Chu-Lui-Edmonds
27//   would discard the lower-scoring edges anyway.
28
29/// Chu-Liu-Edmonds maximum spanning tree for dense graphs
30///
31/// This function returns the parent of each vertex in the maximum
32/// spanning tree of the `scores` square matrix, rooted at
33/// `root_vertex`.  Each row in the matrix represents outgoing edge
34/// scores of the corresponding vertex, each column incoming edge
35/// scores. Thus, `scores[(parent, child)]` should give the weight of
36/// the edge from `parent` to child.
37///
38/// Returns vertex parents. The length of the returned `Vec` equals
39/// the number of rows/columns of the scores matrix.
40pub fn chu_liu_edmonds(scores: ArrayView2<f32>, root_vertex: usize) -> Vec<Option<usize>> {
41    assert_eq!(
42        scores.nrows(),
43        scores.ncols(),
44        "Score matrix must be a square matrix, has shape: ({}, {})",
45        scores.nrows(),
46        scores.ncols()
47    );
48
49    // We use this `Vec` to keep track of which vertices are 'active'.
50    // Vertices that are part of a contracted cycle become inactive.
51    let mut active_vertices = vec![true; scores.nrows()];
52
53    chu_liu_edmonds_(
54        scores.to_owned().view_mut(),
55        root_vertex,
56        &mut active_vertices,
57    )
58}
59
60fn chu_liu_edmonds_(
61    mut scores: ArrayViewMut2<f32>,
62    root_vertex: usize,
63    active_vertices: &mut [bool],
64) -> Vec<Option<usize>> {
65    // For each vertex, find the parent with the highest incoming edge
66    // score.
67    let max_parents = find_max_parents(scores.view(), root_vertex, active_vertices);
68
69    // Base case: if the resulting graph does not contain a cycle, we
70    // have found the MST of the (possibly contracted) graph.
71    let cycle = match find_cycle(&max_parents) {
72        Some(cycle) => cycle,
73        None => return max_parents,
74    };
75
76    // Contract the cycle into a single vertex. We use the first
77    // vertex of the cycle to represent the cycle.
78    let (incoming_replacements, outgoing_replacements) =
79        contract_cycle(scores.view_mut(), &max_parents, active_vertices, &cycle);
80
81    // Recursively apply Chu-Liu-Edmonds to the graph with the
82    // contracted cycle, until we hit the base case.
83    let contracted_mst = chu_liu_edmonds_(scores, root_vertex, active_vertices);
84
85    // Expand the contracted cycle in the MST.
86    expand_cycle(
87        max_parents,
88        contracted_mst,
89        cycle,
90        incoming_replacements,
91        outgoing_replacements,
92    )
93}
94
95/// Contract the given cycle.
96///
97/// This updates the score matrix and active vertices.
98///
99/// Returns a mapping of replaced edges.
100#[allow(clippy::type_complexity)]
101fn contract_cycle(
102    mut scores: ArrayViewMut2<f32>,
103    max_parents: &[Option<usize>],
104    active_vertices: &mut [bool],
105    cycle: &[usize],
106) -> (
107    HashMap<(usize, usize), usize>,
108    HashMap<(usize, usize), usize>,
109) {
110    // We will use the first vertex of the cycle to represent the
111    // contraction.
112    let first_in_cycle = cycle[0];
113
114    // Get the sum of edge scores in the cycle. See Kübler et al.,
115    // 2009, pp. 47.
116    let cycle_sum = cycle
117        .iter()
118        .map(|&vertex| {
119            let parent = max_parents[vertex].unwrap();
120            scores[(parent, vertex)]
121        })
122        .sum::<f32>();
123
124    // Mark cycle vertices as inactive.
125    for &vertex in &cycle[1..] {
126        active_vertices[vertex] = false;
127    }
128
129    // Convert the cycle to a set for constant-time
130    // lookups. Constructing and using a set has a negative
131    // performance impact on small graphs, but we are willing to trade
132    // off a small loss for better runtime properties on large graphs.
133    let cycle = cycle.iter().map(ToOwned::to_owned).collect::<HashSet<_>>();
134
135    let mut incoming_replacements = HashMap::new();
136    let mut outgoing_replacements = HashMap::new();
137    for vertex in 0..scores.nrows() {
138        // Skip inactive vertices and vertices that are in the cycle.
139        if !active_vertices[vertex] || cycle.contains(&vertex) {
140            continue;
141        }
142
143        let mut best_incoming = -f32::INFINITY;
144        let mut best_outgoing = -f32::INFINITY;
145
146        let mut best_incoming_vertex = None;
147        let mut best_outgoing_vertex = None;
148
149        for &cycle_vertex in &cycle {
150            // Replace (v, w) by (v_cycle, w)
151            if scores[(cycle_vertex, vertex)] > best_outgoing {
152                best_outgoing = scores[(cycle_vertex, vertex)];
153                best_outgoing_vertex = Some(cycle_vertex);
154            }
155
156            let best_parent = max_parents[cycle_vertex].unwrap();
157            let best_weight = scores[(best_parent, cycle_vertex)];
158            let incoming_score = cycle_sum + scores[(vertex, cycle_vertex)] - best_weight;
159
160            // Replace (u, v) by (u, v_cycle)
161            if incoming_score > best_incoming {
162                best_incoming = incoming_score;
163                best_incoming_vertex = Some(cycle_vertex);
164            }
165        }
166
167        // Save max incoming edge (u, v_cyle) and max outgoing edge
168        // (v_cycle, w).
169        scores[(vertex, first_in_cycle)] = best_incoming;
170        scores[(first_in_cycle, vertex)] = best_outgoing;
171
172        incoming_replacements.insert(
173            (vertex, first_in_cycle),
174            best_incoming_vertex.expect("No edge improves over -INF"),
175        );
176        outgoing_replacements.insert(
177            (first_in_cycle, vertex),
178            best_outgoing_vertex.expect("No edge improves over -INF"),
179        );
180    }
181
182    (incoming_replacements, outgoing_replacements)
183}
184
185/// Expand contracted cycles.
186fn expand_cycle(
187    max_parents: Vec<Option<usize>>,
188    mut mst: Vec<Option<usize>>,
189    cycle: Vec<usize>,
190    incoming_replacements: HashMap<(usize, usize), usize>,
191    outgoing_replacements: HashMap<(usize, usize), usize>,
192) -> Vec<Option<usize>> {
193    let cycle_vertex = cycle[0];
194
195    // Find out which edge was replaced by the incoming edge of
196    // the cycle vertex...
197    let kicked_out = incoming_replacements[&(mst[cycle_vertex].unwrap(), cycle_vertex)];
198
199    // ...v of the kicked-out edge (u, v) becomes the root of the
200    // to-be-broken cycle.
201    mst[kicked_out] = mst[cycle_vertex];
202
203    // Copy all other edges from the cycle.
204    for cycle_vertex in cycle {
205        if cycle_vertex == kicked_out {
206            continue;
207        }
208
209        mst[cycle_vertex] = max_parents[cycle_vertex];
210    }
211
212    // Restore original outgoing edges, replacing (v_cycle, w) by
213    // (v, w).
214    for (contracted_edge, orig_edge) in outgoing_replacements {
215        if mst[contracted_edge.1] == Some(contracted_edge.0) {
216            mst[contracted_edge.1] = Some(orig_edge);
217        }
218    }
219
220    mst
221}
222
223/// Find the parent vertex with the highest edge score for every
224/// active vertex.
225fn find_max_parents(
226    scores: ArrayView2<f32>,
227    root_vertex: usize,
228    active_vertices: &[bool],
229) -> Vec<Option<usize>> {
230    let mut max_parents = vec![None; active_vertices.len()];
231
232    for child in 0..scores.ncols() {
233        // Do not search for parents of root.
234        if child == root_vertex {
235            continue;
236        }
237
238        // Skip inactive vertices.
239        if !active_vertices[child] {
240            continue;
241        }
242
243        // Edge scores are indexed as (parent, child).
244        let parent = scores
245            .index_axis(Axis(1), child)
246            .iter()
247            .enumerate()
248            // Ignore self-loops and inactive vertices.
249            .filter(|v| v.0 != child && active_vertices[v.0])
250            // Find the source (parent) with the largest score.
251            .max_by_key(|v| OrderedFloat(*v.1))
252            // Return the index of the largest parent.
253            .map(|v| v.0);
254
255        max_parents[child] = parent;
256    }
257
258    max_parents
259}
260
261fn find_cycle(parents: &[Option<usize>]) -> Option<Vec<usize>> {
262    let mut visited = vec![false; parents.len()];
263    let mut on_stack = vec![false; parents.len()];
264    let mut edge_to = vec![0; parents.len()];
265
266    for start in 0..parents.len() {
267        if let cycle @ Some(_) =
268            find_cycle_(parents, &mut visited, &mut edge_to, &mut on_stack, start)
269        {
270            return cycle;
271        }
272    }
273
274    None
275}
276
277fn find_cycle_(
278    parents: &[Option<usize>],
279    visited: &mut [bool],
280    edge_to: &mut [usize],
281    on_stack: &mut [bool],
282    vertex: usize,
283) -> Option<Vec<usize>> {
284    visited[vertex] = true;
285    on_stack[vertex] = true;
286
287    // Add the vertex to the stack.
288    if let Some(parent) = parents[vertex] {
289        // Don't perform DFS when the vertex was already visited.
290        if !visited[parent] {
291            edge_to[parent] = vertex;
292            if let cycle @ Some(_) = find_cycle_(parents, visited, edge_to, on_stack, parent) {
293                return cycle;
294            }
295        } else if on_stack[parent] {
296            let mut cycle = Vec::new();
297            let mut cycle_vertex = vertex;
298
299            while cycle_vertex != parent {
300                cycle.push(cycle_vertex);
301                cycle_vertex = edge_to[cycle_vertex];
302            }
303            cycle.push(parent);
304
305            return Some(cycle);
306        }
307    }
308
309    on_stack[vertex] = false;
310    visited[vertex] = true;
311
312    None
313}
314
315#[cfg(test)]
316mod tests {
317    use ndarray::{array, Array};
318    use ndarray_rand::rand_distr::Uniform;
319    use ndarray_rand::RandomExt;
320
321    use super::{chu_liu_edmonds, find_cycle, find_max_parents};
322
323    fn assert_tree(parents: &[Option<usize>], root: usize) {
324        for (vertex, &parent) in parents.iter().enumerate() {
325            if vertex == root {
326                assert_eq!(
327                    parent, None,
328                    "Root vertex {} has a parent in graph {:?}",
329                    root, parents
330                )
331            } else {
332                assert!(
333                    parent.is_some(),
334                    "Non-root vertex {} does not have a parent in the graph {:?}",
335                    vertex,
336                    parents
337                )
338            }
339        }
340
341        let cycle = find_cycle(parents);
342        assert_eq!(
343            find_cycle(parents),
344            None,
345            "Graph {:?} contains a cycle: {:?}",
346            parents,
347            cycle.unwrap()
348        );
349    }
350
351    #[test]
352    pub fn finds_max_parents() {
353        let distances = Array::range(0f32, 25f32, 1f32).into_shape((5, 5)).unwrap();
354        let max_parents = find_max_parents(distances.view(), 0, &[true; 5]);
355        assert_eq!(max_parents, vec![None, Some(4), Some(4), Some(4), Some(3)]);
356    }
357
358    #[test]
359    pub fn finds_max_parents_with_inactive_vertices() {
360        let distances = Array::range(0f32, 25f32, 1f32).into_shape((5, 5)).unwrap();
361        let max_parents = find_max_parents(distances.view(), 0, &[true, false, true, false, true]);
362        assert_eq!(max_parents, vec![None, None, Some(4), None, Some(2)]);
363    }
364
365    #[test]
366    pub fn finds_trees_in_random_graphs() {
367        // We should probably use quickcheck or proptest for this, but
368        // then I have to figure out how to do proper shrinkage, since
369        // we require square matrices. For now, we are just happy to know
370        // if we produce proper trees.
371
372        const NUM_TEST_ITERATIONS: usize = 1000;
373
374        //let mut rng = XorShiftRng::seed_from_u64(42);
375        for _ in 0..NUM_TEST_ITERATIONS {
376            let scores = Array::random((10, 10), Uniform::new(0., 1.));
377            let mst = chu_liu_edmonds(scores.view(), 0);
378            assert_tree(&mst, 0);
379        }
380    }
381
382    #[test]
383    pub fn finds_cycle() {
384        // No cycle.
385        assert_eq!(
386            find_cycle(&[None, Some(0), Some(1), Some(2), Some(3)]),
387            None,
388        );
389
390        // No cycle.
391        assert_eq!(
392            find_cycle(&[None, Some(0), Some(0), Some(0), Some(0)]),
393            None,
394        );
395
396        // Short cycle: 3 -> 4 -> 3
397        assert_eq!(
398            find_cycle(&[None, Some(4), Some(4), Some(4), Some(3)]),
399            Some(vec![3, 4])
400        );
401
402        // Long cycle: 1 -> 2 -> 3 -> 4 -> 1
403        assert_eq!(
404            find_cycle(&[None, Some(4), Some(1), Some(2), Some(3)]),
405            Some(vec![2, 3, 4, 1])
406        );
407
408        // Self-cycle
409        assert_eq!(find_cycle(&[Some(0)]), Some(vec![0]));
410    }
411
412    #[test]
413    fn correctly_decodes_toy_matrices() {
414        let scores = Array::zeros((1, 1));
415        let parents = chu_liu_edmonds(scores.view(), 0);
416        assert_eq!(parents, vec![None]);
417
418        let scores = Array::range(1f32, 10f32, 1f32).into_shape((3, 3)).unwrap();
419        let parents = chu_liu_edmonds(scores.view(), 0);
420        assert_eq!(parents, vec![None, Some(2), Some(0)]);
421
422        let scores = Array::range(1f32, 17f32, 1f32).into_shape((4, 4)).unwrap();
423        let parents = chu_liu_edmonds(scores.view(), 0);
424        assert_eq!(parents, vec![None, Some(3), Some(3), Some(0)]);
425    }
426
427    #[test]
428    fn correctly_decodes_random_large_matrices() {
429        // This unit test checks the output for five random matrices
430        // against the output of the AllenNLP implementation of
431        // Chu-Lui-Edmonds.
432
433        let check1 = array![
434            [
435                0.15154335, 0.21364425, 0.02926004, 0.24640401, 0.05929783, 0.98366485, 0.53015432,
436                0.07778964, 0.00989446, 0.17998191
437            ],
438            [
439                0.68921352, 0.33551225, 0.91974265, 0.08476561, 0.48800752, 0.87661821, 0.31723634,
440                0.51386131, 0.97963044, 0.36960274
441            ],
442            [
443                0.13969799, 0.46092784, 0.75821582, 0.78823102, 0.63945137, 0.42556879, 0.81997744,
444                0.12978648, 0.40536874, 0.4744205
445            ],
446            [
447                0.40688978, 0.25514681, 0.59851297, 0.82950985, 0.46627791, 0.05888491, 0.97450763,
448                0.90287058, 0.35996474, 0.6448661
449            ],
450            [
451                0.30530523, 0.76566773, 0.64714425, 0.1424588, 0.14283951, 0.00153444, 0.9688441,
452                0.87582559, 0.63371798, 0.67004456
453            ],
454            [
455                0.88822529, 0.26780501, 0.61901697, 0.35049028, 0.06430303, 0.44334551, 0.15308377,
456                0.42145127, 0.87420229, 0.3309963
457            ],
458            [
459                0.31808055, 0.35399265, 0.31438455, 0.63534316, 0.36917357, 0.7707749, 0.1686939,
460                0.66622048, 0.67872444, 0.28663183
461            ],
462            [
463                0.82167446, 0.15910145, 0.6654594, 0.54279563, 0.19068867, 0.17368633, 0.07199292,
464                0.29239669, 0.60002772, 0.75121407
465            ],
466            [
467                0.74016819, 0.28619099, 0.71608573, 0.64490596, 0.05975497, 0.8792097, 0.85888953,
468                0.90590799, 0.62783992, 0.12660846
469            ],
470            [
471                0.80810707, 0.10910174, 0.11777376, 0.36885688, 0.88732921, 0.82053854, 0.84096041,
472                0.53546477, 0.49554398, 0.21705035
473            ]
474        ];
475
476        assert_eq!(
477            chu_liu_edmonds(check1.view(), 0),
478            [
479                None,
480                Some(4),
481                Some(1),
482                Some(2),
483                Some(9),
484                Some(0),
485                Some(3),
486                Some(8),
487                Some(5),
488                Some(7)
489            ]
490        );
491
492        let check2 = array![
493            [
494                0.63699522, 0.87615555, 0.45236657, 0.5188734, 0.13080447, 0.30954603, 0.70385654,
495                0.00940039, 0.99012901, 0.91048303
496            ],
497            [
498                0.6110081, 0.11629512, 0.91845679, 0.55938488, 0.45709085, 0.16727591, 0.3338458,
499                0.87262039, 0.26543677, 0.78429413
500            ],
501            [
502                0.06226577, 0.3509711, 0.8738929, 0.77723445, 0.83439156, 0.72800083, 0.70465176,
503                0.9323746, 0.01803918, 0.50092784
504            ],
505            [
506                0.30294811, 0.65599656, 0.23342294, 0.01840916, 0.78500845, 0.78103093, 0.82584077,
507                0.72756822, 0.60326683, 0.44574654
508            ],
509            [
510                0.75513096, 0.06980882, 0.72330091, 0.94334981, 0.262673, 0.84566782, 0.6318016,
511                0.0442728, 0.2669838, 0.59781991
512            ],
513            [
514                0.27443631, 0.33890352, 0.83353679, 0.88552379, 0.89789705, 0.00165288, 0.17836232,
515                0.59181986, 0.426987, 0.91632828
516            ],
517            [
518                0.55585136, 0.87230681, 0.10995064, 0.65543565, 0.96603594, 0.34425304, 0.07438735,
519                0.21991817, 0.53278602, 0.46460502
520            ],
521            [
522                0.78368679, 0.55949995, 0.42268737, 0.1681499, 0.62903574, 0.75765237, 0.07484798,
523                0.37319298, 0.62900207, 0.26623339
524            ],
525            [
526                0.66636035, 0.19227743, 0.48126272, 0.14611228, 0.6107612, 0.30056951, 0.77329224,
527                0.93780084, 0.12710157, 0.96506847
528            ],
529            [
530                0.76441608, 0.25583239, 0.14817458, 0.68389535, 0.85748418, 0.81745151, 0.71656758,
531                0.11733889, 0.98476048, 0.26556185
532            ]
533        ];
534
535        assert_eq!(
536            chu_liu_edmonds(check2.view(), 0),
537            [
538                None,
539                Some(0),
540                Some(1),
541                Some(4),
542                Some(6),
543                Some(4),
544                Some(8),
545                Some(8),
546                Some(0),
547                Some(8)
548            ]
549        );
550
551        let check3 = array![
552            [
553                0.32226934, 0.03494655, 0.13943128, 0.77627796, 0.32289177, 0.20728151, 0.79354934,
554                0.44277001, 0.70666543, 0.76361263
555            ],
556            [
557                0.89787456, 0.19412729, 0.2769623, 0.42547065, 0.78306101, 0.99639906, 0.44910723,
558                0.69166559, 0.5974235, 0.6019087
559            ],
560            [
561                0.01936413, 0.77783413, 0.2635923, 0.24239049, 0.15320177, 0.58810727, 0.93770173,
562                0.97238493, 0.40536974, 0.28189387
563            ],
564            [
565                0.21176774, 0.90580752, 0.48167285, 0.17517493, 0.35126148, 0.09566258, 0.77651317,
566                0.844114, 0.32902123, 0.93356815
567            ],
568            [
569                0.68965019, 0.98577739, 0.06460552, 0.103729, 0.59807881, 0.82418659, 0.20288672,
570                0.55119795, 0.01953631, 0.75208802
571            ],
572            [
573                0.49706455, 0.52543525, 0.16288358, 0.72442708, 0.57151594, 0.68195141, 0.47521668,
574                0.56127222, 0.6673682, 0.93037853
575            ],
576            [
577                0.12841745, 0.89183647, 0.21585613, 0.73852511, 0.09812739, 0.06616884, 0.12730214,
578                0.8322976, 0.93773286, 0.23950978
579            ],
580            [
581                0.73496813, 0.52910843, 0.94925765, 0.77135859, 0.85716859, 0.47158383, 0.88753378,
582                0.00141653, 0.47463287, 0.33777619
583            ],
584            [
585                0.76116294, 0.77581507, 0.99508616, 0.24001213, 0.13688175, 0.57771731, 0.1435426,
586                0.18420174, 0.07373099, 0.15492254
587            ],
588            [
589                0.88146862, 0.27868822, 0.41427004, 0.989063, 0.08847578, 0.31721111, 0.13694788,
590                0.99730908, 0.8523681, 0.81020978
591            ]
592        ];
593
594        assert_eq!(
595            chu_liu_edmonds(check3.view(), 0),
596            [
597                None,
598                Some(4),
599                Some(8),
600                Some(9),
601                Some(7),
602                Some(1),
603                Some(0),
604                Some(2),
605                Some(6),
606                Some(5)
607            ]
608        );
609
610        let check4 = array![
611            [
612                0.94146094, 0.08429249, 0.11658879, 0.7209569, 0.04588338, 0.41361274, 0.00335799,
613                0.58725318, 0.37633847, 0.50978681
614            ],
615            [
616                0.50163181, 0.96919669, 0.16614751, 0.15533209, 0.15054694, 0.08811524, 0.13978445,
617                0.65591973, 0.95264964, 0.17669406
618            ],
619            [
620                0.36864862, 0.95739286, 0.65356991, 0.71690581, 0.29263559, 0.98409776, 0.61308834,
621                0.50921288, 0.49160935, 0.53610581
622            ],
623            [
624                0.23275999, 0.60587704, 0.55893549, 0.69733286, 0.30008536, 0.13133368, 0.90196987,
625                0.52283165, 0.96302483, 0.44467621
626            ],
627            [
628                0.15057842, 0.58499236, 0.11330645, 0.57510935, 0.39645653, 0.53736407, 0.08391498,
629                0.06004636, 0.88086527, 0.25429321
630            ],
631            [
632                0.40042428, 0.08725659, 0.87216523, 0.18444633, 0.61547065, 0.8032823, 0.16163181,
633                0.81884952, 0.51741822, 0.73005934
634            ],
635            [
636                0.08460523, 0.01342742, 0.70127922, 0.45693109, 0.40153192, 0.07611445, 0.74831201,
637                0.3385515, 0.24000027, 0.33290993
638            ],
639            [
640                0.01990056, 0.28629396, 0.85476794, 0.68330081, 0.93204836, 0.14587584, 0.06681271,
641                0.50342723, 0.30878763, 0.51632671
642            ],
643            [
644                0.22297607, 0.99004514, 0.02590417, 0.61425698, 0.16932825, 0.06197453, 0.58227628,
645                0.46317503, 0.21611736, 0.88426682
646            ],
647            [
648                0.21695749, 0.52528143, 0.9569687, 0.70641648, 0.45516634, 0.59951297, 0.82591367,
649                0.6038499, 0.14423517, 0.12984568
650            ]
651        ];
652
653        assert_eq!(
654            chu_liu_edmonds(check4.view(), 0),
655            [
656                None,
657                Some(8),
658                Some(9),
659                Some(0),
660                Some(7),
661                Some(2),
662                Some(3),
663                Some(5),
664                Some(3),
665                Some(8)
666            ]
667        );
668
669        let check5 = array![
670            [
671                0.19181828, 0.07215655, 0.49029481, 0.40338361, 0.77464947, 0.15287357, 0.33550702,
672                0.9075557, 0.16816009, 0.12815985
673            ],
674            [
675                0.39814249, 0.83951939, 0.6197687, 0.10285881, 0.35754604, 0.03372432, 0.26903616,
676                0.39758852, 0.27831648, 0.8626124
677            ],
678            [
679                0.32651809, 0.36621293, 0.55139869, 0.48841691, 0.86105511, 0.95220918, 0.99901665,
680                0.43452191, 0.51957831, 0.12977951
681            ],
682            [
683                0.24777433, 0.20835293, 0.35423981, 0.8647926, 0.54734269, 0.19705202, 0.20262791,
684                0.29885766, 0.89558149, 0.48529723
685            ],
686            [
687                0.99486246, 0.02998787, 0.94388915, 0.16682153, 0.04621821, 0.78283825, 0.32711021,
688                0.11668783, 0.54230828, 0.01990573
689            ],
690            [
691                0.81816179, 0.77223827, 0.3778254, 0.14590591, 0.53032985, 0.12751733, 0.80951733,
692                0.94590486, 0.14917576, 0.0905699
693            ],
694            [
695                0.56977204, 0.6759112, 0.86349563, 0.30270709, 0.03673155, 0.8814458, 0.52538187,
696                0.97650872, 0.9278274, 0.73412665
697            ],
698            [
699                0.96577082, 0.17352435, 0.71417166, 0.57713058, 0.99690502, 0.5856659, 0.87223811,
700                0.8265802, 0.07539461, 0.28718492
701            ],
702            [
703                0.64135636, 0.53712009, 0.98343642, 0.68861079, 0.33153221, 0.86677607, 0.65411023,
704                0.97146557, 0.78007143, 0.24988737
705            ],
706            [
707                0.52704545, 0.39384584, 0.99308, 0.03148114, 0.43305557, 0.11551732, 0.13331425,
708                0.17881437, 0.05076005, 0.20889167
709            ]
710        ];
711
712        assert_eq!(
713            chu_liu_edmonds(check5.view(), 0),
714            [
715                None,
716                Some(5),
717                Some(4),
718                Some(8),
719                Some(7),
720                Some(2),
721                Some(2),
722                Some(0),
723                Some(6),
724                Some(1)
725            ]
726        );
727    }
728
729    #[test]
730    #[should_panic]
731    fn panics_on_incorrect_shape_score_matrix() {
732        let scores = Array::range(0f32, 16f32, 1f32).into_shape((2, 8)).unwrap();
733        let _ = chu_liu_edmonds(scores.view(), 0);
734    }
735}