1use super::{max_independent_set, min_vertex_cover};
13use crate::{
14 algorithm::{
15 dynamic_programming::utils::remap_vertices,
16 nice_tree_decomposition::{get_children, NiceTdNodeType, NiceTreeDecomposition},
17 },
18 utils::convert::{to_hash_map_graph, UndirectedGraph},
19};
20use arboretum_td::{graph::HashMapGraph, solver::Solver, tree_decomposition::TreeDecomposition};
21use bitvec::vec::BitVec;
22use fxhash::FxHashSet;
23use std::collections::{HashMap, HashSet};
24
25pub type DpTable = HashMap<BitVec, DpTableEntry>;
30
31#[derive(Debug, Clone)]
36pub struct DpTableEntry {
37 pub val: i32,
39 pub children: HashSet<(usize, BitVec)>,
41 pub vertex_used: Option<usize>,
43}
44
45impl DpTableEntry {
46 pub fn new_leaf(val: i32, vertex_used: Option<usize>) -> Self {
48 Self {
49 val,
50 children: HashSet::new(),
51 vertex_used,
52 }
53 }
54
55 pub fn new_forget(val: i32, child_id: usize, child_subset: BitVec) -> Self {
57 Self {
58 val,
59 children: vec![(child_id, child_subset)].into_iter().collect(),
60 vertex_used: None,
61 }
62 }
63
64 pub fn new_intro(
66 val: i32,
67 child_id: usize,
68 child_subset: BitVec,
69 vertex_used: Option<usize>,
70 ) -> Self {
71 Self {
72 val,
73 children: vec![(child_id, child_subset)].into_iter().collect(),
74 vertex_used,
75 }
76 }
77
78 pub fn new_join(val: i32, left_id: usize, right_id: usize, subset: BitVec) -> Self {
80 Self {
81 val,
82 children: vec![(left_id, subset.clone()), (right_id, subset)]
83 .into_iter()
84 .collect(),
85 vertex_used: None,
86 }
87 }
88}
89
90type LeafNodeHandler = fn(graph: &HashMapGraph, id: usize, tables: &mut [DpTable], vertex: usize);
91
92type JoinNodeHandler = fn(
93 graph: &HashMapGraph,
94 id: usize,
95 left_child_id: usize,
96 right_child_id: usize,
97 tables: &mut [DpTable],
98 vertex_set: &FxHashSet<usize>,
99);
100
101type ForgetNodeHandler = fn(
102 graph: &HashMapGraph,
103 id: usize,
104 child_id: usize,
105 tables: &mut [DpTable],
106 vertex_set: &FxHashSet<usize>,
107 forgotten_vertex: usize,
108);
109
110type IntroduceNodeHandler = fn(
111 graph: &HashMapGraph,
112 id: usize,
113 child_id: usize,
114 tables: &mut [DpTable],
115 vertex_set: &FxHashSet<usize>,
116 child_vertex_set: &FxHashSet<usize>,
117 introduced_vertex: usize,
118);
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum DpObjective {
123 Minimize,
125 Maximize,
127}
128
129pub struct DpProblem {
132 pub objective: DpObjective,
134 pub handle_leaf_node: LeafNodeHandler,
136 pub handle_join_node: JoinNodeHandler,
138 pub handle_forget_node: ForgetNodeHandler,
140 pub handle_introduce_node: IntroduceNodeHandler,
142}
143
144impl DpProblem {
145 pub fn max_independent_set() -> DpProblem {
147 DpProblem {
148 objective: DpObjective::Maximize,
149 handle_leaf_node: max_independent_set::handle_leaf_node,
150 handle_join_node: max_independent_set::handle_join_node,
151 handle_forget_node: max_independent_set::handle_forget_node,
152 handle_introduce_node: max_independent_set::handle_introduce_node,
153 }
154 }
155
156 pub fn min_vertex_cover() -> DpProblem {
158 DpProblem {
159 objective: DpObjective::Minimize,
160 handle_leaf_node: min_vertex_cover::handle_leaf_node,
161 handle_join_node: min_vertex_cover::handle_join_node,
162 handle_forget_node: min_vertex_cover::handle_forget_node,
163 handle_introduce_node: min_vertex_cover::handle_introduce_node,
164 }
165 }
166}
167
168pub fn dp_solve(
177 graph: &UndirectedGraph,
178 td: Option<TreeDecomposition>,
179 prob: &DpProblem,
180) -> HashSet<usize> {
181 dp_solve_hashmap_graph(&to_hash_map_graph(graph), td, prob)
182}
183
184pub fn dp_solve_hashmap_graph(
186 graph: &HashMapGraph,
187 td: Option<TreeDecomposition>,
188 prob: &DpProblem,
189) -> HashSet<usize> {
190 let (graph, mapping) = remap_vertices(graph);
191 let td = td.unwrap_or_else(|| Solver::auto(&graph).solve(&graph));
192 let nice_td = NiceTreeDecomposition::new(td);
193
194 assert!(nice_td.td.verify(&graph).is_ok());
195
196 let mut tables: Vec<_> = vec![DpTable::new(); nice_td.td.bags().len()];
197 let root = nice_td.td.root.unwrap();
198
199 dp_solve_rec(
200 &nice_td.td,
201 &graph,
202 prob,
203 root,
204 usize::max_value(),
205 &nice_td.mapping,
206 &mut tables,
207 );
208
209 let mut sol = HashSet::new();
210 dp_read_solution_from_table(prob.objective, &tables, root, &mut sol);
211
212 sol.iter()
213 .map(|v| mapping.get(v).unwrap())
214 .copied()
215 .collect()
216}
217
218fn dp_solve_rec(
219 td: &TreeDecomposition,
220 graph: &HashMapGraph,
221 prob: &DpProblem,
222 id: usize,
223 parent_id: usize,
224 mapping: &[NiceTdNodeType],
225 tables: &mut Vec<DpTable>,
226) {
227 let children = get_children(td, id, parent_id);
228
229 for child_id in &children {
230 dp_solve_rec(td, graph, prob, *child_id, id, mapping, tables);
231 }
232
233 let vertex_set = &td.bags()[id].vertex_set;
234
235 match mapping[id] {
236 NiceTdNodeType::Leaf => {
237 let vertex = vertex_set.iter().next().unwrap();
238 (prob.handle_leaf_node)(graph, id, tables, *vertex);
239 }
240 NiceTdNodeType::Join => {
241 let mut it = children.iter();
242 let left_child_id = *it.next().unwrap();
243 let right_child_id = *it.next().unwrap();
244 (prob.handle_join_node)(graph, id, left_child_id, right_child_id, tables, vertex_set);
245 }
246 NiceTdNodeType::Forget(v) => {
247 let child_id = *children.iter().next().unwrap();
248 (prob.handle_forget_node)(graph, id, child_id, tables, vertex_set, v);
249 }
250 NiceTdNodeType::Introduce(v) => {
251 let child_id = *children.iter().next().unwrap();
252 let child_vertex_set = &td.bags()[child_id].vertex_set;
253 (prob.handle_introduce_node)(
254 graph,
255 id,
256 child_id,
257 tables,
258 vertex_set,
259 child_vertex_set,
260 v,
261 );
262 }
263 }
264}
265
266fn dp_read_solution_from_table(
267 objective: DpObjective,
268 tables: &[DpTable],
269 root: usize,
270 sol: &mut HashSet<usize>,
271) {
272 let root_entry = match objective {
273 DpObjective::Maximize => tables[root].values().max_by(|e1, e2| e1.val.cmp(&e2.val)),
274 DpObjective::Minimize => tables[root].values().min_by(|e1, e2| e1.val.cmp(&e2.val)),
275 }
276 .unwrap();
277 dp_read_solution_from_table_rec(tables, root_entry, sol);
278}
279
280fn dp_read_solution_from_table_rec(
281 tables: &[DpTable],
282 entry: &DpTableEntry,
283 sol: &mut HashSet<usize>,
284) {
285 if let Some(v) = entry.vertex_used {
286 sol.insert(v);
287 }
288
289 for (v, subset) in &entry.children {
290 dp_read_solution_from_table_rec(tables, tables[*v].get(subset).unwrap(), sol);
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::dp_solve_hashmap_graph;
297 use crate::{
298 algorithm::dynamic_programming::{
299 solve::{remap_vertices, DpProblem},
300 utils::init_bit_vec,
301 },
302 generation::erdos_renyi::generate_hash_map_graph,
303 utils::{
304 max_independent_set::{brute_force_max_independent_set, is_independent_set},
305 min_vertex_cover::{brute_force_min_vertex_cover, is_vertex_cover},
306 },
307 };
308 use arboretum_td::graph::{BaseGraph, HashMapGraph, MutableGraph};
309 use rand::{rngs::StdRng, Rng, SeedableRng};
310 use std::collections::HashSet;
311
312 fn solve_max_independent_set(graph: &HashMapGraph) -> HashSet<usize> {
313 dp_solve_hashmap_graph(graph, None, &DpProblem::max_independent_set())
314 }
315
316 fn solve_min_vertex_cover(graph: &HashMapGraph) -> HashSet<usize> {
317 dp_solve_hashmap_graph(graph, None, &DpProblem::min_vertex_cover())
318 }
319
320 #[test]
321 fn remapping() {
322 let mut graph = HashMapGraph::new();
323 graph.add_vertex(10);
324 graph.add_vertex(11);
325 graph.add_vertex(12);
326 graph.add_edge(10, 11);
327
328 let (remapped_graph, _) = remap_vertices(&graph);
329
330 assert!(remapped_graph.order() == graph.order());
331 assert!(remapped_graph.has_vertex(0));
332 assert!(remapped_graph.has_vertex(1));
333 assert!(remapped_graph.has_vertex(2));
334 assert!(remapped_graph.has_edge(0, 1) ^ remapped_graph.has_edge(1, 2));
335 }
336
337 #[test]
338 fn large_bit_vec() {
339 let mut bit_vec = init_bit_vec(65);
340 bit_vec.set(127, true);
341 }
342
343 #[test]
344 fn max_independent_set_isolated() {
345 for n in 1..10 {
346 let graph = generate_hash_map_graph(n, 0., Some(n as u64));
347
348 let sol = solve_max_independent_set(&graph);
349
350 assert!(sol.len() == n);
351 }
352 }
353
354 #[test]
355 fn max_independent_set_clique() {
356 for n in 1..10 {
357 let graph = generate_hash_map_graph(n, 1., Some(n as u64));
358 let sol = solve_max_independent_set(&graph);
359
360 assert!(sol.len() == 1);
361 }
362 }
363
364 #[test]
365 fn max_independent_set_random() {
366 let seed = [1; 32];
367 let mut rng = StdRng::from_seed(seed);
368
369 for i in 0..30 {
370 let graph = generate_hash_map_graph(
371 rng.gen_range(1..15),
372 rng.gen_range(0.05..0.1),
373 Some(i as u64),
374 );
375 let sol = solve_max_independent_set(&graph);
376
377 assert!(is_independent_set(&graph, &sol), "{:?} {:?}", graph, sol);
378
379 let sol2 = brute_force_max_independent_set(&graph);
380 assert!(sol.len() == sol2.len());
381 }
382 }
383
384 #[test]
385 fn min_vertex_cover_isolated() {
386 for n in 1..10 {
387 let graph = generate_hash_map_graph(n, 0., Some(n as u64));
388 let sol = solve_min_vertex_cover(&graph);
389
390 assert!(sol.is_empty());
391 }
392 }
393
394 #[test]
395 fn min_vertex_cover_clique() {
396 for n in 1..10 {
397 let graph = generate_hash_map_graph(n, 1., Some(n as u64));
398 let sol = solve_min_vertex_cover(&graph);
399
400 assert!(sol.len() == graph.order() - 1);
401 }
402 }
403
404 #[test]
405 fn min_vertex_cover_random() {
406 let seed = [2; 32];
407 let mut rng = StdRng::from_seed(seed);
408
409 for i in 0..30 {
410 let graph = generate_hash_map_graph(
411 rng.gen_range(1..15),
412 rng.gen_range(0.2..0.5),
413 Some(i as u64),
414 );
415 let sol = solve_min_vertex_cover(&graph);
416
417 assert!(is_vertex_cover(&graph, &sol));
418
419 let sol2 = brute_force_min_vertex_cover(&graph);
420 assert!(sol.len() == sol2.len());
421 }
422 }
423}