algorithms_edu/problems/graph/tsp/
dp.rs

1//! This mod contains a recursive implementation of the TSP problem using dynamic programming. The
2//! main idea is that since we need to do all n! permutations of nodes to find the optimal solution
3//! that caching the results of sub paths can improve performance.
4//!
5//! For example, if one permutation is: `... D A B C` then later when we need to compute the value
6//! of the permutation `... E B A C` we should already have cached the answer for the subgraph
7//! containing the nodes `{A, B, C}`.
8//!
9//! - Time Complexity: O(n^2 * 2^n) Space Complexity: O(n * 2^n)
10//!
11//! # Resources
12//!
13//! - [W. Fiset's video](https://www.youtube.com/watch?v=cY4HiiFHO1o&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=25)
14//! - [W. Fiset's video](https://www.youtube.com/watch?v=cY4HiiFHO1o&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=25)
15
16use crate::algo::graph::WeightedAdjacencyMatrix;
17use crate::data_structures::bit::Bit;
18
19pub struct TspSolver {}
20
21impl TspSolver {
22    #[allow(clippy::needless_range_loop)]
23    pub fn solve(distance: &WeightedAdjacencyMatrix, start: usize) -> (f64, Vec<usize>) {
24        let n = distance.node_count();
25        let mut memo = vec![vec![f64::INFINITY; 1 << n]; n];
26        // store the optimal distance from the start node to each node `i`
27        for i in 0..n {
28            memo[i][1 << i | 1 << start] = distance[start][i];
29        }
30
31        let mut memo = vec![vec![f64::INFINITY; 1 << n]; n];
32        // store the optimal distance from the start node to each node `i`
33        for i in 0..n {
34            memo[i][1 << i | 1 << start] = distance[start][i];
35        }
36        for r in 3..=n {
37            for state in BinaryCombinations::new(n, r as u32).filter(|state| state.get_bit(start)) {
38                for next in (0..n).filter(|&node| state.get_bit(node) && node != start) {
39                    // the state without the next node
40                    let prev_state = state ^ (1 << next);
41                    let mut min_dist = f64::INFINITY;
42                    for prev_end in
43                        (0..n).filter(|&node| state.get_bit(node) && node != start && node != next)
44                    {
45                        let new_dist = memo[prev_end][prev_state] + distance[prev_end][next];
46                        if new_dist < min_dist {
47                            min_dist = new_dist;
48                        }
49                    }
50                    memo[next][state] = min_dist;
51                }
52            }
53        }
54
55        // the end state is the bit mask with `n` bits set to 1
56        let end_state = (1 << n) - 1;
57        let mut min_dist = f64::INFINITY;
58        for e in (0..start).chain(start + 1..n) {
59            let dist = memo[e][end_state] + distance[e][start];
60            if dist < min_dist {
61                min_dist = dist;
62            }
63        }
64
65        let mut state = end_state;
66        let mut last_index = start;
67        let mut tour = vec![start];
68        for _ in 1..n {
69            let mut best_j = usize::MAX;
70            let mut best_dist = f64::MAX;
71            for j in (0..n).filter(|&j| state.get_bit(j) && j != start) {
72                let dist = memo[j][state] + distance[j][last_index];
73                if dist < best_dist {
74                    best_j = j;
75                    best_dist = dist;
76                }
77            }
78            tour.push(best_j);
79            state ^= 1 << best_j;
80            last_index = best_j;
81        }
82
83        (min_dist, tour)
84    }
85}
86pub struct BinaryCombinations {
87    curr: usize,
88    r: u32,
89    n: usize,
90}
91
92impl Iterator for BinaryCombinations {
93    type Item = usize;
94    fn next(&mut self) -> Option<Self::Item> {
95        for i in self.curr..1 << self.n {
96            if i.count_ones() == self.r {
97                self.curr = i + 1;
98                return Some(i);
99            }
100        }
101        None
102    }
103}
104
105impl BinaryCombinations {
106    pub fn new(n: usize, r: u32) -> Self {
107        Self { curr: 0, r, n }
108    }
109}
110
111// // To find all the combinations of size r we need to recurse until we have
112// // selected r elements (aka r = 0), otherwise if r != 0 then we still need to select
113// // an element which is found after the position of our last selected element
114// fn combinations(mut set: u32, at: u32, r: u32, n: u32, subsets: &mut Vec<u32>) {
115//     // Return early if there are more elements left to select than what is available.
116//     let elements_left_to_pick = n - at;
117//     if elements_left_to_pick < r {
118//         return;
119//     }
120
121//     // We selected 'r' elements so we found a valid subset!
122//     if r == 0 {
123//         subsets.push(set);
124//     } else {
125//         for i in at..n {
126//             // Try including this element
127//             set ^= 1 << i;
128
129//             combinations(set, i + 1, r - 1, n, subsets);
130
131//             // Backtrack and try the instance where we did not include this element
132//             set ^= 1 << i;
133//         }
134//     }
135// }