use crate::ising::{IsingError, IsingModel, IsingResult};
use crate::simulator::{AnnealingError, AnnealingResult, AnnealingSolution};
use scirs2_core::random::ChaCha8Rng;
use scirs2_core::random::{thread_rng, Rng, RngExt, SeedableRng};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ReverseAnnealingSchedule {
pub s_start: f64,
pub s_target: f64,
pub pause_duration: f64,
pub quench_rate: f64,
pub hold_duration: f64,
}
impl Default for ReverseAnnealingSchedule {
fn default() -> Self {
Self {
s_start: 1.0, s_target: 0.45, pause_duration: 0.1, quench_rate: 1.0, hold_duration: 0.0, }
}
}
impl ReverseAnnealingSchedule {
pub fn new(s_target: f64, pause_duration: f64) -> AnnealingResult<Self> {
if !(0.0..=1.0).contains(&s_target) {
return Err(AnnealingError::InvalidSchedule(format!(
"s_target must be in [0,1], got {s_target}"
)));
}
if !(0.0..=1.0).contains(&pause_duration) {
return Err(AnnealingError::InvalidSchedule(format!(
"pause_duration must be in [0,1], got {pause_duration}"
)));
}
Ok(Self {
s_start: 1.0,
s_target,
pause_duration,
quench_rate: 1.0,
hold_duration: 0.0,
})
}
#[must_use]
pub fn s_of_t(&self, t_normalized: f64) -> f64 {
let t1 = (1.0 - self.pause_duration - self.hold_duration) / 2.0;
let t2 = t1 + self.pause_duration;
let t3 = 1.0 - self.hold_duration;
if t_normalized <= t1 {
(self.s_target - self.s_start).mul_add(t_normalized / t1, self.s_start)
} else if t_normalized <= t2 {
self.s_target
} else if t_normalized <= t3 {
let forward_progress = (t_normalized - t2) / (t3 - t2);
((1.0 - self.s_target) * forward_progress).mul_add(self.quench_rate, self.s_target)
} else {
1.0
}
}
#[must_use]
pub fn transverse_field(&self, s: f64) -> f64 {
let a_max = 2.0; a_max * (1.0 - s)
}
#[must_use]
pub fn problem_strength(&self, s: f64) -> f64 {
let b_max = 1.0; b_max * s
}
}
#[derive(Debug, Clone)]
pub struct ReverseAnnealingParams {
pub schedule: ReverseAnnealingSchedule,
pub initial_state: Vec<i8>,
pub num_sweeps: usize,
pub num_repetitions: usize,
pub seed: Option<u64>,
pub reinitialize_fraction: f64,
pub local_search_radius: Option<usize>,
}
impl ReverseAnnealingParams {
#[must_use]
pub fn new(initial_state: Vec<i8>) -> Self {
Self {
schedule: ReverseAnnealingSchedule::default(),
initial_state,
num_sweeps: 1000,
num_repetitions: 10,
seed: None,
reinitialize_fraction: 0.0,
local_search_radius: None,
}
}
#[must_use]
pub const fn with_local_search(mut self, radius: usize) -> Self {
self.local_search_radius = Some(radius);
self
}
#[must_use]
pub const fn with_reinitialization(mut self, fraction: f64) -> Self {
self.reinitialize_fraction = fraction.clamp(0.0, 1.0);
self
}
}
pub struct ReverseAnnealingSimulator {
params: ReverseAnnealingParams,
rng: ChaCha8Rng,
}
impl ReverseAnnealingSimulator {
pub fn new(params: ReverseAnnealingParams) -> AnnealingResult<Self> {
let rng = match params.seed {
Some(seed) => ChaCha8Rng::seed_from_u64(seed),
None => ChaCha8Rng::seed_from_u64(thread_rng().random()),
};
Ok(Self { params, rng })
}
pub fn solve(&mut self, model: &IsingModel) -> AnnealingResult<AnnealingSolution> {
let start_time = Instant::now();
let num_qubits = model.num_qubits;
if self.params.initial_state.len() != num_qubits {
return Err(AnnealingError::InvalidParameter(format!(
"Initial state length {} doesn't match model size {}",
self.params.initial_state.len(),
num_qubits
)));
}
let mut best_solution = self.params.initial_state.clone();
let mut best_energy = model
.energy(&best_solution)
.map_err(AnnealingError::IsingError)?;
let mut all_solutions = Vec::new();
let mut all_energies = Vec::new();
for rep in 0..self.params.num_repetitions {
let initial_state = self.params.initial_state.clone();
let mut state = self.prepare_initial_state(&initial_state);
let solution = self.run_reverse_annealing(model, &mut state)?;
let energy = model
.energy(&solution)
.map_err(AnnealingError::IsingError)?;
all_solutions.push(solution.clone());
all_energies.push(energy);
if energy < best_energy {
best_energy = energy;
best_solution = solution;
}
}
let elapsed = start_time.elapsed();
Ok(AnnealingSolution {
best_spins: best_solution,
best_energy,
repetitions: self.params.num_repetitions,
total_sweeps: self.params.num_sweeps * self.params.num_repetitions,
runtime: elapsed,
info: format!(
"Reverse annealing with {} repetitions, {} sweeps each, s_target={}",
self.params.num_repetitions, self.params.num_sweeps, self.params.schedule.s_target
),
})
}
fn prepare_initial_state(&mut self, base_state: &[i8]) -> Vec<i8> {
let mut state = base_state.to_vec();
if self.params.reinitialize_fraction > 0.0 {
let num_to_reinit = (state.len() as f64 * self.params.reinitialize_fraction) as usize;
for _ in 0..num_to_reinit {
let idx = self.rng.random_range(0..state.len());
state[idx] = if self.rng.random_bool(0.5) { 1 } else { -1 };
}
}
if let Some(radius) = self.params.local_search_radius {
self.apply_local_search_mask(&mut state, radius);
}
state
}
fn apply_local_search_mask(&mut self, state: &[i8], radius: usize) {
let num_centers = (state.len() as f64 * 0.1).max(1.0) as usize;
let mut can_update = vec![false; state.len()];
for _ in 0..num_centers {
let center = self.rng.random_range(0..state.len());
for i in 0..state.len() {
if (i as i32 - center as i32).abs() <= radius as i32 {
can_update[i] = true;
}
}
}
}
#[must_use]
fn run_reverse_annealing(
&mut self,
model: &IsingModel,
state: &mut Vec<i8>,
) -> AnnealingResult<Vec<i8>> {
let schedule = &self.params.schedule;
for sweep in 0..self.params.num_sweeps {
let t_norm = sweep as f64 / self.params.num_sweeps as f64;
let s = schedule.s_of_t(t_norm);
let transverse_field = schedule.transverse_field(s);
let problem_strength = schedule.problem_strength(s);
for _ in 0..model.num_qubits {
let i = self.rng.random_range(0..model.num_qubits);
let mut h_local = 0.0;
if let Ok(bias) = model.get_bias(i) {
h_local += bias * problem_strength;
}
for j in 0..model.num_qubits {
if i != j {
if let Ok(coupling) = model.get_coupling(i, j) {
h_local += coupling * f64::from(state[j]) * problem_strength;
}
}
}
let quantum_term = transverse_field;
let delta_e = 2.0 * f64::from(state[i]) * h_local;
let effective_temp = quantum_term.mul_add(0.5, 0.1); let accept_prob = (-delta_e / effective_temp).exp().min(1.0);
if self.rng.random_bool(accept_prob) {
state[i] *= -1;
}
}
}
Ok(state.clone())
}
}
pub struct ReverseAnnealingScheduleBuilder {
s_target: f64,
pause_duration: f64,
quench_rate: f64,
hold_duration: f64,
}
impl ReverseAnnealingScheduleBuilder {
#[must_use]
pub const fn new() -> Self {
Self {
s_target: 0.45,
pause_duration: 0.1,
quench_rate: 1.0,
hold_duration: 0.0,
}
}
#[must_use]
pub const fn s_target(mut self, s: f64) -> Self {
self.s_target = s;
self
}
#[must_use]
pub const fn pause_duration(mut self, duration: f64) -> Self {
self.pause_duration = duration;
self
}
#[must_use]
pub const fn quench_rate(mut self, rate: f64) -> Self {
self.quench_rate = rate;
self
}
#[must_use]
pub const fn hold_duration(mut self, duration: f64) -> Self {
self.hold_duration = duration;
self
}
pub fn build(self) -> AnnealingResult<ReverseAnnealingSchedule> {
if !(0.0..=1.0).contains(&self.s_target) {
return Err(AnnealingError::InvalidSchedule(
"s_target must be in [0,1]".to_string(),
));
}
Ok(ReverseAnnealingSchedule {
s_start: 1.0,
s_target: self.s_target,
pause_duration: self.pause_duration,
quench_rate: self.quench_rate,
hold_duration: self.hold_duration,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reverse_schedule_creation() {
let schedule =
ReverseAnnealingSchedule::new(0.45, 0.1).expect("Schedule creation should succeed");
assert_eq!(schedule.s_start, 1.0);
assert_eq!(schedule.s_target, 0.45);
}
#[test]
fn test_schedule_s_of_t() {
let schedule = ReverseAnnealingSchedule::default();
assert!((schedule.s_of_t(0.0) - 1.0).abs() < 1e-6);
let mid = 0.45 / 2.0;
let s_mid = schedule.s_of_t(mid);
assert!(s_mid > schedule.s_target && s_mid < schedule.s_start);
assert!((schedule.s_of_t(1.0) - 1.0).abs() < 1e-6);
}
#[test]
fn test_reverse_annealing_params() {
let initial_state = vec![1, -1, 1, -1];
let params = ReverseAnnealingParams::new(initial_state.clone())
.with_local_search(2)
.with_reinitialization(0.25);
assert_eq!(params.initial_state, initial_state);
assert_eq!(params.local_search_radius, Some(2));
assert_eq!(params.reinitialize_fraction, 0.25);
}
}