quantrs2_anneal/
chain_break.rs

1//! Chain break resolution algorithms for quantum annealing
2//!
3//! When logical variables are embedded onto physical qubits using chains,
4//! the physical qubits in a chain may disagree in the solution. This module
5//! provides algorithms to resolve these chain breaks.
6
7use crate::embedding::Embedding;
8use crate::ising::{IsingError, IsingResult};
9use std::collections::{HashMap, HashSet};
10
11/// Represents a solution from quantum annealing hardware
12#[derive(Debug, Clone)]
13pub struct HardwareSolution {
14    /// Values of physical qubits (spin values: +1 or -1)
15    pub spins: Vec<i8>,
16    /// Energy of this solution
17    pub energy: f64,
18    /// Number of occurrences (for multiple reads)
19    pub occurrences: usize,
20}
21
22/// Resolved solution after chain break resolution
23#[derive(Debug, Clone)]
24pub struct ResolvedSolution {
25    /// Values of logical variables
26    pub logical_spins: Vec<i8>,
27    /// Number of broken chains
28    pub chain_breaks: usize,
29    /// Energy after resolution
30    pub energy: f64,
31    /// Original hardware solution
32    pub hardware_solution: HardwareSolution,
33}
34
35/// Chain break resolution method
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ResolutionMethod {
38    /// Take majority vote within each chain
39    MajorityVote,
40    /// Minimize energy of the logical problem
41    EnergyMinimization,
42    /// Use weighted majority based on coupling strengths
43    WeightedMajority,
44    /// Discard solutions with broken chains
45    Discard,
46}
47
48/// Chain break resolver
49pub struct ChainBreakResolver {
50    /// Resolution method to use
51    pub method: ResolutionMethod,
52    /// Tie-breaking strategy for majority vote
53    pub tie_break_random: bool,
54    /// Random seed for tie-breaking
55    pub seed: Option<u64>,
56}
57
58impl Default for ChainBreakResolver {
59    fn default() -> Self {
60        Self {
61            method: ResolutionMethod::MajorityVote,
62            tie_break_random: true,
63            seed: None,
64        }
65    }
66}
67
68impl ChainBreakResolver {
69    /// Resolve chain breaks in a single hardware solution
70    pub fn resolve_solution(
71        &self,
72        hardware_solution: &HardwareSolution,
73        embedding: &Embedding,
74        logical_problem: Option<&LogicalProblem>,
75    ) -> IsingResult<ResolvedSolution> {
76        match self.method {
77            ResolutionMethod::MajorityVote => {
78                self.resolve_majority_vote(hardware_solution, embedding)
79            }
80            ResolutionMethod::WeightedMajority => {
81                self.resolve_weighted_majority(hardware_solution, embedding)
82            }
83            ResolutionMethod::EnergyMinimization => {
84                let problem = logical_problem.ok_or_else(|| {
85                    IsingError::InvalidValue(
86                        "Energy minimization requires logical problem".to_string(),
87                    )
88                })?;
89                self.resolve_energy_minimization(hardware_solution, embedding, problem)
90            }
91            ResolutionMethod::Discard => self.resolve_discard(hardware_solution, embedding),
92        }
93    }
94
95    /// Resolve multiple hardware solutions
96    pub fn resolve_solutions(
97        &self,
98        hardware_solutions: &[HardwareSolution],
99        embedding: &Embedding,
100        logical_problem: Option<&LogicalProblem>,
101    ) -> IsingResult<Vec<ResolvedSolution>> {
102        let mut resolved = Vec::new();
103
104        for hw_solution in hardware_solutions {
105            match self.resolve_solution(hw_solution, embedding, logical_problem) {
106                Ok(solution) => resolved.push(solution),
107                Err(_) if self.method == ResolutionMethod::Discard => {
108                    // Skip broken solutions when using discard method
109                    continue;
110                }
111                Err(e) => return Err(e),
112            }
113        }
114
115        // Sort by energy
116        resolved.sort_by(|a, b| {
117            a.energy
118                .partial_cmp(&b.energy)
119                .unwrap_or(std::cmp::Ordering::Equal)
120        });
121
122        Ok(resolved)
123    }
124
125    /// Resolve using majority vote
126    fn resolve_majority_vote(
127        &self,
128        hardware_solution: &HardwareSolution,
129        embedding: &Embedding,
130    ) -> IsingResult<ResolvedSolution> {
131        let mut logical_spins = Vec::new();
132        let mut chain_breaks = 0;
133        let num_vars = embedding.chains.len();
134
135        for var in 0..num_vars {
136            let chain = embedding
137                .chains
138                .get(&var)
139                .ok_or_else(|| IsingError::InvalidQubit(var))?;
140
141            // Count votes
142            let mut plus_votes = 0;
143            let mut minus_votes = 0;
144
145            for &qubit in chain {
146                if qubit >= hardware_solution.spins.len() {
147                    return Err(IsingError::InvalidQubit(qubit));
148                }
149
150                match hardware_solution.spins[qubit] {
151                    1 => plus_votes += 1,
152                    -1 => minus_votes += 1,
153                    _ => return Err(IsingError::InvalidValue("Invalid spin value".to_string())),
154                }
155            }
156
157            // Determine logical value
158            let logical_value = if plus_votes > minus_votes {
159                1
160            } else if minus_votes > plus_votes {
161                -1
162            } else {
163                // Tie - use random or default to +1
164                if self.tie_break_random {
165                    // Simple deterministic tie-break based on variable index
166                    if var % 2 == 0 {
167                        1
168                    } else {
169                        -1
170                    }
171                } else {
172                    1
173                }
174            };
175
176            // Check for chain breaks
177            let unanimous = plus_votes == 0 || minus_votes == 0;
178            if !unanimous {
179                chain_breaks += 1;
180            }
181
182            logical_spins.push(logical_value);
183        }
184
185        Ok(ResolvedSolution {
186            logical_spins,
187            chain_breaks,
188            energy: hardware_solution.energy, // Will be recalculated if needed
189            hardware_solution: hardware_solution.clone(),
190        })
191    }
192
193    /// Resolve using weighted majority based on coupling strengths
194    fn resolve_weighted_majority(
195        &self,
196        hardware_solution: &HardwareSolution,
197        embedding: &Embedding,
198    ) -> IsingResult<ResolvedSolution> {
199        // Weighted majority voting: weight each qubit's vote by the number of
200        // other qubits in the chain that agree with it. This gives more influence
201        // to qubits that are part of a larger consensus.
202
203        let num_vars = embedding.chains.len();
204        let mut logical_spins = vec![0i8; num_vars];
205        let mut chain_breaks = 0;
206
207        for var in 0..num_vars {
208            if let Some(chain) = embedding.chains.get(&var) {
209                if chain.is_empty() {
210                    return Err(IsingError::InvalidValue(format!(
211                        "Empty chain for variable {var}"
212                    )));
213                }
214
215                if chain.len() == 1 {
216                    // Single qubit chain - no possibility of chain break
217                    logical_spins[var] = hardware_solution.spins[chain[0]];
218                    continue;
219                }
220
221                // Calculate weighted votes for +1 and -1
222                let mut weight_plus = 0.0;
223                let mut weight_minus = 0.0;
224                let mut has_disagreement = false;
225
226                for &qubit_i in chain {
227                    let spin_i = hardware_solution.spins[qubit_i];
228
229                    // Calculate weight: count how many qubits in the chain agree with this one
230                    let mut agreement_count = 0.0;
231                    for &qubit_j in chain {
232                        if qubit_i != qubit_j && hardware_solution.spins[qubit_j] == spin_i {
233                            agreement_count += 1.0;
234                        }
235                    }
236
237                    // Weight is: 1.0 (base) + agreement_count (bonus for consensus)
238                    let weight = 1.0 + agreement_count;
239
240                    if spin_i == 1 {
241                        weight_plus += weight;
242                    } else if spin_i == -1 {
243                        weight_minus += weight;
244                    }
245
246                    // Check for disagreement
247                    if hardware_solution.spins[chain[0]] != spin_i {
248                        has_disagreement = true;
249                    }
250                }
251
252                // Choose the spin value with higher weighted vote
253                if weight_plus > weight_minus {
254                    logical_spins[var] = 1;
255                } else if weight_minus > weight_plus {
256                    logical_spins[var] = -1;
257                } else {
258                    // Tie - use random or first qubit
259                    if self.tie_break_random {
260                        use scirs2_core::random::{thread_rng, Rng};
261                        let mut rng = thread_rng();
262                        logical_spins[var] = if rng.gen::<bool>() { 1 } else { -1 };
263                    } else {
264                        logical_spins[var] = hardware_solution.spins[chain[0]];
265                    }
266                }
267
268                if has_disagreement {
269                    chain_breaks += 1;
270                }
271            }
272        }
273
274        Ok(ResolvedSolution {
275            logical_spins,
276            chain_breaks,
277            energy: hardware_solution.energy,
278            hardware_solution: hardware_solution.clone(),
279        })
280    }
281
282    /// Resolve by minimizing energy of logical problem
283    fn resolve_energy_minimization(
284        &self,
285        hardware_solution: &HardwareSolution,
286        embedding: &Embedding,
287        logical_problem: &LogicalProblem,
288    ) -> IsingResult<ResolvedSolution> {
289        let mut resolved = self.resolve_majority_vote(hardware_solution, embedding)?;
290
291        // For each broken chain, try flipping the logical variable
292        for var in 0..resolved.logical_spins.len() {
293            if self.is_chain_broken(var, hardware_solution, embedding)? {
294                // Calculate energy with current value
295                let current_energy = logical_problem.calculate_energy(&resolved.logical_spins);
296
297                // Flip and calculate energy
298                resolved.logical_spins[var] *= -1;
299                let flipped_energy = logical_problem.calculate_energy(&resolved.logical_spins);
300
301                // Keep the flip if it lowers energy
302                if flipped_energy >= current_energy {
303                    resolved.logical_spins[var] *= -1; // Flip back
304                }
305            }
306        }
307
308        // Recalculate final energy
309        resolved.energy = logical_problem.calculate_energy(&resolved.logical_spins);
310
311        Ok(resolved)
312    }
313
314    /// Discard solutions with broken chains
315    fn resolve_discard(
316        &self,
317        hardware_solution: &HardwareSolution,
318        embedding: &Embedding,
319    ) -> IsingResult<ResolvedSolution> {
320        let resolved = self.resolve_majority_vote(hardware_solution, embedding)?;
321
322        if resolved.chain_breaks > 0 {
323            Err(IsingError::HardwareConstraint(format!(
324                "Solution has {} broken chains",
325                resolved.chain_breaks
326            )))
327        } else {
328            Ok(resolved)
329        }
330    }
331
332    /// Check if a chain is broken
333    fn is_chain_broken(
334        &self,
335        var: usize,
336        hardware_solution: &HardwareSolution,
337        embedding: &Embedding,
338    ) -> IsingResult<bool> {
339        let chain = embedding
340            .chains
341            .get(&var)
342            .ok_or_else(|| IsingError::InvalidQubit(var))?;
343
344        if chain.is_empty() {
345            return Ok(false);
346        }
347
348        let first_spin = hardware_solution.spins[chain[0]];
349
350        for &qubit in &chain[1..] {
351            if hardware_solution.spins[qubit] != first_spin {
352                return Ok(true);
353            }
354        }
355
356        Ok(false)
357    }
358}
359
360/// Represents a logical problem (QUBO or Ising)
361#[derive(Debug, Clone)]
362pub struct LogicalProblem {
363    /// Linear coefficients (`h_i` in Ising, diagonal in QUBO)
364    pub linear: Vec<f64>,
365    /// Quadratic coefficients as adjacency list
366    pub quadratic: HashMap<(usize, usize), f64>,
367    /// Constant offset
368    pub offset: f64,
369}
370
371impl LogicalProblem {
372    /// Create a new logical problem
373    #[must_use]
374    pub fn new(num_vars: usize) -> Self {
375        Self {
376            linear: vec![0.0; num_vars],
377            quadratic: HashMap::new(),
378            offset: 0.0,
379        }
380    }
381
382    /// Calculate energy for a given spin configuration
383    #[must_use]
384    pub fn calculate_energy(&self, spins: &[i8]) -> f64 {
385        let mut energy = self.offset;
386
387        // Linear terms
388        for (i, &h) in self.linear.iter().enumerate() {
389            if i < spins.len() {
390                energy += h * f64::from(spins[i]);
391            }
392        }
393
394        // Quadratic terms
395        for (&(i, j), &J) in &self.quadratic {
396            if i < spins.len() && j < spins.len() {
397                energy += J * f64::from(spins[i]) * f64::from(spins[j]);
398            }
399        }
400
401        energy
402    }
403
404    /// Convert from QUBO to Ising representation
405    pub fn from_qubo(qubo_matrix: &[Vec<f64>], offset: f64) -> IsingResult<Self> {
406        let n = qubo_matrix.len();
407        let mut problem = Self::new(n);
408        problem.offset = offset;
409
410        // Convert QUBO Q_ij to Ising h_i and J_ij
411        // x_i = (s_i + 1) / 2
412        // Minimize x^T Q x becomes minimize sum_i h_i s_i + sum_{i<j} J_ij s_i s_j
413
414        for i in 0..n {
415            for j in i..n {
416                let q_ij = qubo_matrix[i][j];
417                if q_ij.abs() > 1e-10 {
418                    problem.offset += q_ij / 4.0;
419                    if i == j {
420                        // Diagonal term contributes to linear coefficient
421                        problem.linear[i] += q_ij / 2.0;
422                    } else {
423                        // Off-diagonal term
424                        problem.quadratic.insert((i, j), q_ij / 4.0);
425                        problem.linear[i] += q_ij / 4.0;
426                        problem.linear[j] += q_ij / 4.0;
427                    }
428                }
429            }
430        }
431
432        Ok(problem)
433    }
434}
435
436/// Chain strength optimizer
437pub struct ChainStrengthOptimizer {
438    /// Minimum chain strength
439    pub min_strength: f64,
440    /// Maximum chain strength
441    pub max_strength: f64,
442    /// Number of strength values to try
443    pub num_tries: usize,
444}
445
446impl Default for ChainStrengthOptimizer {
447    fn default() -> Self {
448        Self {
449            min_strength: 0.1,
450            max_strength: 10.0,
451            num_tries: 10,
452        }
453    }
454}
455
456impl ChainStrengthOptimizer {
457    /// Find optimal chain strength by analyzing the problem
458    #[must_use]
459    pub fn find_optimal_strength(&self, logical_problem: &LogicalProblem) -> f64 {
460        // Calculate statistics of the logical problem coefficients
461        let mut all_coeffs = Vec::new();
462
463        // Add linear coefficients
464        for &h in &logical_problem.linear {
465            if h.abs() > 1e-10 {
466                all_coeffs.push(h.abs());
467            }
468        }
469
470        // Add quadratic coefficients
471        for &J in logical_problem.quadratic.values() {
472            if J.abs() > 1e-10 {
473                all_coeffs.push(J.abs());
474            }
475        }
476
477        if all_coeffs.is_empty() {
478            return 1.0; // Default strength
479        }
480
481        // Sort coefficients
482        all_coeffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
483
484        // Use median as base strength
485        let median = if all_coeffs.len() % 2 == 0 {
486            f64::midpoint(
487                all_coeffs[all_coeffs.len() / 2 - 1],
488                all_coeffs[all_coeffs.len() / 2],
489            )
490        } else {
491            all_coeffs[all_coeffs.len() / 2]
492        };
493
494        // Chain strength should be strong enough to keep chains together
495        // but not so strong as to dominate the problem
496        (median * 1.5).max(self.min_strength).min(self.max_strength)
497    }
498
499    /// Optimize chain strength through multiple runs
500    #[must_use]
501    pub fn optimize_strength(
502        &self,
503        logical_problem: &LogicalProblem,
504        test_solutions: &[Vec<i8>],
505    ) -> f64 {
506        let mut best_strength = self.find_optimal_strength(logical_problem);
507        let mut best_score = f64::INFINITY;
508
509        // Try different strengths
510        let step = (self.max_strength - self.min_strength) / (self.num_tries as f64);
511
512        for i in 0..self.num_tries {
513            let strength = (i as f64).mul_add(step, self.min_strength);
514
515            // Evaluate this strength
516            let score = self.evaluate_strength(strength, logical_problem, test_solutions);
517
518            if score < best_score {
519                best_score = score;
520                best_strength = strength;
521            }
522        }
523
524        best_strength
525    }
526
527    /// Evaluate a chain strength
528    fn evaluate_strength(
529        &self,
530        strength: f64,
531        logical_problem: &LogicalProblem,
532        test_solutions: &[Vec<i8>],
533    ) -> f64 {
534        // Simple evaluation: prefer strengths that maintain solution quality
535        // In practice, this would run actual annealing with different strengths
536
537        // For now, return a score based on the ratio to problem coefficients
538        let avg_coeff = self.calculate_average_coefficient(logical_problem);
539
540        // Penalty for being too different from problem scale
541        (strength / avg_coeff - 1.5).abs()
542    }
543
544    /// Calculate average coefficient magnitude
545    fn calculate_average_coefficient(&self, logical_problem: &LogicalProblem) -> f64 {
546        let mut sum = 0.0;
547        let mut count = 0;
548
549        for &h in &logical_problem.linear {
550            if h.abs() > 1e-10 {
551                sum += h.abs();
552                count += 1;
553            }
554        }
555
556        for &J in logical_problem.quadratic.values() {
557            if J.abs() > 1e-10 {
558                sum += J.abs();
559                count += 1;
560            }
561        }
562
563        if count > 0 {
564            sum / f64::from(count)
565        } else {
566            1.0
567        }
568    }
569}
570
571/// Statistics about chain breaks
572#[derive(Debug, Clone, Default)]
573pub struct ChainBreakStats {
574    /// Total number of chains
575    pub total_chains: usize,
576    /// Number of broken chains per solution
577    pub broken_chains: Vec<usize>,
578    /// Chain break rate
579    pub break_rate: f64,
580    /// Most frequently broken variables
581    pub frequent_breaks: Vec<(usize, usize)>,
582}
583
584impl ChainBreakStats {
585    /// Analyze chain breaks across multiple solutions
586    pub fn analyze(
587        hardware_solutions: &[HardwareSolution],
588        embedding: &Embedding,
589    ) -> IsingResult<Self> {
590        let total_chains = embedding.chains.len();
591        let mut broken_chains = Vec::new();
592        let mut break_counts: HashMap<usize, usize> = HashMap::new();
593
594        for hw_solution in hardware_solutions {
595            let mut breaks_in_solution = 0;
596
597            for (&var, chain) in &embedding.chains {
598                if chain.len() > 1 {
599                    let first_spin = hw_solution.spins[chain[0]];
600                    let is_broken = chain[1..]
601                        .iter()
602                        .any(|&q| hw_solution.spins[q] != first_spin);
603
604                    if is_broken {
605                        breaks_in_solution += 1;
606                        *break_counts.entry(var).or_insert(0) += 1;
607                    }
608                }
609            }
610
611            broken_chains.push(breaks_in_solution);
612        }
613
614        // Calculate statistics
615        let total_breaks: usize = broken_chains.iter().sum();
616        let break_rate = if hardware_solutions.is_empty() || total_chains == 0 {
617            0.0
618        } else {
619            total_breaks as f64 / (hardware_solutions.len() * total_chains) as f64
620        };
621
622        // Find most frequently broken variables
623        let mut frequent_breaks: Vec<(usize, usize)> = break_counts.into_iter().collect();
624        frequent_breaks.sort_by_key(|&(_, count)| std::cmp::Reverse(count));
625        frequent_breaks.truncate(10); // Keep top 10
626
627        Ok(Self {
628            total_chains,
629            broken_chains,
630            break_rate,
631            frequent_breaks,
632        })
633    }
634
635    /// Get recommendations based on statistics
636    #[must_use]
637    pub fn get_recommendations(&self) -> Vec<String> {
638        let mut recommendations = Vec::new();
639
640        if self.break_rate > 0.5 {
641            recommendations.push(
642                "High chain break rate detected. Consider increasing chain strength.".to_string(),
643            );
644        }
645
646        if self.break_rate > 0.2 {
647            recommendations.push(
648                "Moderate chain breaks. Try optimizing embedding or chain strength.".to_string(),
649            );
650        }
651
652        if !self.frequent_breaks.is_empty() {
653            let vars: Vec<String> = self
654                .frequent_breaks
655                .iter()
656                .take(3)
657                .map(|(var, _)| var.to_string())
658                .collect();
659            recommendations.push(format!(
660                "Variables {} frequently have broken chains. Check embedding quality.",
661                vars.join(", ")
662            ));
663        }
664
665        recommendations
666    }
667}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672
673    #[test]
674    fn test_majority_vote_resolution() {
675        let mut embedding = Embedding::new();
676        embedding
677            .add_chain(0, vec![0, 1, 2])
678            .expect("failed to add chain in test");
679        embedding
680            .add_chain(1, vec![3, 4, 5])
681            .expect("failed to add chain in test");
682
683        let hw_solution = HardwareSolution {
684            spins: vec![1, 1, -1, -1, -1, -1], // First chain: 2 vs 1, second: unanimous
685            energy: -1.0,
686            occurrences: 1,
687        };
688
689        let resolver = ChainBreakResolver::default();
690        let resolved = resolver
691            .resolve_solution(&hw_solution, &embedding, None)
692            .expect("failed to resolve solution in test");
693
694        assert_eq!(resolved.logical_spins, vec![1, -1]);
695        assert_eq!(resolved.chain_breaks, 1); // First chain is broken
696    }
697
698    #[test]
699    fn test_chain_strength_optimizer() {
700        let mut problem = LogicalProblem::new(3);
701        problem.linear = vec![1.0, -0.5, 0.0];
702        problem.quadratic.insert((0, 1), -2.0);
703        problem.quadratic.insert((1, 2), 1.5);
704
705        let optimizer = ChainStrengthOptimizer::default();
706        let strength = optimizer.find_optimal_strength(&problem);
707
708        // Should be around the median of coefficients
709        assert!(strength > 0.5 && strength < 5.0);
710    }
711
712    #[test]
713    fn test_chain_break_stats() {
714        let mut embedding = Embedding::new();
715        embedding
716            .add_chain(0, vec![0, 1])
717            .expect("failed to add chain in test");
718        embedding
719            .add_chain(1, vec![2, 3])
720            .expect("failed to add chain in test");
721
722        let solutions = vec![
723            HardwareSolution {
724                spins: vec![1, 1, -1, -1], // No breaks
725                energy: -1.0,
726                occurrences: 1,
727            },
728            HardwareSolution {
729                spins: vec![1, -1, -1, -1], // First chain broken
730                energy: -0.5,
731                occurrences: 1,
732            },
733        ];
734
735        let stats = ChainBreakStats::analyze(&solutions, &embedding)
736            .expect("failed to analyze chain break stats in test");
737
738        assert_eq!(stats.total_chains, 2);
739        assert_eq!(stats.broken_chains, vec![0, 1]);
740        assert_eq!(stats.break_rate, 0.25); // 1 break out of 4 chain instances
741    }
742}