1use std::cmp::Ordering;
38use std::collections::BinaryHeap;
39
40use crate::error::{GraphalgError, GraphalgResult};
41use crate::repr::weighted_graph::WeightedGraph;
42
43#[derive(Debug, Clone)]
45pub struct DisjointPaths {
46 pub path_a: Vec<usize>,
48 pub path_b: Vec<usize>,
51 pub total_cost: f64,
53}
54
55#[inline]
70fn node_in(v: usize) -> usize {
71 2 * v
72}
73#[inline]
74fn node_out(v: usize) -> usize {
75 2 * v + 1
76}
77
78#[derive(Debug, Clone, Copy)]
79struct ResEdge {
80 to: usize,
81 rev: usize,
82 cap: i64,
83 cost: f64,
84}
85
86struct Residual {
87 adj: Vec<Vec<usize>>,
88 edges: Vec<ResEdge>,
89}
90
91impl Residual {
92 fn new(num_nodes: usize) -> Self {
93 Self {
94 adj: vec![Vec::new(); num_nodes],
95 edges: Vec::new(),
96 }
97 }
98
99 fn add(&mut self, u: usize, v: usize, cap: i64, cost: f64) {
101 let a = self.edges.len();
102 let b = a + 1;
103 self.edges.push(ResEdge {
104 to: v,
105 rev: b,
106 cap,
107 cost,
108 });
109 self.edges.push(ResEdge {
110 to: u,
111 rev: a,
112 cap: 0,
113 cost: -cost,
114 });
115 self.adj[u].push(a);
116 self.adj[v].push(b);
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq)]
123struct HeapItem {
124 dist: f64,
125 node: usize,
126}
127impl Eq for HeapItem {}
128impl PartialOrd for HeapItem {
129 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
130 Some(self.cmp(other))
131 }
132}
133impl Ord for HeapItem {
134 fn cmp(&self, other: &Self) -> Ordering {
135 other
137 .dist
138 .partial_cmp(&self.dist)
139 .unwrap_or(Ordering::Equal)
140 .then_with(|| other.node.cmp(&self.node))
141 }
142}
143
144fn dijkstra_residual(res: &Residual, src: usize, num_nodes: usize) -> (Vec<f64>, Vec<usize>) {
147 let mut dist = vec![f64::INFINITY; num_nodes];
148 let mut prev_edge = vec![usize::MAX; num_nodes];
149 dist[src] = 0.0;
150 let mut heap: BinaryHeap<HeapItem> = BinaryHeap::new();
151 heap.push(HeapItem {
152 dist: 0.0,
153 node: src,
154 });
155 while let Some(HeapItem { dist: d, node: u }) = heap.pop() {
156 if d > dist[u] + 1e-12 {
157 continue;
158 }
159 for &eid in &res.adj[u] {
160 let e = res.edges[eid];
161 if e.cap <= 0 {
162 continue;
163 }
164 let step = if e.cost < 0.0 { 0.0 } else { e.cost };
166 let nd = d + step;
167 if nd + 1e-12 < dist[e.to] {
168 dist[e.to] = nd;
169 prev_edge[e.to] = eid;
170 heap.push(HeapItem {
171 dist: nd,
172 node: e.to,
173 });
174 }
175 }
176 }
177 (dist, prev_edge)
178}
179
180pub fn suurballe_vertex_disjoint(
192 graph: &WeightedGraph,
193 s: usize,
194 t: usize,
195) -> GraphalgResult<DisjointPaths> {
196 let n = graph.n;
197 if s >= n || t >= n {
198 return Err(GraphalgError::SourceOutOfRange { node: s.max(t), n });
199 }
200 if s == t {
201 return Err(GraphalgError::InvalidParameter(
202 "source must differ from target".to_string(),
203 ));
204 }
205 for u in 0..n {
207 for &(v, w) in graph.neighbors(u)? {
208 if w < 0.0 {
209 return Err(GraphalgError::NegativeWeight {
210 edge: (u, v),
211 weight: w,
212 });
213 }
214 }
215 }
216
217 let d = dijkstra_potentials(graph, s)?;
219 if d[t].is_infinite() {
220 return Err(GraphalgError::NoSolution(
221 "target unreachable from source".to_string(),
222 ));
223 }
224
225 let num_nodes = 2 * n;
227 let mut res = Residual::new(num_nodes);
228
229 for v in 0..n {
231 if v == s || v == t {
232 continue;
233 }
234 res.add(node_in(v), node_out(v), 1, 0.0);
235 }
236
237 for u in 0..n {
240 for &(v, w) in graph.neighbors(u)? {
241 if u == v {
242 continue; }
244 if d[u].is_infinite() {
247 continue;
248 }
249 if v == s || u == t {
252 continue;
253 }
254 let from = node_out(u);
255 let to = node_in(v);
256 let mut rc = w + d[u] - d[v];
258 if rc < 0.0 {
259 rc = 0.0;
260 }
261 res.add(from, to, 1, rc);
262 }
263 }
264
265 let src = node_out(s);
266 let dst = node_in(t);
267
268 let (_, prev1) = dijkstra_residual(&res, src, num_nodes);
270 if prev1[dst] == usize::MAX {
271 return Err(GraphalgError::NoSolution(
272 "target unreachable in residual graph".to_string(),
273 ));
274 }
275 augment(&mut res, src, dst, &prev1);
276
277 let (_, prev2) = dijkstra_residual(&res, src, num_nodes);
279 if prev2[dst] == usize::MAX {
280 return Err(GraphalgError::NoSolution(
281 "no second vertex-disjoint path exists".to_string(),
282 ));
283 }
284 augment(&mut res, src, dst, &prev2);
285
286 let (path_a, path_b) = decompose_two_paths(&mut res, s, t, n)?;
288
289 let total_cost = path_cost(graph, &path_a)? + path_cost(graph, &path_b)?;
291
292 Ok(DisjointPaths {
293 path_a,
294 path_b,
295 total_cost,
296 })
297}
298
299fn dijkstra_potentials(graph: &WeightedGraph, src: usize) -> GraphalgResult<Vec<f64>> {
301 let n = graph.n;
302 let mut dist = vec![f64::INFINITY; n];
303 dist[src] = 0.0;
304 let mut heap: BinaryHeap<HeapItem> = BinaryHeap::new();
305 heap.push(HeapItem {
306 dist: 0.0,
307 node: src,
308 });
309 while let Some(HeapItem { dist: dd, node: u }) = heap.pop() {
310 if dd > dist[u] + 1e-12 {
311 continue;
312 }
313 for &(v, w) in graph.neighbors(u)? {
314 let nd = dd + w;
315 if nd + 1e-12 < dist[v] {
316 dist[v] = nd;
317 heap.push(HeapItem { dist: nd, node: v });
318 }
319 }
320 }
321 Ok(dist)
322}
323
324fn augment(res: &mut Residual, src: usize, dst: usize, prev_edge: &[usize]) {
326 let mut v = dst;
327 while v != src {
328 let eid = prev_edge[v];
329 res.edges[eid].cap -= 1;
330 let rev = res.edges[eid].rev;
331 res.edges[rev].cap += 1;
332 v = res.edges[rev].to;
333 }
334}
335
336fn carries_flow(edges: &[ResEdge], eid: usize) -> bool {
342 if eid % 2 != 0 {
343 return false;
344 }
345 let rev = edges[eid].rev;
346 edges[rev].cap > 0
347}
348
349fn decompose_two_paths(
354 res: &mut Residual,
355 s: usize,
356 t: usize,
357 n: usize,
358) -> GraphalgResult<(Vec<usize>, Vec<usize>)> {
359 let num_nodes = 2 * n;
368 let mut used_next: Vec<usize> = vec![0; num_nodes];
370
371 let src = node_out(s);
372 let dst = node_in(t);
373
374 let mut paths: Vec<Vec<usize>> = Vec::new();
375
376 for _ in 0..2 {
377 let mut path_vertices: Vec<usize> = vec![s];
378 let mut cur = src;
379 let mut guard = 0usize;
380 let limit = num_nodes * 4 + 8;
381 loop {
382 guard += 1;
383 if guard > limit {
384 return Err(GraphalgError::NoSolution(
385 "path decomposition did not terminate".to_string(),
386 ));
387 }
388 if cur == dst {
389 break;
390 }
391 let mut advanced = false;
393 while used_next[cur] < res.adj[cur].len() {
394 let eid = res.adj[cur][used_next[cur]];
395 used_next[cur] += 1;
396 if carries_flow(&res.edges, eid) {
397 let rev = res.edges[eid].rev;
399 res.edges[rev].cap -= 1;
400 let to = res.edges[eid].to;
401 if cur % 2 == 1 && to % 2 == 0 {
403 let v = to / 2;
404 path_vertices.push(v);
405 }
406 cur = to;
407 advanced = true;
408 break;
409 }
410 }
411 if !advanced {
412 return Err(GraphalgError::NoSolution(
413 "incomplete vertex-disjoint path pair".to_string(),
414 ));
415 }
416 }
417 paths.push(path_vertices);
418 }
419
420 let path_a = paths.remove(0);
421 let path_b = paths.remove(0);
422
423 if path_a.first() != Some(&s)
425 || path_a.last() != Some(&t)
426 || path_b.first() != Some(&s)
427 || path_b.last() != Some(&t)
428 {
429 return Err(GraphalgError::NoSolution(
430 "recovered paths are malformed".to_string(),
431 ));
432 }
433 Ok((path_a, path_b))
434}
435
436fn path_cost(graph: &WeightedGraph, path: &[usize]) -> GraphalgResult<f64> {
438 let mut total = 0.0;
439 for w in path.windows(2) {
440 let (u, v) = (w[0], w[1]);
441 let mut best: Option<f64> = None;
442 for &(nb, weight) in graph.neighbors(u)? {
443 if nb == v {
444 best = Some(best.map_or(weight, |b: f64| b.min(weight)));
445 }
446 }
447 match best {
448 Some(c) => total += c,
449 None => {
450 return Err(GraphalgError::NoSolution(format!(
451 "reconstructed edge ({u},{v}) absent from graph"
452 )));
453 }
454 }
455 }
456 Ok(total)
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use crate::min_cost_flow::successive_shortest_paths::{
463 MinCostFlowNetwork, min_cost_flow_bounded,
464 };
465
466 fn wgraph(n: usize, edges: &[(usize, usize, f64)]) -> WeightedGraph {
467 let mut g = WeightedGraph::new(n);
468 for &(u, v, w) in edges {
469 g.add_edge(u, v, w).expect("add ok");
470 }
471 g
472 }
473
474 fn mcf_vertex_disjoint_cost(
477 n: usize,
478 edges: &[(usize, usize, f64)],
479 s: usize,
480 t: usize,
481 ) -> Option<f64> {
482 let mut net = MinCostFlowNetwork::new(2 * n);
484 for v in 0..n {
485 if v == s || v == t {
486 continue;
487 }
488 net.add_edge(2 * v, 2 * v + 1, 1.0, 0.0).expect("ok");
489 }
490 for &(u, v, w) in edges {
491 if u == v || v == s || u == t {
492 continue;
493 }
494 net.add_edge(2 * u + 1, 2 * v, 1.0, w).expect("ok");
495 }
496 let src = 2 * s + 1;
497 let dst = 2 * t;
498 let r = min_cost_flow_bounded(&net, src, dst, 2.0).expect("mcf ok");
499 if (r.flow - 2.0).abs() < 1e-9 {
500 Some(r.cost)
501 } else {
502 None
503 }
504 }
505
506 fn assert_vertex_disjoint(dp: &DisjointPaths, s: usize, t: usize) {
507 use std::collections::HashSet;
508 let interior_a: HashSet<usize> = dp
509 .path_a
510 .iter()
511 .copied()
512 .filter(|&v| v != s && v != t)
513 .collect();
514 let interior_b: HashSet<usize> = dp
515 .path_b
516 .iter()
517 .copied()
518 .filter(|&v| v != s && v != t)
519 .collect();
520 assert!(
521 interior_a.is_disjoint(&interior_b),
522 "paths share an interior vertex: {:?} vs {:?}",
523 dp.path_a,
524 dp.path_b
525 );
526 let mut seen_a = HashSet::new();
528 for &v in &dp.path_a {
529 assert!(seen_a.insert(v), "path_a revisits {v}");
530 }
531 let mut seen_b = HashSet::new();
532 for &v in &dp.path_b {
533 assert!(seen_b.insert(v), "path_b revisits {v}");
534 }
535 assert_eq!(dp.path_a.first(), Some(&s));
536 assert_eq!(dp.path_a.last(), Some(&t));
537 assert_eq!(dp.path_b.first(), Some(&s));
538 assert_eq!(dp.path_b.last(), Some(&t));
539 }
540
541 #[test]
542 fn two_parallel_paths_diamond() {
543 let edges = [(0, 1, 1.0), (1, 3, 1.0), (0, 2, 1.0), (2, 3, 1.0)];
545 let g = wgraph(4, &edges);
546 let dp = suurballe_vertex_disjoint(&g, 0, 3).expect("ok");
547 assert_vertex_disjoint(&dp, 0, 3);
548 assert!((dp.total_cost - 4.0).abs() < 1e-9, "cost={}", dp.total_cost);
549 }
550
551 #[test]
552 fn min_total_cost_matches_mcf_oracle() {
553 let edges = [
556 (0, 1, 1.0),
557 (1, 4, 1.0),
558 (0, 2, 2.0),
559 (2, 4, 2.0),
560 (0, 3, 3.0),
561 (3, 4, 3.0),
562 (1, 2, 1.0),
563 ];
564 let g = wgraph(5, &edges);
565 let dp = suurballe_vertex_disjoint(&g, 0, 4).expect("ok");
566 assert_vertex_disjoint(&dp, 0, 4);
567 let oracle = mcf_vertex_disjoint_cost(5, &edges, 0, 4).expect("oracle has 2 paths");
568 assert!(
569 (dp.total_cost - oracle).abs() < 1e-6,
570 "suurballe={} oracle={}",
571 dp.total_cost,
572 oracle
573 );
574 }
575
576 #[test]
577 fn matches_oracle_on_grid() {
578 let edges = [
580 (0, 1, 2.0),
581 (0, 2, 1.0),
582 (1, 3, 1.0),
583 (1, 4, 3.0),
584 (2, 4, 1.0),
585 (2, 5, 2.0),
586 (3, 6, 2.0),
587 (4, 6, 1.0),
588 (4, 7, 2.0),
589 (5, 7, 1.0),
590 (6, 8, 1.0),
591 (7, 8, 2.0),
592 ];
593 let n = 9;
594 let g = wgraph(n, &edges);
595 let dp = suurballe_vertex_disjoint(&g, 0, 8).expect("ok");
596 assert_vertex_disjoint(&dp, 0, 8);
597 let oracle = mcf_vertex_disjoint_cost(n, &edges, 0, 8).expect("oracle");
598 assert!(
599 (dp.total_cost - oracle).abs() < 1e-6,
600 "suurballe={} oracle={}",
601 dp.total_cost,
602 oracle
603 );
604 }
605
606 #[test]
607 fn fails_when_only_a_bridge_connects() {
608 let edges = [(0, 1, 1.0), (1, 2, 1.0)];
610 let g = wgraph(3, &edges);
611 assert!(matches!(
612 suurballe_vertex_disjoint(&g, 0, 2),
613 Err(GraphalgError::NoSolution(_))
614 ));
615 }
616
617 #[test]
618 fn fails_when_target_unreachable() {
619 let edges = [(0, 1, 1.0)];
620 let g = wgraph(3, &edges); assert!(matches!(
622 suurballe_vertex_disjoint(&g, 0, 2),
623 Err(GraphalgError::NoSolution(_))
624 ));
625 }
626
627 #[test]
628 fn fails_with_single_direct_edge_only() {
629 let edges = [(0, 1, 5.0)];
631 let g = wgraph(2, &edges);
632 assert!(matches!(
633 suurballe_vertex_disjoint(&g, 0, 1),
634 Err(GraphalgError::NoSolution(_))
635 ));
636 }
637
638 #[test]
639 fn two_disjoint_with_a_shared_cut_attempt() {
640 let edges = [
645 (0, 1, 1.0),
646 (1, 2, 1.0),
647 (2, 5, 1.0),
648 (0, 3, 2.0),
649 (3, 2, 1.0),
650 (1, 4, 2.0),
651 (4, 5, 1.0),
652 ];
653 let n = 6;
654 let g = wgraph(n, &edges);
655 let dp = suurballe_vertex_disjoint(&g, 0, 5).expect("ok");
656 assert_vertex_disjoint(&dp, 0, 5);
657 let oracle = mcf_vertex_disjoint_cost(n, &edges, 0, 5).expect("oracle");
658 assert!(
659 (dp.total_cost - oracle).abs() < 1e-6,
660 "suurballe={} oracle={}",
661 dp.total_cost,
662 oracle
663 );
664 }
665
666 #[test]
667 fn reduced_costs_are_nonnegative() {
668 let edges = [
670 (0, 1, 4.0),
671 (0, 2, 1.0),
672 (2, 1, 1.0),
673 (1, 3, 1.0),
674 (2, 3, 5.0),
675 ];
676 let g = wgraph(4, &edges);
677 let d = dijkstra_potentials(&g, 0).expect("ok");
678 for u in 0..g.n {
679 if d[u].is_infinite() {
680 continue;
681 }
682 for &(v, w) in g.neighbors(u).expect("nb") {
683 if d[v].is_infinite() {
684 continue;
685 }
686 let rc = w + d[u] - d[v];
687 assert!(rc >= -1e-9, "reduced cost {rc} negative on {u}->{v}");
688 }
689 }
690 }
691
692 #[test]
693 fn rejects_negative_weight() {
694 let mut g = WeightedGraph::new(3);
695 g.add_edge(0, 1, -2.0).expect("add");
696 g.add_edge(1, 2, 1.0).expect("add");
697 assert!(matches!(
698 suurballe_vertex_disjoint(&g, 0, 2),
699 Err(GraphalgError::NegativeWeight { .. })
700 ));
701 }
702
703 #[test]
704 fn rejects_source_equals_target() {
705 let g = wgraph(3, &[(0, 1, 1.0), (1, 2, 1.0)]);
706 assert!(matches!(
707 suurballe_vertex_disjoint(&g, 1, 1),
708 Err(GraphalgError::InvalidParameter(_))
709 ));
710 }
711
712 #[test]
713 fn rejects_out_of_range() {
714 let g = wgraph(3, &[(0, 1, 1.0)]);
715 assert!(matches!(
716 suurballe_vertex_disjoint(&g, 0, 9),
717 Err(GraphalgError::SourceOutOfRange { .. })
718 ));
719 }
720
721 #[test]
722 fn k1_reduces_to_dijkstra_shortest_path() {
723 use crate::shortest_path::dijkstra::dijkstra;
727 let edges = [
728 (0, 1, 1.0),
729 (1, 3, 1.0),
730 (0, 2, 5.0),
731 (2, 3, 1.0),
732 (0, 3, 9.0),
733 ];
734 let g = wgraph(4, &edges);
735 let sp = dijkstra(&g, 0).expect("dij");
736 let dp = suurballe_vertex_disjoint(&g, 0, 3).expect("ok");
737 let ca = path_cost(&g, &dp.path_a).expect("ca");
738 let cb = path_cost(&g, &dp.path_b).expect("cb");
739 let cheaper = ca.min(cb);
740 assert!(
741 (cheaper - sp.dist[3]).abs() < 1e-9,
742 "cheaper path {cheaper} != dijkstra {}",
743 sp.dist[3]
744 );
745 }
746
747 #[test]
748 fn total_cost_is_sum_of_path_costs() {
749 let edges = [(0, 1, 2.0), (1, 3, 3.0), (0, 2, 4.0), (2, 3, 1.0)];
750 let g = wgraph(4, &edges);
751 let dp = suurballe_vertex_disjoint(&g, 0, 3).expect("ok");
752 let ca = path_cost(&g, &dp.path_a).expect("ca");
753 let cb = path_cost(&g, &dp.path_b).expect("cb");
754 assert!((dp.total_cost - (ca + cb)).abs() < 1e-12);
755 assert!(
756 (dp.total_cost - 10.0).abs() < 1e-9,
757 "cost={}",
758 dp.total_cost
759 );
760 }
761
762 #[test]
763 fn three_disjoint_available_picks_cheapest_two() {
764 let edges = [
766 (0, 1, 1.0),
767 (1, 7, 1.0),
768 (0, 2, 2.0),
769 (2, 7, 2.0),
770 (0, 3, 3.0),
771 (3, 7, 3.0),
772 ];
773 let n = 8;
774 let g = wgraph(n, &edges);
775 let dp = suurballe_vertex_disjoint(&g, 0, 7).expect("ok");
776 assert_vertex_disjoint(&dp, 0, 7);
777 assert!((dp.total_cost - 6.0).abs() < 1e-9, "cost={}", dp.total_cost);
778 let oracle = mcf_vertex_disjoint_cost(n, &edges, 0, 7).expect("oracle");
779 assert!((dp.total_cost - oracle).abs() < 1e-6);
780 }
781}