use rand::prelude::*;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use crate::timing::Timer;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CoolingSchedule {
#[default]
Geometric,
Linear,
Adaptive,
LundyMees,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct SaConfig {
pub initial_temp: f64,
pub final_temp: f64,
pub cooling_rate: f64,
pub iterations_per_temp: usize,
pub max_iterations: Option<u64>,
pub cooling_schedule: CoolingSchedule,
pub time_limit: Option<Duration>,
pub target_fitness: Option<f64>,
pub enable_reheating: bool,
pub reheat_threshold: u64,
pub reheat_factor: f64,
}
impl Default for SaConfig {
fn default() -> Self {
Self {
initial_temp: 1000.0,
final_temp: 0.001,
cooling_rate: 0.95,
iterations_per_temp: 100,
max_iterations: Some(100_000),
cooling_schedule: CoolingSchedule::Geometric,
time_limit: None,
target_fitness: None,
enable_reheating: false,
reheat_threshold: 1000,
reheat_factor: 2.0,
}
}
}
impl SaConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_initial_temp(mut self, temp: f64) -> Self {
self.initial_temp = temp.max(0.001);
self
}
pub fn with_final_temp(mut self, temp: f64) -> Self {
self.final_temp = temp.max(0.0001);
self
}
pub fn with_cooling_rate(mut self, rate: f64) -> Self {
self.cooling_rate = rate.clamp(0.001, 0.9999);
self
}
pub fn with_iterations_per_temp(mut self, iterations: usize) -> Self {
self.iterations_per_temp = iterations.max(1);
self
}
pub fn with_max_iterations(mut self, iterations: u64) -> Self {
self.max_iterations = Some(iterations);
self
}
pub fn with_cooling_schedule(mut self, schedule: CoolingSchedule) -> Self {
self.cooling_schedule = schedule;
self
}
pub fn with_time_limit(mut self, duration: Duration) -> Self {
self.time_limit = Some(duration);
self
}
pub fn with_target_fitness(mut self, fitness: f64) -> Self {
self.target_fitness = Some(fitness);
self
}
pub fn with_reheating(mut self, threshold: u64, factor: f64) -> Self {
self.enable_reheating = true;
self.reheat_threshold = threshold;
self.reheat_factor = factor.max(1.1);
self
}
}
pub trait SaSolution: Clone + Send + Sync {
fn objective(&self) -> f64;
fn set_objective(&mut self, value: f64);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NeighborhoodOperator {
Swap,
Relocate,
Inversion,
Rotation,
Chain,
}
pub trait SaProblem: Send + Sync {
type Solution: SaSolution;
fn initial_solution<R: Rng>(&self, rng: &mut R) -> Self::Solution;
fn neighbor<R: Rng>(
&self,
solution: &Self::Solution,
operator: NeighborhoodOperator,
rng: &mut R,
) -> Self::Solution;
fn evaluate(&self, solution: &mut Self::Solution);
fn available_operators(&self) -> Vec<NeighborhoodOperator> {
vec![
NeighborhoodOperator::Swap,
NeighborhoodOperator::Relocate,
NeighborhoodOperator::Inversion,
]
}
fn on_temperature_change(
&self,
_temperature: f64,
_iteration: u64,
_best: &Self::Solution,
_current: &Self::Solution,
) {
}
}
#[derive(Debug, Clone)]
pub struct SaProgress {
pub temperature: f64,
pub iteration: u64,
pub best_fitness: f64,
pub current_fitness: f64,
pub acceptance_rate: f64,
pub elapsed: Duration,
pub running: bool,
}
#[derive(Debug, Clone)]
pub struct SaResult<S: SaSolution> {
pub best: S,
pub final_temperature: f64,
pub iterations: u64,
pub elapsed: Duration,
pub target_reached: bool,
pub reheat_count: u32,
pub history: Vec<f64>,
}
pub struct SaRunner<P: SaProblem> {
config: SaConfig,
problem: P,
cancelled: Arc<AtomicBool>,
}
impl<P: SaProblem> SaRunner<P> {
pub fn new(config: SaConfig, problem: P) -> Self {
Self {
config,
problem,
cancelled: Arc::new(AtomicBool::new(false)),
}
}
pub fn cancel_handle(&self) -> Arc<AtomicBool> {
self.cancelled.clone()
}
pub fn run(&self) -> SaResult<P::Solution> {
self.run_with_rng(&mut rand::rng())
}
pub fn run_with_rng<R: Rng>(&self, rng: &mut R) -> SaResult<P::Solution> {
let start = Timer::now();
let mut history = Vec::new();
let mut current = self.problem.initial_solution(rng);
self.problem.evaluate(&mut current);
let mut best = current.clone();
let mut best_fitness = best.objective();
let mut temperature = self.config.initial_temp;
let mut iteration = 0u64;
let mut target_reached = false;
let mut reheat_count = 0u32;
let mut stagnation_count = 0u64;
let operators = self.problem.available_operators();
let temp_delta = if matches!(self.config.cooling_schedule, CoolingSchedule::Linear) {
(self.config.initial_temp - self.config.final_temp)
/ (self.config.max_iterations.unwrap_or(10000) as f64
/ self.config.iterations_per_temp as f64)
} else {
0.0
};
let mut accepted_count = 0usize;
let mut total_count = 0usize;
while temperature > self.config.final_temp {
if self.cancelled.load(Ordering::Relaxed) {
break;
}
if let Some(limit) = self.config.time_limit {
if start.elapsed() > limit {
break;
}
}
if let Some(max) = self.config.max_iterations {
if iteration >= max {
break;
}
}
if let Some(target) = self.config.target_fitness {
if best_fitness >= target {
target_reached = true;
break;
}
}
for _ in 0..self.config.iterations_per_temp {
iteration += 1;
total_count += 1;
let operator = operators[rng.random_range(0..operators.len())];
let mut neighbor = self.problem.neighbor(¤t, operator, rng);
self.problem.evaluate(&mut neighbor);
let current_obj = current.objective();
let neighbor_obj = neighbor.objective();
let delta = neighbor_obj - current_obj;
let accept = if delta >= 0.0 {
true
} else {
let probability = (delta / temperature).exp();
rng.random::<f64>() < probability
};
if accept {
accepted_count += 1;
current = neighbor;
if current.objective() > best_fitness {
best = current.clone();
best_fitness = best.objective();
stagnation_count = 0;
} else {
stagnation_count += 1;
}
} else {
stagnation_count += 1;
}
if let Some(max) = self.config.max_iterations {
if iteration >= max {
break;
}
}
}
history.push(best_fitness);
self.problem
.on_temperature_change(temperature, iteration, &best, ¤t);
if self.config.enable_reheating && stagnation_count >= self.config.reheat_threshold {
temperature *= self.config.reheat_factor;
temperature = temperature.min(self.config.initial_temp);
stagnation_count = 0;
reheat_count += 1;
}
temperature = self.cool_down(temperature, temp_delta, accepted_count, total_count);
accepted_count = 0;
total_count = 0;
}
history.push(best_fitness);
SaResult {
best,
final_temperature: temperature,
iterations: iteration,
elapsed: start.elapsed(),
target_reached,
reheat_count,
history,
}
}
fn cool_down(&self, current_temp: f64, delta: f64, accepted: usize, total: usize) -> f64 {
match self.config.cooling_schedule {
CoolingSchedule::Geometric => current_temp * self.config.cooling_rate,
CoolingSchedule::Linear => (current_temp - delta).max(self.config.final_temp),
CoolingSchedule::Adaptive => {
let acceptance_rate = if total > 0 {
accepted as f64 / total as f64
} else {
0.5
};
let adjusted_rate = if acceptance_rate > 0.5 {
self.config.cooling_rate * 0.95 } else if acceptance_rate < 0.1 {
self.config.cooling_rate.powf(0.5) } else {
self.config.cooling_rate
};
current_temp * adjusted_rate
}
CoolingSchedule::LundyMees => {
current_temp / (1.0 + self.config.cooling_rate * current_temp)
}
}
}
#[cfg(feature = "parallel")]
pub fn run_parallel(&self, num_restarts: usize) -> SaResult<P::Solution>
where
P: Clone,
{
let num_restarts = num_restarts.max(1);
let results: Vec<SaResult<P::Solution>> = (0..num_restarts)
.into_par_iter()
.map(|_| {
let mut rng = rand::rng();
self.run_with_rng(&mut rng)
})
.collect();
results
.into_iter()
.max_by(|a, b| {
a.best
.objective()
.partial_cmp(&b.best.objective())
.unwrap_or(std::cmp::Ordering::Equal)
})
.expect("At least one result should exist")
}
}
#[derive(Debug, Clone)]
pub struct PermutationSolution {
pub sequence: Vec<usize>,
pub rotations: Vec<usize>,
pub rotation_options: usize,
objective: f64,
}
impl PermutationSolution {
pub fn new(size: usize, rotation_options: usize) -> Self {
Self {
sequence: (0..size).collect(),
rotations: vec![0; size],
rotation_options,
objective: f64::NEG_INFINITY,
}
}
pub fn random<R: Rng>(size: usize, rotation_options: usize, rng: &mut R) -> Self {
let mut sequence: Vec<usize> = (0..size).collect();
sequence.shuffle(rng);
let rotations: Vec<usize> = (0..size)
.map(|_| rng.random_range(0..rotation_options.max(1)))
.collect();
Self {
sequence,
rotations,
rotation_options,
objective: f64::NEG_INFINITY,
}
}
pub fn len(&self) -> usize {
self.sequence.len()
}
pub fn is_empty(&self) -> bool {
self.sequence.is_empty()
}
pub fn apply_swap<R: Rng>(&self, rng: &mut R) -> Self {
let mut result = self.clone();
if result.sequence.len() < 2 {
return result;
}
let i = rng.random_range(0..result.sequence.len());
let j = rng.random_range(0..result.sequence.len());
result.sequence.swap(i, j);
result.objective = f64::NEG_INFINITY;
result
}
pub fn apply_relocate<R: Rng>(&self, rng: &mut R) -> Self {
let mut result = self.clone();
if result.sequence.len() < 2 {
return result;
}
let from = rng.random_range(0..result.sequence.len());
let to = rng.random_range(0..result.sequence.len());
if from != to {
let elem = result.sequence.remove(from);
let insert_pos = if to > from { to - 1 } else { to };
result
.sequence
.insert(insert_pos.min(result.sequence.len()), elem);
}
result.objective = f64::NEG_INFINITY;
result
}
pub fn apply_inversion<R: Rng>(&self, rng: &mut R) -> Self {
let mut result = self.clone();
let n = result.sequence.len();
if n < 2 {
return result;
}
let (mut p1, mut p2) = (rng.random_range(0..n), rng.random_range(0..n));
if p1 > p2 {
std::mem::swap(&mut p1, &mut p2);
}
result.sequence[p1..=p2].reverse();
result.objective = f64::NEG_INFINITY;
result
}
pub fn apply_rotation<R: Rng>(&self, rng: &mut R) -> Self {
let mut result = self.clone();
if result.rotations.is_empty() || result.rotation_options <= 1 {
return result;
}
let idx = rng.random_range(0..result.rotations.len());
result.rotations[idx] = rng.random_range(0..result.rotation_options);
result.objective = f64::NEG_INFINITY;
result
}
pub fn apply_chain<R: Rng>(&self, rng: &mut R) -> Self {
let mut result = self.clone();
let n = result.sequence.len();
if n < 4 {
return self.apply_swap(rng);
}
let mut positions: Vec<usize> = (0..n).collect();
positions.shuffle(rng);
let mut selected: Vec<usize> = positions.into_iter().take(3).collect();
selected.sort();
let (p1, p2, p3) = (selected[0], selected[1], selected[2]);
let seg1: Vec<usize> = result.sequence[..p1].to_vec();
let seg2: Vec<usize> = result.sequence[p1..p2].to_vec();
let seg3: Vec<usize> = result.sequence[p2..p3].to_vec();
let seg4: Vec<usize> = result.sequence[p3..].to_vec();
result.sequence = [seg1, seg3, seg2, seg4].concat();
result.objective = f64::NEG_INFINITY;
result
}
}
impl SaSolution for PermutationSolution {
fn objective(&self) -> f64 {
self.objective
}
fn set_objective(&mut self, value: f64) {
self.objective = value;
}
}
#[cfg(test)]
mod tests {
use super::*;
struct SimpleMaxProblem {
size: usize,
}
impl SaProblem for SimpleMaxProblem {
type Solution = PermutationSolution;
fn initial_solution<R: Rng>(&self, rng: &mut R) -> Self::Solution {
PermutationSolution::random(self.size, 1, rng)
}
fn neighbor<R: Rng>(
&self,
solution: &Self::Solution,
operator: NeighborhoodOperator,
rng: &mut R,
) -> Self::Solution {
match operator {
NeighborhoodOperator::Swap => solution.apply_swap(rng),
NeighborhoodOperator::Relocate => solution.apply_relocate(rng),
NeighborhoodOperator::Inversion => solution.apply_inversion(rng),
NeighborhoodOperator::Rotation => solution.apply_rotation(rng),
NeighborhoodOperator::Chain => solution.apply_chain(rng),
}
}
fn evaluate(&self, solution: &mut Self::Solution) {
let mut inversions = 0i64;
for i in 0..solution.sequence.len() {
for j in (i + 1)..solution.sequence.len() {
if solution.sequence[i] > solution.sequence[j] {
inversions += 1;
}
}
}
solution.set_objective(-inversions as f64);
}
}
#[test]
fn test_sa_basic() {
let config = SaConfig::default()
.with_initial_temp(100.0)
.with_final_temp(0.1)
.with_cooling_rate(0.9)
.with_iterations_per_temp(50)
.with_max_iterations(5000);
let problem = SimpleMaxProblem { size: 10 };
let runner = SaRunner::new(config, problem);
let result = runner.run();
assert!(result.best.objective() > -20.0);
assert!(result.iterations > 0);
}
#[test]
fn test_cooling_schedules() {
let problem = SimpleMaxProblem { size: 5 };
for schedule in [
CoolingSchedule::Geometric,
CoolingSchedule::Linear,
CoolingSchedule::Adaptive,
CoolingSchedule::LundyMees,
] {
let config = SaConfig::default()
.with_cooling_schedule(schedule)
.with_max_iterations(1000);
let runner = SaRunner::new(config, problem.clone());
let result = runner.run();
assert!(result.iterations > 0);
}
}
#[test]
fn test_neighborhood_operators() {
let mut rng = rand::rng();
let solution = PermutationSolution::random(10, 4, &mut rng);
let swap = solution.apply_swap(&mut rng);
let relocate = solution.apply_relocate(&mut rng);
let inversion = solution.apply_inversion(&mut rng);
let rotation = solution.apply_rotation(&mut rng);
let chain = solution.apply_chain(&mut rng);
for sol in [&swap, &relocate, &inversion, &rotation, &chain] {
let mut sorted = sol.sequence.clone();
sorted.sort();
assert_eq!(sorted, (0..10).collect::<Vec<_>>());
}
}
#[test]
fn test_reheating() {
let config = SaConfig::default()
.with_initial_temp(10.0)
.with_final_temp(0.1)
.with_max_iterations(500)
.with_reheating(50, 1.5);
let problem = SimpleMaxProblem { size: 8 };
let runner = SaRunner::new(config, problem);
let result = runner.run();
assert!(result.iterations > 0);
}
impl Clone for SimpleMaxProblem {
fn clone(&self) -> Self {
Self { size: self.size }
}
}
}