algorithms_edu/problems/graph/tsp/dp.rs
1//! This mod contains a recursive implementation of the TSP problem using dynamic programming. The
2//! main idea is that since we need to do all n! permutations of nodes to find the optimal solution
3//! that caching the results of sub paths can improve performance.
4//!
5//! For example, if one permutation is: `... D A B C` then later when we need to compute the value
6//! of the permutation `... E B A C` we should already have cached the answer for the subgraph
7//! containing the nodes `{A, B, C}`.
8//!
9//! - Time Complexity: O(n^2 * 2^n) Space Complexity: O(n * 2^n)
10//!
11//! # Resources
12//!
13//! - [W. Fiset's video](https://www.youtube.com/watch?v=cY4HiiFHO1o&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=25)
14//! - [W. Fiset's video](https://www.youtube.com/watch?v=cY4HiiFHO1o&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=25)
15
16use crate::algo::graph::WeightedAdjacencyMatrix;
17use crate::data_structures::bit::Bit;
18
19pub struct TspSolver {}
20
21impl TspSolver {
22 #[allow(clippy::needless_range_loop)]
23 pub fn solve(distance: &WeightedAdjacencyMatrix, start: usize) -> (f64, Vec<usize>) {
24 let n = distance.node_count();
25 let mut memo = vec![vec![f64::INFINITY; 1 << n]; n];
26 // store the optimal distance from the start node to each node `i`
27 for i in 0..n {
28 memo[i][1 << i | 1 << start] = distance[start][i];
29 }
30
31 let mut memo = vec![vec![f64::INFINITY; 1 << n]; n];
32 // store the optimal distance from the start node to each node `i`
33 for i in 0..n {
34 memo[i][1 << i | 1 << start] = distance[start][i];
35 }
36 for r in 3..=n {
37 for state in BinaryCombinations::new(n, r as u32).filter(|state| state.get_bit(start)) {
38 for next in (0..n).filter(|&node| state.get_bit(node) && node != start) {
39 // the state without the next node
40 let prev_state = state ^ (1 << next);
41 let mut min_dist = f64::INFINITY;
42 for prev_end in
43 (0..n).filter(|&node| state.get_bit(node) && node != start && node != next)
44 {
45 let new_dist = memo[prev_end][prev_state] + distance[prev_end][next];
46 if new_dist < min_dist {
47 min_dist = new_dist;
48 }
49 }
50 memo[next][state] = min_dist;
51 }
52 }
53 }
54
55 // the end state is the bit mask with `n` bits set to 1
56 let end_state = (1 << n) - 1;
57 let mut min_dist = f64::INFINITY;
58 for e in (0..start).chain(start + 1..n) {
59 let dist = memo[e][end_state] + distance[e][start];
60 if dist < min_dist {
61 min_dist = dist;
62 }
63 }
64
65 let mut state = end_state;
66 let mut last_index = start;
67 let mut tour = vec![start];
68 for _ in 1..n {
69 let mut best_j = usize::MAX;
70 let mut best_dist = f64::MAX;
71 for j in (0..n).filter(|&j| state.get_bit(j) && j != start) {
72 let dist = memo[j][state] + distance[j][last_index];
73 if dist < best_dist {
74 best_j = j;
75 best_dist = dist;
76 }
77 }
78 tour.push(best_j);
79 state ^= 1 << best_j;
80 last_index = best_j;
81 }
82
83 (min_dist, tour)
84 }
85}
86pub struct BinaryCombinations {
87 curr: usize,
88 r: u32,
89 n: usize,
90}
91
92impl Iterator for BinaryCombinations {
93 type Item = usize;
94 fn next(&mut self) -> Option<Self::Item> {
95 for i in self.curr..1 << self.n {
96 if i.count_ones() == self.r {
97 self.curr = i + 1;
98 return Some(i);
99 }
100 }
101 None
102 }
103}
104
105impl BinaryCombinations {
106 pub fn new(n: usize, r: u32) -> Self {
107 Self { curr: 0, r, n }
108 }
109}
110
111// // To find all the combinations of size r we need to recurse until we have
112// // selected r elements (aka r = 0), otherwise if r != 0 then we still need to select
113// // an element which is found after the position of our last selected element
114// fn combinations(mut set: u32, at: u32, r: u32, n: u32, subsets: &mut Vec<u32>) {
115// // Return early if there are more elements left to select than what is available.
116// let elements_left_to_pick = n - at;
117// if elements_left_to_pick < r {
118// return;
119// }
120
121// // We selected 'r' elements so we found a valid subset!
122// if r == 0 {
123// subsets.push(set);
124// } else {
125// for i in at..n {
126// // Try including this element
127// set ^= 1 << i;
128
129// combinations(set, i + 1, r - 1, n, subsets);
130
131// // Backtrack and try the instance where we did not include this element
132// set ^= 1 << i;
133// }
134// }
135// }