1use scirs2_core::ndarray::Array2;
8use scirs2_core::Complex64;
9use std::collections::{HashMap, HashSet};
10
11use crate::error::{Result, SimulatorError};
12use crate::sparse::{CSRMatrix, SparseMatrixBuilder};
13use quantrs2_core::gate::GateOp;
14use quantrs2_core::qubit::QubitId;
15
16#[derive(Debug)]
18struct SciRS2MatrixMultiplier;
19
20impl SciRS2MatrixMultiplier {
21 fn multiply_sparse(a: &CSRMatrix, b: &CSRMatrix) -> Result<CSRMatrix> {
22 if a.num_cols != b.num_rows {
24 return Err(SimulatorError::DimensionMismatch(format!(
25 "Cannot multiply {}x{} with {}x{}",
26 a.num_rows, a.num_cols, b.num_rows, b.num_cols
27 )));
28 }
29
30 let mut builder = SparseMatrixBuilder::new(a.num_rows, b.num_cols);
31
32 for i in 0..a.num_rows {
34 for k in a.row_ptr[i]..a.row_ptr[i + 1] {
35 let a_val = a.values[k];
36 let a_col = a.col_indices[k];
37
38 for j_idx in b.row_ptr[a_col]..b.row_ptr[a_col + 1] {
39 let b_val = b.values[j_idx];
40 let b_col = b.col_indices[j_idx];
41
42 builder.add(i, b_col, a_val * b_val);
43 }
44 }
45 }
46
47 Ok(builder.build())
48 }
49
50 fn multiply_dense(a: &Array2<Complex64>, b: &Array2<Complex64>) -> Result<Array2<Complex64>> {
51 if a.ncols() != b.nrows() {
53 return Err(SimulatorError::DimensionMismatch(format!(
54 "Cannot multiply {}x{} with {}x{}",
55 a.nrows(),
56 a.ncols(),
57 b.nrows(),
58 b.ncols()
59 )));
60 }
61
62 Ok(a.dot(b))
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum FusionStrategy {
69 Aggressive,
71 Conservative,
73 DepthOptimized,
75 Custom,
77}
78
79#[derive(Debug, Clone)]
81pub struct GateGroup {
82 pub gate_indices: Vec<usize>,
84 pub qubits: Vec<QubitId>,
86 pub fusable: bool,
88 pub fusion_cost: f64,
90}
91
92pub struct GateFusion {
94 strategy: FusionStrategy,
96 max_fusion_qubits: usize,
98 min_fusion_gates: usize,
100 cost_threshold: f64,
102}
103
104impl GateFusion {
105 pub fn new(strategy: FusionStrategy) -> Self {
107 Self {
108 strategy,
109 max_fusion_qubits: 4,
110 min_fusion_gates: 2,
111 cost_threshold: 0.8,
112 }
113 }
114
115 pub fn with_params(mut self, max_qubits: usize, min_gates: usize, threshold: f64) -> Self {
117 self.max_fusion_qubits = max_qubits;
118 self.min_fusion_gates = min_gates;
119 self.cost_threshold = threshold;
120 self
121 }
122
123 pub fn analyze_circuit(&self, gates: &[Box<dyn GateOp>]) -> Result<Vec<GateGroup>> {
125 let mut groups = Vec::new();
126 let mut processed = vec![false; gates.len()];
127
128 for i in 0..gates.len() {
129 if processed[i] {
130 continue;
131 }
132
133 let mut group = GateGroup {
135 gate_indices: vec![i],
136 qubits: gates[i].qubits().to_vec(),
137 fusable: false,
138 fusion_cost: 0.0,
139 };
140
141 for j in i + 1..gates.len() {
143 if processed[j] {
144 continue;
145 }
146
147 if self.can_fuse_with_group(&group, gates[j].as_ref()) {
149 group.gate_indices.push(j);
150
151 for qubit in gates[j].qubits() {
153 if !group.qubits.contains(&qubit) {
154 group.qubits.push(qubit);
155 }
156 }
157
158 if group.qubits.len() > self.max_fusion_qubits {
160 group.gate_indices.pop();
161 break;
162 }
163 } else if self.blocks_fusion(&group, gates[j].as_ref()) {
164 break;
166 }
167 }
168
169 if group.gate_indices.len() >= self.min_fusion_gates {
171 group.fusion_cost = self.compute_fusion_cost(&group, gates)?;
172 group.fusable = self.should_fuse(&group);
173
174 if group.fusable {
176 for &idx in &group.gate_indices {
177 processed[idx] = true;
178 }
179 }
180 }
181
182 groups.push(group);
183 }
184
185 Ok(groups)
186 }
187
188 fn can_fuse_with_group(&self, group: &GateGroup, gate: &dyn GateOp) -> bool {
190 let gate_qubits: HashSet<_> = gate.qubits().iter().cloned().collect();
192 let group_qubits: HashSet<_> = group.qubits.iter().cloned().collect();
193
194 match self.strategy {
195 FusionStrategy::Aggressive => {
196 !gate_qubits.is_disjoint(&group_qubits)
198 }
199 FusionStrategy::Conservative => {
200 gate_qubits.is_subset(&group_qubits) || group_qubits.is_subset(&gate_qubits)
202 }
203 FusionStrategy::DepthOptimized => {
204 let combined_qubits: HashSet<_> =
206 gate_qubits.union(&group_qubits).cloned().collect();
207 combined_qubits.len() <= self.max_fusion_qubits
208 }
209 FusionStrategy::Custom => {
210 !gate_qubits.is_disjoint(&group_qubits)
212 }
213 }
214 }
215
216 fn blocks_fusion(&self, group: &GateGroup, gate: &dyn GateOp) -> bool {
218 let gate_qubits: HashSet<_> = gate.qubits().iter().cloned().collect();
220 let group_qubits: HashSet<_> = group.qubits.iter().cloned().collect();
221
222 let intersection = gate_qubits.intersection(&group_qubits).count();
223 intersection > 0 && intersection < group_qubits.len()
224 }
225
226 fn compute_fusion_cost(&self, group: &GateGroup, gates: &[Box<dyn GateOp>]) -> Result<f64> {
228 let num_qubits = group.qubits.len();
229 let num_gates = group.gate_indices.len();
230
231 let matrix_size_cost = (1 << num_qubits) as f64;
234
235 let ops_saved = (num_gates - 1) as f64;
237
238 let memory_cost = matrix_size_cost * matrix_size_cost * 16.0; let cost = matrix_size_cost / (ops_saved + 1.0) + memory_cost / 1e9;
243
244 Ok(cost)
245 }
246
247 fn should_fuse(&self, group: &GateGroup) -> bool {
249 match self.strategy {
250 FusionStrategy::Aggressive => true,
251 FusionStrategy::Conservative => group.fusion_cost < self.cost_threshold,
252 FusionStrategy::DepthOptimized => group.gate_indices.len() > 2,
253 FusionStrategy::Custom => group.fusion_cost < self.cost_threshold,
254 }
255 }
256
257 pub fn fuse_group(
259 &self,
260 group: &GateGroup,
261 gates: &[Box<dyn GateOp>],
262 num_qubits: usize,
263 ) -> Result<FusedGate> {
264 let group_qubits = &group.qubits;
265 let group_size = group_qubits.len();
266
267 let dim = 1 << group_size;
269 let mut fused_matrix = Array2::eye(dim);
270
271 for &gate_idx in &group.gate_indices {
273 let gate = &gates[gate_idx];
274 let gate_matrix = self.get_gate_matrix(gate.as_ref())?;
275
276 let gate_qubits = gate.qubits();
278 let qubit_map: HashMap<QubitId, usize> = group_qubits
279 .iter()
280 .enumerate()
281 .map(|(i, &q)| (q, i))
282 .collect();
283
284 let expanded =
286 self.expand_gate_matrix(&gate_matrix, &gate_qubits, &qubit_map, group_size)?;
287
288 fused_matrix = SciRS2MatrixMultiplier::multiply_dense(&expanded, &fused_matrix)?;
290 }
291
292 Ok(FusedGate {
293 matrix: fused_matrix,
294 qubits: group_qubits.clone(),
295 original_gates: group.gate_indices.clone(),
296 })
297 }
298
299 fn get_gate_matrix(&self, gate: &dyn GateOp) -> Result<Array2<Complex64>> {
301 match gate.name() {
304 "Hadamard" => Ok(Array2::from_shape_vec(
305 (2, 2),
306 vec![
307 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
308 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
309 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
310 Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
311 ],
312 )
313 .map_err(|_| SimulatorError::InvalidInput("Shape error".to_string()))?),
314 "PauliX" => Ok(Array2::from_shape_vec(
315 (2, 2),
316 vec![
317 Complex64::new(0.0, 0.0),
318 Complex64::new(1.0, 0.0),
319 Complex64::new(1.0, 0.0),
320 Complex64::new(0.0, 0.0),
321 ],
322 )
323 .map_err(|_| SimulatorError::InvalidInput("Shape error".to_string()))?),
324 "CNOT" => Ok(Array2::from_shape_vec(
325 (4, 4),
326 vec![
327 Complex64::new(1.0, 0.0),
328 Complex64::new(0.0, 0.0),
329 Complex64::new(0.0, 0.0),
330 Complex64::new(0.0, 0.0),
331 Complex64::new(0.0, 0.0),
332 Complex64::new(1.0, 0.0),
333 Complex64::new(0.0, 0.0),
334 Complex64::new(0.0, 0.0),
335 Complex64::new(0.0, 0.0),
336 Complex64::new(0.0, 0.0),
337 Complex64::new(0.0, 0.0),
338 Complex64::new(1.0, 0.0),
339 Complex64::new(0.0, 0.0),
340 Complex64::new(0.0, 0.0),
341 Complex64::new(1.0, 0.0),
342 Complex64::new(0.0, 0.0),
343 ],
344 )
345 .map_err(|_| SimulatorError::InvalidInput("Shape error".to_string()))?),
346 _ => {
347 let n = gate.qubits().len();
349 let dim = 1 << n;
350 Ok(Array2::eye(dim))
351 }
352 }
353 }
354
355 fn expand_gate_matrix(
357 &self,
358 gate_matrix: &Array2<Complex64>,
359 gate_qubits: &[QubitId],
360 qubit_map: &HashMap<QubitId, usize>,
361 total_qubits: usize,
362 ) -> Result<Array2<Complex64>> {
363 let dim = 1 << total_qubits;
364 let mut expanded = Array2::zeros((dim, dim));
365
366 let gate_positions: Vec<usize> = gate_qubits
368 .iter()
369 .map(|q| qubit_map.get(q).copied().unwrap_or(0))
370 .collect();
371
372 for i in 0..dim {
374 for j in 0..dim {
375 let mut gate_i = 0;
377 let mut gate_j = 0;
378 let mut other_bits_match = true;
379
380 for (k, &pos) in gate_positions.iter().enumerate() {
381 if (i >> pos) & 1 == 1 {
382 gate_i |= 1 << k;
383 }
384 if (j >> pos) & 1 == 1 {
385 gate_j |= 1 << k;
386 }
387 }
388
389 for k in 0..total_qubits {
391 if !gate_positions.contains(&k) && ((i >> k) & 1) != ((j >> k) & 1) {
392 other_bits_match = false;
393 break;
394 }
395 }
396
397 if other_bits_match {
398 expanded[[i, j]] = gate_matrix[[gate_i, gate_j]];
399 }
400 }
401 }
402
403 Ok(expanded)
404 }
405
406 pub fn optimize_circuit(
408 &self,
409 gates: Vec<Box<dyn GateOp>>,
410 num_qubits: usize,
411 ) -> Result<OptimizedCircuit> {
412 let groups = self.analyze_circuit(&gates)?;
413 let mut optimized_gates = Vec::new();
414 let mut fusion_map = HashMap::new();
415
416 let mut processed = vec![false; gates.len()];
417
418 for group in &groups {
419 if group.fusable && group.gate_indices.len() > 1 {
420 let fused = self.fuse_group(&group, &gates, num_qubits)?;
422 let fused_idx = optimized_gates.len();
423 optimized_gates.push(OptimizedGate::Fused(fused));
424
425 for &gate_idx in &group.gate_indices {
427 fusion_map.insert(gate_idx, fused_idx);
428 processed[gate_idx] = true;
429 }
430 } else {
431 for &gate_idx in &group.gate_indices {
433 if !processed[gate_idx] {
434 optimized_gates.push(OptimizedGate::Original(gate_idx));
435 processed[gate_idx] = true;
436 }
437 }
438 }
439 }
440
441 for (i, &p) in processed.iter().enumerate() {
443 if !p {
444 optimized_gates.push(OptimizedGate::Original(i));
445 }
446 }
447
448 Ok(OptimizedCircuit {
449 gates: optimized_gates,
450 original_gates: gates,
451 fusion_map,
452 stats: self.compute_stats(&groups),
453 })
454 }
455
456 fn compute_stats(&self, groups: &[GateGroup]) -> FusionStats {
458 let total_groups = groups.len();
459 let fused_groups = groups.iter().filter(|g| g.fusable).count();
460 let total_gates: usize = groups.iter().map(|g| g.gate_indices.len()).sum();
461 let fused_gates: usize = groups
462 .iter()
463 .filter(|g| g.fusable)
464 .map(|g| g.gate_indices.len())
465 .sum();
466
467 FusionStats {
468 total_gates,
469 fused_gates,
470 fusion_ratio: fused_gates as f64 / total_gates.max(1) as f64,
471 groups_analyzed: total_groups,
472 groups_fused: fused_groups,
473 }
474 }
475}
476
477#[derive(Debug)]
479pub struct FusedGate {
480 pub matrix: Array2<Complex64>,
482 pub qubits: Vec<QubitId>,
484 pub original_gates: Vec<usize>,
486}
487
488impl FusedGate {
489 pub fn to_sparse(&self) -> Result<CSRMatrix> {
491 let mut builder = SparseMatrixBuilder::new(self.matrix.nrows(), self.matrix.ncols());
492
493 for ((i, j), &val) in self.matrix.indexed_iter() {
494 if val.norm() > 1e-12 {
495 builder.set_value(i, j, val);
496 }
497 }
498
499 Ok(builder.build())
500 }
501
502 pub fn dimension(&self) -> usize {
504 self.matrix.nrows()
505 }
506}
507
508#[derive(Debug)]
510pub enum OptimizedGate {
511 Original(usize),
513 Fused(FusedGate),
515}
516
517#[derive(Debug)]
519pub struct OptimizedCircuit {
520 pub gates: Vec<OptimizedGate>,
522 pub original_gates: Vec<Box<dyn GateOp>>,
524 pub fusion_map: HashMap<usize, usize>,
526 pub stats: FusionStats,
528}
529
530impl OptimizedCircuit {
531 pub fn gate_count(&self) -> usize {
533 self.gates.len()
534 }
535
536 pub fn memory_usage(&self) -> usize {
538 self.gates
539 .iter()
540 .map(|g| match g {
541 OptimizedGate::Original(_) => 64, OptimizedGate::Fused(f) => f.dimension() * f.dimension() * 16,
543 })
544 .sum()
545 }
546}
547
548#[derive(Debug)]
550pub struct FusionStats {
551 pub total_gates: usize,
553 pub fused_gates: usize,
555 pub fusion_ratio: f64,
557 pub groups_analyzed: usize,
559 pub groups_fused: usize,
561}
562
563pub fn benchmark_fusion_strategies(gates: Vec<Box<dyn GateOp>>, num_qubits: usize) -> Result<()> {
565 println!("\nGate Fusion Benchmark");
566 println!("Original circuit: {} gates", gates.len());
567 println!("{:-<60}", "");
568
569 for strategy in [
570 FusionStrategy::Conservative,
571 FusionStrategy::Aggressive,
572 FusionStrategy::DepthOptimized,
573 ] {
574 let fusion = GateFusion::new(strategy);
575 let start = std::time::Instant::now();
576
577 let optimized = fusion.optimize_circuit(gates.clone(), num_qubits)?;
578 let elapsed = start.elapsed();
579
580 println!("\n{:?} Strategy:", strategy);
581 println!(" Gates after fusion: {}", optimized.gate_count());
582 println!(
583 " Fusion ratio: {:.2}%",
584 optimized.stats.fusion_ratio * 100.0
585 );
586 println!(
587 " Groups fused: {}/{}",
588 optimized.stats.groups_fused, optimized.stats.groups_analyzed
589 );
590 println!(
591 " Memory usage: {:.2} MB",
592 optimized.memory_usage() as f64 / 1e6
593 );
594 println!(" Optimization time: {:?}", elapsed);
595 }
596
597 Ok(())
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603 use quantrs2_core::gate::multi::CNOT;
604 use quantrs2_core::gate::single::{Hadamard, PauliX};
605
606 #[test]
607 fn test_gate_group_creation() {
608 let group = GateGroup {
609 gate_indices: vec![0, 1, 2],
610 qubits: vec![QubitId::new(0), QubitId::new(1)],
611 fusable: true,
612 fusion_cost: 0.5,
613 };
614
615 assert_eq!(group.gate_indices.len(), 3);
616 assert_eq!(group.qubits.len(), 2);
617 }
618
619 #[test]
620 fn test_fusion_strategy() {
621 let fusion = GateFusion::new(FusionStrategy::Conservative);
622 assert_eq!(fusion.max_fusion_qubits, 4);
623 assert_eq!(fusion.min_fusion_gates, 2);
624 }
625
626 #[test]
627 fn test_sparse_matrix_multiplication() {
628 let mut builder1 = SparseMatrixBuilder::new(2, 2);
629 builder1.set_value(0, 0, Complex64::new(1.0, 0.0));
630 builder1.set_value(1, 1, Complex64::new(1.0, 0.0));
631 let m1 = builder1.build();
632
633 let mut builder2 = SparseMatrixBuilder::new(2, 2);
634 builder2.set_value(0, 1, Complex64::new(1.0, 0.0));
635 builder2.set_value(1, 0, Complex64::new(1.0, 0.0));
636 let m2 = builder2.build();
637
638 let result = SciRS2MatrixMultiplier::multiply_sparse(&m1, &m2).unwrap();
639 assert_eq!(result.num_rows, 2);
640 assert_eq!(result.num_cols, 2);
641 }
642
643 #[test]
644 fn test_fused_gate() {
645 let matrix = Array2::eye(4);
646 let fused = FusedGate {
647 matrix,
648 qubits: vec![QubitId::new(0), QubitId::new(1)],
649 original_gates: vec![0, 1],
650 };
651
652 assert_eq!(fused.dimension(), 4);
653 let sparse = fused.to_sparse().unwrap();
654 assert_eq!(sparse.num_rows, 4);
655 }
656
657 #[test]
658 fn test_fusion_cost() {
659 let fusion = GateFusion::new(FusionStrategy::Conservative);
660 let group = GateGroup {
661 gate_indices: vec![0, 1],
662 qubits: vec![QubitId::new(0), QubitId::new(1)],
663 fusable: false,
664 fusion_cost: 0.0,
665 };
666
667 let gates: Vec<Box<dyn GateOp>> = vec![
668 Box::new(Hadamard {
669 target: QubitId::new(0),
670 }),
671 Box::new(CNOT {
672 control: QubitId::new(0),
673 target: QubitId::new(1),
674 }),
675 ];
676
677 let cost = fusion.compute_fusion_cost(&group, &gates).unwrap();
678 assert!(cost > 0.0);
679 }
680}