rs_graph/
branching.rs

1// Copyright (c) 2016-2022 Frank Fischer <frank-fischer@shadow-soft.de>
2//
3// This program is free software: you can redistribute it and/or
4// modify it under the terms of the GNU General Public License as
5// published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful, but
9// WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11// General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program.  If not, see  <http://www.gnu.org/licenses/>
15//
16
17//! Compute a maximum weight branching.
18
19use crate::builder::{Buildable, Builder};
20use crate::linkedlistgraph::LinkedListGraph;
21use crate::traits::{IndexDigraph, IndexGraph};
22
23use crate::num::traits::NumAssign;
24
25#[allow(clippy::cognitive_complexity)]
26pub fn max_weight_branching<'a, G, W>(g: &'a G, weights: &[W]) -> Vec<G::Edge<'a>>
27where
28    G: IndexDigraph,
29    W: NumAssign + Ord + Copy,
30{
31    // find non-cycle-free subset
32    let mut inarcs = vec![None; g.num_nodes()];
33    for e in g.edges() {
34        let u = g.snk(e);
35        let uid = g.node_id(u);
36        let w = weights[g.edge_id(e)];
37        if let Some((_, max_w)) = inarcs[uid] {
38            if max_w < w {
39                inarcs[uid] = Some((e, w))
40            }
41        } else if w > W::zero() {
42            inarcs[uid] = Some((e, w))
43        }
44    }
45
46    let mut newnodes = vec![None; g.num_nodes()];
47    let mut newg = LinkedListGraph::<usize>::new_builder();
48
49    // find cycles
50    let mut label = vec![0; g.num_nodes()];
51    let mut diffweights = vec![W::zero(); g.num_nodes()];
52    for u in g.nodes() {
53        let uid = g.node_id(u);
54        if label[uid] != 0 {
55            continue;
56        } // node already seen
57
58        // run along predecessors of unseen nodes
59        let mut vid = uid;
60        while label[vid] == 0 {
61            label[vid] = 1;
62            if let Some((e, _)) = inarcs[vid] {
63                vid = g.node_id(g.src(e));
64            } else {
65                break;
66            }
67        }
68
69        if let Some((e, w_e)) = inarcs[vid] {
70            // last node has an incoming arc ...
71            if label[vid] == 1 {
72                // ... and has been seen on *this* path
73                // we have found a cycle
74                // find the minimal weight
75                let mut minweight = w_e;
76                let mut wid = g.node_id(g.src(e));
77                while wid != vid {
78                    let (e, w_e) = inarcs[wid].unwrap();
79                    minweight = minweight.min(w_e);
80                    wid = g.node_id(g.src(e));
81                }
82
83                // contract the cycle and compute the weight difference
84                // for each node
85                let contracted_node = newg.add_node();
86                newnodes[vid] = Some(contracted_node);
87                diffweights[vid] = w_e - minweight;
88                label[vid] = 2;
89                let mut wid = g.node_id(g.src(e));
90                while wid != vid {
91                    newnodes[wid] = Some(contracted_node);
92                    label[wid] = 2;
93                    let (e, w_e) = inarcs[wid].unwrap();
94                    diffweights[wid] = w_e - minweight;
95                    wid = g.node_id(g.src(e));
96                }
97            }
98        }
99
100        // add all remaining nodes on the path as single nodes
101        let mut vid = uid;
102        while label[vid] == 1 {
103            newnodes[vid] = Some(newg.add_node());
104            label[vid] = 2;
105            if let Some((e, _)) = inarcs[vid] {
106                vid = g.node_id(g.src(e));
107            } else {
108                break;
109            }
110        }
111    }
112
113    if newg.num_nodes() == g.num_nodes() {
114        // nothing contracted => found a branching
115        return inarcs.into_iter().filter_map(|e| e.map(|(e, _)| e)).collect();
116    }
117
118    // add arcs
119    let mut newweights = vec![];
120    let mut newarcs = vec![];
121    for e in g.edges() {
122        let newu = newnodes[g.node_id(g.src(e))].unwrap();
123        let newv = newnodes[g.node_id(g.snk(e))].unwrap();
124        if newu != newv {
125            let w_e = weights[g.edge_id(e)];
126            if w_e > W::zero() {
127                newg.add_edge(newu, newv);
128                newarcs.push(e);
129                newweights.push(w_e - diffweights[g.node_id(g.snk(e))]);
130            }
131        }
132    }
133
134    let newg = newg.into_graph();
135
136    // recursively determine branching on smaller graph
137    let newbranching = max_weight_branching(&newg, &newweights[..]);
138    let mut branching = vec![];
139
140    // add original arcs
141    for newa in newbranching {
142        let e = newarcs[newg.edge_id(newa)];
143        branching.push(e);
144        let uid = g.node_id(g.snk(e));
145        label[uid] = 3;
146        // if sink of arc is a contraction node, add the cycle
147        if let Some((inarc, _)) = inarcs[uid] {
148            if inarc != e {
149                let mut vid = g.node_id(g.src(inarc));
150                while vid != uid {
151                    label[vid] = 3;
152                    let e = inarcs[vid].unwrap().0;
153                    branching.push(e);
154                    vid = g.node_id(g.src(e));
155                }
156            }
157        }
158    }
159
160    // Now find all nodes that are not contained in the branching.
161    // These nodes might be contained in a cycle, we add that cycle
162    // except for the cheapest arc.
163    for u in g.nodes() {
164        let uid = g.node_id(u);
165        if label[uid] == 2 {
166            label[uid] = 3;
167            if let Some((mut minarc, mut min_w)) = inarcs[uid] {
168                let mut vid = g.node_id(g.src(minarc));
169                while label[vid] != 3 {
170                    label[vid] = 3;
171                    if let Some((e, w_e)) = inarcs[vid] {
172                        if w_e >= min_w {
173                            branching.push(e);
174                        } else {
175                            branching.push(minarc);
176                            minarc = e;
177                            min_w = w_e;
178                        }
179                        vid = g.node_id(g.src(e));
180                    } else {
181                        break;
182                    }
183                }
184            }
185        }
186    }
187
188    branching
189}
190
191#[cfg(test)]
192mod tests {
193    use crate::branching::max_weight_branching;
194    use crate::traits::IndexGraph;
195    use crate::{Buildable, Builder, LinkedListGraph};
196
197    #[test]
198    fn test_branching1() {
199        let mut g = LinkedListGraph::<usize>::new_builder();
200        let mut weights = vec![];
201        let nodes = g.add_nodes(9);
202
203        for &(u, v, c) in [
204            (1, 4, 17u32),
205            (1, 5, 5),
206            (1, 3, 18),
207            (2, 1, 21),
208            (2, 6, 17),
209            (2, 7, 12),
210            (3, 2, 21),
211            (3, 8, 15),
212            (4, 9, 12),
213            (5, 2, 12),
214            (5, 4, 12),
215            (6, 5, 4),
216            (6, 7, 13),
217            (7, 3, 14),
218            (7, 8, 12),
219            (8, 9, 18),
220            (9, 1, 19),
221            (9, 3, 15),
222        ]
223        .iter()
224        {
225            g.add_edge(nodes[u - 1], nodes[v - 1]);
226            weights.push(c);
227        }
228
229        let g = g.into_graph();
230
231        let branching = max_weight_branching(&g, &weights);
232        assert_eq!(branching.iter().fold(0, |acc, &e| acc + weights[g.edge_id(e)]), 131);
233    }
234
235    #[test]
236    fn test_branching2() {
237        let mut g = LinkedListGraph::<usize>::new_builder();
238        let mut weights = vec![];
239        let nodes = g.add_nodes(9);
240
241        for &(u, v, c) in [
242            (2, 1, 3),
243            (1, 3, 4),
244            (6, 3, 3),
245            (6, 7, 1),
246            (7, 4, 3),
247            (1, 2, 10),
248            (4, 1, 5),
249            (3, 4, 5),
250            (4, 5, 2),
251            (4, 6, 4),
252            (5, 6, 2),
253        ]
254        .iter()
255        {
256            g.add_edge(nodes[u - 1], nodes[v - 1]);
257            weights.push(c);
258        }
259
260        let g = g.into_graph();
261        let branching = max_weight_branching(&g, &weights);
262        assert_eq!(branching.iter().fold(0, |acc, &e| acc + weights[g.edge_id(e)]), 28);
263    }
264}