use crate::sampler::{SampleResult, Sampler};
use quantrs2_anneal::{IsingModel, QuboModel};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::parallel_ops;
use scirs2_core::random::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LocalSearchStrategy {
SteepestDescent,
FirstImprovement,
RandomDescent,
TabuSearch,
VariableNeighborhoodDescent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RepairStrategy {
Greedy,
Random,
Weighted,
Iterative,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum FixingCriterion {
HighFrequency { threshold: f64 },
LowVariance { threshold: f64 },
StrongCorrelation { threshold: f64 },
ReducedCost { threshold: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridConfig {
pub local_search: LocalSearchStrategy,
pub max_local_iterations: usize,
pub repair_strategy: RepairStrategy,
pub enable_repair: bool,
pub fixing_criterion: Option<FixingCriterion>,
pub fixing_percentage: f64,
pub max_qc_iterations: usize,
pub convergence_tolerance: f64,
pub enable_gradient: bool,
pub learning_rate: f64,
pub parallel: bool,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
local_search: LocalSearchStrategy::SteepestDescent,
max_local_iterations: 1000,
repair_strategy: RepairStrategy::Greedy,
enable_repair: true,
fixing_criterion: Some(FixingCriterion::HighFrequency { threshold: 0.8 }),
fixing_percentage: 0.3,
max_qc_iterations: 10,
convergence_tolerance: 1e-6,
enable_gradient: false,
learning_rate: 0.01,
parallel: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefinedSolution {
pub assignments: HashMap<String, bool>,
pub energy: f64,
pub violations: Vec<ConstraintViolation>,
pub iterations: usize,
pub improvement: f64,
pub is_feasible: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstraintViolation {
pub constraint_id: String,
pub magnitude: f64,
pub variables: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FixedVariable {
pub name: String,
pub value: bool,
pub confidence: f64,
pub reason: String,
}
pub struct HybridOptimizer {
config: HybridConfig,
rng: Box<dyn RngCore>,
tabu_list: HashSet<u64>,
fixed_variables: HashMap<String, bool>,
history: Vec<f64>,
}
impl HybridOptimizer {
pub fn new(config: HybridConfig) -> Self {
Self {
config,
rng: Box::new(thread_rng()),
tabu_list: HashSet::new(),
fixed_variables: HashMap::new(),
history: Vec::new(),
}
}
pub fn refine_solution(
&mut self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<RefinedSolution, String> {
let initial_energy = self.compute_energy(solution, qubo_matrix);
let mut current_solution = solution.clone();
let mut current_energy = initial_energy;
let mut iterations = 0;
self.history.clear();
self.history.push(current_energy);
if self.config.enable_repair {
current_solution = self.repair_constraints(¤t_solution, qubo_matrix)?;
current_energy = self.compute_energy(¤t_solution, qubo_matrix);
self.history.push(current_energy);
}
for iter in 0..self.config.max_local_iterations {
iterations = iter + 1;
let (improved_solution, improved_energy) = match self.config.local_search {
LocalSearchStrategy::SteepestDescent => {
self.steepest_descent_step(¤t_solution, qubo_matrix)
}
LocalSearchStrategy::FirstImprovement => {
self.first_improvement_step(¤t_solution, qubo_matrix)
}
LocalSearchStrategy::RandomDescent => {
self.random_descent_step(¤t_solution, qubo_matrix)
}
LocalSearchStrategy::TabuSearch => {
self.tabu_search_step(¤t_solution, qubo_matrix)
}
LocalSearchStrategy::VariableNeighborhoodDescent => {
self.vnd_step(¤t_solution, qubo_matrix)
}
}?;
if improved_energy < current_energy - self.config.convergence_tolerance {
current_solution = improved_solution;
current_energy = improved_energy;
self.history.push(current_energy);
} else {
break;
}
if self.has_converged() {
break;
}
}
let violations = self.compute_violations(¤t_solution);
let is_feasible = violations.is_empty();
Ok(RefinedSolution {
assignments: current_solution,
energy: current_energy,
violations,
iterations,
improvement: initial_energy - current_energy,
is_feasible,
})
}
fn steepest_descent_step(
&self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<(HashMap<String, bool>, f64), String> {
let current_energy = self.compute_energy(solution, qubo_matrix);
let mut best_solution = solution.clone();
let mut best_energy = current_energy;
let mut improved = false;
for (var_name, ¤t_value) in solution {
if self.fixed_variables.contains_key(var_name) {
continue;
}
let mut neighbor = solution.clone();
neighbor.insert(var_name.clone(), !current_value);
let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
if neighbor_energy < best_energy {
best_solution = neighbor;
best_energy = neighbor_energy;
improved = true;
}
}
if improved {
Ok((best_solution, best_energy))
} else {
Ok((solution.clone(), current_energy))
}
}
fn first_improvement_step(
&mut self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<(HashMap<String, bool>, f64), String> {
let current_energy = self.compute_energy(solution, qubo_matrix);
let mut var_names: Vec<_> = solution.keys().cloned().collect();
var_names.shuffle(&mut *self.rng);
for var_name in var_names {
if self.fixed_variables.contains_key(&var_name) {
continue;
}
let current_value = solution[&var_name];
let mut neighbor = solution.clone();
neighbor.insert(var_name, !current_value);
let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
if neighbor_energy < current_energy {
return Ok((neighbor, neighbor_energy));
}
}
Ok((solution.clone(), current_energy))
}
fn random_descent_step(
&mut self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<(HashMap<String, bool>, f64), String> {
let current_energy = self.compute_energy(solution, qubo_matrix);
let var_names: Vec<_> = solution
.keys()
.filter(|k| !self.fixed_variables.contains_key(*k))
.cloned()
.collect();
if var_names.is_empty() {
return Ok((solution.clone(), current_energy));
}
let var_name = &var_names[self.rng.random_range(0..var_names.len())];
let current_value = solution[var_name];
let mut neighbor = solution.clone();
neighbor.insert(var_name.clone(), !current_value);
let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
if neighbor_energy < current_energy {
Ok((neighbor, neighbor_energy))
} else {
Ok((solution.clone(), current_energy))
}
}
fn tabu_search_step(
&mut self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<(HashMap<String, bool>, f64), String> {
let current_energy = self.compute_energy(solution, qubo_matrix);
let mut best_solution = solution.clone();
let mut best_energy = current_energy;
for (var_name, ¤t_value) in solution {
if self.fixed_variables.contains_key(var_name) {
continue;
}
let mut neighbor = solution.clone();
neighbor.insert(var_name.clone(), !current_value);
let move_hash = self.hash_solution(&neighbor);
if self.tabu_list.contains(&move_hash) {
continue;
}
let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
if neighbor_energy < best_energy {
best_solution = neighbor;
best_energy = neighbor_energy;
}
}
let move_hash = self.hash_solution(&best_solution);
self.tabu_list.insert(move_hash);
if self.tabu_list.len() > 100 {
self.tabu_list.clear();
}
Ok((best_solution, best_energy))
}
fn vnd_step(
&mut self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<(HashMap<String, bool>, f64), String> {
let mut current_solution = solution.clone();
let mut current_energy = self.compute_energy(solution, qubo_matrix);
let (sol1, e1) = self.steepest_descent_step(¤t_solution, qubo_matrix)?;
if e1 < current_energy {
current_solution = sol1;
current_energy = e1;
}
let (sol2, e2) = self.two_variable_swap(¤t_solution, qubo_matrix)?;
if e2 < current_energy {
current_solution = sol2;
current_energy = e2;
}
Ok((current_solution, current_energy))
}
fn two_variable_swap(
&self,
solution: &HashMap<String, bool>,
qubo_matrix: &Array2<f64>,
) -> Result<(HashMap<String, bool>, f64), String> {
let current_energy = self.compute_energy(solution, qubo_matrix);
let mut best_solution = solution.clone();
let mut best_energy = current_energy;
let var_names: Vec<_> = solution
.keys()
.filter(|k| !self.fixed_variables.contains_key(*k))
.cloned()
.collect();
for i in 0..var_names.len() {
for j in (i + 1)..var_names.len() {
let mut neighbor = solution.clone();
let val_i = solution[&var_names[i]];
let val_j = solution[&var_names[j]];
neighbor.insert(var_names[i].clone(), !val_i);
neighbor.insert(var_names[j].clone(), !val_j);
let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
if neighbor_energy < best_energy {
best_solution = neighbor;
best_energy = neighbor_energy;
}
}
}
Ok((best_solution, best_energy))
}
fn repair_constraints(
&self,
solution: &HashMap<String, bool>,
_qubo_matrix: &Array2<f64>,
) -> Result<HashMap<String, bool>, String> {
let mut repaired = solution.clone();
match self.config.repair_strategy {
RepairStrategy::Greedy => {
}
RepairStrategy::Random => {
}
RepairStrategy::Weighted => {
}
RepairStrategy::Iterative => {
}
}
Ok(repaired)
}
pub fn fix_variables(
&mut self,
samples: &[HashMap<String, bool>],
criterion: FixingCriterion,
) -> Result<Vec<FixedVariable>, String> {
if samples.is_empty() {
return Ok(Vec::new());
}
let mut fixed = Vec::new();
match criterion {
FixingCriterion::HighFrequency { threshold } => {
let mut frequencies: HashMap<String, (usize, usize)> = HashMap::new();
for sample in samples {
for (var, &value) in sample {
let entry = frequencies.entry(var.clone()).or_insert((0, 0));
if value {
entry.0 += 1;
} else {
entry.1 += 1;
}
}
}
for (var, (true_count, false_count)) in frequencies {
let total = (true_count + false_count) as f64;
let true_freq = true_count as f64 / total;
let false_freq = false_count as f64 / total;
if true_freq >= threshold {
self.fixed_variables.insert(var.clone(), true);
fixed.push(FixedVariable {
name: var,
value: true,
confidence: true_freq,
reason: format!("High frequency ({true_freq})"),
});
} else if false_freq >= threshold {
self.fixed_variables.insert(var.clone(), false);
fixed.push(FixedVariable {
name: var,
value: false,
confidence: false_freq,
reason: format!("High frequency ({false_freq})"),
});
}
}
}
FixingCriterion::LowVariance { threshold } => {
}
FixingCriterion::StrongCorrelation { threshold } => {
}
FixingCriterion::ReducedCost { threshold } => {
}
}
Ok(fixed)
}
pub fn unfix_all(&mut self) {
self.fixed_variables.clear();
}
pub fn iterative_refinement<S: Sampler>(
&mut self,
sampler: &S,
qubo_matrix: &Array2<f64>,
num_samples: usize,
) -> Result<Vec<RefinedSolution>, String> {
let mut refined_solutions = Vec::new();
let mut best_energy = f64::INFINITY;
for iteration in 0..self.config.max_qc_iterations {
println!(
"Quantum-Classical iteration {}/{}",
iteration + 1,
self.config.max_qc_iterations
);
let mut samples = Vec::new();
for _ in 0..num_samples {
let mut sample = HashMap::new();
for i in 0..qubo_matrix.nrows() {
sample.insert(format!("x{i}"), self.rng.random::<bool>());
}
samples.push(sample);
}
if let Some(criterion) = self.config.fixing_criterion {
let fixed = self.fix_variables(&samples, criterion)?;
println!("Fixed {} variables", fixed.len());
}
for sample in samples {
let refined = self.refine_solution(&sample, qubo_matrix)?;
if refined.energy < best_energy {
best_energy = refined.energy;
println!("New best energy: {best_energy}");
}
refined_solutions.push(refined);
}
if iteration > 0 && self.has_converged() {
println!("Converged after {} iterations", iteration + 1);
break;
}
}
Ok(refined_solutions)
}
fn compute_energy(&self, solution: &HashMap<String, bool>, qubo_matrix: &Array2<f64>) -> f64 {
let n = qubo_matrix.nrows();
let mut energy = 0.0;
for i in 0..n {
for j in 0..n {
let x_i = if solution.get(&format!("x{i}")).copied().unwrap_or(false) {
1.0
} else {
0.0
};
let x_j = if solution.get(&format!("x{j}")).copied().unwrap_or(false) {
1.0
} else {
0.0
};
energy += qubo_matrix[[i, j]] * x_i * x_j;
}
}
energy
}
const fn compute_violations(
&self,
_solution: &HashMap<String, bool>,
) -> Vec<ConstraintViolation> {
Vec::new()
}
fn has_converged(&self) -> bool {
if self.history.len() < 3 {
return false;
}
let recent = &self.history[self.history.len() - 3..];
let max_change = recent
.windows(2)
.map(|w| (w[0] - w[1]).abs())
.fold(0.0, f64::max);
max_change < self.config.convergence_tolerance
}
fn hash_solution(&self, solution: &HashMap<String, bool>) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let mut sorted: Vec<_> = solution.iter().collect();
sorted.sort_by_key(|(k, _)| k.as_str());
for (k, v) in sorted {
k.hash(&mut hasher);
v.hash(&mut hasher);
}
hasher.finish()
}
pub fn get_history(&self) -> &[f64] {
&self.history
}
pub const fn get_fixed_variables(&self) -> &HashMap<String, bool> {
&self.fixed_variables
}
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
#[test]
fn test_hybrid_optimizer_creation() {
let config = HybridConfig::default();
let optimizer = HybridOptimizer::new(config);
assert_eq!(optimizer.fixed_variables.len(), 0);
assert_eq!(optimizer.history.len(), 0);
}
#[test]
fn test_energy_computation() {
let config = HybridConfig::default();
let optimizer = HybridOptimizer::new(config);
let qubo = Array2::from_shape_fn((2, 2), |(i, j)| if i == j { -1.0 } else { 2.0 });
let solution = HashMap::from([("x0".to_string(), true), ("x1".to_string(), false)]);
let energy = optimizer.compute_energy(&solution, &qubo);
assert_eq!(energy, -1.0); }
#[test]
fn test_local_search_refinement() {
let config = HybridConfig {
max_local_iterations: 10,
..Default::default()
};
let mut optimizer = HybridOptimizer::new(config);
let qubo = Array2::from_shape_fn((3, 3), |(i, j)| if i == j { -1.0 } else { 0.5 });
let initial_solution = HashMap::from([
("x0".to_string(), false),
("x1".to_string(), false),
("x2".to_string(), false),
]);
let refined = optimizer
.refine_solution(&initial_solution, &qubo)
.expect("refinement should succeed");
assert!(refined.improvement >= 0.0);
assert!(refined.energy <= optimizer.compute_energy(&initial_solution, &qubo));
}
#[test]
fn test_variable_fixing() {
let config = HybridConfig::default();
let mut optimizer = HybridOptimizer::new(config);
let samples = vec![
HashMap::from([("x0".to_string(), true), ("x1".to_string(), false)]),
HashMap::from([("x0".to_string(), true), ("x1".to_string(), true)]),
HashMap::from([("x0".to_string(), true), ("x1".to_string(), false)]),
];
let criterion = FixingCriterion::HighFrequency { threshold: 0.8 };
let fixed = optimizer
.fix_variables(&samples, criterion)
.expect("variable fixing should succeed");
assert!(!fixed.is_empty());
assert!(fixed.iter().any(|f| f.name == "x0" && f.value));
}
#[test]
fn test_convergence_detection() {
let mut config = HybridConfig::default();
config.convergence_tolerance = 0.001; let mut optimizer = HybridOptimizer::new(config);
optimizer.history = vec![10.0, 10.00001, 10.00002];
assert!(optimizer.has_converged());
optimizer.history = vec![10.0, 9.0, 8.0];
assert!(!optimizer.has_converged());
}
}