graph_algo_ptas/algorithm/dynamic_programming/
solve.rs

1//! Contains data structures and algorithms for dynamic programming on tree decompositions.
2//!
3//! ```rust
4//! use graph_algo_ptas::generation::erdos_renyi::generate_petgraph;
5//! use graph_algo_ptas::algorithm::dynamic_programming::solve::dp_solve;
6//! use graph_algo_ptas::algorithm::dynamic_programming::solve::DpProblem;
7//!
8//! let graph = generate_petgraph(20, 0.1, None);
9//! let sol = dp_solve(&graph, None, &DpProblem::max_independent_set());
10//! ```
11
12use super::{max_independent_set, min_vertex_cover};
13use crate::{
14    algorithm::{
15        dynamic_programming::utils::remap_vertices,
16        nice_tree_decomposition::{get_children, NiceTdNodeType, NiceTreeDecomposition},
17    },
18    utils::convert::{to_hash_map_graph, UndirectedGraph},
19};
20use arboretum_td::{graph::HashMapGraph, solver::Solver, tree_decomposition::TreeDecomposition};
21use bitvec::vec::BitVec;
22use fxhash::FxHashSet;
23use std::collections::{HashMap, HashSet};
24
25/// For each bag in the tree decomposition a table is calculated.
26/// Such a table is represented by `HashMap`.
27///
28/// The `BitVec` key represents the subset to which the table entry belongs
29pub type DpTable = HashMap<BitVec, DpTableEntry>;
30
31/// Represents a single entry in a dynamic programming table.
32///
33/// Contains the value of the entry and additional information needed for
34/// retrieving the actual solution at the end of the algorithm.
35#[derive(Debug, Clone)]
36pub struct DpTableEntry {
37    /// Value of the table entry. Its meaning depends on the problem to be solved.
38    pub val: i32,
39    /// References to table entries of child nodes.
40    pub children: HashSet<(usize, BitVec)>,
41    /// The vertex which is used for calculating the table entry.
42    pub vertex_used: Option<usize>,
43}
44
45impl DpTableEntry {
46    /// Create a table entry for a Leaf node.
47    pub fn new_leaf(val: i32, vertex_used: Option<usize>) -> Self {
48        Self {
49            val,
50            children: HashSet::new(),
51            vertex_used,
52        }
53    }
54
55    /// Create a table entry for a Forget node.
56    pub fn new_forget(val: i32, child_id: usize, child_subset: BitVec) -> Self {
57        Self {
58            val,
59            children: vec![(child_id, child_subset)].into_iter().collect(),
60            vertex_used: None,
61        }
62    }
63
64    /// Create a table entry for an Introduce node.
65    pub fn new_intro(
66        val: i32,
67        child_id: usize,
68        child_subset: BitVec,
69        vertex_used: Option<usize>,
70    ) -> Self {
71        Self {
72            val,
73            children: vec![(child_id, child_subset)].into_iter().collect(),
74            vertex_used,
75        }
76    }
77
78    /// Create a table entry for a Join node.
79    pub fn new_join(val: i32, left_id: usize, right_id: usize, subset: BitVec) -> Self {
80        Self {
81            val,
82            children: vec![(left_id, subset.clone()), (right_id, subset)]
83                .into_iter()
84                .collect(),
85            vertex_used: None,
86        }
87    }
88}
89
90type LeafNodeHandler = fn(graph: &HashMapGraph, id: usize, tables: &mut [DpTable], vertex: usize);
91
92type JoinNodeHandler = fn(
93    graph: &HashMapGraph,
94    id: usize,
95    left_child_id: usize,
96    right_child_id: usize,
97    tables: &mut [DpTable],
98    vertex_set: &FxHashSet<usize>,
99);
100
101type ForgetNodeHandler = fn(
102    graph: &HashMapGraph,
103    id: usize,
104    child_id: usize,
105    tables: &mut [DpTable],
106    vertex_set: &FxHashSet<usize>,
107    forgotten_vertex: usize,
108);
109
110type IntroduceNodeHandler = fn(
111    graph: &HashMapGraph,
112    id: usize,
113    child_id: usize,
114    tables: &mut [DpTable],
115    vertex_set: &FxHashSet<usize>,
116    child_vertex_set: &FxHashSet<usize>,
117    introduced_vertex: usize,
118);
119
120/// Used for differentiating between minimization and maximization problems.
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum DpObjective {
123    /// Minimization problem
124    Minimize,
125    /// Maximization problem
126    Maximize,
127}
128
129/// Contains the neccessary information for solving a (hard) problem
130/// using dynamic programming on tree decompositions.
131pub struct DpProblem {
132    /// Indicates whether the problem is a maximization or minimization problem.
133    pub objective: DpObjective,
134    /// Function for calculating the the table entries at a Leaf node.
135    pub handle_leaf_node: LeafNodeHandler,
136    /// Function for calculating the the table entries at a Join node.
137    pub handle_join_node: JoinNodeHandler,
138    /// Function for calculating the the table entries at a Forget node.
139    pub handle_forget_node: ForgetNodeHandler,
140    /// Function for calculating the the table entries at a Introduce node.
141    pub handle_introduce_node: IntroduceNodeHandler,
142}
143
144impl DpProblem {
145    /// Return a `DpProblem` instance for maximum independent set.
146    pub fn max_independent_set() -> DpProblem {
147        DpProblem {
148            objective: DpObjective::Maximize,
149            handle_leaf_node: max_independent_set::handle_leaf_node,
150            handle_join_node: max_independent_set::handle_join_node,
151            handle_forget_node: max_independent_set::handle_forget_node,
152            handle_introduce_node: max_independent_set::handle_introduce_node,
153        }
154    }
155
156    /// Return a `DpProblem` instance for minimum vertex cover.
157    pub fn min_vertex_cover() -> DpProblem {
158        DpProblem {
159            objective: DpObjective::Minimize,
160            handle_leaf_node: min_vertex_cover::handle_leaf_node,
161            handle_join_node: min_vertex_cover::handle_join_node,
162            handle_forget_node: min_vertex_cover::handle_forget_node,
163            handle_introduce_node: min_vertex_cover::handle_introduce_node,
164        }
165    }
166}
167
168/// Solves the given problem on the input graph using dynamic programming.
169///
170/// When `td` is `None`, an optimal tree decomposition is calculated and used
171/// for the algorithm.
172///
173/// The `prob` parameter specifies whether the problem is a minimization
174/// or maximization problem and contains the "recipe" for how to calculate
175/// the dynamic programming tables in order to arrive at the solution.
176pub fn dp_solve(
177    graph: &UndirectedGraph,
178    td: Option<TreeDecomposition>,
179    prob: &DpProblem,
180) -> HashSet<usize> {
181    dp_solve_hashmap_graph(&to_hash_map_graph(graph), td, prob)
182}
183
184/// For convenience.
185pub fn dp_solve_hashmap_graph(
186    graph: &HashMapGraph,
187    td: Option<TreeDecomposition>,
188    prob: &DpProblem,
189) -> HashSet<usize> {
190    let (graph, mapping) = remap_vertices(graph);
191    let td = td.unwrap_or_else(|| Solver::auto(&graph).solve(&graph));
192    let nice_td = NiceTreeDecomposition::new(td);
193
194    assert!(nice_td.td.verify(&graph).is_ok());
195
196    let mut tables: Vec<_> = vec![DpTable::new(); nice_td.td.bags().len()];
197    let root = nice_td.td.root.unwrap();
198
199    dp_solve_rec(
200        &nice_td.td,
201        &graph,
202        prob,
203        root,
204        usize::max_value(),
205        &nice_td.mapping,
206        &mut tables,
207    );
208
209    let mut sol = HashSet::new();
210    dp_read_solution_from_table(prob.objective, &tables, root, &mut sol);
211
212    sol.iter()
213        .map(|v| mapping.get(v).unwrap())
214        .copied()
215        .collect()
216}
217
218fn dp_solve_rec(
219    td: &TreeDecomposition,
220    graph: &HashMapGraph,
221    prob: &DpProblem,
222    id: usize,
223    parent_id: usize,
224    mapping: &[NiceTdNodeType],
225    tables: &mut Vec<DpTable>,
226) {
227    let children = get_children(td, id, parent_id);
228
229    for child_id in &children {
230        dp_solve_rec(td, graph, prob, *child_id, id, mapping, tables);
231    }
232
233    let vertex_set = &td.bags()[id].vertex_set;
234
235    match mapping[id] {
236        NiceTdNodeType::Leaf => {
237            let vertex = vertex_set.iter().next().unwrap();
238            (prob.handle_leaf_node)(graph, id, tables, *vertex);
239        }
240        NiceTdNodeType::Join => {
241            let mut it = children.iter();
242            let left_child_id = *it.next().unwrap();
243            let right_child_id = *it.next().unwrap();
244            (prob.handle_join_node)(graph, id, left_child_id, right_child_id, tables, vertex_set);
245        }
246        NiceTdNodeType::Forget(v) => {
247            let child_id = *children.iter().next().unwrap();
248            (prob.handle_forget_node)(graph, id, child_id, tables, vertex_set, v);
249        }
250        NiceTdNodeType::Introduce(v) => {
251            let child_id = *children.iter().next().unwrap();
252            let child_vertex_set = &td.bags()[child_id].vertex_set;
253            (prob.handle_introduce_node)(
254                graph,
255                id,
256                child_id,
257                tables,
258                vertex_set,
259                child_vertex_set,
260                v,
261            );
262        }
263    }
264}
265
266fn dp_read_solution_from_table(
267    objective: DpObjective,
268    tables: &[DpTable],
269    root: usize,
270    sol: &mut HashSet<usize>,
271) {
272    let root_entry = match objective {
273        DpObjective::Maximize => tables[root].values().max_by(|e1, e2| e1.val.cmp(&e2.val)),
274        DpObjective::Minimize => tables[root].values().min_by(|e1, e2| e1.val.cmp(&e2.val)),
275    }
276    .unwrap();
277    dp_read_solution_from_table_rec(tables, root_entry, sol);
278}
279
280fn dp_read_solution_from_table_rec(
281    tables: &[DpTable],
282    entry: &DpTableEntry,
283    sol: &mut HashSet<usize>,
284) {
285    if let Some(v) = entry.vertex_used {
286        sol.insert(v);
287    }
288
289    for (v, subset) in &entry.children {
290        dp_read_solution_from_table_rec(tables, tables[*v].get(subset).unwrap(), sol);
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::dp_solve_hashmap_graph;
297    use crate::{
298        algorithm::dynamic_programming::{
299            solve::{remap_vertices, DpProblem},
300            utils::init_bit_vec,
301        },
302        generation::erdos_renyi::generate_hash_map_graph,
303        utils::{
304            max_independent_set::{brute_force_max_independent_set, is_independent_set},
305            min_vertex_cover::{brute_force_min_vertex_cover, is_vertex_cover},
306        },
307    };
308    use arboretum_td::graph::{BaseGraph, HashMapGraph, MutableGraph};
309    use rand::{rngs::StdRng, Rng, SeedableRng};
310    use std::collections::HashSet;
311
312    fn solve_max_independent_set(graph: &HashMapGraph) -> HashSet<usize> {
313        dp_solve_hashmap_graph(graph, None, &DpProblem::max_independent_set())
314    }
315
316    fn solve_min_vertex_cover(graph: &HashMapGraph) -> HashSet<usize> {
317        dp_solve_hashmap_graph(graph, None, &DpProblem::min_vertex_cover())
318    }
319
320    #[test]
321    fn remapping() {
322        let mut graph = HashMapGraph::new();
323        graph.add_vertex(10);
324        graph.add_vertex(11);
325        graph.add_vertex(12);
326        graph.add_edge(10, 11);
327
328        let (remapped_graph, _) = remap_vertices(&graph);
329
330        assert!(remapped_graph.order() == graph.order());
331        assert!(remapped_graph.has_vertex(0));
332        assert!(remapped_graph.has_vertex(1));
333        assert!(remapped_graph.has_vertex(2));
334        assert!(remapped_graph.has_edge(0, 1) ^ remapped_graph.has_edge(1, 2));
335    }
336
337    #[test]
338    fn large_bit_vec() {
339        let mut bit_vec = init_bit_vec(65);
340        bit_vec.set(127, true);
341    }
342
343    #[test]
344    fn max_independent_set_isolated() {
345        for n in 1..10 {
346            let graph = generate_hash_map_graph(n, 0., Some(n as u64));
347
348            let sol = solve_max_independent_set(&graph);
349
350            assert!(sol.len() == n);
351        }
352    }
353
354    #[test]
355    fn max_independent_set_clique() {
356        for n in 1..10 {
357            let graph = generate_hash_map_graph(n, 1., Some(n as u64));
358            let sol = solve_max_independent_set(&graph);
359
360            assert!(sol.len() == 1);
361        }
362    }
363
364    #[test]
365    fn max_independent_set_random() {
366        let seed = [1; 32];
367        let mut rng = StdRng::from_seed(seed);
368
369        for i in 0..30 {
370            let graph = generate_hash_map_graph(
371                rng.gen_range(1..15),
372                rng.gen_range(0.05..0.1),
373                Some(i as u64),
374            );
375            let sol = solve_max_independent_set(&graph);
376
377            assert!(is_independent_set(&graph, &sol), "{:?} {:?}", graph, sol);
378
379            let sol2 = brute_force_max_independent_set(&graph);
380            assert!(sol.len() == sol2.len());
381        }
382    }
383
384    #[test]
385    fn min_vertex_cover_isolated() {
386        for n in 1..10 {
387            let graph = generate_hash_map_graph(n, 0., Some(n as u64));
388            let sol = solve_min_vertex_cover(&graph);
389
390            assert!(sol.is_empty());
391        }
392    }
393
394    #[test]
395    fn min_vertex_cover_clique() {
396        for n in 1..10 {
397            let graph = generate_hash_map_graph(n, 1., Some(n as u64));
398            let sol = solve_min_vertex_cover(&graph);
399
400            assert!(sol.len() == graph.order() - 1);
401        }
402    }
403
404    #[test]
405    fn min_vertex_cover_random() {
406        let seed = [2; 32];
407        let mut rng = StdRng::from_seed(seed);
408
409        for i in 0..30 {
410            let graph = generate_hash_map_graph(
411                rng.gen_range(1..15),
412                rng.gen_range(0.2..0.5),
413                Some(i as u64),
414            );
415            let sol = solve_min_vertex_cover(&graph);
416
417            assert!(is_vertex_cover(&graph, &sol));
418
419            let sol2 = brute_force_min_vertex_cover(&graph);
420            assert!(sol.len() == sol2.len());
421        }
422    }
423}