Skip to main content

proof_engine/graph/
flow.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use glam::{Vec2, Vec4};
3use super::graph_core::{Graph, GraphKind, NodeId, EdgeId};
4
5#[derive(Debug, Clone)]
6pub struct FlowNetwork {
7    pub graph: Graph<(), f32>,
8    pub source: NodeId,
9    pub sink: NodeId,
10}
11
12impl FlowNetwork {
13    /// Create a flow network. Edge data stores capacity.
14    pub fn new(source: NodeId, sink: NodeId, graph: Graph<(), f32>) -> Self {
15        Self { graph, source, sink }
16    }
17
18    /// Build a flow network from scratch.
19    pub fn builder() -> FlowNetworkBuilder {
20        FlowNetworkBuilder {
21            graph: Graph::new(GraphKind::Directed),
22            source: None,
23            sink: None,
24        }
25    }
26
27    pub fn capacity(&self, edge: EdgeId) -> f32 {
28        self.graph.get_edge(edge).map(|e| e.data).unwrap_or(0.0)
29    }
30}
31
32pub struct FlowNetworkBuilder {
33    graph: Graph<(), f32>,
34    source: Option<NodeId>,
35    sink: Option<NodeId>,
36}
37
38impl FlowNetworkBuilder {
39    pub fn add_node(&mut self) -> NodeId {
40        self.graph.add_node(())
41    }
42
43    pub fn set_source(&mut self, id: NodeId) {
44        self.source = Some(id);
45    }
46
47    pub fn set_sink(&mut self, id: NodeId) {
48        self.sink = Some(id);
49    }
50
51    pub fn add_capacity(&mut self, from: NodeId, to: NodeId, capacity: f32) -> EdgeId {
52        self.graph.add_edge(from, to, capacity)
53    }
54
55    pub fn build(self) -> FlowNetwork {
56        FlowNetwork {
57            graph: self.graph,
58            source: self.source.expect("source must be set"),
59            sink: self.sink.expect("sink must be set"),
60        }
61    }
62}
63
64#[derive(Debug, Clone)]
65pub struct FlowResult {
66    pub max_flow: f32,
67    pub edge_flows: HashMap<EdgeId, f32>,
68}
69
70/// Ford-Fulkerson max flow using BFS (Edmonds-Karp).
71pub fn ford_fulkerson(network: &FlowNetwork, source: NodeId, sink: NodeId) -> FlowResult {
72    let node_ids = network.graph.node_ids();
73    let edge_ids = network.graph.edge_ids();
74
75    // Build residual capacity structure
76    // For each edge, track forward capacity and flow
77    let mut flow: HashMap<EdgeId, f32> = HashMap::new();
78    for &eid in &edge_ids {
79        flow.insert(eid, 0.0);
80    }
81
82    // Build adjacency with edge info for residual graph traversal
83    // We need both forward and backward edges
84    // residual_adj[node] = Vec<(neighbor, edge_id, is_forward)>
85    let mut residual_adj: HashMap<NodeId, Vec<(NodeId, EdgeId, bool)>> = HashMap::new();
86    for &nid in &node_ids {
87        residual_adj.insert(nid, Vec::new());
88    }
89    for &eid in &edge_ids {
90        if let Some(edge) = network.graph.get_edge(eid) {
91            residual_adj.get_mut(&edge.from).unwrap().push((edge.to, eid, true));
92            // Backward edge
93            if !residual_adj.contains_key(&edge.to) {
94                residual_adj.insert(edge.to, Vec::new());
95            }
96            residual_adj.get_mut(&edge.to).unwrap().push((edge.from, eid, false));
97        }
98    }
99
100    let mut total_flow = 0.0f32;
101
102    // BFS to find augmenting path
103    loop {
104        // BFS
105        let mut visited: HashMap<NodeId, (NodeId, EdgeId, bool)> = HashMap::new();
106        let mut queue = VecDeque::new();
107        queue.push_back(source);
108        let mut found_sink = false;
109
110        while let Some(node) = queue.pop_front() {
111            if node == sink {
112                found_sink = true;
113                break;
114            }
115            for &(nbr, eid, is_forward) in residual_adj.get(&node).unwrap_or(&Vec::new()) {
116                if visited.contains_key(&nbr) || nbr == source {
117                    continue;
118                }
119                let residual = if is_forward {
120                    let cap = network.graph.get_edge(eid).map(|e| e.data).unwrap_or(0.0);
121                    cap - flow.get(&eid).copied().unwrap_or(0.0)
122                } else {
123                    flow.get(&eid).copied().unwrap_or(0.0)
124                };
125                if residual > 0.0 {
126                    visited.insert(nbr, (node, eid, is_forward));
127                    queue.push_back(nbr);
128                }
129            }
130        }
131
132        if !found_sink { break; }
133
134        // Find bottleneck
135        let mut bottleneck = f32::INFINITY;
136        let mut current = sink;
137        while current != source {
138            let (prev, eid, is_forward) = visited[&current];
139            let residual = if is_forward {
140                let cap = network.graph.get_edge(eid).map(|e| e.data).unwrap_or(0.0);
141                cap - flow.get(&eid).copied().unwrap_or(0.0)
142            } else {
143                flow.get(&eid).copied().unwrap_or(0.0)
144            };
145            bottleneck = bottleneck.min(residual);
146            current = prev;
147        }
148
149        // Update flows
150        current = sink;
151        while current != source {
152            let (prev, eid, is_forward) = visited[&current];
153            if is_forward {
154                *flow.get_mut(&eid).unwrap() += bottleneck;
155            } else {
156                *flow.get_mut(&eid).unwrap() -= bottleneck;
157            }
158            current = prev;
159        }
160
161        total_flow += bottleneck;
162    }
163
164    FlowResult {
165        max_flow: total_flow,
166        edge_flows: flow,
167    }
168}
169
170impl FlowResult {
171    /// Extract min-cut: returns (S, T) partition where S contains the source.
172    /// S = nodes reachable from source in the residual graph after max flow.
173    pub fn min_cut<N, E>(&self, network: &FlowNetwork) -> (HashSet<NodeId>, HashSet<NodeId>) {
174        let node_ids = network.graph.node_ids();
175
176        // BFS from source in residual graph
177        let mut reachable = HashSet::new();
178        let mut queue = VecDeque::new();
179        reachable.insert(network.source);
180        queue.push_back(network.source);
181
182        while let Some(node) = queue.pop_front() {
183            for (nbr, eid) in network.graph.neighbor_edges(node) {
184                let cap = network.graph.get_edge(eid).map(|e| e.data).unwrap_or(0.0);
185                let f = self.edge_flows.get(&eid).copied().unwrap_or(0.0);
186                if cap - f > 1e-9 && !reachable.contains(&nbr) {
187                    reachable.insert(nbr);
188                    queue.push_back(nbr);
189                }
190            }
191        }
192
193        let s_set = reachable;
194        let t_set: HashSet<NodeId> = node_ids.into_iter().filter(|n| !s_set.contains(n)).collect();
195        (s_set, t_set)
196    }
197}
198
199/// Visualizes flow as particle speed along edges.
200pub struct FlowVisualizer {
201    pub base_speed: f32,
202    pub max_speed: f32,
203    pub particle_color: Vec4,
204}
205
206impl FlowVisualizer {
207    pub fn new() -> Self {
208        Self {
209            base_speed: 1.0,
210            max_speed: 10.0,
211            particle_color: Vec4::new(0.3, 0.6, 1.0, 1.0),
212        }
213    }
214
215    /// Generate particle data for each edge based on flow.
216    /// Returns Vec of (edge_id, start, end, speed, color).
217    pub fn generate_particles(&self, network: &FlowNetwork, result: &FlowResult) -> Vec<FlowParticle> {
218        let mut particles = Vec::new();
219        let max_flow = result.max_flow.max(1e-6);
220
221        for (&eid, &flow) in &result.edge_flows {
222            if flow <= 0.0 { continue; }
223            if let Some(edge) = network.graph.get_edge(eid) {
224                let start = network.graph.node_position(edge.from);
225                let end = network.graph.node_position(edge.to);
226                let ratio = flow / max_flow;
227                let speed = self.base_speed + ratio * (self.max_speed - self.base_speed);
228                let alpha = 0.3 + 0.7 * ratio;
229                let color = Vec4::new(
230                    self.particle_color.x,
231                    self.particle_color.y,
232                    self.particle_color.z,
233                    alpha,
234                );
235                particles.push(FlowParticle {
236                    edge_id: eid,
237                    start,
238                    end,
239                    speed,
240                    color,
241                    flow,
242                });
243            }
244        }
245        particles
246    }
247}
248
249#[derive(Debug, Clone)]
250pub struct FlowParticle {
251    pub edge_id: EdgeId,
252    pub start: Vec2,
253    pub end: Vec2,
254    pub speed: f32,
255    pub color: Vec4,
256    pub flow: f32,
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn simple_network() -> FlowNetwork {
264        // s -> a -> t with cap 10
265        // s -> b -> t with cap 5
266        // a -> b with cap 3
267        let mut b = FlowNetwork::builder();
268        let s = b.add_node();
269        let a = b.add_node();
270        let bb = b.add_node();
271        let t = b.add_node();
272        b.set_source(s);
273        b.set_sink(t);
274        b.add_capacity(s, a, 10.0);
275        b.add_capacity(s, bb, 5.0);
276        b.add_capacity(a, t, 8.0);
277        b.add_capacity(bb, t, 7.0);
278        b.add_capacity(a, bb, 3.0);
279        b.build()
280    }
281
282    #[test]
283    fn test_max_flow_simple() {
284        let net = simple_network();
285        let result = ford_fulkerson(&net, net.source, net.sink);
286        // Max flow = min(supply from s, demand to t)
287        // s can push 15 total (10 + 5), t can accept 15 (8 + 7)
288        // a->t is bottleneck: 10 from s->a, only 8 to t, but 3 can go a->b->t
289        // Total: 8 + min(5+3, 7) = 8 + 7 = 15
290        assert!((result.max_flow - 15.0).abs() < 0.01);
291    }
292
293    #[test]
294    fn test_max_flow_single_edge() {
295        let mut b = FlowNetwork::builder();
296        let s = b.add_node();
297        let t = b.add_node();
298        b.set_source(s);
299        b.set_sink(t);
300        b.add_capacity(s, t, 42.0);
301        let net = b.build();
302        let result = ford_fulkerson(&net, net.source, net.sink);
303        assert!((result.max_flow - 42.0).abs() < 0.01);
304    }
305
306    #[test]
307    fn test_max_flow_no_path() {
308        let mut b = FlowNetwork::builder();
309        let s = b.add_node();
310        let t = b.add_node();
311        b.set_source(s);
312        b.set_sink(t);
313        // No edge
314        let net = b.build();
315        let result = ford_fulkerson(&net, net.source, net.sink);
316        assert_eq!(result.max_flow, 0.0);
317    }
318
319    #[test]
320    fn test_min_cut() {
321        let mut b = FlowNetwork::builder();
322        let s = b.add_node();
323        let a = b.add_node();
324        let t = b.add_node();
325        b.set_source(s);
326        b.set_sink(t);
327        b.add_capacity(s, a, 5.0);
328        b.add_capacity(a, t, 3.0);
329        let net = b.build();
330        let result = ford_fulkerson(&net, net.source, net.sink);
331        assert!((result.max_flow - 3.0).abs() < 0.01);
332        let (s_set, t_set) = result.min_cut::<(), f32>(&net);
333        assert!(s_set.contains(&s));
334        assert!(t_set.contains(&t));
335    }
336
337    #[test]
338    fn test_flow_visualizer() {
339        let mut b = FlowNetwork::builder();
340        let s = b.add_node();
341        let t = b.add_node();
342        b.set_source(s);
343        b.set_sink(t);
344        b.add_capacity(s, t, 10.0);
345        let net = b.build();
346        let result = ford_fulkerson(&net, net.source, net.sink);
347        let viz = FlowVisualizer::new();
348        let particles = viz.generate_particles(&net, &result);
349        assert_eq!(particles.len(), 1);
350        assert!(particles[0].speed > 0.0);
351    }
352
353    #[test]
354    fn test_parallel_paths() {
355        let mut b = FlowNetwork::builder();
356        let s = b.add_node();
357        let t = b.add_node();
358        b.set_source(s);
359        b.set_sink(t);
360        b.add_capacity(s, t, 5.0);
361        b.add_capacity(s, t, 3.0);
362        let net = b.build();
363        let result = ford_fulkerson(&net, net.source, net.sink);
364        assert!((result.max_flow - 8.0).abs() < 0.01);
365    }
366}