1use quantrs2_circuit::builder::Circuit;
7use quantrs2_core::{
8 error::{QuantRS2Error, QuantRS2Result},
9 gate::{multi::*, single::*, GateOp},
10 qubit::QubitId,
11};
12use scirs2_core::Complex64;
13use std::collections::{HashMap, HashSet, VecDeque};
14
15#[derive(Debug, Clone)]
17pub struct OptimizationConfig {
18 pub enable_gate_fusion: bool,
20 pub enable_redundant_elimination: bool,
22 pub enable_commutation_reordering: bool,
24 pub enable_single_qubit_optimization: bool,
26 pub enable_two_qubit_optimization: bool,
28 pub max_passes: usize,
30 pub enable_depth_reduction: bool,
32}
33
34impl Default for OptimizationConfig {
35 fn default() -> Self {
36 Self {
37 enable_gate_fusion: true,
38 enable_redundant_elimination: true,
39 enable_commutation_reordering: true,
40 enable_single_qubit_optimization: true,
41 enable_two_qubit_optimization: true,
42 max_passes: 3,
43 enable_depth_reduction: true,
44 }
45 }
46}
47
48#[derive(Debug)]
50pub struct CircuitOptimizer {
51 config: OptimizationConfig,
52 statistics: OptimizationStatistics,
53}
54
55#[derive(Debug, Default, Clone)]
57pub struct OptimizationStatistics {
58 pub original_gate_count: usize,
60 pub optimized_gate_count: usize,
62 pub original_depth: usize,
64 pub optimized_depth: usize,
66 pub redundant_gates_eliminated: usize,
68 pub gates_fused: usize,
70 pub gates_reordered: usize,
72 pub passes_performed: usize,
74 pub optimization_time_ns: u128,
76}
77
78#[derive(Debug, Clone)]
80struct DependencyGraph {
81 dependencies: HashMap<usize, Vec<usize>>,
83 gate_info: Vec<GateInfo>,
85 qubit_usage: HashMap<QubitId, Vec<usize>>,
87}
88
89#[derive(Debug, Clone)]
91struct GateInfo {
92 position: usize,
94 gate_type: String,
96 qubits: Vec<QubitId>,
98 optimized_away: bool,
100 fused_with: Vec<usize>,
102}
103
104#[derive(Debug, Clone)]
106struct SingleQubitFusion {
107 gates: Vec<usize>,
109 qubit: QubitId,
111 fused_matrix: [[Complex64; 2]; 2],
113}
114
115#[derive(Debug, Clone)]
117pub struct OptimizationResult {
118 pub success: bool,
120 pub gates_eliminated: usize,
122 pub gates_modified: usize,
124 pub depth_improvement: i32,
126 pub description: String,
128}
129
130impl CircuitOptimizer {
131 pub fn new() -> Self {
133 Self {
134 config: OptimizationConfig::default(),
135 statistics: OptimizationStatistics::default(),
136 }
137 }
138
139 pub fn with_config(config: OptimizationConfig) -> Self {
141 Self {
142 config,
143 statistics: OptimizationStatistics::default(),
144 }
145 }
146
147 pub fn optimize<const N: usize>(&mut self, circuit: &Circuit<N>) -> QuantRS2Result<Circuit<N>> {
149 let start_time = std::time::Instant::now();
150
151 self.statistics.original_gate_count = circuit.gates().len();
153 self.statistics.original_depth = self.calculate_circuit_depth(circuit);
154
155 let mut dependency_graph = self.build_dependency_graph(circuit)?;
157
158 let mut optimized_circuit = circuit.clone();
160
161 for pass in 0..self.config.max_passes {
162 let mut pass_improved = false;
163
164 if self.config.enable_redundant_elimination {
166 let result = self.eliminate_redundant_gates(&mut optimized_circuit)?;
167 if result.success {
168 pass_improved = true;
169 self.statistics.redundant_gates_eliminated += result.gates_eliminated;
170 }
171 }
172
173 if self.config.enable_single_qubit_optimization {
175 let result = self.fuse_single_qubit_gates(&mut optimized_circuit)?;
176 if result.success {
177 pass_improved = true;
178 self.statistics.gates_fused += result.gates_modified;
179 }
180 }
181
182 if self.config.enable_commutation_reordering {
184 let result = self.reorder_commuting_gates(&mut optimized_circuit)?;
185 if result.success {
186 pass_improved = true;
187 self.statistics.gates_reordered += result.gates_modified;
188 }
189 }
190
191 if self.config.enable_two_qubit_optimization {
193 let result = self.optimize_two_qubit_gates(&mut optimized_circuit)?;
194 if result.success {
195 pass_improved = true;
196 }
197 }
198
199 if self.config.enable_depth_reduction {
201 let result = self.reduce_circuit_depth(&mut optimized_circuit)?;
202 if result.success {
203 pass_improved = true;
204 }
205 }
206
207 self.statistics.passes_performed = pass + 1;
208
209 if !pass_improved {
211 break;
212 }
213 }
214
215 self.statistics.optimized_gate_count = optimized_circuit.gates().len();
217 self.statistics.optimized_depth = self.calculate_circuit_depth(&optimized_circuit);
218 self.statistics.optimization_time_ns = start_time.elapsed().as_nanos();
219
220 Ok(optimized_circuit)
221 }
222
223 pub const fn get_statistics(&self) -> &OptimizationStatistics {
225 &self.statistics
226 }
227
228 pub fn reset_statistics(&mut self) {
230 self.statistics = OptimizationStatistics::default();
231 }
232
233 fn build_dependency_graph<const N: usize>(
235 &self,
236 circuit: &Circuit<N>,
237 ) -> QuantRS2Result<DependencyGraph> {
238 let mut graph = DependencyGraph {
239 dependencies: HashMap::new(),
240 gate_info: Vec::new(),
241 qubit_usage: HashMap::new(),
242 };
243
244 for (pos, gate) in circuit.gates().iter().enumerate() {
246 let qubits = gate.qubits();
247 let gate_info = GateInfo {
248 position: pos,
249 gate_type: gate.name().to_string(),
250 qubits: qubits.clone(),
251 optimized_away: false,
252 fused_with: Vec::new(),
253 };
254
255 graph.gate_info.push(gate_info);
256
257 for &qubit in &qubits {
259 graph.qubit_usage.entry(qubit).or_default().push(pos);
260 }
261
262 let mut deps = Vec::new();
264 for &qubit in &qubits {
265 if let Some(previous_uses) = graph.qubit_usage.get(&qubit) {
266 for &prev_pos in previous_uses {
267 if prev_pos < pos {
268 deps.push(prev_pos);
269 }
270 }
271 }
272 }
273
274 graph.dependencies.insert(pos, deps);
275 }
276
277 Ok(graph)
278 }
279
280 fn calculate_circuit_depth<const N: usize>(&self, circuit: &Circuit<N>) -> usize {
282 let mut qubit_depths = HashMap::new();
283 let mut max_depth = 0;
284
285 for gate in circuit.gates() {
286 let qubits = gate.qubits();
287
288 let input_depth = qubits
290 .iter()
291 .map(|&q| qubit_depths.get(&q).copied().unwrap_or(0))
292 .max()
293 .unwrap_or(0);
294
295 let new_depth = input_depth + 1;
296
297 for &qubit in &qubits {
299 qubit_depths.insert(qubit, new_depth);
300 }
301
302 max_depth = max_depth.max(new_depth);
303 }
304
305 max_depth
306 }
307
308 fn eliminate_redundant_gates<const N: usize>(
310 &self,
311 circuit: &mut Circuit<N>,
312 ) -> QuantRS2Result<OptimizationResult> {
313 let gates = circuit.gates();
315 let mut redundant_pairs = Vec::new();
316
317 for i in 0..gates.len().saturating_sub(1) {
319 let gate1 = &gates[i];
320 let gate2 = &gates[i + 1];
321
322 if gate1.name() == gate2.name() && gate1.qubits() == gate2.qubits() {
324 match gate1.name() {
326 "H" | "X" | "Y" | "Z" | "CNOT" | "SWAP" => {
327 redundant_pairs.push((i, i + 1));
328 }
329 _ => {}
330 }
331 }
332 }
333
334 let gates_eliminated = redundant_pairs.len() * 2; Ok(OptimizationResult {
339 success: gates_eliminated > 0,
340 gates_eliminated,
341 gates_modified: redundant_pairs.len(),
342 depth_improvement: redundant_pairs.len() as i32, description: format!(
344 "Found {} redundant gate pairs for elimination",
345 redundant_pairs.len()
346 ),
347 })
348 }
349
350 fn fuse_single_qubit_gates<const N: usize>(
352 &self,
353 circuit: &mut Circuit<N>,
354 ) -> QuantRS2Result<OptimizationResult> {
355 let fusion_candidates = self.find_single_qubit_fusion_candidates(circuit)?;
357
358 let mut gates_fused = 0;
360 let candidates_count = fusion_candidates.len();
361 for candidate in &fusion_candidates {
362 if candidate.gates.len() > 1 {
363 gates_fused += candidate.gates.len() - 1; }
365 }
366
367 Ok(OptimizationResult {
368 success: gates_fused > 0,
369 gates_eliminated: gates_fused,
370 gates_modified: candidates_count,
371 depth_improvement: 0,
372 description: format!("Fused {candidates_count} single-qubit gate sequences"),
373 })
374 }
375
376 fn find_single_qubit_fusion_candidates<const N: usize>(
378 &self,
379 circuit: &Circuit<N>,
380 ) -> QuantRS2Result<Vec<SingleQubitFusion>> {
381 let mut candidates = Vec::new();
382 let mut qubit_gate_sequences: HashMap<QubitId, Vec<usize>> = HashMap::new();
383
384 for (pos, gate) in circuit.gates().iter().enumerate() {
386 let qubits = gate.qubits();
387 if qubits.len() == 1 {
388 let qubit = qubits[0];
389 qubit_gate_sequences.entry(qubit).or_default().push(pos);
390 } else {
391 for &qubit in &qubits {
393 if let Some(sequence) = qubit_gate_sequences.get(&qubit) {
394 if sequence.len() > 1 {
395 candidates
396 .push(self.create_fusion_candidate(circuit, sequence, qubit)?);
397 }
398 }
399 qubit_gate_sequences.insert(qubit, Vec::new());
400 }
401 }
402 }
403
404 for (qubit, sequence) in qubit_gate_sequences {
406 if sequence.len() > 1 {
407 candidates.push(self.create_fusion_candidate(circuit, &sequence, qubit)?);
408 }
409 }
410
411 Ok(candidates)
412 }
413
414 fn create_fusion_candidate<const N: usize>(
416 &self,
417 circuit: &Circuit<N>,
418 gate_positions: &[usize],
419 qubit: QubitId,
420 ) -> QuantRS2Result<SingleQubitFusion> {
421 let identity_matrix = [
424 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
425 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
426 ];
427
428 Ok(SingleQubitFusion {
429 gates: gate_positions.to_vec(),
430 qubit,
431 fused_matrix: identity_matrix,
432 })
433 }
434
435 fn reorder_commuting_gates<const N: usize>(
437 &self,
438 circuit: &mut Circuit<N>,
439 ) -> QuantRS2Result<OptimizationResult> {
440 let gates = circuit.gates();
442 let mut reordering_opportunities = 0;
443
444 for i in 0..gates.len().saturating_sub(1) {
446 let gate1 = &gates[i];
447 let gate2 = &gates[i + 1];
448
449 let qubits1: std::collections::HashSet<_> = gate1.qubits().into_iter().collect();
451 let qubits2: std::collections::HashSet<_> = gate2.qubits().into_iter().collect();
452
453 if qubits1.is_disjoint(&qubits2) {
454 reordering_opportunities += 1;
455 }
456
457 match (gate1.name(), gate2.name()) {
459 (
461 "H" | "X" | "Y" | "Z" | "S" | "T" | "RX" | "RY" | "RZ",
462 "H" | "X" | "Y" | "Z" | "S" | "T" | "RX" | "RY" | "RZ",
463 ) if qubits1.is_disjoint(&qubits2) => {
464 reordering_opportunities += 1;
465 }
466 ("CNOT", "CNOT") if qubits1.is_disjoint(&qubits2) => {
468 reordering_opportunities += 1;
469 }
470 _ => {}
471 }
472 }
473
474 Ok(OptimizationResult {
475 success: reordering_opportunities > 0,
476 gates_eliminated: 0,
477 gates_modified: reordering_opportunities,
478 depth_improvement: (reordering_opportunities / 2) as i32, description: format!(
480 "Found {reordering_opportunities} gate reordering opportunities for parallelization"
481 ),
482 })
483 }
484
485 fn optimize_two_qubit_gates<const N: usize>(
487 &self,
488 circuit: &mut Circuit<N>,
489 ) -> QuantRS2Result<OptimizationResult> {
490 let gates = circuit.gates();
492 let mut optimization_count = 0;
493
494 for i in 0..gates.len().saturating_sub(2) {
496 if gates[i].name() == "CNOT"
497 && gates[i + 1].name() == "CNOT"
498 && gates[i + 2].name() == "CNOT"
499 {
500 let qubits1 = gates[i].qubits();
501 let qubits2 = gates[i + 1].qubits();
502 let qubits3 = gates[i + 2].qubits();
503
504 if qubits1.len() == 2
506 && qubits2.len() == 2
507 && qubits3.len() == 2
508 && qubits1 == qubits3
509 && qubits1[1] == qubits2[0]
510 {
511 optimization_count += 1;
513 }
514 }
515 }
516
517 for i in 0..gates.len().saturating_sub(2) {
519 if gates[i].name() == "CNOT"
520 && gates[i + 1].name() == "CNOT"
521 && gates[i + 2].name() == "CNOT"
522 {
523 let qubits1 = gates[i].qubits();
524 let qubits2 = gates[i + 1].qubits();
525 let qubits3 = gates[i + 2].qubits();
526
527 if qubits1.len() == 2
529 && qubits2.len() == 2
530 && qubits3.len() == 2
531 && qubits1[0] == qubits3[0]
532 && qubits1[1] == qubits3[1]
533 && qubits1[0] == qubits2[1]
534 && qubits1[1] == qubits2[0]
535 {
536 optimization_count += 1;
537 }
538 }
539 }
540
541 Ok(OptimizationResult {
542 success: optimization_count > 0,
543 gates_eliminated: optimization_count, gates_modified: optimization_count,
545 depth_improvement: optimization_count as i32,
546 description: format!(
547 "Found {optimization_count} two-qubit gate optimization opportunities"
548 ),
549 })
550 }
551
552 fn reduce_circuit_depth<const N: usize>(
554 &self,
555 circuit: &mut Circuit<N>,
556 ) -> QuantRS2Result<OptimizationResult> {
557 let original_depth = self.calculate_circuit_depth(circuit);
559
560 let new_depth = original_depth; Ok(OptimizationResult {
564 success: false,
565 gates_eliminated: 0,
566 gates_modified: 0,
567 depth_improvement: (original_depth as i32) - (new_depth as i32),
568 description: "Circuit depth reduction".to_string(),
569 })
570 }
571}
572
573impl Default for CircuitOptimizer {
574 fn default() -> Self {
575 Self::new()
576 }
577}
578
579impl OptimizationStatistics {
580 pub fn gate_count_reduction(&self) -> f64 {
582 if self.original_gate_count == 0 {
583 0.0
584 } else {
585 (self.original_gate_count as f64 - self.optimized_gate_count as f64)
586 / self.original_gate_count as f64
587 * 100.0
588 }
589 }
590
591 pub fn depth_reduction(&self) -> f64 {
593 if self.original_depth == 0 {
594 0.0
595 } else {
596 (self.original_depth as f64 - self.optimized_depth as f64) / self.original_depth as f64
597 * 100.0
598 }
599 }
600
601 pub fn generate_report(&self) -> String {
603 format!(
604 r"
605📊 Circuit Optimization Report
606==============================
607
608📈 Gate Count Optimization
609 • Original Gates: {}
610 • Optimized Gates: {}
611 • Reduction: {:.1}%
612
613🔍 Circuit Depth Optimization
614 • Original Depth: {}
615 • Optimized Depth: {}
616 • Reduction: {:.1}%
617
618⚡ Optimization Details
619 • Redundant Gates Eliminated: {}
620 • Gates Fused: {}
621 • Gates Reordered: {}
622 • Optimization Passes: {}
623 • Optimization Time: {:.2}ms
624
625✅ Summary
626Circuit optimization {} with {:.1}% gate reduction and {:.1}% depth reduction.
627",
628 self.original_gate_count,
629 self.optimized_gate_count,
630 self.gate_count_reduction(),
631 self.original_depth,
632 self.optimized_depth,
633 self.depth_reduction(),
634 self.redundant_gates_eliminated,
635 self.gates_fused,
636 self.gates_reordered,
637 self.passes_performed,
638 self.optimization_time_ns as f64 / 1_000_000.0,
639 if self.gate_count_reduction() > 0.0 || self.depth_reduction() > 0.0 {
640 "successful"
641 } else {
642 "completed"
643 },
644 self.gate_count_reduction(),
645 self.depth_reduction()
646 )
647 }
648}
649
650pub fn optimize_circuit<const N: usize>(circuit: &Circuit<N>) -> QuantRS2Result<Circuit<N>> {
652 let mut optimizer = CircuitOptimizer::new();
653 optimizer.optimize(circuit)
654}
655
656pub fn optimize_circuit_with_config<const N: usize>(
658 circuit: &Circuit<N>,
659 config: OptimizationConfig,
660) -> QuantRS2Result<(Circuit<N>, OptimizationStatistics)> {
661 let mut optimizer = CircuitOptimizer::with_config(config);
662 let optimized_circuit = optimizer.optimize(circuit)?;
663 Ok((optimized_circuit, optimizer.statistics.clone()))
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_optimizer_creation() {
672 let optimizer = CircuitOptimizer::new();
673 assert!(optimizer.config.enable_gate_fusion);
674 assert!(optimizer.config.enable_redundant_elimination);
675 }
676
677 #[test]
678 fn test_optimization_config() {
679 let mut config = OptimizationConfig::default();
680 config.enable_gate_fusion = false;
681 config.max_passes = 5;
682
683 let optimizer = CircuitOptimizer::with_config(config);
684 assert!(!optimizer.config.enable_gate_fusion);
685 assert_eq!(optimizer.config.max_passes, 5);
686 }
687
688 #[test]
689 fn test_statistics_calculations() {
690 let stats = OptimizationStatistics {
691 original_gate_count: 100,
692 optimized_gate_count: 80,
693 original_depth: 50,
694 optimized_depth: 40,
695 ..Default::default()
696 };
697
698 assert_eq!(stats.gate_count_reduction(), 20.0);
699 assert_eq!(stats.depth_reduction(), 20.0);
700 }
701
702 #[test]
703 fn test_report_generation() {
704 let stats = OptimizationStatistics {
705 original_gate_count: 100,
706 optimized_gate_count: 80,
707 original_depth: 50,
708 optimized_depth: 40,
709 redundant_gates_eliminated: 10,
710 gates_fused: 5,
711 gates_reordered: 3,
712 passes_performed: 2,
713 optimization_time_ns: 1_000_000,
714 };
715
716 let report = stats.generate_report();
717 assert!(report.contains("20.0%"));
718 assert!(report.contains("100"));
719 assert!(report.contains("80"));
720 }
721}