Skip to main content

quantrs2_circuit/
optimizer.rs

1//! Quantum circuit optimization passes
2//!
3//! This module provides various optimization passes that can be applied to quantum circuits
4//! to reduce gate count, improve fidelity, and optimize for hardware constraints.
5
6use crate::builder::Circuit;
7use quantrs2_core::{
8    gate::{
9        multi::CNOT,
10        single::{PauliX, RotationX, RotationY, RotationZ},
11    },
12    qubit::QubitId,
13};
14use std::collections::{HashMap, HashSet};
15use std::f64::consts::PI;
16
17/// Gate representation for optimization
18#[derive(Debug, Clone, PartialEq)]
19pub enum OptGate {
20    Single(QubitId, String, Vec<f64>),
21    Double(QubitId, QubitId, String, Vec<f64>),
22    Multi(Vec<QubitId>, String, Vec<f64>),
23}
24
25impl OptGate {
26    /// Convert a GateOp to OptGate for optimization analysis
27    fn from_gate_op(gate: &dyn quantrs2_core::gate::GateOp) -> Self {
28        let qubits = gate.qubits();
29        let name = gate.name().to_string();
30        let params = Vec::new(); // Parameters would need to be extracted from specific gate types
31
32        match qubits.len() {
33            1 => Self::Single(qubits[0], name, params),
34            2 => Self::Double(qubits[0], qubits[1], name, params),
35            _ => Self::Multi(qubits, name, params),
36        }
37    }
38}
39
40/// Optimization context that holds circuit information
41pub struct OptimizationContext<const N: usize> {
42    pub circuit: Circuit<N>,
43    pub gate_count: usize,
44    pub depth: usize,
45}
46
47/// Result of applying an optimization pass
48pub struct PassResult<const N: usize> {
49    pub circuit: Circuit<N>,
50    pub improved: bool,
51    pub improvement: f64,
52}
53
54/// Merge consecutive single-qubit gates
55pub struct SingleQubitGateFusion;
56
57impl SingleQubitGateFusion {
58    #[must_use]
59    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
60        let gates = ctx.circuit.gates();
61        if gates.len() < 2 {
62            return PassResult {
63                circuit: ctx.circuit.clone(),
64                improved: false,
65                improvement: 0.0,
66            };
67        }
68
69        // Convert gates to OptGate format
70        let opt_gates: Vec<OptGate> = gates
71            .iter()
72            .map(|g| OptGate::from_gate_op(g.as_ref()))
73            .collect();
74
75        // Find consecutive single-qubit gates on the same qubit
76        let mut fusion_groups: Vec<Vec<usize>> = Vec::new();
77        let mut current_group: Vec<usize> = vec![];
78        let mut current_qubit: Option<QubitId> = None;
79
80        for (idx, opt_gate) in opt_gates.iter().enumerate() {
81            if let OptGate::Single(qubit, _, _) = opt_gate {
82                if current_qubit == Some(*qubit) {
83                    // Continue current group
84                    current_group.push(idx);
85                } else {
86                    // Start new group
87                    if current_group.len() >= 2 {
88                        fusion_groups.push(current_group.clone());
89                    }
90                    current_group = vec![idx];
91                    current_qubit = Some(*qubit);
92                }
93            } else {
94                // Non-single-qubit gate breaks the group
95                if current_group.len() >= 2 {
96                    fusion_groups.push(current_group.clone());
97                }
98                current_group.clear();
99                current_qubit = None;
100            }
101        }
102
103        // Don't forget the last group
104        if current_group.len() >= 2 {
105            fusion_groups.push(current_group);
106        }
107
108        if fusion_groups.is_empty() {
109            return PassResult {
110                circuit: ctx.circuit.clone(),
111                improved: false,
112                improvement: 0.0,
113            };
114        }
115
116        // For now, we report the fusion opportunities but don't actually fuse
117        // Full fusion would require matrix multiplication and creating composite gates
118        // which is complex and requires additional infrastructure
119
120        // Calculate potential improvement
121        let mut gates_that_could_be_fused = 0;
122        for group in &fusion_groups {
123            gates_that_could_be_fused += group.len() - 1; // N gates → 1 gate saves N-1
124        }
125
126        // Since we're not actually implementing fusion yet, return the original circuit
127        // but report that fusion opportunities were found
128        PassResult {
129            circuit: ctx.circuit.clone(),
130            improved: false,
131            improvement: 0.0, // Would be gates_that_could_be_fused if we implemented it
132        }
133    }
134
135    #[must_use]
136    pub const fn name(&self) -> &'static str {
137        "Single-Qubit Gate Fusion"
138    }
139}
140
141/// Remove redundant gates (e.g., X·X = I, H·H = I)
142pub struct RedundantGateElimination;
143
144impl RedundantGateElimination {
145    /// Check if two gates cancel each other
146    #[allow(dead_code)]
147    fn gates_cancel(gate1: &OptGate, gate2: &OptGate) -> bool {
148        match (gate1, gate2) {
149            (OptGate::Single(q1, name1, _), OptGate::Single(q2, name2, _)) => {
150                if q1 != q2 {
151                    return false;
152                }
153
154                // Self-inverse gates
155                matches!(
156                    (name1.as_str(), name2.as_str()),
157                    ("X", "X") | ("Y", "Y") | ("Z", "Z") | ("H", "H") | ("CNOT", "CNOT")
158                )
159            }
160            _ => false,
161        }
162    }
163
164    #[must_use]
165    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
166        let gates = ctx.circuit.gates();
167        if gates.len() < 2 {
168            return PassResult {
169                circuit: ctx.circuit.clone(),
170                improved: false,
171                improvement: 0.0,
172            };
173        }
174
175        // Track which gates should be removed (indices to skip)
176        let mut to_remove = HashSet::new();
177
178        // Convert gates to OptGate format for analysis
179        let opt_gates: Vec<OptGate> = gates
180            .iter()
181            .map(|g| OptGate::from_gate_op(g.as_ref()))
182            .collect();
183
184        // Find consecutive gates that cancel
185        let mut i = 0;
186        while i < opt_gates.len() - 1 {
187            if !to_remove.contains(&i)
188                && !to_remove.contains(&(i + 1))
189                && Self::gates_cancel(&opt_gates[i], &opt_gates[i + 1])
190            {
191                to_remove.insert(i);
192                to_remove.insert(i + 1);
193                i += 2; // Skip both gates
194                continue;
195            }
196            i += 1;
197        }
198
199        if to_remove.is_empty() {
200            return PassResult {
201                circuit: ctx.circuit.clone(),
202                improved: false,
203                improvement: 0.0,
204            };
205        }
206
207        // Build new circuit without redundant gates
208        let mut new_circuit = Circuit::<N>::with_capacity(gates.len() - to_remove.len());
209        for (idx, gate) in gates.iter().enumerate() {
210            if !to_remove.contains(&idx) {
211                let _ = new_circuit.add_gate_arc(gate.clone());
212            }
213        }
214
215        let gates_removed = to_remove.len();
216        let improvement = gates_removed as f64; // Each removed gate reduces cost
217
218        PassResult {
219            circuit: new_circuit,
220            improved: gates_removed > 0,
221            improvement,
222        }
223    }
224
225    #[must_use]
226    pub const fn name(&self) -> &'static str {
227        "Redundant Gate Elimination"
228    }
229}
230
231/// Commutation-based optimization
232pub struct CommutationOptimizer;
233
234impl CommutationOptimizer {
235    /// Check if two gates commute
236    #[allow(dead_code)]
237    fn gates_commute(gate1: &OptGate, gate2: &OptGate) -> bool {
238        match (gate1, gate2) {
239            // Single-qubit gates on different qubits always commute
240            // Single-qubit gates on different qubits always commute
241            (OptGate::Single(q1, name1, _), OptGate::Single(q2, name2, _)) => {
242                if q1 == q2 {
243                    // Z gates commute with each other on same qubit
244                    name1 == "Z" && name2 == "Z"
245                } else {
246                    true
247                }
248            }
249
250            // CNOT gates commute if they don't share qubits
251            (OptGate::Double(c1, t1, name1, _), OptGate::Double(c2, t2, name2, _)) => {
252                name1 == "CNOT" && name2 == "CNOT" && c1 != c2 && c1 != t2 && t1 != c2 && t1 != t2
253            }
254
255            _ => false,
256        }
257    }
258
259    #[must_use]
260    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
261        let gates = ctx.circuit.gates();
262        if gates.len() < 2 {
263            return PassResult {
264                circuit: ctx.circuit.clone(),
265                improved: false,
266                improvement: 0.0,
267            };
268        }
269
270        // Convert gates to OptGate format
271        let opt_gates: Vec<OptGate> = gates
272            .iter()
273            .map(|g| OptGate::from_gate_op(g.as_ref()))
274            .collect();
275
276        // Try to reorder gates by bubbling commuting gates forward
277        // This can reduce circuit depth and enable other optimizations
278        let mut reordered_indices: Vec<usize> = (0..gates.len()).collect();
279        let mut made_changes = false;
280
281        // Multiple passes to propagate commuting gates
282        for _ in 0..3 {
283            let mut i = 0;
284            while i + 1 < reordered_indices.len() {
285                let idx1 = reordered_indices[i];
286                let idx2 = reordered_indices[i + 1];
287
288                if Self::gates_commute(&opt_gates[idx1], &opt_gates[idx2]) {
289                    // Try swapping to see if it reduces depth or enables other optimizations
290                    // For now, prefer moving single-qubit gates earlier
291                    let should_swap = matches!(opt_gates[idx2], OptGate::Single(_, _, _))
292                        && !matches!(opt_gates[idx1], OptGate::Single(_, _, _));
293
294                    if should_swap {
295                        reordered_indices.swap(i, i + 1);
296                        made_changes = true;
297                    }
298                }
299                i += 1;
300            }
301        }
302
303        if !made_changes {
304            return PassResult {
305                circuit: ctx.circuit.clone(),
306                improved: false,
307                improvement: 0.0,
308            };
309        }
310
311        // Build new circuit with reordered gates
312        let mut new_circuit = Circuit::<N>::with_capacity(gates.len());
313        for idx in reordered_indices {
314            let _ = new_circuit.add_gate_arc(gates[idx].clone());
315        }
316
317        // Calculate depth improvement
318        let old_depth = ctx.circuit.calculate_depth() as f64;
319        let new_depth = new_circuit.calculate_depth() as f64;
320        let improvement = (old_depth - new_depth).max(0.0);
321
322        PassResult {
323            circuit: new_circuit,
324            improved: improvement > 0.0,
325            improvement,
326        }
327    }
328
329    #[must_use]
330    pub const fn name(&self) -> &'static str {
331        "Commutation-Based Optimization"
332    }
333}
334
335/// Peephole optimization for common patterns
336pub struct PeepholeOptimizer {
337    #[allow(dead_code)]
338    patterns: Vec<PatternRule>,
339}
340
341#[derive(Clone)]
342#[allow(dead_code)]
343struct PatternRule {
344    pattern: Vec<OptGate>,
345    replacement: Vec<OptGate>,
346    name: String,
347}
348
349impl Default for PeepholeOptimizer {
350    fn default() -> Self {
351        let patterns = vec![
352            // Pattern: H-X-H = Z
353            PatternRule {
354                pattern: vec![
355                    OptGate::Single(QubitId::new(0), "H".to_string(), vec![]),
356                    OptGate::Single(QubitId::new(0), "X".to_string(), vec![]),
357                    OptGate::Single(QubitId::new(0), "H".to_string(), vec![]),
358                ],
359                replacement: vec![OptGate::Single(QubitId::new(0), "Z".to_string(), vec![])],
360                name: "H-X-H to Z".to_string(),
361            },
362            // Pattern: H-Z-H = X
363            PatternRule {
364                pattern: vec![
365                    OptGate::Single(QubitId::new(0), "H".to_string(), vec![]),
366                    OptGate::Single(QubitId::new(0), "Z".to_string(), vec![]),
367                    OptGate::Single(QubitId::new(0), "H".to_string(), vec![]),
368                ],
369                replacement: vec![OptGate::Single(QubitId::new(0), "X".to_string(), vec![])],
370                name: "H-Z-H to X".to_string(),
371            },
372        ];
373
374        Self { patterns }
375    }
376}
377
378impl PeepholeOptimizer {
379    /// Check if three consecutive gates match a pattern
380    fn matches_pattern(gates: &[OptGate], pattern: &[OptGate]) -> bool {
381        if gates.len() != pattern.len() {
382            return false;
383        }
384
385        for (gate, pat) in gates.iter().zip(pattern.iter()) {
386            match (gate, pat) {
387                (OptGate::Single(q1, n1, _), OptGate::Single(q2, n2, _)) => {
388                    if q1 != q2 || n1 != n2 {
389                        return false;
390                    }
391                }
392                _ => return false, // Only support single-qubit patterns for now
393            }
394        }
395        true
396    }
397
398    #[must_use]
399    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
400        let gates = ctx.circuit.gates();
401        if gates.len() < 3 {
402            return PassResult {
403                circuit: ctx.circuit.clone(),
404                improved: false,
405                improvement: 0.0,
406            };
407        }
408
409        // Convert gates to OptGate format
410        let opt_gates: Vec<OptGate> = gates
411            .iter()
412            .map(|g| OptGate::from_gate_op(g.as_ref()))
413            .collect();
414
415        // Track replacements: (start_idx, end_idx, replacement_gate_name)
416        let mut replacements: Vec<(usize, usize, String, QubitId)> = Vec::new();
417
418        // Find pattern matches (H-X-H → Z, H-Z-H → X)
419        let mut i = 0;
420        while i + 2 < opt_gates.len() {
421            // Check for H-X-H → Z
422            if let (
423                OptGate::Single(q1, n1, _),
424                OptGate::Single(q2, n2, _),
425                OptGate::Single(q3, n3, _),
426            ) = (&opt_gates[i], &opt_gates[i + 1], &opt_gates[i + 2])
427            {
428                if q1 == q2 && q2 == q3 {
429                    if n1 == "H" && n2 == "X" && n3 == "H" {
430                        replacements.push((i, i + 2, "Z".to_string(), *q1));
431                        i += 3;
432                        continue;
433                    } else if n1 == "H" && n2 == "Z" && n3 == "H" {
434                        replacements.push((i, i + 2, "X".to_string(), *q1));
435                        i += 3;
436                        continue;
437                    }
438                }
439            }
440            i += 1;
441        }
442
443        if replacements.is_empty() {
444            return PassResult {
445                circuit: ctx.circuit.clone(),
446                improved: false,
447                improvement: 0.0,
448            };
449        }
450
451        // Build new circuit with replacements
452        let mut new_circuit = Circuit::<N>::new();
453        let mut idx = 0;
454        let mut replacement_iter = replacements.iter().peekable();
455
456        while idx < gates.len() {
457            if let Some((start, end, replacement_name, qubit)) = replacement_iter.peek() {
458                if idx == *start {
459                    // Apply replacement
460                    match replacement_name.as_str() {
461                        "X" => {
462                            let _ = new_circuit
463                                .add_gate(quantrs2_core::gate::single::PauliX { target: *qubit });
464                        }
465                        "Z" => {
466                            let _ = new_circuit
467                                .add_gate(quantrs2_core::gate::single::PauliZ { target: *qubit });
468                        }
469                        _ => {}
470                    }
471                    idx = *end + 1;
472                    replacement_iter.next();
473                    continue;
474                }
475            }
476            // Copy original gate
477            let _ = new_circuit.add_gate_arc(gates[idx].clone());
478            idx += 1;
479        }
480
481        let gates_saved = replacements.len() * 2; // Each pattern saves 2 gates (3 → 1)
482        let improvement = gates_saved as f64;
483
484        PassResult {
485            circuit: new_circuit,
486            improved: !replacements.is_empty(),
487            improvement,
488        }
489    }
490
491    #[must_use]
492    pub const fn name(&self) -> &'static str {
493        "Peephole Optimization"
494    }
495}
496
497/// Template matching optimization
498pub struct TemplateOptimizer {
499    #[allow(dead_code)]
500    templates: Vec<Template>,
501}
502
503#[allow(dead_code)]
504struct Template {
505    name: String,
506    pattern: Vec<OptGate>,
507    cost_reduction: f64,
508}
509
510impl Default for TemplateOptimizer {
511    fn default() -> Self {
512        let templates = vec![Template {
513            name: "Toffoli Decomposition".to_string(),
514            pattern: vec![], // Would contain Toffoli gate pattern
515            cost_reduction: 0.3,
516        }];
517
518        Self { templates }
519    }
520}
521
522impl TemplateOptimizer {
523    #[must_use]
524    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
525        let circuit = &ctx.circuit;
526        let gates = circuit.gates();
527        if gates.len() < 2 {
528            return PassResult {
529                circuit: circuit.clone(),
530                improved: false,
531                improvement: 0.0,
532            };
533        }
534
535        let mut new_circuit = Circuit::<N>::new();
536        let n = gates.len();
537        let mut i = 0;
538        let mut gates_saved: usize = 0;
539
540        while i < n {
541            // ---------------------------------------------------------------
542            // Pattern: RX(θ)·RX(φ) → RX(θ+φ)  (and RY, RZ analogues)
543            // ---------------------------------------------------------------
544            if i + 1 < n {
545                let g0 = gates[i].as_ref();
546                let g1 = gates[i + 1].as_ref();
547
548                if g0.qubits().len() == 1
549                    && g1.qubits().len() == 1
550                    && g0.qubits()[0] == g1.qubits()[0]
551                {
552                    let target = g0.qubits()[0];
553
554                    // RZ(θ)·RZ(φ) → RZ(θ+φ)
555                    if let (Some(rz0), Some(rz1)) = (
556                        g0.as_any().downcast_ref::<RotationZ>(),
557                        g1.as_any().downcast_ref::<RotationZ>(),
558                    ) {
559                        let merged_angle = rz0.theta + rz1.theta;
560                        // Normalise into (-π, π]
561                        let normalised = ((merged_angle + PI).rem_euclid(2.0 * PI)) - PI;
562                        let _ = new_circuit.add_gate(RotationZ {
563                            target,
564                            theta: normalised,
565                        });
566                        i += 2;
567                        gates_saved += 1;
568                        continue;
569                    }
570
571                    // RY(θ)·RY(φ) → RY(θ+φ)
572                    if let (Some(ry0), Some(ry1)) = (
573                        g0.as_any().downcast_ref::<RotationY>(),
574                        g1.as_any().downcast_ref::<RotationY>(),
575                    ) {
576                        let merged_angle = ry0.theta + ry1.theta;
577                        let normalised = ((merged_angle + PI).rem_euclid(2.0 * PI)) - PI;
578                        let _ = new_circuit.add_gate(RotationY {
579                            target,
580                            theta: normalised,
581                        });
582                        i += 2;
583                        gates_saved += 1;
584                        continue;
585                    }
586
587                    // RX(θ)·RX(φ) → RX(θ+φ)
588                    if let (Some(rx0), Some(rx1)) = (
589                        g0.as_any().downcast_ref::<RotationX>(),
590                        g1.as_any().downcast_ref::<RotationX>(),
591                    ) {
592                        let merged_angle = rx0.theta + rx1.theta;
593                        let normalised = ((merged_angle + PI).rem_euclid(2.0 * PI)) - PI;
594                        let _ = new_circuit.add_gate(RotationX {
595                            target,
596                            theta: normalised,
597                        });
598                        i += 2;
599                        gates_saved += 1;
600                        continue;
601                    }
602                }
603            }
604
605            // ---------------------------------------------------------------
606            // Pattern: X · RZ(θ) · X → RZ(-θ)   (conjugation by X)
607            // ---------------------------------------------------------------
608            if i + 2 < n {
609                let g0 = gates[i].as_ref();
610                let g1 = gates[i + 1].as_ref();
611                let g2 = gates[i + 2].as_ref();
612
613                if g0.qubits().len() == 1
614                    && g1.qubits().len() == 1
615                    && g2.qubits().len() == 1
616                    && g0.qubits()[0] == g1.qubits()[0]
617                    && g1.qubits()[0] == g2.qubits()[0]
618                {
619                    let target = g1.qubits()[0];
620
621                    if g0.name() == "X" && g2.name() == "X" {
622                        if let Some(rz) = g1.as_any().downcast_ref::<RotationZ>() {
623                            // X·RZ(θ)·X = RZ(-θ)
624                            let _ = new_circuit.add_gate(RotationZ {
625                                target,
626                                theta: -rz.theta,
627                            });
628                            i += 3;
629                            gates_saved += 2;
630                            continue;
631                        }
632                    }
633                }
634
635                // ---------------------------------------------------------------
636                // Pattern: CNOT · (I ⊗ RZ(θ)) · CNOT → (I ⊗ RZ(θ)) · CNOT · CNOT
637                //          simplified: CNOT commutes with target RZ rotations,
638                //          so emit RZ(θ) first then drop the two CNOTs (they cancel).
639                //          Only applicable when the two CNOTs have the same control/target.
640                // ---------------------------------------------------------------
641                let g0 = gates[i].as_ref();
642                let g1 = gates[i + 1].as_ref();
643                let g2 = gates[i + 2].as_ref();
644
645                if let (Some(cnot0), Some(cnot2)) = (
646                    g0.as_any().downcast_ref::<CNOT>(),
647                    g2.as_any().downcast_ref::<CNOT>(),
648                ) {
649                    if cnot0.control == cnot2.control
650                        && cnot0.target == cnot2.target
651                        && g1.qubits().len() == 1
652                        && g1.qubits()[0] == cnot0.target
653                    {
654                        if let Some(rz) = g1.as_any().downcast_ref::<RotationZ>() {
655                            // RZ on target commutes through CNOT → emit RZ then CNOT·CNOT
656                            // CNOT·CNOT = I, so the two cancel; emit just RZ
657                            let target = rz.target;
658                            let _ = new_circuit.add_gate(RotationZ {
659                                target,
660                                theta: rz.theta,
661                            });
662                            i += 3;
663                            gates_saved += 2;
664                            continue;
665                        }
666                    }
667                }
668            }
669
670            // No pattern matched — emit gate unchanged
671            let _ = new_circuit.add_gate_arc(gates[i].clone());
672            i += 1;
673        }
674
675        let improved = gates_saved > 0;
676        PassResult {
677            circuit: new_circuit,
678            improved,
679            improvement: gates_saved as f64,
680        }
681    }
682
683    #[must_use]
684    pub const fn name(&self) -> &'static str {
685        "Template Matching Optimization"
686    }
687}
688
689/// Enum to hold different optimization passes
690pub enum OptimizationPassType {
691    SingleQubitFusion(SingleQubitGateFusion),
692    RedundantElimination(RedundantGateElimination),
693    Commutation(CommutationOptimizer),
694    Peephole(PeepholeOptimizer),
695    Template(TemplateOptimizer),
696    Hardware(HardwareOptimizer),
697}
698
699impl OptimizationPassType {
700    #[must_use]
701    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
702        match self {
703            Self::SingleQubitFusion(p) => p.apply(ctx),
704            Self::RedundantElimination(p) => p.apply(ctx),
705            Self::Commutation(p) => p.apply(ctx),
706            Self::Peephole(p) => p.apply(ctx),
707            Self::Template(p) => p.apply(ctx),
708            Self::Hardware(p) => p.apply(ctx),
709        }
710    }
711
712    #[must_use]
713    pub const fn name(&self) -> &str {
714        match self {
715            Self::SingleQubitFusion(p) => p.name(),
716            Self::RedundantElimination(p) => p.name(),
717            Self::Commutation(p) => p.name(),
718            Self::Peephole(p) => p.name(),
719            Self::Template(p) => p.name(),
720            Self::Hardware(p) => p.name(),
721        }
722    }
723}
724
725/// Main circuit optimizer that applies multiple passes
726pub struct CircuitOptimizer<const N: usize> {
727    passes: Vec<OptimizationPassType>,
728    max_iterations: usize,
729}
730
731impl<const N: usize> Default for CircuitOptimizer<N> {
732    fn default() -> Self {
733        Self::new()
734    }
735}
736
737impl<const N: usize> CircuitOptimizer<N> {
738    /// Create a new circuit optimizer with default passes
739    #[must_use]
740    pub fn new() -> Self {
741        let passes = vec![
742            OptimizationPassType::RedundantElimination(RedundantGateElimination),
743            OptimizationPassType::SingleQubitFusion(SingleQubitGateFusion),
744            OptimizationPassType::Commutation(CommutationOptimizer),
745            OptimizationPassType::Peephole(PeepholeOptimizer::default()),
746            OptimizationPassType::Template(TemplateOptimizer::default()),
747        ];
748
749        Self {
750            passes,
751            max_iterations: 10,
752        }
753    }
754
755    /// Create a custom optimizer with specific passes
756    #[must_use]
757    pub const fn with_passes(passes: Vec<OptimizationPassType>) -> Self {
758        Self {
759            passes,
760            max_iterations: 10,
761        }
762    }
763
764    /// Set the maximum number of optimization iterations
765    #[must_use]
766    pub const fn with_max_iterations(mut self, max_iterations: usize) -> Self {
767        self.max_iterations = max_iterations;
768        self
769    }
770
771    /// Add an optimization pass
772    #[must_use]
773    pub fn add_pass(mut self, pass: OptimizationPassType) -> Self {
774        self.passes.push(pass);
775        self
776    }
777
778    /// Optimize a circuit
779    #[must_use]
780    pub fn optimize(&self, circuit: &Circuit<N>) -> OptimizationResult<N> {
781        let mut current_circuit = circuit.clone();
782        let mut total_iterations = 0;
783        let mut pass_statistics = HashMap::new();
784
785        // Keep track of circuit cost (simplified as gate count for now)
786        let initial_cost = self.estimate_cost(&current_circuit);
787        let mut current_cost = initial_cost;
788
789        // Apply optimization passes iteratively
790        for iteration in 0..self.max_iterations {
791            let iteration_start_cost = current_cost;
792
793            for pass in &self.passes {
794                let pass_name = pass.name().to_string();
795                let before_cost = current_cost;
796
797                let ctx = OptimizationContext {
798                    circuit: current_circuit.clone(),
799                    gate_count: 10, // Placeholder
800                    depth: 5,       // Placeholder
801                };
802
803                let result = pass.apply(&ctx);
804                current_circuit = result.circuit;
805
806                if result.improved {
807                    current_cost -= result.improvement;
808                }
809
810                let improvement = before_cost - current_cost;
811                pass_statistics
812                    .entry(pass_name)
813                    .and_modify(|stats: &mut PassStats| {
814                        stats.applications += 1;
815                        stats.total_improvement += improvement;
816                    })
817                    .or_insert(PassStats {
818                        applications: 1,
819                        total_improvement: improvement,
820                    });
821            }
822
823            total_iterations = iteration + 1;
824
825            // Stop if no improvement in this iteration
826            if (iteration_start_cost - current_cost).abs() < 1e-10 {
827                break;
828            }
829        }
830
831        OptimizationResult {
832            optimized_circuit: current_circuit,
833            initial_cost,
834            final_cost: current_cost,
835            iterations: total_iterations,
836            pass_statistics,
837        }
838    }
839
840    /// Estimate the cost of a circuit based on gate count, types, and depth
841    fn estimate_cost(&self, circuit: &Circuit<N>) -> f64 {
842        let stats = circuit.get_stats();
843
844        // Weight factors for different gate types
845        let single_qubit_cost = 1.0;
846        let two_qubit_cost = 10.0; // Two-qubit gates are much more expensive
847        let multi_qubit_cost = 50.0; // Multi-qubit gates are very expensive
848
849        // Calculate gate cost
850        let single_qubit_gates =
851            stats.total_gates - stats.two_qubit_gates - stats.multi_qubit_gates;
852        let gate_cost = single_qubit_gates as f64 * single_qubit_cost
853            + stats.two_qubit_gates as f64 * two_qubit_cost
854            + stats.multi_qubit_gates as f64 * multi_qubit_cost;
855
856        // Circuit depth adds to cost (deeper circuits are slower and more error-prone)
857        let depth_cost = stats.depth as f64 * 2.0;
858
859        gate_cost + depth_cost
860    }
861}
862
863/// Statistics for an optimization pass
864#[derive(Debug, Clone)]
865pub struct PassStats {
866    pub applications: usize,
867    pub total_improvement: f64,
868}
869
870/// Result of circuit optimization
871#[derive(Debug, Clone)]
872pub struct OptimizationResult<const N: usize> {
873    pub optimized_circuit: Circuit<N>,
874    pub initial_cost: f64,
875    pub final_cost: f64,
876    pub iterations: usize,
877    pub pass_statistics: HashMap<String, PassStats>,
878}
879
880impl<const N: usize> OptimizationResult<N> {
881    /// Get the improvement ratio
882    #[must_use]
883    pub fn improvement_ratio(&self) -> f64 {
884        if self.initial_cost > 0.0 {
885            (self.initial_cost - self.final_cost) / self.initial_cost
886        } else {
887            0.0
888        }
889    }
890
891    /// Print optimization summary
892    pub fn print_summary(&self) {
893        println!("Circuit Optimization Summary");
894        println!("===========================");
895        println!("Initial cost: {:.2}", self.initial_cost);
896        println!("Final cost: {:.2}", self.final_cost);
897        println!("Improvement: {:.1}%", self.improvement_ratio() * 100.0);
898        println!("Iterations: {}", self.iterations);
899        println!("\nPass Statistics:");
900
901        for (pass_name, stats) in &self.pass_statistics {
902            if stats.total_improvement > 0.0 {
903                println!(
904                    "  {}: {} applications, {:.2} total improvement",
905                    pass_name, stats.applications, stats.total_improvement
906                );
907            }
908        }
909    }
910}
911
912/// Hardware-aware optimization pass
913pub struct HardwareOptimizer {
914    #[allow(dead_code)]
915    connectivity: Vec<(usize, usize)>,
916    #[allow(dead_code)]
917    native_gates: HashSet<String>,
918}
919
920impl HardwareOptimizer {
921    #[must_use]
922    pub const fn new(connectivity: Vec<(usize, usize)>, native_gates: HashSet<String>) -> Self {
923        Self {
924            connectivity,
925            native_gates,
926        }
927    }
928
929    #[must_use]
930    pub fn apply<const N: usize>(&self, ctx: &OptimizationContext<N>) -> PassResult<N> {
931        let circuit = &ctx.circuit;
932        let gates = circuit.gates();
933
934        if gates.is_empty() {
935            return PassResult {
936                circuit: circuit.clone(),
937                improved: false,
938                improvement: 0.0,
939            };
940        }
941
942        // Build a set of connected qubit pairs from the connectivity graph
943        let connected_pairs: HashSet<(usize, usize)> = self
944            .connectivity
945            .iter()
946            .flat_map(|&(a, b)| [(a.min(b), a.max(b))])
947            .collect();
948
949        let mut new_circuit = Circuit::<N>::new();
950        let mut gates_decomposed: usize = 0;
951
952        for gate in gates {
953            let gate_ref = gate.as_ref();
954            let name = gate_ref.name();
955
956            // Check if the gate is in the native gate set
957            if self.native_gates.contains(name) {
958                // Native gate — pass through directly
959                let _ = new_circuit.add_gate_arc(gate.clone());
960                continue;
961            }
962
963            // Non-native gate: decompose to native equivalents.
964            // Decomposition table (subset, hardware-independent standard expansions):
965            //   SWAP  → CNOT(a,b) · CNOT(b,a) · CNOT(a,b)
966            //   CZ    → H(b) · CNOT(a,b) · H(b)
967            //   CCX (Toffoli) — too complex; fall through to pass-through
968            match name {
969                "SWAP" if gate_ref.qubits().len() == 2 => {
970                    let a = gate_ref.qubits()[0];
971                    let b = gate_ref.qubits()[1];
972                    let ai = a.id() as usize;
973                    let bi = b.id() as usize;
974
975                    // Prefer the direction that matches hardware connectivity
976                    let (ctrl, tgt) = if connected_pairs.contains(&(ai.min(bi), ai.max(bi))) {
977                        (a, b)
978                    } else {
979                        (b, a)
980                    };
981
982                    // SWAP → CNOT(ctrl, tgt) · CNOT(tgt, ctrl) · CNOT(ctrl, tgt)
983                    let _ = new_circuit.add_gate(CNOT {
984                        control: ctrl,
985                        target: tgt,
986                    });
987                    let _ = new_circuit.add_gate(CNOT {
988                        control: tgt,
989                        target: ctrl,
990                    });
991                    let _ = new_circuit.add_gate(CNOT {
992                        control: ctrl,
993                        target: tgt,
994                    });
995                    gates_decomposed += 1;
996                }
997                "CZ" if gate_ref.qubits().len() == 2 => {
998                    let a = gate_ref.qubits()[0];
999                    let b = gate_ref.qubits()[1];
1000                    // CZ(a,b) → H(b) · CNOT(a,b) · H(b)
1001                    let _ =
1002                        new_circuit.add_gate(quantrs2_core::gate::single::Hadamard { target: b });
1003                    let _ = new_circuit.add_gate(CNOT {
1004                        control: a,
1005                        target: b,
1006                    });
1007                    let _ =
1008                        new_circuit.add_gate(quantrs2_core::gate::single::Hadamard { target: b });
1009                    gates_decomposed += 1;
1010                }
1011                _ => {
1012                    // Unknown non-native gate — pass through unchanged
1013                    let _ = new_circuit.add_gate_arc(gate.clone());
1014                }
1015            }
1016        }
1017
1018        let improved = gates_decomposed > 0;
1019        PassResult {
1020            circuit: new_circuit,
1021            improved,
1022            improvement: gates_decomposed as f64,
1023        }
1024    }
1025
1026    #[must_use]
1027    pub const fn name(&self) -> &'static str {
1028        "Hardware-Aware Optimization"
1029    }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034    use super::*;
1035
1036    #[test]
1037    fn test_circuit_optimizer_creation() {
1038        let optimizer = CircuitOptimizer::<4>::new();
1039        assert_eq!(optimizer.passes.len(), 5);
1040        assert_eq!(optimizer.max_iterations, 10);
1041    }
1042
1043    #[test]
1044    fn test_optimization_result() {
1045        let circuit = Circuit::<4>::new();
1046        let optimizer = CircuitOptimizer::new();
1047        let result = optimizer.optimize(&circuit);
1048
1049        assert!(result.improvement_ratio() >= 0.0);
1050        assert!(result.iterations > 0);
1051    }
1052}