1use crate::sampler::{SampleResult, Sampler};
42use quantrs2_anneal::{IsingModel, QuboModel};
43use scirs2_core::ndarray::{Array1, Array2};
44use scirs2_core::parallel_ops;
45use scirs2_core::random::prelude::*;
46use serde::{Deserialize, Serialize};
47use std::collections::{HashMap, HashSet};
48use std::fmt;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52pub enum LocalSearchStrategy {
53 SteepestDescent,
55 FirstImprovement,
57 RandomDescent,
59 TabuSearch,
61 VariableNeighborhoodDescent,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum RepairStrategy {
68 Greedy,
70 Random,
72 Weighted,
74 Iterative,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
80pub enum FixingCriterion {
81 HighFrequency { threshold: f64 },
83 LowVariance { threshold: f64 },
85 StrongCorrelation { threshold: f64 },
87 ReducedCost { threshold: f64 },
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct HybridConfig {
94 pub local_search: LocalSearchStrategy,
96 pub max_local_iterations: usize,
98 pub repair_strategy: RepairStrategy,
100 pub enable_repair: bool,
102 pub fixing_criterion: Option<FixingCriterion>,
104 pub fixing_percentage: f64,
106 pub max_qc_iterations: usize,
108 pub convergence_tolerance: f64,
110 pub enable_gradient: bool,
112 pub learning_rate: f64,
114 pub parallel: bool,
116}
117
118impl Default for HybridConfig {
119 fn default() -> Self {
120 Self {
121 local_search: LocalSearchStrategy::SteepestDescent,
122 max_local_iterations: 1000,
123 repair_strategy: RepairStrategy::Greedy,
124 enable_repair: true,
125 fixing_criterion: Some(FixingCriterion::HighFrequency { threshold: 0.8 }),
126 fixing_percentage: 0.3,
127 max_qc_iterations: 10,
128 convergence_tolerance: 1e-6,
129 enable_gradient: false,
130 learning_rate: 0.01,
131 parallel: true,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct RefinedSolution {
139 pub assignments: HashMap<String, bool>,
141 pub energy: f64,
143 pub violations: Vec<ConstraintViolation>,
145 pub iterations: usize,
147 pub improvement: f64,
149 pub is_feasible: bool,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct ConstraintViolation {
156 pub constraint_id: String,
158 pub magnitude: f64,
160 pub variables: Vec<String>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct FixedVariable {
167 pub name: String,
169 pub value: bool,
171 pub confidence: f64,
173 pub reason: String,
175}
176
177pub struct HybridOptimizer {
179 config: HybridConfig,
181 rng: Box<dyn RngCore>,
183 tabu_list: HashSet<u64>,
185 fixed_variables: HashMap<String, bool>,
187 history: Vec<f64>,
189}
190
191impl HybridOptimizer {
192 pub fn new(config: HybridConfig) -> Self {
194 Self {
195 config,
196 rng: Box::new(thread_rng()),
197 tabu_list: HashSet::new(),
198 fixed_variables: HashMap::new(),
199 history: Vec::new(),
200 }
201 }
202
203 pub fn refine_solution(
205 &mut self,
206 solution: &HashMap<String, bool>,
207 qubo_matrix: &Array2<f64>,
208 ) -> Result<RefinedSolution, String> {
209 let initial_energy = self.compute_energy(solution, qubo_matrix);
210 let mut current_solution = solution.clone();
211 let mut current_energy = initial_energy;
212 let mut iterations = 0;
213
214 self.history.clear();
215 self.history.push(current_energy);
216
217 if self.config.enable_repair {
219 current_solution = self.repair_constraints(¤t_solution, qubo_matrix)?;
220 current_energy = self.compute_energy(¤t_solution, qubo_matrix);
221 self.history.push(current_energy);
222 }
223
224 for iter in 0..self.config.max_local_iterations {
226 iterations = iter + 1;
227
228 let (improved_solution, improved_energy) = match self.config.local_search {
229 LocalSearchStrategy::SteepestDescent => {
230 self.steepest_descent_step(¤t_solution, qubo_matrix)
231 }
232 LocalSearchStrategy::FirstImprovement => {
233 self.first_improvement_step(¤t_solution, qubo_matrix)
234 }
235 LocalSearchStrategy::RandomDescent => {
236 self.random_descent_step(¤t_solution, qubo_matrix)
237 }
238 LocalSearchStrategy::TabuSearch => {
239 self.tabu_search_step(¤t_solution, qubo_matrix)
240 }
241 LocalSearchStrategy::VariableNeighborhoodDescent => {
242 self.vnd_step(¤t_solution, qubo_matrix)
243 }
244 }?;
245
246 if improved_energy < current_energy - self.config.convergence_tolerance {
248 current_solution = improved_solution;
249 current_energy = improved_energy;
250 self.history.push(current_energy);
251 } else {
252 break;
254 }
255
256 if self.has_converged() {
258 break;
259 }
260 }
261
262 let violations = self.compute_violations(¤t_solution);
264 let is_feasible = violations.is_empty();
265
266 Ok(RefinedSolution {
267 assignments: current_solution,
268 energy: current_energy,
269 violations,
270 iterations,
271 improvement: initial_energy - current_energy,
272 is_feasible,
273 })
274 }
275
276 fn steepest_descent_step(
278 &self,
279 solution: &HashMap<String, bool>,
280 qubo_matrix: &Array2<f64>,
281 ) -> Result<(HashMap<String, bool>, f64), String> {
282 let current_energy = self.compute_energy(solution, qubo_matrix);
283 let mut best_solution = solution.clone();
284 let mut best_energy = current_energy;
285 let mut improved = false;
286
287 for (var_name, ¤t_value) in solution {
289 if self.fixed_variables.contains_key(var_name) {
291 continue;
292 }
293
294 let mut neighbor = solution.clone();
295 neighbor.insert(var_name.clone(), !current_value);
296
297 let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
298
299 if neighbor_energy < best_energy {
300 best_solution = neighbor;
301 best_energy = neighbor_energy;
302 improved = true;
303 }
304 }
305
306 if improved {
307 Ok((best_solution, best_energy))
308 } else {
309 Ok((solution.clone(), current_energy))
310 }
311 }
312
313 fn first_improvement_step(
315 &mut self,
316 solution: &HashMap<String, bool>,
317 qubo_matrix: &Array2<f64>,
318 ) -> Result<(HashMap<String, bool>, f64), String> {
319 let current_energy = self.compute_energy(solution, qubo_matrix);
320
321 let mut var_names: Vec<_> = solution.keys().cloned().collect();
323 var_names.shuffle(&mut *self.rng);
324
325 for var_name in var_names {
326 if self.fixed_variables.contains_key(&var_name) {
328 continue;
329 }
330
331 let current_value = solution[&var_name];
332 let mut neighbor = solution.clone();
333 neighbor.insert(var_name, !current_value);
334
335 let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
336
337 if neighbor_energy < current_energy {
338 return Ok((neighbor, neighbor_energy));
339 }
340 }
341
342 Ok((solution.clone(), current_energy))
343 }
344
345 fn random_descent_step(
347 &mut self,
348 solution: &HashMap<String, bool>,
349 qubo_matrix: &Array2<f64>,
350 ) -> Result<(HashMap<String, bool>, f64), String> {
351 let current_energy = self.compute_energy(solution, qubo_matrix);
352
353 let var_names: Vec<_> = solution
355 .keys()
356 .filter(|k| !self.fixed_variables.contains_key(*k))
357 .cloned()
358 .collect();
359
360 if var_names.is_empty() {
361 return Ok((solution.clone(), current_energy));
362 }
363
364 let var_name = &var_names[self.rng.gen_range(0..var_names.len())];
365 let current_value = solution[var_name];
366
367 let mut neighbor = solution.clone();
368 neighbor.insert(var_name.clone(), !current_value);
369
370 let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
371
372 if neighbor_energy < current_energy {
373 Ok((neighbor, neighbor_energy))
374 } else {
375 Ok((solution.clone(), current_energy))
376 }
377 }
378
379 fn tabu_search_step(
381 &mut self,
382 solution: &HashMap<String, bool>,
383 qubo_matrix: &Array2<f64>,
384 ) -> Result<(HashMap<String, bool>, f64), String> {
385 let current_energy = self.compute_energy(solution, qubo_matrix);
386 let mut best_solution = solution.clone();
387 let mut best_energy = current_energy;
388
389 for (var_name, ¤t_value) in solution {
391 if self.fixed_variables.contains_key(var_name) {
392 continue;
393 }
394
395 let mut neighbor = solution.clone();
396 neighbor.insert(var_name.clone(), !current_value);
397
398 let move_hash = self.hash_solution(&neighbor);
400 if self.tabu_list.contains(&move_hash) {
401 continue;
402 }
403
404 let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
405
406 if neighbor_energy < best_energy {
407 best_solution = neighbor;
408 best_energy = neighbor_energy;
409 }
410 }
411
412 let move_hash = self.hash_solution(&best_solution);
414 self.tabu_list.insert(move_hash);
415
416 if self.tabu_list.len() > 100 {
418 self.tabu_list.clear();
419 }
420
421 Ok((best_solution, best_energy))
422 }
423
424 fn vnd_step(
426 &mut self,
427 solution: &HashMap<String, bool>,
428 qubo_matrix: &Array2<f64>,
429 ) -> Result<(HashMap<String, bool>, f64), String> {
430 let mut current_solution = solution.clone();
431 let mut current_energy = self.compute_energy(solution, qubo_matrix);
432
433 let (sol1, e1) = self.steepest_descent_step(¤t_solution, qubo_matrix)?;
435 if e1 < current_energy {
436 current_solution = sol1;
437 current_energy = e1;
438 }
439
440 let (sol2, e2) = self.two_variable_swap(¤t_solution, qubo_matrix)?;
442 if e2 < current_energy {
443 current_solution = sol2;
444 current_energy = e2;
445 }
446
447 Ok((current_solution, current_energy))
448 }
449
450 fn two_variable_swap(
452 &self,
453 solution: &HashMap<String, bool>,
454 qubo_matrix: &Array2<f64>,
455 ) -> Result<(HashMap<String, bool>, f64), String> {
456 let current_energy = self.compute_energy(solution, qubo_matrix);
457 let mut best_solution = solution.clone();
458 let mut best_energy = current_energy;
459
460 let var_names: Vec<_> = solution
461 .keys()
462 .filter(|k| !self.fixed_variables.contains_key(*k))
463 .cloned()
464 .collect();
465
466 for i in 0..var_names.len() {
467 for j in (i + 1)..var_names.len() {
468 let mut neighbor = solution.clone();
469 let val_i = solution[&var_names[i]];
470 let val_j = solution[&var_names[j]];
471
472 neighbor.insert(var_names[i].clone(), !val_i);
473 neighbor.insert(var_names[j].clone(), !val_j);
474
475 let neighbor_energy = self.compute_energy(&neighbor, qubo_matrix);
476
477 if neighbor_energy < best_energy {
478 best_solution = neighbor;
479 best_energy = neighbor_energy;
480 }
481 }
482 }
483
484 Ok((best_solution, best_energy))
485 }
486
487 fn repair_constraints(
489 &self,
490 solution: &HashMap<String, bool>,
491 _qubo_matrix: &Array2<f64>,
492 ) -> Result<HashMap<String, bool>, String> {
493 let mut repaired = solution.clone();
496
497 match self.config.repair_strategy {
498 RepairStrategy::Greedy => {
499 }
502 RepairStrategy::Random => {
503 }
506 RepairStrategy::Weighted => {
507 }
510 RepairStrategy::Iterative => {
511 }
514 }
515
516 Ok(repaired)
517 }
518
519 pub fn fix_variables(
521 &mut self,
522 samples: &[HashMap<String, bool>],
523 criterion: FixingCriterion,
524 ) -> Result<Vec<FixedVariable>, String> {
525 if samples.is_empty() {
526 return Ok(Vec::new());
527 }
528
529 let mut fixed = Vec::new();
530
531 match criterion {
532 FixingCriterion::HighFrequency { threshold } => {
533 let mut frequencies: HashMap<String, (usize, usize)> = HashMap::new();
535
536 for sample in samples {
537 for (var, &value) in sample {
538 let entry = frequencies.entry(var.clone()).or_insert((0, 0));
539 if value {
540 entry.0 += 1;
541 } else {
542 entry.1 += 1;
543 }
544 }
545 }
546
547 for (var, (true_count, false_count)) in frequencies {
549 let total = (true_count + false_count) as f64;
550 let true_freq = true_count as f64 / total;
551 let false_freq = false_count as f64 / total;
552
553 if true_freq >= threshold {
554 self.fixed_variables.insert(var.clone(), true);
555 fixed.push(FixedVariable {
556 name: var,
557 value: true,
558 confidence: true_freq,
559 reason: format!("High frequency ({true_freq})"),
560 });
561 } else if false_freq >= threshold {
562 self.fixed_variables.insert(var.clone(), false);
563 fixed.push(FixedVariable {
564 name: var,
565 value: false,
566 confidence: false_freq,
567 reason: format!("High frequency ({false_freq})"),
568 });
569 }
570 }
571 }
572 FixingCriterion::LowVariance { threshold } => {
573 }
576 FixingCriterion::StrongCorrelation { threshold } => {
577 }
580 FixingCriterion::ReducedCost { threshold } => {
581 }
584 }
585
586 Ok(fixed)
587 }
588
589 pub fn unfix_all(&mut self) {
591 self.fixed_variables.clear();
592 }
593
594 pub fn iterative_refinement<S: Sampler>(
596 &mut self,
597 sampler: &S,
598 qubo_matrix: &Array2<f64>,
599 num_samples: usize,
600 ) -> Result<Vec<RefinedSolution>, String> {
601 let mut refined_solutions = Vec::new();
602 let mut best_energy = f64::INFINITY;
603
604 for iteration in 0..self.config.max_qc_iterations {
605 println!(
606 "Quantum-Classical iteration {}/{}",
607 iteration + 1,
608 self.config.max_qc_iterations
609 );
610
611 let mut samples = Vec::new();
615 for _ in 0..num_samples {
616 let mut sample = HashMap::new();
617 for i in 0..qubo_matrix.nrows() {
618 sample.insert(format!("x{i}"), self.rng.gen::<bool>());
619 }
620 samples.push(sample);
621 }
622
623 if let Some(criterion) = self.config.fixing_criterion {
625 let fixed = self.fix_variables(&samples, criterion)?;
626 println!("Fixed {} variables", fixed.len());
627 }
628
629 for sample in samples {
631 let refined = self.refine_solution(&sample, qubo_matrix)?;
632
633 if refined.energy < best_energy {
634 best_energy = refined.energy;
635 println!("New best energy: {best_energy}");
636 }
637
638 refined_solutions.push(refined);
639 }
640
641 if iteration > 0 && self.has_converged() {
643 println!("Converged after {} iterations", iteration + 1);
644 break;
645 }
646 }
647
648 Ok(refined_solutions)
649 }
650
651 fn compute_energy(&self, solution: &HashMap<String, bool>, qubo_matrix: &Array2<f64>) -> f64 {
653 let n = qubo_matrix.nrows();
654 let mut energy = 0.0;
655
656 for i in 0..n {
657 for j in 0..n {
658 let x_i = if solution.get(&format!("x{i}")).copied().unwrap_or(false) {
659 1.0
660 } else {
661 0.0
662 };
663 let x_j = if solution.get(&format!("x{j}")).copied().unwrap_or(false) {
664 1.0
665 } else {
666 0.0
667 };
668 energy += qubo_matrix[[i, j]] * x_i * x_j;
669 }
670 }
671
672 energy
673 }
674
675 const fn compute_violations(
677 &self,
678 _solution: &HashMap<String, bool>,
679 ) -> Vec<ConstraintViolation> {
680 Vec::new()
682 }
683
684 fn has_converged(&self) -> bool {
686 if self.history.len() < 3 {
687 return false;
688 }
689
690 let recent = &self.history[self.history.len() - 3..];
691 let max_change = recent
692 .windows(2)
693 .map(|w| (w[0] - w[1]).abs())
694 .fold(0.0, f64::max);
695
696 max_change < self.config.convergence_tolerance
697 }
698
699 fn hash_solution(&self, solution: &HashMap<String, bool>) -> u64 {
701 use std::collections::hash_map::DefaultHasher;
702 use std::hash::{Hash, Hasher};
703
704 let mut hasher = DefaultHasher::new();
705 let mut sorted: Vec<_> = solution.iter().collect();
706 sorted.sort_by_key(|(k, _)| k.as_str());
707
708 for (k, v) in sorted {
709 k.hash(&mut hasher);
710 v.hash(&mut hasher);
711 }
712
713 hasher.finish()
714 }
715
716 pub fn get_history(&self) -> &[f64] {
718 &self.history
719 }
720
721 pub const fn get_fixed_variables(&self) -> &HashMap<String, bool> {
723 &self.fixed_variables
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730
731 #[test]
732 fn test_hybrid_optimizer_creation() {
733 let config = HybridConfig::default();
734 let optimizer = HybridOptimizer::new(config);
735
736 assert_eq!(optimizer.fixed_variables.len(), 0);
737 assert_eq!(optimizer.history.len(), 0);
738 }
739
740 #[test]
741 fn test_energy_computation() {
742 let config = HybridConfig::default();
743 let optimizer = HybridOptimizer::new(config);
744
745 let qubo = Array2::from_shape_fn((2, 2), |(i, j)| if i == j { -1.0 } else { 2.0 });
746
747 let solution = HashMap::from([("x0".to_string(), true), ("x1".to_string(), false)]);
748
749 let energy = optimizer.compute_energy(&solution, &qubo);
750 assert_eq!(energy, -1.0); }
752
753 #[test]
754 fn test_local_search_refinement() {
755 let config = HybridConfig {
756 max_local_iterations: 10,
757 ..Default::default()
758 };
759 let mut optimizer = HybridOptimizer::new(config);
760
761 let qubo = Array2::from_shape_fn((3, 3), |(i, j)| if i == j { -1.0 } else { 0.5 });
762
763 let initial_solution = HashMap::from([
764 ("x0".to_string(), false),
765 ("x1".to_string(), false),
766 ("x2".to_string(), false),
767 ]);
768
769 let refined = optimizer
770 .refine_solution(&initial_solution, &qubo)
771 .expect("refinement should succeed");
772
773 assert!(refined.improvement >= 0.0);
774 assert!(refined.energy <= optimizer.compute_energy(&initial_solution, &qubo));
775 }
776
777 #[test]
778 fn test_variable_fixing() {
779 let config = HybridConfig::default();
780 let mut optimizer = HybridOptimizer::new(config);
781
782 let samples = vec![
783 HashMap::from([("x0".to_string(), true), ("x1".to_string(), false)]),
784 HashMap::from([("x0".to_string(), true), ("x1".to_string(), true)]),
785 HashMap::from([("x0".to_string(), true), ("x1".to_string(), false)]),
786 ];
787
788 let criterion = FixingCriterion::HighFrequency { threshold: 0.8 };
789 let fixed = optimizer
790 .fix_variables(&samples, criterion)
791 .expect("variable fixing should succeed");
792
793 assert!(!fixed.is_empty());
795 assert!(fixed.iter().any(|f| f.name == "x0" && f.value));
796 }
797
798 #[test]
799 fn test_convergence_detection() {
800 let mut config = HybridConfig::default();
801 config.convergence_tolerance = 0.001; let mut optimizer = HybridOptimizer::new(config);
803
804 optimizer.history = vec![10.0, 10.00001, 10.00002];
806
807 assert!(optimizer.has_converged());
808
809 optimizer.history = vec![10.0, 9.0, 8.0];
811
812 assert!(!optimizer.has_converged());
813 }
814}