use scirs2_core::ndarray::Array2;
use std::collections::BinaryHeap;
use std::cmp::Ordering;
use crate::error::OptimizeError;
pub type TspResult<T> = Result<T, OptimizeError>;
#[derive(Clone, PartialEq)]
struct PrimEntry {
cost: f64,
vertex: usize,
}
impl Eq for PrimEntry {}
impl PartialOrd for PrimEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PrimEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
.then(self.vertex.cmp(&other.vertex))
}
}
pub fn tour_length(tour: &[usize], dist: &Array2<f64>) -> f64 {
let n = tour.len();
if n == 0 {
return 0.0;
}
let mut total = 0.0;
for i in 0..n {
let from = tour[i];
let to = tour[(i + 1) % n];
total += dist[[from, to]];
}
total
}
pub fn nearest_neighbor_heuristic(
dist: &Array2<f64>,
start: usize,
) -> TspResult<(Vec<usize>, f64)> {
let n = dist.nrows();
if n == 0 {
return Ok((vec![], 0.0));
}
if start >= n {
return Err(OptimizeError::InvalidInput(format!(
"start index {start} out of range for {n} cities"
)));
}
let mut visited = vec![false; n];
let mut tour = Vec::with_capacity(n);
let mut current = start;
visited[current] = true;
tour.push(current);
for _ in 1..n {
let mut best_next = None;
let mut best_dist = f64::INFINITY;
for j in 0..n {
if !visited[j] {
let d = dist[[current, j]];
if d < best_dist {
best_dist = d;
best_next = Some(j);
}
}
}
match best_next {
Some(next) => {
visited[next] = true;
tour.push(next);
current = next;
}
None => break,
}
}
let length = tour_length(&tour, dist);
Ok((tour, length))
}
pub fn two_opt(tour: &mut Vec<usize>, dist: &Array2<f64>) -> f64 {
let n = tour.len();
if n < 4 {
return tour_length(tour, dist);
}
let mut improved = true;
while improved {
improved = false;
for i in 0..n - 1 {
for j in i + 2..n {
if i == 0 && j == n - 1 {
continue;
}
let a = tour[i];
let b = tour[i + 1];
let c = tour[j];
let d = tour[(j + 1) % n];
let current_cost = dist[[a, b]] + dist[[c, d]];
let new_cost = dist[[a, c]] + dist[[b, d]];
if new_cost < current_cost - 1e-10 {
tour[i + 1..=j].reverse();
improved = true;
}
}
}
}
tour_length(tour, dist)
}
pub fn three_opt_move(
dist: &Array2<f64>,
i: usize,
j: usize,
k: usize,
tour: &[usize],
) -> Option<Vec<usize>> {
let n = tour.len();
if n < 6 {
return None;
}
if !(i < j && j < k && k < n) {
return None;
}
let a = tour[i];
let b = tour[i + 1];
let c = tour[j];
let d = tour[j + 1];
let e = tour[k];
let f = tour[(k + 1) % n];
let d0 = dist[[a, b]] + dist[[c, d]] + dist[[e, f]];
let candidates: [(f64, u8); 7] = [
(dist[[a, c]] + dist[[b, d]] + dist[[e, f]], 1),
(dist[[a, b]] + dist[[c, e]] + dist[[d, f]], 2),
(dist[[a, c]] + dist[[b, e]] + dist[[d, f]], 3),
(dist[[a, d]] + dist[[e, b]] + dist[[c, f]], 4),
(dist[[a, d]] + dist[[e, c]] + dist[[b, f]], 5),
(dist[[a, e]] + dist[[d, b]] + dist[[c, f]], 6),
(dist[[a, e]] + dist[[d, c]] + dist[[b, f]], 7),
];
let best = candidates
.iter()
.min_by(|x, y| x.0.partial_cmp(&y.0).unwrap_or(Ordering::Equal));
let (best_cost, reconnect_type) = match best {
Some(&(c, t)) => (c, t),
None => return None,
};
if best_cost >= d0 - 1e-10 {
return None;
}
let seg1: Vec<usize> = tour[..=i].to_vec();
let seg2: Vec<usize> = tour[i + 1..=j].to_vec();
let seg3: Vec<usize> = tour[j + 1..=k].to_vec();
let seg4: Vec<usize> = if k + 1 < n {
tour[k + 1..].to_vec()
} else {
vec![]
};
let mut new_tour = seg1;
match reconnect_type {
1 => {
new_tour.extend(seg2.iter().rev());
new_tour.extend_from_slice(&seg3);
}
2 => {
new_tour.extend_from_slice(&seg2);
new_tour.extend(seg3.iter().rev());
}
3 => {
new_tour.extend(seg2.iter().rev());
new_tour.extend(seg3.iter().rev());
}
4 => {
new_tour.extend_from_slice(&seg3);
new_tour.extend_from_slice(&seg2);
}
5 => {
new_tour.extend_from_slice(&seg3);
new_tour.extend(seg2.iter().rev());
}
6 => {
new_tour.extend(seg3.iter().rev());
new_tour.extend_from_slice(&seg2);
}
7 => {
new_tour.extend(seg3.iter().rev());
new_tour.extend(seg2.iter().rev());
}
_ => unreachable!(),
}
new_tour.extend_from_slice(&seg4);
Some(new_tour)
}
pub fn or_opt(tour: &mut Vec<usize>, dist: &Array2<f64>) -> f64 {
let n = tour.len();
if n < 4 {
return tour_length(tour, dist);
}
let mut improved = true;
while improved {
improved = false;
for seg_len in 1..=3_usize {
if n < seg_len + 2 {
continue;
}
'outer: for seg_start in 0..n {
let seg_end = (seg_start + seg_len - 1) % n;
let prev = if seg_start == 0 { n - 1 } else { seg_start - 1 };
let after = (seg_end + 1) % n;
if prev == seg_end || after == seg_start {
continue;
}
let first_city = tour[seg_start];
let last_city = tour[seg_end];
let prev_city = tour[prev];
let after_city = tour[after];
let remove_cost = dist[[prev_city, first_city]]
+ dist[[last_city, after_city]]
- dist[[prev_city, after_city]];
let mut best_gain = 1e-10; let mut best_ins = None;
let mut best_reverse = false;
for ins in 0..n {
let in_seg = if seg_start <= seg_end {
ins >= seg_start && ins <= seg_end
} else {
ins >= seg_start || ins <= seg_end
};
if in_seg || ins == prev {
continue;
}
let ins_next = (ins + 1) % n;
let ins_city = tour[ins];
let ins_next_city = tour[ins_next];
let fwd = dist[[ins_city, first_city]]
+ dist[[last_city, ins_next_city]]
- dist[[ins_city, ins_next_city]];
let gain_fwd = remove_cost - fwd;
if gain_fwd > best_gain {
best_gain = gain_fwd;
best_ins = Some(ins);
best_reverse = false;
}
if seg_len > 1 {
let rev = dist[[ins_city, last_city]]
+ dist[[first_city, ins_next_city]]
- dist[[ins_city, ins_next_city]];
let gain_rev = remove_cost - rev;
if gain_rev > best_gain {
best_gain = gain_rev;
best_ins = Some(ins);
best_reverse = true;
}
}
}
if let Some(ins) = best_ins {
let segment: Vec<usize> = (0..seg_len)
.map(|k| tour[(seg_start + k) % n])
.collect();
let seg_set: std::collections::HashSet<usize> =
segment.iter().cloned().collect();
let remaining: Vec<usize> = tour
.iter()
.cloned()
.filter(|v| !seg_set.contains(v))
.collect();
let ins_city = tour[ins];
let ins_pos = remaining
.iter()
.position(|&v| v == ins_city)
.unwrap_or(0);
let mut new_tour: Vec<usize> = Vec::with_capacity(n);
new_tour.extend_from_slice(&remaining[..=ins_pos]);
if best_reverse {
new_tour.extend(segment.iter().rev());
} else {
new_tour.extend_from_slice(&segment);
}
if ins_pos + 1 < remaining.len() {
new_tour.extend_from_slice(&remaining[ins_pos + 1..]);
}
if new_tour.len() == n {
*tour = new_tour;
improved = true;
break 'outer;
}
}
}
}
}
tour_length(tour, dist)
}
pub fn mst_lower_bound(dist: &Array2<f64>) -> f64 {
let n = dist.nrows();
if n == 0 {
return 0.0;
}
if n == 1 {
return 0.0;
}
let mut in_mst = vec![false; n];
let mut min_edge = vec![f64::INFINITY; n];
min_edge[0] = 0.0;
let mut heap: BinaryHeap<PrimEntry> = BinaryHeap::new();
heap.push(PrimEntry {
cost: 0.0,
vertex: 0,
});
let mut mst_weight = 0.0;
while let Some(PrimEntry { cost, vertex }) = heap.pop() {
if in_mst[vertex] {
continue;
}
in_mst[vertex] = true;
mst_weight += cost;
for j in 0..n {
if !in_mst[j] {
let d = dist[[vertex, j]];
if d < min_edge[j] {
min_edge[j] = d;
heap.push(PrimEntry { cost: d, vertex: j });
}
}
}
}
mst_weight
}
pub struct TspSolver {
dist: Array2<f64>,
}
impl TspSolver {
pub fn new(dist: Array2<f64>) -> TspResult<Self> {
if dist.nrows() != dist.ncols() {
return Err(OptimizeError::InvalidInput(
"Distance matrix must be square".to_string(),
));
}
Ok(Self { dist })
}
pub fn solve(&self) -> TspResult<(Vec<usize>, f64)> {
let n = self.dist.nrows();
if n == 0 {
return Ok((vec![], 0.0));
}
let mut best_tour = vec![];
let mut best_len = f64::INFINITY;
for start in 0..n {
let (mut tour, _) = nearest_neighbor_heuristic(&self.dist, start)?;
two_opt(&mut tour, &self.dist);
or_opt(&mut tour, &self.dist);
let len = tour_length(&tour, &self.dist);
if len < best_len {
best_len = len;
best_tour = tour;
}
}
Ok((best_tour, best_len))
}
pub fn lower_bound(&self) -> f64 {
mst_lower_bound(&self.dist)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn square_dist() -> Array2<f64> {
array![
[0.0, 1.0, 1.414, 1.0],
[1.0, 0.0, 1.0, 1.414],
[1.414, 1.0, 0.0, 1.0],
[1.0, 1.414, 1.0, 0.0]
]
}
#[test]
fn test_tour_length() {
let dist = square_dist();
let tour = vec![0, 1, 2, 3];
let len = tour_length(&tour, &dist);
assert!((len - 4.0).abs() < 1e-6);
}
#[test]
fn test_nearest_neighbor() {
let dist = square_dist();
let (tour, len) = nearest_neighbor_heuristic(&dist, 0).expect("unexpected None or Err");
assert_eq!(tour.len(), 4);
assert!(len > 0.0);
}
#[test]
fn test_two_opt_improves() {
let dist = square_dist();
let mut tour = vec![0, 2, 1, 3];
let original_len = tour_length(&tour, &dist);
let new_len = two_opt(&mut tour, &dist);
assert!(new_len <= original_len + 1e-9);
}
#[test]
fn test_or_opt() {
let dist = square_dist();
let mut tour = vec![0, 1, 2, 3];
let len = or_opt(&mut tour, &dist);
assert!(len > 0.0);
assert_eq!(tour.len(), 4);
}
#[test]
fn test_mst_lower_bound() {
let dist = square_dist();
let lb = mst_lower_bound(&dist);
assert!(lb > 0.0);
assert!(lb <= 4.0 + 1e-6); }
#[test]
fn test_solver_small() {
let dist = square_dist();
let solver = TspSolver::new(dist).expect("failed to create solver");
let (tour, len) = solver.solve().expect("unexpected None or Err");
assert_eq!(tour.len(), 4);
assert!(len <= 4.5);
}
#[test]
fn test_three_opt_move() {
let dist = square_dist();
let n = 6;
let mut big_dist = Array2::<f64>::zeros((n, n));
for r in 0..n {
for c in 0..n {
if r != c {
let dx = (r as f64) - (c as f64);
big_dist[[r, c]] = dx.abs();
}
}
}
let tour: Vec<usize> = vec![0, 1, 2, 3, 4, 5];
let _ = three_opt_move(&big_dist, 0, 2, 4, &tour);
}
#[test]
fn test_invalid_start() {
let dist = square_dist();
assert!(nearest_neighbor_heuristic(&dist, 10).is_err());
}
#[test]
fn test_empty_tour() {
let dist: Array2<f64> = Array2::zeros((0, 0));
let (tour, len) = nearest_neighbor_heuristic(&dist, 0).expect("unexpected None or Err");
assert!(tour.is_empty());
assert_eq!(len, 0.0);
}
}