1use crate::embedding::Embedding;
8use crate::ising::{IsingError, IsingResult};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone)]
13pub struct HardwareSolution {
14 pub spins: Vec<i8>,
16 pub energy: f64,
18 pub occurrences: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ResolvedSolution {
25 pub logical_spins: Vec<i8>,
27 pub chain_breaks: usize,
29 pub energy: f64,
31 pub hardware_solution: HardwareSolution,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ResolutionMethod {
38 MajorityVote,
40 EnergyMinimization,
42 WeightedMajority,
44 Discard,
46}
47
48pub struct ChainBreakResolver {
50 pub method: ResolutionMethod,
52 pub tie_break_random: bool,
54 pub seed: Option<u64>,
56}
57
58impl Default for ChainBreakResolver {
59 fn default() -> Self {
60 Self {
61 method: ResolutionMethod::MajorityVote,
62 tie_break_random: true,
63 seed: None,
64 }
65 }
66}
67
68impl ChainBreakResolver {
69 pub fn resolve_solution(
71 &self,
72 hardware_solution: &HardwareSolution,
73 embedding: &Embedding,
74 logical_problem: Option<&LogicalProblem>,
75 ) -> IsingResult<ResolvedSolution> {
76 match self.method {
77 ResolutionMethod::MajorityVote => {
78 self.resolve_majority_vote(hardware_solution, embedding)
79 }
80 ResolutionMethod::WeightedMajority => {
81 self.resolve_weighted_majority(hardware_solution, embedding)
82 }
83 ResolutionMethod::EnergyMinimization => {
84 let problem = logical_problem.ok_or_else(|| {
85 IsingError::InvalidValue(
86 "Energy minimization requires logical problem".to_string(),
87 )
88 })?;
89 self.resolve_energy_minimization(hardware_solution, embedding, problem)
90 }
91 ResolutionMethod::Discard => self.resolve_discard(hardware_solution, embedding),
92 }
93 }
94
95 pub fn resolve_solutions(
97 &self,
98 hardware_solutions: &[HardwareSolution],
99 embedding: &Embedding,
100 logical_problem: Option<&LogicalProblem>,
101 ) -> IsingResult<Vec<ResolvedSolution>> {
102 let mut resolved = Vec::new();
103
104 for hw_solution in hardware_solutions {
105 match self.resolve_solution(hw_solution, embedding, logical_problem) {
106 Ok(solution) => resolved.push(solution),
107 Err(_) if self.method == ResolutionMethod::Discard => {
108 continue;
110 }
111 Err(e) => return Err(e),
112 }
113 }
114
115 resolved.sort_by(|a, b| {
117 a.energy
118 .partial_cmp(&b.energy)
119 .unwrap_or(std::cmp::Ordering::Equal)
120 });
121
122 Ok(resolved)
123 }
124
125 fn resolve_majority_vote(
127 &self,
128 hardware_solution: &HardwareSolution,
129 embedding: &Embedding,
130 ) -> IsingResult<ResolvedSolution> {
131 let mut logical_spins = Vec::new();
132 let mut chain_breaks = 0;
133 let num_vars = embedding.chains.len();
134
135 for var in 0..num_vars {
136 let chain = embedding
137 .chains
138 .get(&var)
139 .ok_or_else(|| IsingError::InvalidQubit(var))?;
140
141 let mut plus_votes = 0;
143 let mut minus_votes = 0;
144
145 for &qubit in chain {
146 if qubit >= hardware_solution.spins.len() {
147 return Err(IsingError::InvalidQubit(qubit));
148 }
149
150 match hardware_solution.spins[qubit] {
151 1 => plus_votes += 1,
152 -1 => minus_votes += 1,
153 _ => return Err(IsingError::InvalidValue("Invalid spin value".to_string())),
154 }
155 }
156
157 let logical_value = if plus_votes > minus_votes {
159 1
160 } else if minus_votes > plus_votes {
161 -1
162 } else {
163 if self.tie_break_random {
165 if var % 2 == 0 {
167 1
168 } else {
169 -1
170 }
171 } else {
172 1
173 }
174 };
175
176 let unanimous = plus_votes == 0 || minus_votes == 0;
178 if !unanimous {
179 chain_breaks += 1;
180 }
181
182 logical_spins.push(logical_value);
183 }
184
185 Ok(ResolvedSolution {
186 logical_spins,
187 chain_breaks,
188 energy: hardware_solution.energy, hardware_solution: hardware_solution.clone(),
190 })
191 }
192
193 fn resolve_weighted_majority(
195 &self,
196 hardware_solution: &HardwareSolution,
197 embedding: &Embedding,
198 ) -> IsingResult<ResolvedSolution> {
199 let num_vars = embedding.chains.len();
204 let mut logical_spins = vec![0i8; num_vars];
205 let mut chain_breaks = 0;
206
207 for var in 0..num_vars {
208 if let Some(chain) = embedding.chains.get(&var) {
209 if chain.is_empty() {
210 return Err(IsingError::InvalidValue(format!(
211 "Empty chain for variable {var}"
212 )));
213 }
214
215 if chain.len() == 1 {
216 logical_spins[var] = hardware_solution.spins[chain[0]];
218 continue;
219 }
220
221 let mut weight_plus = 0.0;
223 let mut weight_minus = 0.0;
224 let mut has_disagreement = false;
225
226 for &qubit_i in chain {
227 let spin_i = hardware_solution.spins[qubit_i];
228
229 let mut agreement_count = 0.0;
231 for &qubit_j in chain {
232 if qubit_i != qubit_j && hardware_solution.spins[qubit_j] == spin_i {
233 agreement_count += 1.0;
234 }
235 }
236
237 let weight = 1.0 + agreement_count;
239
240 if spin_i == 1 {
241 weight_plus += weight;
242 } else if spin_i == -1 {
243 weight_minus += weight;
244 }
245
246 if hardware_solution.spins[chain[0]] != spin_i {
248 has_disagreement = true;
249 }
250 }
251
252 if weight_plus > weight_minus {
254 logical_spins[var] = 1;
255 } else if weight_minus > weight_plus {
256 logical_spins[var] = -1;
257 } else {
258 if self.tie_break_random {
260 use scirs2_core::random::{thread_rng, Rng};
261 let mut rng = thread_rng();
262 logical_spins[var] = if rng.gen::<bool>() { 1 } else { -1 };
263 } else {
264 logical_spins[var] = hardware_solution.spins[chain[0]];
265 }
266 }
267
268 if has_disagreement {
269 chain_breaks += 1;
270 }
271 }
272 }
273
274 Ok(ResolvedSolution {
275 logical_spins,
276 chain_breaks,
277 energy: hardware_solution.energy,
278 hardware_solution: hardware_solution.clone(),
279 })
280 }
281
282 fn resolve_energy_minimization(
284 &self,
285 hardware_solution: &HardwareSolution,
286 embedding: &Embedding,
287 logical_problem: &LogicalProblem,
288 ) -> IsingResult<ResolvedSolution> {
289 let mut resolved = self.resolve_majority_vote(hardware_solution, embedding)?;
290
291 for var in 0..resolved.logical_spins.len() {
293 if self.is_chain_broken(var, hardware_solution, embedding)? {
294 let current_energy = logical_problem.calculate_energy(&resolved.logical_spins);
296
297 resolved.logical_spins[var] *= -1;
299 let flipped_energy = logical_problem.calculate_energy(&resolved.logical_spins);
300
301 if flipped_energy >= current_energy {
303 resolved.logical_spins[var] *= -1; }
305 }
306 }
307
308 resolved.energy = logical_problem.calculate_energy(&resolved.logical_spins);
310
311 Ok(resolved)
312 }
313
314 fn resolve_discard(
316 &self,
317 hardware_solution: &HardwareSolution,
318 embedding: &Embedding,
319 ) -> IsingResult<ResolvedSolution> {
320 let resolved = self.resolve_majority_vote(hardware_solution, embedding)?;
321
322 if resolved.chain_breaks > 0 {
323 Err(IsingError::HardwareConstraint(format!(
324 "Solution has {} broken chains",
325 resolved.chain_breaks
326 )))
327 } else {
328 Ok(resolved)
329 }
330 }
331
332 fn is_chain_broken(
334 &self,
335 var: usize,
336 hardware_solution: &HardwareSolution,
337 embedding: &Embedding,
338 ) -> IsingResult<bool> {
339 let chain = embedding
340 .chains
341 .get(&var)
342 .ok_or_else(|| IsingError::InvalidQubit(var))?;
343
344 if chain.is_empty() {
345 return Ok(false);
346 }
347
348 let first_spin = hardware_solution.spins[chain[0]];
349
350 for &qubit in &chain[1..] {
351 if hardware_solution.spins[qubit] != first_spin {
352 return Ok(true);
353 }
354 }
355
356 Ok(false)
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct LogicalProblem {
363 pub linear: Vec<f64>,
365 pub quadratic: HashMap<(usize, usize), f64>,
367 pub offset: f64,
369}
370
371impl LogicalProblem {
372 #[must_use]
374 pub fn new(num_vars: usize) -> Self {
375 Self {
376 linear: vec![0.0; num_vars],
377 quadratic: HashMap::new(),
378 offset: 0.0,
379 }
380 }
381
382 #[must_use]
384 pub fn calculate_energy(&self, spins: &[i8]) -> f64 {
385 let mut energy = self.offset;
386
387 for (i, &h) in self.linear.iter().enumerate() {
389 if i < spins.len() {
390 energy += h * f64::from(spins[i]);
391 }
392 }
393
394 for (&(i, j), &J) in &self.quadratic {
396 if i < spins.len() && j < spins.len() {
397 energy += J * f64::from(spins[i]) * f64::from(spins[j]);
398 }
399 }
400
401 energy
402 }
403
404 pub fn from_qubo(qubo_matrix: &[Vec<f64>], offset: f64) -> IsingResult<Self> {
406 let n = qubo_matrix.len();
407 let mut problem = Self::new(n);
408 problem.offset = offset;
409
410 for i in 0..n {
415 for j in i..n {
416 let q_ij = qubo_matrix[i][j];
417 if q_ij.abs() > 1e-10 {
418 problem.offset += q_ij / 4.0;
419 if i == j {
420 problem.linear[i] += q_ij / 2.0;
422 } else {
423 problem.quadratic.insert((i, j), q_ij / 4.0);
425 problem.linear[i] += q_ij / 4.0;
426 problem.linear[j] += q_ij / 4.0;
427 }
428 }
429 }
430 }
431
432 Ok(problem)
433 }
434}
435
436pub struct ChainStrengthOptimizer {
438 pub min_strength: f64,
440 pub max_strength: f64,
442 pub num_tries: usize,
444}
445
446impl Default for ChainStrengthOptimizer {
447 fn default() -> Self {
448 Self {
449 min_strength: 0.1,
450 max_strength: 10.0,
451 num_tries: 10,
452 }
453 }
454}
455
456impl ChainStrengthOptimizer {
457 #[must_use]
459 pub fn find_optimal_strength(&self, logical_problem: &LogicalProblem) -> f64 {
460 let mut all_coeffs = Vec::new();
462
463 for &h in &logical_problem.linear {
465 if h.abs() > 1e-10 {
466 all_coeffs.push(h.abs());
467 }
468 }
469
470 for &J in logical_problem.quadratic.values() {
472 if J.abs() > 1e-10 {
473 all_coeffs.push(J.abs());
474 }
475 }
476
477 if all_coeffs.is_empty() {
478 return 1.0; }
480
481 all_coeffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
483
484 let median = if all_coeffs.len() % 2 == 0 {
486 f64::midpoint(
487 all_coeffs[all_coeffs.len() / 2 - 1],
488 all_coeffs[all_coeffs.len() / 2],
489 )
490 } else {
491 all_coeffs[all_coeffs.len() / 2]
492 };
493
494 (median * 1.5).max(self.min_strength).min(self.max_strength)
497 }
498
499 #[must_use]
501 pub fn optimize_strength(
502 &self,
503 logical_problem: &LogicalProblem,
504 test_solutions: &[Vec<i8>],
505 ) -> f64 {
506 let mut best_strength = self.find_optimal_strength(logical_problem);
507 let mut best_score = f64::INFINITY;
508
509 let step = (self.max_strength - self.min_strength) / (self.num_tries as f64);
511
512 for i in 0..self.num_tries {
513 let strength = (i as f64).mul_add(step, self.min_strength);
514
515 let score = self.evaluate_strength(strength, logical_problem, test_solutions);
517
518 if score < best_score {
519 best_score = score;
520 best_strength = strength;
521 }
522 }
523
524 best_strength
525 }
526
527 fn evaluate_strength(
529 &self,
530 strength: f64,
531 logical_problem: &LogicalProblem,
532 test_solutions: &[Vec<i8>],
533 ) -> f64 {
534 let avg_coeff = self.calculate_average_coefficient(logical_problem);
539
540 (strength / avg_coeff - 1.5).abs()
542 }
543
544 fn calculate_average_coefficient(&self, logical_problem: &LogicalProblem) -> f64 {
546 let mut sum = 0.0;
547 let mut count = 0;
548
549 for &h in &logical_problem.linear {
550 if h.abs() > 1e-10 {
551 sum += h.abs();
552 count += 1;
553 }
554 }
555
556 for &J in logical_problem.quadratic.values() {
557 if J.abs() > 1e-10 {
558 sum += J.abs();
559 count += 1;
560 }
561 }
562
563 if count > 0 {
564 sum / f64::from(count)
565 } else {
566 1.0
567 }
568 }
569}
570
571#[derive(Debug, Clone, Default)]
573pub struct ChainBreakStats {
574 pub total_chains: usize,
576 pub broken_chains: Vec<usize>,
578 pub break_rate: f64,
580 pub frequent_breaks: Vec<(usize, usize)>,
582}
583
584impl ChainBreakStats {
585 pub fn analyze(
587 hardware_solutions: &[HardwareSolution],
588 embedding: &Embedding,
589 ) -> IsingResult<Self> {
590 let total_chains = embedding.chains.len();
591 let mut broken_chains = Vec::new();
592 let mut break_counts: HashMap<usize, usize> = HashMap::new();
593
594 for hw_solution in hardware_solutions {
595 let mut breaks_in_solution = 0;
596
597 for (&var, chain) in &embedding.chains {
598 if chain.len() > 1 {
599 let first_spin = hw_solution.spins[chain[0]];
600 let is_broken = chain[1..]
601 .iter()
602 .any(|&q| hw_solution.spins[q] != first_spin);
603
604 if is_broken {
605 breaks_in_solution += 1;
606 *break_counts.entry(var).or_insert(0) += 1;
607 }
608 }
609 }
610
611 broken_chains.push(breaks_in_solution);
612 }
613
614 let total_breaks: usize = broken_chains.iter().sum();
616 let break_rate = if hardware_solutions.is_empty() || total_chains == 0 {
617 0.0
618 } else {
619 total_breaks as f64 / (hardware_solutions.len() * total_chains) as f64
620 };
621
622 let mut frequent_breaks: Vec<(usize, usize)> = break_counts.into_iter().collect();
624 frequent_breaks.sort_by_key(|&(_, count)| std::cmp::Reverse(count));
625 frequent_breaks.truncate(10); Ok(Self {
628 total_chains,
629 broken_chains,
630 break_rate,
631 frequent_breaks,
632 })
633 }
634
635 #[must_use]
637 pub fn get_recommendations(&self) -> Vec<String> {
638 let mut recommendations = Vec::new();
639
640 if self.break_rate > 0.5 {
641 recommendations.push(
642 "High chain break rate detected. Consider increasing chain strength.".to_string(),
643 );
644 }
645
646 if self.break_rate > 0.2 {
647 recommendations.push(
648 "Moderate chain breaks. Try optimizing embedding or chain strength.".to_string(),
649 );
650 }
651
652 if !self.frequent_breaks.is_empty() {
653 let vars: Vec<String> = self
654 .frequent_breaks
655 .iter()
656 .take(3)
657 .map(|(var, _)| var.to_string())
658 .collect();
659 recommendations.push(format!(
660 "Variables {} frequently have broken chains. Check embedding quality.",
661 vars.join(", ")
662 ));
663 }
664
665 recommendations
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_majority_vote_resolution() {
675 let mut embedding = Embedding::new();
676 embedding
677 .add_chain(0, vec![0, 1, 2])
678 .expect("failed to add chain in test");
679 embedding
680 .add_chain(1, vec![3, 4, 5])
681 .expect("failed to add chain in test");
682
683 let hw_solution = HardwareSolution {
684 spins: vec![1, 1, -1, -1, -1, -1], energy: -1.0,
686 occurrences: 1,
687 };
688
689 let resolver = ChainBreakResolver::default();
690 let resolved = resolver
691 .resolve_solution(&hw_solution, &embedding, None)
692 .expect("failed to resolve solution in test");
693
694 assert_eq!(resolved.logical_spins, vec![1, -1]);
695 assert_eq!(resolved.chain_breaks, 1); }
697
698 #[test]
699 fn test_chain_strength_optimizer() {
700 let mut problem = LogicalProblem::new(3);
701 problem.linear = vec![1.0, -0.5, 0.0];
702 problem.quadratic.insert((0, 1), -2.0);
703 problem.quadratic.insert((1, 2), 1.5);
704
705 let optimizer = ChainStrengthOptimizer::default();
706 let strength = optimizer.find_optimal_strength(&problem);
707
708 assert!(strength > 0.5 && strength < 5.0);
710 }
711
712 #[test]
713 fn test_chain_break_stats() {
714 let mut embedding = Embedding::new();
715 embedding
716 .add_chain(0, vec![0, 1])
717 .expect("failed to add chain in test");
718 embedding
719 .add_chain(1, vec![2, 3])
720 .expect("failed to add chain in test");
721
722 let solutions = vec![
723 HardwareSolution {
724 spins: vec![1, 1, -1, -1], energy: -1.0,
726 occurrences: 1,
727 },
728 HardwareSolution {
729 spins: vec![1, -1, -1, -1], energy: -0.5,
731 occurrences: 1,
732 },
733 ];
734
735 let stats = ChainBreakStats::analyze(&solutions, &embedding)
736 .expect("failed to analyze chain break stats in test");
737
738 assert_eq!(stats.total_chains, 2);
739 assert_eq!(stats.broken_chains, vec![0, 1]);
740 assert_eq!(stats.break_rate, 0.25); }
742}