use scirs2_core::ndarray::{Array, Ix2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::rngs::StdRng;
use std::collections::HashMap;
use quantrs2_anneal::{
simulator::{AnnealingParams, ClassicalAnnealingSimulator, TemperatureSchedule},
QuboModel,
};
use super::{SampleResult, Sampler, SamplerResult};
#[cfg(feature = "parallel")]
use scirs2_core::parallel_ops::*;
#[derive(Clone)]
pub struct SASampler {
seed: Option<u64>,
params: AnnealingParams,
}
impl SASampler {
#[must_use]
pub fn new(seed: Option<u64>) -> Self {
let mut params = AnnealingParams::default();
if let Some(seed) = seed {
params.seed = Some(seed);
}
Self { seed, params }
}
#[must_use]
pub const fn with_params(seed: Option<u64>, params: AnnealingParams) -> Self {
let mut params = params;
if let Some(seed) = seed {
params.seed = Some(seed);
}
Self { seed, params }
}
pub fn with_beta_range(mut self, beta_min: f64, beta_max: f64) -> Self {
self.params.initial_temperature = 1.0 / beta_max;
self.params.temperature_schedule = TemperatureSchedule::Exponential(beta_min / beta_max);
self
}
pub const fn with_sweeps(mut self, sweeps: usize) -> Self {
self.params.num_sweeps = sweeps;
self
}
fn run_generic<D>(
&self,
matrix_or_tensor: &Array<f64, D>,
var_map: &HashMap<String, usize>,
shots: usize,
) -> SamplerResult<Vec<SampleResult>>
where
D: scirs2_core::ndarray::Dimension + 'static,
{
let shots = std::cmp::max(shots, 1);
let n_vars = var_map.len();
let idx_to_var: HashMap<usize, String> = var_map
.iter()
.map(|(var, &idx)| (idx, var.clone()))
.collect();
if matrix_or_tensor.ndim() == 2 {
let mut qubo = QuboModel::new(n_vars);
for i in 0..n_vars {
let diag_val = match matrix_or_tensor.ndim() {
2 => {
let matrix = matrix_or_tensor
.to_owned()
.into_dimensionality::<Ix2>()
.ok();
matrix.map_or(0.0, |m| m[[i, i]])
}
_ => 0.0, };
if diag_val != 0.0 {
qubo.set_linear(i, diag_val)?;
}
for j in (i + 1)..n_vars {
let quad_val = match matrix_or_tensor.ndim() {
2 => {
let matrix = matrix_or_tensor
.to_owned()
.into_dimensionality::<Ix2>()
.ok();
matrix.map_or(0.0, |m| m[[i, j]])
}
_ => 0.0, };
if quad_val != 0.0 {
qubo.set_quadratic(i, j, quad_val)?;
}
}
}
let params = self.params.clone();
let simulator = ClassicalAnnealingSimulator::new(params)?;
let (ising_model, _) = qubo.to_ising();
let annealing_result = simulator.solve(&ising_model)?;
let mut results = Vec::new();
let binary_vars: Vec<bool> = annealing_result
.best_spins
.iter()
.map(|&spin| spin > 0)
.collect();
let assignments: HashMap<String, bool> = binary_vars
.iter()
.enumerate()
.filter_map(|(idx, &value)| {
idx_to_var
.get(&idx)
.map(|var_name| (var_name.clone(), value))
})
.collect();
let result = SampleResult {
assignments,
energy: annealing_result.best_energy,
occurrences: 1,
};
results.push(result);
return Ok(results);
}
self.run_hobo_tensor(matrix_or_tensor, var_map, shots)
}
fn run_hobo_tensor<D>(
&self,
tensor: &Array<f64, D>,
var_map: &HashMap<String, usize>,
shots: usize,
) -> SamplerResult<Vec<SampleResult>>
where
D: scirs2_core::ndarray::Dimension + 'static,
{
let n_vars = var_map.len();
let idx_to_var: HashMap<usize, String> = var_map
.iter()
.map(|(var, &idx)| (idx, var.clone()))
.collect();
let mut rng = if let Some(seed) = self.seed {
StdRng::seed_from_u64(seed)
} else {
let seed: u64 = thread_rng().random();
StdRng::seed_from_u64(seed)
};
let tensor_dyn: scirs2_core::ndarray::ArrayD<f64> = tensor.to_owned().into_dyn();
let mut solution_counts: HashMap<Vec<bool>, (f64, usize)> = HashMap::new();
#[cfg(feature = "parallel")]
let num_threads = scirs2_core::parallel_ops::current_num_threads();
#[cfg(not(feature = "parallel"))]
let num_threads = 1;
let shots_per_thread = shots / num_threads + usize::from(shots % num_threads > 0);
let total_runs = shots_per_thread * num_threads;
let initial_temp = 10.0;
let final_temp = 0.1;
let sweeps = 1000;
let evaluate_energy = |state: &[bool]| -> f64 {
super::energy::hobo_energy_full_dispatch(state, &tensor_dyn)
};
#[allow(unused_assignments)]
let mut all_solutions = Vec::with_capacity(total_runs);
#[cfg(feature = "parallel")]
{
let seeds: Vec<u64> = (0..total_runs)
.map(|i| match self.seed {
Some(seed) => seed.wrapping_add(i as u64),
None => thread_rng().random(),
})
.collect();
all_solutions = seeds
.into_par_iter()
.map(|seed| {
let mut thread_rng = StdRng::seed_from_u64(seed);
let mut state = vec![false; n_vars];
for bit in &mut state {
*bit = thread_rng.random_bool(0.5);
}
let mut energy = evaluate_energy(&state);
let mut best_state = state.clone();
let mut best_energy = energy;
for sweep in 0..sweeps {
let temp = initial_temp
* f64::powf(final_temp / initial_temp, sweep as f64 / sweeps as f64);
for _ in 0..n_vars {
let idx = thread_rng.random_range(0..n_vars);
state[idx] = !state[idx];
let new_energy = evaluate_energy(&state);
let delta_e = new_energy - energy;
let accept = delta_e <= 0.0
|| thread_rng.random_range(0.0..1.0) < (-delta_e / temp).exp();
if accept {
energy = new_energy;
if energy < best_energy {
best_energy = energy;
best_state = state.clone();
}
} else {
state[idx] = !state[idx];
}
}
}
(best_state, best_energy)
})
.collect();
}
#[cfg(not(feature = "parallel"))]
{
for _ in 0..total_runs {
let mut state = vec![false; n_vars];
for bit in &mut state {
*bit = rng.random_bool(0.5);
}
let mut energy = evaluate_energy(&state);
let mut best_state = state.clone();
let mut best_energy = energy;
for sweep in 0..sweeps {
let temp = initial_temp
* f64::powf(final_temp / initial_temp, sweep as f64 / sweeps as f64);
for _ in 0..n_vars {
let mut idx = rng.random_range(0..n_vars);
state[idx] = !state[idx];
let new_energy = evaluate_energy(&state);
let delta_e = new_energy - energy;
let accept = delta_e <= 0.0
|| rng.random_range(0.0..1.0) < f64::exp(-delta_e / temp);
if accept {
energy = new_energy;
if energy < best_energy {
best_energy = energy;
best_state = state.clone();
}
} else {
state[idx] = !state[idx];
}
}
}
all_solutions.push((best_state, best_energy));
}
}
for (state, energy) in all_solutions {
let entry = solution_counts.entry(state).or_insert((energy, 0));
entry.1 += 1;
}
let mut results: Vec<SampleResult> = solution_counts
.into_iter()
.map(|(state, (energy, count))| {
let assignments: HashMap<String, bool> = state
.iter()
.enumerate()
.filter_map(|(idx, &value)| {
idx_to_var
.get(&idx)
.map(|var_name| (var_name.clone(), value))
})
.collect();
SampleResult {
assignments,
energy,
occurrences: count,
}
})
.collect();
results.sort_by(|a, b| {
a.energy
.partial_cmp(&b.energy)
.unwrap_or(std::cmp::Ordering::Equal)
});
if results.len() > shots {
results.truncate(shots);
}
Ok(results)
}
}
impl Sampler for SASampler {
fn run_qubo(
&self,
qubo: &(
Array<f64, scirs2_core::ndarray::Ix2>,
HashMap<String, usize>,
),
shots: usize,
) -> SamplerResult<Vec<SampleResult>> {
self.run_generic(&qubo.0, &qubo.1, shots)
}
fn run_hobo(
&self,
hobo: &(
Array<f64, scirs2_core::ndarray::IxDyn>,
HashMap<String, usize>,
),
shots: usize,
) -> SamplerResult<Vec<SampleResult>> {
self.run_generic(&hobo.0, &hobo.1, shots)
}
}