competitive_programming_rs/graph/
cost_scaling_push_relabel.rs

1pub mod cost_scaling_push_relabel {
2    use std::collections::VecDeque;
3
4    type Flow = i64;
5    type Cost = i64;
6
7    const INF_POTENTIAL: f64 = 1e10;
8    const SCALING_FACTOR: f64 = 2.0;
9
10    #[derive(Clone)]
11    struct Edge {
12        to: usize,
13        rev: usize,
14        flow: Flow,
15        capacity: Flow,
16        cost: Cost,
17        is_rev: bool,
18    }
19
20    impl Edge {
21        fn residual(&self) -> Flow {
22            self.capacity - self.flow
23        }
24    }
25
26    #[derive(Clone)]
27    struct Node {
28        excess_flow: Flow,
29        potential: f64,
30    }
31
32    pub struct Solver {
33        nodes: Vec<Node>,
34        graph: Vec<Vec<Edge>>,
35        active_nodes: VecDeque<usize>,
36
37        cost_scaling_factor: f64,
38        eps: f64,
39    }
40
41    impl Solver {
42        pub fn new(num_nodes: usize) -> Self {
43            Self {
44                nodes: vec![
45                    Node {
46                        excess_flow: 0,
47                        potential: 0.0
48                    };
49                    num_nodes
50                ],
51                graph: vec![vec![]; num_nodes],
52                active_nodes: VecDeque::new(),
53                eps: 1.0,
54                cost_scaling_factor: num_nodes as f64 * 2.0,
55            }
56        }
57        pub fn add_edge(&mut self, from: usize, to: usize, capacity: Flow, cost: Cost) {
58            let rev = self.graph[to].len();
59            self.graph[from].push(Edge {
60                to,
61                rev,
62                flow: 0,
63                capacity,
64                cost,
65                is_rev: false,
66            });
67
68            let rev = self.graph[from].len() - 1;
69            self.graph[to].push(Edge {
70                to: from,
71                rev,
72                flow: capacity,
73                capacity,
74                cost: -cost,
75                is_rev: true,
76            });
77
78            self.eps = max(self.eps, cost.abs() as f64 * self.cost_scaling_factor);
79        }
80
81        pub fn solve(&mut self, source: usize, sink: usize, flow: Flow) -> Flow {
82            self.nodes[source].excess_flow = flow;
83            self.nodes[sink].excess_flow = -flow;
84
85            while self.eps > 1.0 {
86                for node in 0..self.nodes.len() {
87                    for edge in 0..self.graph[node].len() {
88                        if self.graph[node][edge].is_rev {
89                            continue;
90                        }
91
92                        let reduced_cost = self.calc_reduced_cost(node, edge);
93                        if reduced_cost < 0.0 && self.graph[node][edge].residual() > 0 {
94                            let f = self.graph[node][edge].residual();
95                            self.push_flow(node, edge, f);
96                        }
97                        if reduced_cost > 0.0 && self.graph[node][edge].flow > 0 {
98                            let f = -self.graph[node][edge].flow;
99                            self.push_flow(node, edge, f);
100                        }
101                    }
102                }
103
104                self.get_active_nodes();
105                while let Some(node) = self.active_nodes.pop_front() {
106                    while self.nodes[node].excess_flow > 0 {
107                        if !self.push(node) {
108                            self.relabel(node);
109                            self.active_nodes.push_back(node);
110                            break;
111                        }
112                    }
113                }
114
115                self.eps = max(1.0, self.eps / SCALING_FACTOR);
116            }
117
118            let mut total_cost = 0;
119            for e in self.graph.iter().flat_map(|g| g.iter()) {
120                if e.is_rev {
121                    continue;
122                }
123                total_cost += e.flow * e.cost;
124            }
125            total_cost
126        }
127
128        fn push_flow(&mut self, node: usize, edge: usize, flow: Flow) {
129            self.graph[node][edge].flow += flow;
130
131            let to = self.graph[node][edge].to;
132            let rev = self.graph[node][edge].rev;
133            let from = node;
134
135            self.graph[to][rev].flow -= flow;
136            self.nodes[from].excess_flow -= flow;
137            self.nodes[to].excess_flow += flow;
138        }
139        fn calc_reduced_cost(&self, node: usize, edge: usize) -> f64 {
140            let cost = self.graph[node][edge].cost;
141            let from = node;
142            let to = self.graph[node][edge].to;
143            cost as f64 * self.cost_scaling_factor - self.nodes[from].potential
144                + self.nodes[to].potential
145        }
146
147        fn get_active_nodes(&mut self) {
148            for u in 0..self.nodes.len() {
149                if self.nodes[u].excess_flow > 0 {
150                    self.active_nodes.push_back(u);
151                }
152            }
153        }
154
155        fn push(&mut self, from: usize) -> bool {
156            if self.nodes[from].excess_flow == 0 {
157                return false;
158            }
159            assert!(self.nodes[from].excess_flow > 0);
160            for i in (0..self.graph[from].len()).rev() {
161                if self.graph[from][i].residual() == 0 {
162                    continue;
163                }
164                let reduced_cost = self.calc_reduced_cost(from, i);
165
166                if reduced_cost < 0.0 {
167                    let flow = min(self.graph[from][i].residual(), self.nodes[from].excess_flow);
168                    self.push_flow(from, i, flow);
169
170                    let to = self.graph[from][i].to;
171                    if self.nodes[to].excess_flow > 0 && self.nodes[to].excess_flow <= flow {
172                        self.active_nodes.push_back(to);
173                    }
174                    return true;
175                }
176            }
177            false
178        }
179
180        fn relabel(&mut self, from: usize) {
181            let mut min_potential = INF_POTENTIAL;
182            for e in self.graph[from].iter() {
183                if e.residual() > 0 {
184                    min_potential = min(
185                        min_potential,
186                        e.cost as f64 * self.cost_scaling_factor
187                            + self.nodes[e.to].potential
188                            + self.eps,
189                    );
190                }
191            }
192
193            assert!(min_potential < INF_POTENTIAL);
194            self.nodes[from].potential = min_potential;
195        }
196    }
197
198    fn min<T: PartialOrd>(a: T, b: T) -> T {
199        if a > b {
200            b
201        } else {
202            a
203        }
204    }
205    fn max<T: PartialOrd>(a: T, b: T) -> T {
206        if a < b {
207            b
208        } else {
209            a
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use crate::graph::cost_scaling_push_relabel::cost_scaling_push_relabel;
217    use crate::graph::min_cost_flow::primal_dual;
218    use crate::utils::test_helper::Tester;
219
220    #[test]
221    fn solve_grl_6_b() {
222        let tester = Tester::new("./assets/GRL_6_B/in/", "./assets/GRL_6_B/out/");
223        tester.test_solution(|sc| {
224            let v: usize = sc.read();
225            let e: usize = sc.read();
226            let f: i64 = sc.read();
227
228            let mut solver = cost_scaling_push_relabel::Solver::new(v);
229            let mut verify = primal_dual::MinimumCostFlowSolver::new(v);
230            for _ in 0..e {
231                let u: usize = sc.read();
232                let v: usize = sc.read();
233                let c: i64 = sc.read();
234                let d: i64 = sc.read();
235                solver.add_edge(u, v, c, d);
236                verify.add_edge(u, v, c, d);
237            }
238
239            match verify.solve(0, v - 1, f) {
240                Some(ans) => {
241                    sc.write(format!("{}\n", ans));
242                    assert_eq!(ans, solver.solve(0, v - 1, f));
243                }
244                _ => {
245                    sc.write("-1\n");
246                }
247            }
248        });
249    }
250}