use super::common::{GraphView, NodeId};
use std::collections::{HashMap, VecDeque};
pub struct FlowResult {
pub max_flow: f64,
}
pub fn edmonds_karp(view: &GraphView, source: NodeId, sink: NodeId) -> Option<FlowResult> {
let s_idx = *view.node_to_index.get(&source)?;
let t_idx = *view.node_to_index.get(&sink)?;
let n = view.node_count;
let mut residual: Vec<HashMap<usize, f64>> = vec![HashMap::new(); n];
for u in 0..n {
let edges = view.successors(u);
let weights = view.weights(u);
for (i, &v) in edges.iter().enumerate() {
let cap = if let Some(w) = weights { w[i] } else { 1.0 };
*residual[u].entry(v).or_insert(0.0) += cap;
residual[v].entry(u).or_insert(0.0);
}
}
let mut total_flow = 0.0;
loop {
let mut parent = vec![None; n];
let mut queue = VecDeque::new();
queue.push_back(s_idx);
let mut found_path = false;
let mut visited = vec![false; n];
visited[s_idx] = true;
while let Some(u) = queue.pop_front() {
if u == t_idx {
found_path = true;
break;
}
for (&v, &cap) in &residual[u] {
if !visited[v] && cap > 1e-9 {
visited[v] = true;
parent[v] = Some(u);
queue.push_back(v);
}
}
}
if !found_path {
break;
}
let mut path_flow = f64::INFINITY;
let mut curr = t_idx;
while curr != s_idx {
let prev = parent[curr].unwrap();
let cap = residual[prev][&curr];
if cap < path_flow {
path_flow = cap;
}
curr = prev;
}
curr = t_idx;
while curr != s_idx {
let prev = parent[curr].unwrap();
*residual[prev].get_mut(&curr).unwrap() -= path_flow;
*residual[curr].get_mut(&prev).unwrap() += path_flow;
curr = prev;
}
total_flow += path_flow;
}
Some(FlowResult { max_flow: total_flow })
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_edmonds_karp() {
let node_count = 4;
let index_to_node = vec![1, 2, 3, 4];
let mut node_to_index = HashMap::new();
for (i, &id) in index_to_node.iter().enumerate() { node_to_index.insert(id, i); }
let mut outgoing = vec![vec![]; 4];
let mut weights = vec![vec![]; 4];
outgoing[0].push(1); weights[0].push(100.0);
outgoing[0].push(2); weights[0].push(50.0);
outgoing[1].push(2); weights[1].push(50.0);
outgoing[1].push(3); weights[1].push(50.0);
outgoing[2].push(3); weights[2].push(100.0);
let view = GraphView::from_adjacency_list(
node_count,
index_to_node,
node_to_index,
outgoing,
vec![vec![]; 4],
Some(weights),
);
let result = edmonds_karp(&view, 1, 4).unwrap();
assert_eq!(result.max_flow, 150.0);
}
}