use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::error::{GraphalgError, GraphalgResult};
use crate::repr::weighted_graph::WeightedGraph;
#[derive(Debug, Clone)]
pub struct DisjointPaths {
pub path_a: Vec<usize>,
pub path_b: Vec<usize>,
pub total_cost: f64,
}
#[inline]
fn node_in(v: usize) -> usize {
2 * v
}
#[inline]
fn node_out(v: usize) -> usize {
2 * v + 1
}
#[derive(Debug, Clone, Copy)]
struct ResEdge {
to: usize,
rev: usize,
cap: i64,
cost: f64,
}
struct Residual {
adj: Vec<Vec<usize>>,
edges: Vec<ResEdge>,
}
impl Residual {
fn new(num_nodes: usize) -> Self {
Self {
adj: vec![Vec::new(); num_nodes],
edges: Vec::new(),
}
}
fn add(&mut self, u: usize, v: usize, cap: i64, cost: f64) {
let a = self.edges.len();
let b = a + 1;
self.edges.push(ResEdge {
to: v,
rev: b,
cap,
cost,
});
self.edges.push(ResEdge {
to: u,
rev: a,
cap: 0,
cost: -cost,
});
self.adj[u].push(a);
self.adj[v].push(b);
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct HeapItem {
dist: f64,
node: usize,
}
impl Eq for HeapItem {}
impl PartialOrd for HeapItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapItem {
fn cmp(&self, other: &Self) -> Ordering {
other
.dist
.partial_cmp(&self.dist)
.unwrap_or(Ordering::Equal)
.then_with(|| other.node.cmp(&self.node))
}
}
fn dijkstra_residual(res: &Residual, src: usize, num_nodes: usize) -> (Vec<f64>, Vec<usize>) {
let mut dist = vec![f64::INFINITY; num_nodes];
let mut prev_edge = vec![usize::MAX; num_nodes];
dist[src] = 0.0;
let mut heap: BinaryHeap<HeapItem> = BinaryHeap::new();
heap.push(HeapItem {
dist: 0.0,
node: src,
});
while let Some(HeapItem { dist: d, node: u }) = heap.pop() {
if d > dist[u] + 1e-12 {
continue;
}
for &eid in &res.adj[u] {
let e = res.edges[eid];
if e.cap <= 0 {
continue;
}
let step = if e.cost < 0.0 { 0.0 } else { e.cost };
let nd = d + step;
if nd + 1e-12 < dist[e.to] {
dist[e.to] = nd;
prev_edge[e.to] = eid;
heap.push(HeapItem {
dist: nd,
node: e.to,
});
}
}
}
(dist, prev_edge)
}
pub fn suurballe_vertex_disjoint(
graph: &WeightedGraph,
s: usize,
t: usize,
) -> GraphalgResult<DisjointPaths> {
let n = graph.n;
if s >= n || t >= n {
return Err(GraphalgError::SourceOutOfRange { node: s.max(t), n });
}
if s == t {
return Err(GraphalgError::InvalidParameter(
"source must differ from target".to_string(),
));
}
for u in 0..n {
for &(v, w) in graph.neighbors(u)? {
if w < 0.0 {
return Err(GraphalgError::NegativeWeight {
edge: (u, v),
weight: w,
});
}
}
}
let d = dijkstra_potentials(graph, s)?;
if d[t].is_infinite() {
return Err(GraphalgError::NoSolution(
"target unreachable from source".to_string(),
));
}
let num_nodes = 2 * n;
let mut res = Residual::new(num_nodes);
for v in 0..n {
if v == s || v == t {
continue;
}
res.add(node_in(v), node_out(v), 1, 0.0);
}
for u in 0..n {
for &(v, w) in graph.neighbors(u)? {
if u == v {
continue; }
if d[u].is_infinite() {
continue;
}
if v == s || u == t {
continue;
}
let from = node_out(u);
let to = node_in(v);
let mut rc = w + d[u] - d[v];
if rc < 0.0 {
rc = 0.0;
}
res.add(from, to, 1, rc);
}
}
let src = node_out(s);
let dst = node_in(t);
let (_, prev1) = dijkstra_residual(&res, src, num_nodes);
if prev1[dst] == usize::MAX {
return Err(GraphalgError::NoSolution(
"target unreachable in residual graph".to_string(),
));
}
augment(&mut res, src, dst, &prev1);
let (_, prev2) = dijkstra_residual(&res, src, num_nodes);
if prev2[dst] == usize::MAX {
return Err(GraphalgError::NoSolution(
"no second vertex-disjoint path exists".to_string(),
));
}
augment(&mut res, src, dst, &prev2);
let (path_a, path_b) = decompose_two_paths(&mut res, s, t, n)?;
let total_cost = path_cost(graph, &path_a)? + path_cost(graph, &path_b)?;
Ok(DisjointPaths {
path_a,
path_b,
total_cost,
})
}
fn dijkstra_potentials(graph: &WeightedGraph, src: usize) -> GraphalgResult<Vec<f64>> {
let n = graph.n;
let mut dist = vec![f64::INFINITY; n];
dist[src] = 0.0;
let mut heap: BinaryHeap<HeapItem> = BinaryHeap::new();
heap.push(HeapItem {
dist: 0.0,
node: src,
});
while let Some(HeapItem { dist: dd, node: u }) = heap.pop() {
if dd > dist[u] + 1e-12 {
continue;
}
for &(v, w) in graph.neighbors(u)? {
let nd = dd + w;
if nd + 1e-12 < dist[v] {
dist[v] = nd;
heap.push(HeapItem { dist: nd, node: v });
}
}
}
Ok(dist)
}
fn augment(res: &mut Residual, src: usize, dst: usize, prev_edge: &[usize]) {
let mut v = dst;
while v != src {
let eid = prev_edge[v];
res.edges[eid].cap -= 1;
let rev = res.edges[eid].rev;
res.edges[rev].cap += 1;
v = res.edges[rev].to;
}
}
fn carries_flow(edges: &[ResEdge], eid: usize) -> bool {
if eid % 2 != 0 {
return false;
}
let rev = edges[eid].rev;
edges[rev].cap > 0
}
fn decompose_two_paths(
res: &mut Residual,
s: usize,
t: usize,
n: usize,
) -> GraphalgResult<(Vec<usize>, Vec<usize>)> {
let num_nodes = 2 * n;
let mut used_next: Vec<usize> = vec![0; num_nodes];
let src = node_out(s);
let dst = node_in(t);
let mut paths: Vec<Vec<usize>> = Vec::new();
for _ in 0..2 {
let mut path_vertices: Vec<usize> = vec![s];
let mut cur = src;
let mut guard = 0usize;
let limit = num_nodes * 4 + 8;
loop {
guard += 1;
if guard > limit {
return Err(GraphalgError::NoSolution(
"path decomposition did not terminate".to_string(),
));
}
if cur == dst {
break;
}
let mut advanced = false;
while used_next[cur] < res.adj[cur].len() {
let eid = res.adj[cur][used_next[cur]];
used_next[cur] += 1;
if carries_flow(&res.edges, eid) {
let rev = res.edges[eid].rev;
res.edges[rev].cap -= 1;
let to = res.edges[eid].to;
if cur % 2 == 1 && to % 2 == 0 {
let v = to / 2;
path_vertices.push(v);
}
cur = to;
advanced = true;
break;
}
}
if !advanced {
return Err(GraphalgError::NoSolution(
"incomplete vertex-disjoint path pair".to_string(),
));
}
}
paths.push(path_vertices);
}
let path_a = paths.remove(0);
let path_b = paths.remove(0);
if path_a.first() != Some(&s)
|| path_a.last() != Some(&t)
|| path_b.first() != Some(&s)
|| path_b.last() != Some(&t)
{
return Err(GraphalgError::NoSolution(
"recovered paths are malformed".to_string(),
));
}
Ok((path_a, path_b))
}
fn path_cost(graph: &WeightedGraph, path: &[usize]) -> GraphalgResult<f64> {
let mut total = 0.0;
for w in path.windows(2) {
let (u, v) = (w[0], w[1]);
let mut best: Option<f64> = None;
for &(nb, weight) in graph.neighbors(u)? {
if nb == v {
best = Some(best.map_or(weight, |b: f64| b.min(weight)));
}
}
match best {
Some(c) => total += c,
None => {
return Err(GraphalgError::NoSolution(format!(
"reconstructed edge ({u},{v}) absent from graph"
)));
}
}
}
Ok(total)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::min_cost_flow::successive_shortest_paths::{
MinCostFlowNetwork, min_cost_flow_bounded,
};
fn wgraph(n: usize, edges: &[(usize, usize, f64)]) -> WeightedGraph {
let mut g = WeightedGraph::new(n);
for &(u, v, w) in edges {
g.add_edge(u, v, w).expect("add ok");
}
g
}
fn mcf_vertex_disjoint_cost(
n: usize,
edges: &[(usize, usize, f64)],
s: usize,
t: usize,
) -> Option<f64> {
let mut net = MinCostFlowNetwork::new(2 * n);
for v in 0..n {
if v == s || v == t {
continue;
}
net.add_edge(2 * v, 2 * v + 1, 1.0, 0.0).expect("ok");
}
for &(u, v, w) in edges {
if u == v || v == s || u == t {
continue;
}
net.add_edge(2 * u + 1, 2 * v, 1.0, w).expect("ok");
}
let src = 2 * s + 1;
let dst = 2 * t;
let r = min_cost_flow_bounded(&net, src, dst, 2.0).expect("mcf ok");
if (r.flow - 2.0).abs() < 1e-9 {
Some(r.cost)
} else {
None
}
}
fn assert_vertex_disjoint(dp: &DisjointPaths, s: usize, t: usize) {
use std::collections::HashSet;
let interior_a: HashSet<usize> = dp
.path_a
.iter()
.copied()
.filter(|&v| v != s && v != t)
.collect();
let interior_b: HashSet<usize> = dp
.path_b
.iter()
.copied()
.filter(|&v| v != s && v != t)
.collect();
assert!(
interior_a.is_disjoint(&interior_b),
"paths share an interior vertex: {:?} vs {:?}",
dp.path_a,
dp.path_b
);
let mut seen_a = HashSet::new();
for &v in &dp.path_a {
assert!(seen_a.insert(v), "path_a revisits {v}");
}
let mut seen_b = HashSet::new();
for &v in &dp.path_b {
assert!(seen_b.insert(v), "path_b revisits {v}");
}
assert_eq!(dp.path_a.first(), Some(&s));
assert_eq!(dp.path_a.last(), Some(&t));
assert_eq!(dp.path_b.first(), Some(&s));
assert_eq!(dp.path_b.last(), Some(&t));
}
#[test]
fn two_parallel_paths_diamond() {
let edges = [(0, 1, 1.0), (1, 3, 1.0), (0, 2, 1.0), (2, 3, 1.0)];
let g = wgraph(4, &edges);
let dp = suurballe_vertex_disjoint(&g, 0, 3).expect("ok");
assert_vertex_disjoint(&dp, 0, 3);
assert!((dp.total_cost - 4.0).abs() < 1e-9, "cost={}", dp.total_cost);
}
#[test]
fn min_total_cost_matches_mcf_oracle() {
let edges = [
(0, 1, 1.0),
(1, 4, 1.0),
(0, 2, 2.0),
(2, 4, 2.0),
(0, 3, 3.0),
(3, 4, 3.0),
(1, 2, 1.0),
];
let g = wgraph(5, &edges);
let dp = suurballe_vertex_disjoint(&g, 0, 4).expect("ok");
assert_vertex_disjoint(&dp, 0, 4);
let oracle = mcf_vertex_disjoint_cost(5, &edges, 0, 4).expect("oracle has 2 paths");
assert!(
(dp.total_cost - oracle).abs() < 1e-6,
"suurballe={} oracle={}",
dp.total_cost,
oracle
);
}
#[test]
fn matches_oracle_on_grid() {
let edges = [
(0, 1, 2.0),
(0, 2, 1.0),
(1, 3, 1.0),
(1, 4, 3.0),
(2, 4, 1.0),
(2, 5, 2.0),
(3, 6, 2.0),
(4, 6, 1.0),
(4, 7, 2.0),
(5, 7, 1.0),
(6, 8, 1.0),
(7, 8, 2.0),
];
let n = 9;
let g = wgraph(n, &edges);
let dp = suurballe_vertex_disjoint(&g, 0, 8).expect("ok");
assert_vertex_disjoint(&dp, 0, 8);
let oracle = mcf_vertex_disjoint_cost(n, &edges, 0, 8).expect("oracle");
assert!(
(dp.total_cost - oracle).abs() < 1e-6,
"suurballe={} oracle={}",
dp.total_cost,
oracle
);
}
#[test]
fn fails_when_only_a_bridge_connects() {
let edges = [(0, 1, 1.0), (1, 2, 1.0)];
let g = wgraph(3, &edges);
assert!(matches!(
suurballe_vertex_disjoint(&g, 0, 2),
Err(GraphalgError::NoSolution(_))
));
}
#[test]
fn fails_when_target_unreachable() {
let edges = [(0, 1, 1.0)];
let g = wgraph(3, &edges); assert!(matches!(
suurballe_vertex_disjoint(&g, 0, 2),
Err(GraphalgError::NoSolution(_))
));
}
#[test]
fn fails_with_single_direct_edge_only() {
let edges = [(0, 1, 5.0)];
let g = wgraph(2, &edges);
assert!(matches!(
suurballe_vertex_disjoint(&g, 0, 1),
Err(GraphalgError::NoSolution(_))
));
}
#[test]
fn two_disjoint_with_a_shared_cut_attempt() {
let edges = [
(0, 1, 1.0),
(1, 2, 1.0),
(2, 5, 1.0),
(0, 3, 2.0),
(3, 2, 1.0),
(1, 4, 2.0),
(4, 5, 1.0),
];
let n = 6;
let g = wgraph(n, &edges);
let dp = suurballe_vertex_disjoint(&g, 0, 5).expect("ok");
assert_vertex_disjoint(&dp, 0, 5);
let oracle = mcf_vertex_disjoint_cost(n, &edges, 0, 5).expect("oracle");
assert!(
(dp.total_cost - oracle).abs() < 1e-6,
"suurballe={} oracle={}",
dp.total_cost,
oracle
);
}
#[test]
fn reduced_costs_are_nonnegative() {
let edges = [
(0, 1, 4.0),
(0, 2, 1.0),
(2, 1, 1.0),
(1, 3, 1.0),
(2, 3, 5.0),
];
let g = wgraph(4, &edges);
let d = dijkstra_potentials(&g, 0).expect("ok");
for u in 0..g.n {
if d[u].is_infinite() {
continue;
}
for &(v, w) in g.neighbors(u).expect("nb") {
if d[v].is_infinite() {
continue;
}
let rc = w + d[u] - d[v];
assert!(rc >= -1e-9, "reduced cost {rc} negative on {u}->{v}");
}
}
}
#[test]
fn rejects_negative_weight() {
let mut g = WeightedGraph::new(3);
g.add_edge(0, 1, -2.0).expect("add");
g.add_edge(1, 2, 1.0).expect("add");
assert!(matches!(
suurballe_vertex_disjoint(&g, 0, 2),
Err(GraphalgError::NegativeWeight { .. })
));
}
#[test]
fn rejects_source_equals_target() {
let g = wgraph(3, &[(0, 1, 1.0), (1, 2, 1.0)]);
assert!(matches!(
suurballe_vertex_disjoint(&g, 1, 1),
Err(GraphalgError::InvalidParameter(_))
));
}
#[test]
fn rejects_out_of_range() {
let g = wgraph(3, &[(0, 1, 1.0)]);
assert!(matches!(
suurballe_vertex_disjoint(&g, 0, 9),
Err(GraphalgError::SourceOutOfRange { .. })
));
}
#[test]
fn k1_reduces_to_dijkstra_shortest_path() {
use crate::shortest_path::dijkstra::dijkstra;
let edges = [
(0, 1, 1.0),
(1, 3, 1.0),
(0, 2, 5.0),
(2, 3, 1.0),
(0, 3, 9.0),
];
let g = wgraph(4, &edges);
let sp = dijkstra(&g, 0).expect("dij");
let dp = suurballe_vertex_disjoint(&g, 0, 3).expect("ok");
let ca = path_cost(&g, &dp.path_a).expect("ca");
let cb = path_cost(&g, &dp.path_b).expect("cb");
let cheaper = ca.min(cb);
assert!(
(cheaper - sp.dist[3]).abs() < 1e-9,
"cheaper path {cheaper} != dijkstra {}",
sp.dist[3]
);
}
#[test]
fn total_cost_is_sum_of_path_costs() {
let edges = [(0, 1, 2.0), (1, 3, 3.0), (0, 2, 4.0), (2, 3, 1.0)];
let g = wgraph(4, &edges);
let dp = suurballe_vertex_disjoint(&g, 0, 3).expect("ok");
let ca = path_cost(&g, &dp.path_a).expect("ca");
let cb = path_cost(&g, &dp.path_b).expect("cb");
assert!((dp.total_cost - (ca + cb)).abs() < 1e-12);
assert!(
(dp.total_cost - 10.0).abs() < 1e-9,
"cost={}",
dp.total_cost
);
}
#[test]
fn three_disjoint_available_picks_cheapest_two() {
let edges = [
(0, 1, 1.0),
(1, 7, 1.0),
(0, 2, 2.0),
(2, 7, 2.0),
(0, 3, 3.0),
(3, 7, 3.0),
];
let n = 8;
let g = wgraph(n, &edges);
let dp = suurballe_vertex_disjoint(&g, 0, 7).expect("ok");
assert_vertex_disjoint(&dp, 0, 7);
assert!((dp.total_cost - 6.0).abs() < 1e-9, "cost={}", dp.total_cost);
let oracle = mcf_vertex_disjoint_cost(n, &edges, 0, 7).expect("oracle");
assert!((dp.total_cost - oracle).abs() < 1e-6);
}
}