use std::cmp::Ordering;
pub(crate) struct Solution {
pub order: Vec<usize>,
pub cost: f64,
}
pub(crate) struct Problem<'a> {
pub cluster_ids: &'a [usize],
pub n_clusters: usize,
pub transition: &'a [f64],
pub options: usize,
pub start: &'a [f64],
pub end: &'a [f64],
}
impl Problem<'_> {
#[inline]
fn trans(&self, i: usize, j: usize) -> f64 {
self.transition[i * self.options + j]
}
}
pub(crate) const DEFAULT_CELL_BUDGET: usize = 16 * 1024 * 1024;
const MAX_EXACT_CLUSTERS: usize = 32;
pub(crate) fn solve(problem: &Problem, cell_budget: usize) -> Option<Solution> {
if problem.n_clusters == 0 {
return Some(Solution {
order: Vec::new(),
cost: 0.0,
});
}
debug_assert_eq!(
problem.options,
problem.cluster_ids.len(),
"transition stride (problem.options) must equal the option count"
);
let options = problem.options;
let too_big = problem.n_clusters > MAX_EXACT_CLUSTERS
|| (1usize << problem.n_clusters).saturating_mul(options) > cell_budget;
if too_big {
tracing::debug!(
n_clusters = problem.n_clusters,
options,
"deke-multipath: AGTSP over cell budget; using cluster-optimization heuristic"
);
solve_heuristic(problem)
} else {
solve_exact(problem)
}
}
#[derive(Clone, Copy)]
struct Back {
prev_mask: u32,
prev_opt: i32,
}
const NO_BACK: Back = Back {
prev_mask: 0,
prev_opt: -1,
};
fn group_by_cluster(problem: &Problem) -> Vec<Vec<usize>> {
let mut by_cluster = vec![Vec::new(); problem.n_clusters];
for (i, &c) in problem.cluster_ids.iter().enumerate() {
by_cluster[c].push(i);
}
by_cluster
}
fn solve_exact(problem: &Problem) -> Option<Solution> {
let options = problem.options;
let states = 1usize << problem.n_clusters;
let mut cost = vec![f64::INFINITY; states * options];
let mut back = vec![NO_BACK; states * options];
let idx = |mask: usize, opt: usize| mask * options + opt;
let by_cluster = group_by_cluster(problem);
for (i, &c) in problem.cluster_ids.iter().enumerate() {
let s = problem.start[i];
if !s.is_finite() {
continue;
}
let id = idx(1usize << c, i);
if s < cost[id] {
cost[id] = s;
back[id] = Back {
prev_mask: u32::MAX,
prev_opt: -1,
};
}
}
for mask in 1..states {
let mut present = mask;
while present != 0 {
let cluster = present.trailing_zeros() as usize;
present &= present - 1;
for &opt in &by_cluster[cluster] {
let base = cost[idx(mask, opt)];
if !base.is_finite() {
continue;
}
for (next_cluster, opts) in by_cluster.iter().enumerate() {
let bit = 1usize << next_cluster;
if mask & bit != 0 {
continue;
}
let new_mask = mask | bit;
for &next in opts {
let step = problem.trans(opt, next);
if !step.is_finite() {
continue;
}
let new_cost = base + step;
let id = idx(new_mask, next);
if new_cost < cost[id] {
cost[id] = new_cost;
back[id] = Back {
prev_mask: mask as u32,
prev_opt: opt as i32,
};
}
}
}
}
}
}
let full = states - 1;
let (best_opt, best_cost) = (0..options)
.map(|opt| (opt, cost[idx(full, opt)] + problem.end[opt]))
.filter(|(_, c)| c.is_finite())
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))?;
let mut order = Vec::with_capacity(problem.n_clusters);
let mut mask = full;
let mut opt = best_opt;
loop {
order.push(opt);
let b = back[idx(mask, opt)];
if b.prev_opt < 0 {
break;
}
mask = b.prev_mask as usize;
opt = b.prev_opt as usize;
}
order.reverse();
Some(Solution {
order,
cost: best_cost,
})
}
const MAX_HEURISTIC_STARTS: usize = 8;
fn solve_heuristic(problem: &Problem) -> Option<Solution> {
let by_cluster = group_by_cluster(problem);
let surrogate = Surrogate::build(problem, &by_cluster);
let mut seeds: Vec<usize> = (0..problem.n_clusters)
.filter(|&c| surrogate.start[c].is_finite())
.collect();
seeds.sort_by(|&a, &b| {
surrogate.start[a]
.partial_cmp(&surrogate.start[b])
.unwrap_or(Ordering::Equal)
});
seeds.truncate(MAX_HEURISTIC_STARTS.max(1));
let candidates: Vec<Solution> = {
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
seeds
.par_iter()
.filter_map(|&first| solve_from_seed(problem, &by_cluster, &surrogate, first))
.collect()
}
#[cfg(not(feature = "rayon"))]
{
seeds
.iter()
.filter_map(|&first| solve_from_seed(problem, &by_cluster, &surrogate, first))
.collect()
}
};
candidates
.into_iter()
.min_by(|a, b| a.cost.partial_cmp(&b.cost).unwrap_or(Ordering::Equal))
}
fn solve_from_seed(
problem: &Problem,
by_cluster: &[Vec<usize>],
surrogate: &Surrogate,
first: usize,
) -> Option<Solution> {
let order = surrogate.nearest_neighbour_order(problem.n_clusters, first);
let order = surrogate.two_opt(order);
let order = surrogate.or_opt(order);
let (order, cost) = cluster_optimize(problem, by_cluster, &order)?;
Some(Solution { order, cost })
}
fn cluster_optimize(
problem: &Problem,
by_cluster: &[Vec<usize>],
cluster_order: &[usize],
) -> Option<(Vec<usize>, f64)> {
if cluster_order.is_empty() {
return Some((Vec::new(), 0.0));
}
let layers: Vec<&[usize]> = cluster_order
.iter()
.map(|&c| by_cluster[c].as_slice())
.collect();
let mut cost: Vec<f64> = layers[0].iter().map(|&o| problem.start[o]).collect();
let mut back: Vec<Vec<i32>> = vec![vec![-1; layers[0].len()]];
for l in 1..layers.len() {
let prev = layers[l - 1];
let mut next_cost = vec![f64::INFINITY; layers[l].len()];
let mut next_back = vec![-1i32; layers[l].len()];
for (j, &oj) in layers[l].iter().enumerate() {
for (i, &oi) in prev.iter().enumerate() {
if !cost[i].is_finite() {
continue;
}
let candidate = cost[i] + problem.trans(oi, oj);
if candidate < next_cost[j] {
next_cost[j] = candidate;
next_back[j] = i as i32;
}
}
}
cost = next_cost;
back.push(next_back);
}
let last = layers.len() - 1;
let (best_j, best_cost) = layers[last]
.iter()
.enumerate()
.map(|(j, &o)| (j, cost[j] + problem.end[o]))
.filter(|(_, c)| c.is_finite())
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))?;
let mut layer_idx = vec![0usize; layers.len()];
layer_idx[last] = best_j;
for l in (1..layers.len()).rev() {
let prev = back[l][layer_idx[l]];
if prev < 0 {
return None;
}
layer_idx[l - 1] = prev as usize;
}
let order = layers
.iter()
.zip(&layer_idx)
.map(|(layer, &idx)| layer[idx])
.collect();
Some((order, best_cost))
}
struct Surrogate {
dist: Vec<Vec<f64>>,
start: Vec<f64>,
end: Vec<f64>,
}
impl Surrogate {
fn build(problem: &Problem, by_cluster: &[Vec<usize>]) -> Self {
let n = problem.n_clusters;
let mut dist = vec![vec![f64::INFINITY; n]; n];
for (a, opts_a) in by_cluster.iter().enumerate() {
for (b, opts_b) in by_cluster.iter().enumerate() {
if a == b {
continue;
}
let mut best = f64::INFINITY;
for &i in opts_a {
for &j in opts_b {
best = best.min(problem.trans(i, j));
}
}
dist[a][b] = best;
}
}
let reduce = |opts: &[usize], src: &[f64]| {
opts.iter().map(|&i| src[i]).fold(f64::INFINITY, f64::min)
};
let start = by_cluster
.iter()
.map(|o| reduce(o, problem.start))
.collect();
let end = by_cluster.iter().map(|o| reduce(o, problem.end)).collect();
Surrogate { dist, start, end }
}
fn nearest_neighbour_order(&self, n_clusters: usize, first: usize) -> Vec<usize> {
let mut visited = vec![false; n_clusters];
let mut order = Vec::with_capacity(n_clusters);
visited[first] = true;
order.push(first);
let mut current = first;
for _ in 1..n_clusters {
let mut best = f64::INFINITY;
let mut pick: Option<usize> = None;
for (cluster, &seen) in visited.iter().enumerate() {
if seen {
continue;
}
if self.dist[current][cluster] < best {
best = self.dist[current][cluster];
pick = Some(cluster);
}
}
let next = pick.or_else(|| visited.iter().position(|&v| !v)).unwrap();
visited[next] = true;
order.push(next);
current = next;
}
order
}
fn reverse_delta(&self, order: &[usize], i: usize, j: usize) -> f64 {
let n = order.len();
let old_in = if i == 0 {
self.start[order[i]]
} else {
self.dist[order[i - 1]][order[i]]
};
let new_in = if i == 0 {
self.start[order[j]]
} else {
self.dist[order[i - 1]][order[j]]
};
let old_out = if j == n - 1 {
self.end[order[j]]
} else {
self.dist[order[j]][order[j + 1]]
};
let new_out = if j == n - 1 {
self.end[order[i]]
} else {
self.dist[order[i]][order[j + 1]]
};
let mut old_int = 0.0;
let mut new_int = 0.0;
for k in i..j {
old_int += self.dist[order[k]][order[k + 1]];
new_int += self.dist[order[k + 1]][order[k]];
}
(new_in + new_int + new_out) - (old_in + old_int + old_out)
}
fn two_opt(&self, mut order: Vec<usize>) -> Vec<usize> {
let n = order.len();
if n < 4 {
return order;
}
let mut improved = true;
while improved {
improved = false;
for i in 0..n - 1 {
for j in i + 1..n {
if self.reverse_delta(&order, i, j) < -1e-9 {
order[i..=j].reverse();
improved = true;
}
}
}
}
order
}
fn or_opt(&self, mut order: Vec<usize>) -> Vec<usize> {
let n = order.len();
if n < 3 {
return order;
}
let mut improved = true;
while improved {
improved = false;
for i in 0..n {
let node = order[i];
let old_in = if i == 0 {
self.start[node]
} else {
self.dist[order[i - 1]][node]
};
let old_out = if i == n - 1 {
self.end[node]
} else {
self.dist[node][order[i + 1]]
};
let bridge = if i == 0 {
self.start[order[1]]
} else if i == n - 1 {
self.end[order[i - 1]]
} else {
self.dist[order[i - 1]][order[i + 1]]
};
let removal = bridge - old_in - old_out;
let mut without = order.clone();
without.remove(i);
let w = without.len();
let mut best = -1e-9;
let mut best_pos: Option<usize> = None;
for p in 0..=w {
let removed = if p == 0 {
self.start[without[0]]
} else if p == w {
self.end[without[w - 1]]
} else {
self.dist[without[p - 1]][without[p]]
};
let add_l = if p == 0 {
self.start[node]
} else {
self.dist[without[p - 1]][node]
};
let add_r = if p == w {
self.end[node]
} else {
self.dist[node][without[p]]
};
let delta = removal + (add_l + add_r - removed);
if delta < best {
best = delta;
best_pos = Some(p);
}
}
if let Some(p) = best_pos {
without.insert(p, node);
order = without;
improved = true;
}
}
}
order
}
}
#[cfg(test)]
mod tests {
use super::*;
fn problem<'a>(
cluster_ids: &'a [usize],
n_clusters: usize,
transition: &'a [f64],
start: &'a [f64],
end: &'a [f64],
) -> Problem<'a> {
Problem {
cluster_ids,
n_clusters,
transition,
options: cluster_ids.len(),
start,
end,
}
}
fn flat(rows: &[&[f64]]) -> Vec<f64> {
rows.iter().flat_map(|r| r.iter().copied()).collect()
}
#[test]
fn empty_problem_is_trivial() {
let sol = solve(&problem(&[], 0, &[], &[], &[]), DEFAULT_CELL_BUDGET).unwrap();
assert!(sol.order.is_empty());
assert_eq!(sol.cost, 0.0);
}
#[test]
fn picks_cheaper_start() {
let cluster_ids = [0, 1];
let transition = flat(&[&[0.0, 5.0], &[5.0, 0.0]]);
let start = [1.0, 10.0];
let end = [0.0, 0.0];
let sol = solve(
&problem(&cluster_ids, 2, &transition, &start, &end),
DEFAULT_CELL_BUDGET,
)
.unwrap();
assert_eq!(sol.order, vec![0, 1]);
assert!((sol.cost - 6.0).abs() < 1e-9);
}
#[test]
fn generalized_picks_best_option_in_cluster() {
let cluster_ids = [0, 0, 1];
let inf = f64::INFINITY;
let transition = flat(&[&[0.0, inf, 9.0], &[inf, 0.0, 1.0], &[9.0, 1.0, 0.0]]);
let start = [5.0, 1.0, 5.0];
let end = [0.0, 0.0, 0.0];
let sol = solve(
&problem(&cluster_ids, 2, &transition, &start, &end),
DEFAULT_CELL_BUDGET,
)
.unwrap();
assert_eq!(sol.order, vec![1, 2]);
assert!((sol.cost - 2.0).abs() < 1e-9);
}
#[test]
fn end_cost_breaks_the_tie() {
let cluster_ids = [0, 0];
let transition = flat(&[&[0.0, 0.0], &[0.0, 0.0]]);
let start = [1.0, 1.0];
let end = [5.0, 0.0];
let sol = solve(
&problem(&cluster_ids, 1, &transition, &start, &end),
DEFAULT_CELL_BUDGET,
)
.unwrap();
assert_eq!(sol.order, vec![1]);
assert!((sol.cost - 1.0).abs() < 1e-9);
}
#[test]
fn cluster_optimize_finds_layered_shortest_path() {
let cluster_ids = [0, 0, 1, 1];
let inf = f64::INFINITY;
let transition = flat(&[
&[0.0, 0.0, 8.0, 8.0],
&[0.0, 0.0, 1.0, 9.0],
&[inf, inf, 0.0, 0.0],
&[inf, inf, 0.0, 0.0],
]);
let start = [7.0, 1.0, inf, inf];
let end = [inf, inf, 0.0, 5.0];
let p = problem(&cluster_ids, 2, &transition, &start, &end);
let by_cluster = group_by_cluster(&p);
let (order, cost) = cluster_optimize(&p, &by_cluster, &[0, 1]).unwrap();
assert_eq!(order, vec![1, 2]);
assert!((cost - 2.0).abs() < 1e-9);
}
#[test]
fn heuristic_matches_exact_generalized() {
let cluster_ids = [0, 0, 1];
let inf = f64::INFINITY;
let transition = flat(&[&[0.0, inf, 9.0], &[inf, 0.0, 1.0], &[9.0, 1.0, 0.0]]);
let start = [5.0, 1.0, 5.0];
let end = [0.0, 0.0, 0.0];
let sol = solve(&problem(&cluster_ids, 2, &transition, &start, &end), 0).unwrap();
assert_eq!(sol.order, vec![1, 2]);
assert!((sol.cost - 2.0).abs() < 1e-9);
}
#[test]
fn heuristic_matches_exact_on_small_instance() {
let cluster_ids = [0, 1, 2];
let transition = flat(&[&[0.0, 2.0, 9.0], &[3.0, 0.0, 2.0], &[4.0, 6.0, 0.0]]);
let start = [0.0, 7.0, 7.0];
let end = [0.0, 0.0, 0.0];
let exact = solve(
&problem(&cluster_ids, 3, &transition, &start, &end),
DEFAULT_CELL_BUDGET,
)
.unwrap();
let heuristic = solve(&problem(&cluster_ids, 3, &transition, &start, &end), 0).unwrap();
assert!(
heuristic.cost + 1e-9 >= exact.cost,
"heuristic beat optimal"
);
assert!(
(heuristic.cost - exact.cost).abs() < 1e-9,
"heuristic {} != exact {}",
heuristic.cost,
exact.cost
);
}
}