1use scirs2_core::ndarray::Array2;
8use scirs2_core::Complex64;
9use std::collections::{HashMap, HashSet};
10
11use crate::error::{Result, SimulatorError};
12use crate::sparse::{CSRMatrix, SparseMatrixBuilder};
13use quantrs2_core::gate::GateOp;
14use quantrs2_core::qubit::QubitId;
15
16#[derive(Debug)]
18struct SciRS2MatrixMultiplier;
19
20impl SciRS2MatrixMultiplier {
21 fn multiply_sparse(a: &CSRMatrix, b: &CSRMatrix) -> Result<CSRMatrix> {
22 if a.num_cols != b.num_rows {
24 return Err(SimulatorError::DimensionMismatch(format!(
25 "Cannot multiply {}x{} with {}x{}",
26 a.num_rows, a.num_cols, b.num_rows, b.num_cols
27 )));
28 }
29
30 let mut builder = SparseMatrixBuilder::new(a.num_rows, b.num_cols);
31
32 for i in 0..a.num_rows {
34 for k in a.row_ptr[i]..a.row_ptr[i + 1] {
35 let a_val = a.values[k];
36 let a_col = a.col_indices[k];
37
38 for j_idx in b.row_ptr[a_col]..b.row_ptr[a_col + 1] {
39 let b_val = b.values[j_idx];
40 let b_col = b.col_indices[j_idx];
41
42 builder.add(i, b_col, a_val * b_val);
43 }
44 }
45 }
46
47 Ok(builder.build())
48 }
49
50 #[must_use]
51 fn multiply_dense(a: &Array2<Complex64>, b: &Array2<Complex64>) -> Result<Array2<Complex64>> {
52 if a.ncols() != b.nrows() {
54 return Err(SimulatorError::DimensionMismatch(format!(
55 "Cannot multiply {}x{} with {}x{}",
56 a.nrows(),
57 a.ncols(),
58 b.nrows(),
59 b.ncols()
60 )));
61 }
62
63 Ok(a.dot(b))
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum FusionStrategy {
70 Aggressive,
72 Conservative,
74 DepthOptimized,
76 Custom,
78}
79
80#[derive(Debug, Clone)]
82pub struct GateGroup {
83 pub gate_indices: Vec<usize>,
85 pub qubits: Vec<QubitId>,
87 pub fusable: bool,
89 pub fusion_cost: f64,
91}
92
93pub struct GateFusion {
95 strategy: FusionStrategy,
97 max_fusion_qubits: usize,
99 min_fusion_gates: usize,
101 cost_threshold: f64,
103}
104
105impl GateFusion {
106 #[must_use]
108 pub const fn new(strategy: FusionStrategy) -> Self {
109 Self {
110 strategy,
111 max_fusion_qubits: 4,
112 min_fusion_gates: 2,
113 cost_threshold: 0.8,
114 }
115 }
116
117 #[must_use]
119 pub const fn with_params(
120 mut self,
121 max_qubits: usize,
122 min_gates: usize,
123 threshold: f64,
124 ) -> Self {
125 self.max_fusion_qubits = max_qubits;
126 self.min_fusion_gates = min_gates;
127 self.cost_threshold = threshold;
128 self
129 }
130
131 pub fn analyze_circuit(&self, gates: &[Box<dyn GateOp>]) -> Result<Vec<GateGroup>> {
133 let mut groups = Vec::new();
134 let mut processed = vec![false; gates.len()];
135
136 for i in 0..gates.len() {
137 if processed[i] {
138 continue;
139 }
140
141 let mut group = GateGroup {
143 gate_indices: vec![i],
144 qubits: gates[i].qubits().clone(),
145 fusable: false,
146 fusion_cost: 0.0,
147 };
148
149 for j in i + 1..gates.len() {
151 if processed[j] {
152 continue;
153 }
154
155 if self.can_fuse_with_group(&group, gates[j].as_ref()) {
157 group.gate_indices.push(j);
158
159 for qubit in gates[j].qubits() {
161 if !group.qubits.contains(&qubit) {
162 group.qubits.push(qubit);
163 }
164 }
165
166 if group.qubits.len() > self.max_fusion_qubits {
168 group.gate_indices.pop();
169 break;
170 }
171 } else if self.blocks_fusion(&group, gates[j].as_ref()) {
172 break;
174 }
175 }
176
177 if group.gate_indices.len() >= self.min_fusion_gates {
179 group.fusion_cost = self.compute_fusion_cost(&group, gates)?;
180 group.fusable = self.should_fuse(&group);
181
182 if group.fusable {
184 for &idx in &group.gate_indices {
185 processed[idx] = true;
186 }
187 }
188 }
189
190 groups.push(group);
191 }
192
193 Ok(groups)
194 }
195
196 fn can_fuse_with_group(&self, group: &GateGroup, gate: &dyn GateOp) -> bool {
198 let gate_qubits: HashSet<_> = gate.qubits().iter().copied().collect();
200 let group_qubits: HashSet<_> = group.qubits.iter().copied().collect();
201
202 match self.strategy {
203 FusionStrategy::Aggressive => {
204 !gate_qubits.is_disjoint(&group_qubits)
206 }
207 FusionStrategy::Conservative => {
208 gate_qubits.is_subset(&group_qubits) || group_qubits.is_subset(&gate_qubits)
210 }
211 FusionStrategy::DepthOptimized => {
212 let combined_qubits: HashSet<_> =
214 gate_qubits.union(&group_qubits).copied().collect();
215 combined_qubits.len() <= self.max_fusion_qubits
216 }
217 FusionStrategy::Custom => {
218 !gate_qubits.is_disjoint(&group_qubits)
220 }
221 }
222 }
223
224 fn blocks_fusion(&self, group: &GateGroup, gate: &dyn GateOp) -> bool {
226 let gate_qubits: HashSet<_> = gate.qubits().iter().copied().collect();
228 let group_qubits: HashSet<_> = group.qubits.iter().copied().collect();
229
230 let intersection = gate_qubits.intersection(&group_qubits).count();
231 intersection > 0 && intersection < group_qubits.len()
232 }
233
234 fn compute_fusion_cost(&self, group: &GateGroup, gates: &[Box<dyn GateOp>]) -> Result<f64> {
236 let num_qubits = group.qubits.len();
237 let num_gates = group.gate_indices.len();
238
239 let matrix_size_cost = f64::from(1 << num_qubits);
242
243 let ops_saved = (num_gates - 1) as f64;
245
246 let memory_cost = matrix_size_cost * matrix_size_cost * 16.0; let cost = matrix_size_cost / (ops_saved + 1.0) + memory_cost / 1e9;
251
252 Ok(cost)
253 }
254
255 fn should_fuse(&self, group: &GateGroup) -> bool {
257 match self.strategy {
258 FusionStrategy::Aggressive => true,
259 FusionStrategy::Conservative => group.fusion_cost < self.cost_threshold,
260 FusionStrategy::DepthOptimized => group.gate_indices.len() > 2,
261 FusionStrategy::Custom => group.fusion_cost < self.cost_threshold,
262 }
263 }
264
265 pub fn fuse_group(
267 &self,
268 group: &GateGroup,
269 gates: &[Box<dyn GateOp>],
270 num_qubits: usize,
271 ) -> Result<FusedGate> {
272 let group_qubits = &group.qubits;
273 let group_size = group_qubits.len();
274
275 let dim = 1 << group_size;
277 let mut fused_matrix = Array2::eye(dim);
278
279 for &gate_idx in &group.gate_indices {
281 let gate = &gates[gate_idx];
282 let gate_matrix = self.get_gate_matrix(gate.as_ref())?;
283
284 let gate_qubits = gate.qubits();
286 let qubit_map: HashMap<QubitId, usize> = group_qubits
287 .iter()
288 .enumerate()
289 .map(|(i, &q)| (q, i))
290 .collect();
291
292 let expanded =
294 self.expand_gate_matrix(&gate_matrix, &gate_qubits, &qubit_map, group_size)?;
295
296 fused_matrix = SciRS2MatrixMultiplier::multiply_dense(&expanded, &fused_matrix)?;
298 }
299
300 Ok(FusedGate {
301 matrix: fused_matrix,
302 qubits: group_qubits.clone(),
303 original_gates: group.gate_indices.clone(),
304 })
305 }
306
307 fn get_gate_matrix(&self, gate: &dyn GateOp) -> Result<Array2<Complex64>> {
309 match gate.name() {
312 "Hadamard" => Ok(Array2::from_shape_vec(
313 (2, 2),
314 vec![
315 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
316 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
317 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
318 Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
319 ],
320 )
321 .map_err(|_| SimulatorError::InvalidInput("Shape error".to_string()))?),
322 "PauliX" => Ok(Array2::from_shape_vec(
323 (2, 2),
324 vec![
325 Complex64::new(0.0, 0.0),
326 Complex64::new(1.0, 0.0),
327 Complex64::new(1.0, 0.0),
328 Complex64::new(0.0, 0.0),
329 ],
330 )
331 .map_err(|_| SimulatorError::InvalidInput("Shape error".to_string()))?),
332 "CNOT" => Ok(Array2::from_shape_vec(
333 (4, 4),
334 vec![
335 Complex64::new(1.0, 0.0),
336 Complex64::new(0.0, 0.0),
337 Complex64::new(0.0, 0.0),
338 Complex64::new(0.0, 0.0),
339 Complex64::new(0.0, 0.0),
340 Complex64::new(1.0, 0.0),
341 Complex64::new(0.0, 0.0),
342 Complex64::new(0.0, 0.0),
343 Complex64::new(0.0, 0.0),
344 Complex64::new(0.0, 0.0),
345 Complex64::new(0.0, 0.0),
346 Complex64::new(1.0, 0.0),
347 Complex64::new(0.0, 0.0),
348 Complex64::new(0.0, 0.0),
349 Complex64::new(1.0, 0.0),
350 Complex64::new(0.0, 0.0),
351 ],
352 )
353 .map_err(|_| SimulatorError::InvalidInput("Shape error".to_string()))?),
354 _ => {
355 let n = gate.qubits().len();
357 let dim = 1 << n;
358 Ok(Array2::eye(dim))
359 }
360 }
361 }
362
363 fn expand_gate_matrix(
365 &self,
366 gate_matrix: &Array2<Complex64>,
367 gate_qubits: &[QubitId],
368 qubit_map: &HashMap<QubitId, usize>,
369 total_qubits: usize,
370 ) -> Result<Array2<Complex64>> {
371 let dim = 1 << total_qubits;
372 let mut expanded = Array2::zeros((dim, dim));
373
374 let gate_positions: Vec<usize> = gate_qubits
376 .iter()
377 .map(|q| qubit_map.get(q).copied().unwrap_or(0))
378 .collect();
379
380 for i in 0..dim {
382 for j in 0..dim {
383 let mut gate_i = 0;
385 let mut gate_j = 0;
386 let mut other_bits_match = true;
387
388 for (k, &pos) in gate_positions.iter().enumerate() {
389 if (i >> pos) & 1 == 1 {
390 gate_i |= 1 << k;
391 }
392 if (j >> pos) & 1 == 1 {
393 gate_j |= 1 << k;
394 }
395 }
396
397 for k in 0..total_qubits {
399 if !gate_positions.contains(&k) && ((i >> k) & 1) != ((j >> k) & 1) {
400 other_bits_match = false;
401 break;
402 }
403 }
404
405 if other_bits_match {
406 expanded[[i, j]] = gate_matrix[[gate_i, gate_j]];
407 }
408 }
409 }
410
411 Ok(expanded)
412 }
413
414 pub fn optimize_circuit(
416 &self,
417 gates: Vec<Box<dyn GateOp>>,
418 num_qubits: usize,
419 ) -> Result<OptimizedCircuit> {
420 let groups = self.analyze_circuit(&gates)?;
421 let mut optimized_gates = Vec::new();
422 let mut fusion_map = HashMap::new();
423
424 let mut processed = vec![false; gates.len()];
425
426 for group in &groups {
427 if group.fusable && group.gate_indices.len() > 1 {
428 let fused = self.fuse_group(group, &gates, num_qubits)?;
430 let fused_idx = optimized_gates.len();
431 optimized_gates.push(OptimizedGate::Fused(fused));
432
433 for &gate_idx in &group.gate_indices {
435 fusion_map.insert(gate_idx, fused_idx);
436 processed[gate_idx] = true;
437 }
438 } else {
439 for &gate_idx in &group.gate_indices {
441 if !processed[gate_idx] {
442 optimized_gates.push(OptimizedGate::Original(gate_idx));
443 processed[gate_idx] = true;
444 }
445 }
446 }
447 }
448
449 for (i, &p) in processed.iter().enumerate() {
451 if !p {
452 optimized_gates.push(OptimizedGate::Original(i));
453 }
454 }
455
456 Ok(OptimizedCircuit {
457 gates: optimized_gates,
458 original_gates: gates,
459 fusion_map,
460 stats: self.compute_stats(&groups),
461 })
462 }
463
464 fn compute_stats(&self, groups: &[GateGroup]) -> FusionStats {
466 let total_groups = groups.len();
467 let fused_groups = groups.iter().filter(|g| g.fusable).count();
468 let total_gates: usize = groups.iter().map(|g| g.gate_indices.len()).sum();
469 let fused_gates: usize = groups
470 .iter()
471 .filter(|g| g.fusable)
472 .map(|g| g.gate_indices.len())
473 .sum();
474
475 FusionStats {
476 total_gates,
477 fused_gates,
478 fusion_ratio: fused_gates as f64 / total_gates.max(1) as f64,
479 groups_analyzed: total_groups,
480 groups_fused: fused_groups,
481 }
482 }
483}
484
485#[derive(Debug)]
487pub struct FusedGate {
488 pub matrix: Array2<Complex64>,
490 pub qubits: Vec<QubitId>,
492 pub original_gates: Vec<usize>,
494}
495
496impl FusedGate {
497 pub fn to_sparse(&self) -> Result<CSRMatrix> {
499 let mut builder = SparseMatrixBuilder::new(self.matrix.nrows(), self.matrix.ncols());
500
501 for ((i, j), &val) in self.matrix.indexed_iter() {
502 if val.norm() > 1e-12 {
503 builder.set_value(i, j, val);
504 }
505 }
506
507 Ok(builder.build())
508 }
509
510 #[must_use]
512 pub fn dimension(&self) -> usize {
513 self.matrix.nrows()
514 }
515}
516
517#[derive(Debug)]
519pub enum OptimizedGate {
520 Original(usize),
522 Fused(FusedGate),
524}
525
526#[derive(Debug)]
528pub struct OptimizedCircuit {
529 pub gates: Vec<OptimizedGate>,
531 pub original_gates: Vec<Box<dyn GateOp>>,
533 pub fusion_map: HashMap<usize, usize>,
535 pub stats: FusionStats,
537}
538
539impl OptimizedCircuit {
540 #[must_use]
542 pub fn gate_count(&self) -> usize {
543 self.gates.len()
544 }
545
546 #[must_use]
548 pub fn memory_usage(&self) -> usize {
549 self.gates
550 .iter()
551 .map(|g| match g {
552 OptimizedGate::Original(_) => 64, OptimizedGate::Fused(f) => f.dimension() * f.dimension() * 16,
554 })
555 .sum()
556 }
557}
558
559#[derive(Debug)]
561pub struct FusionStats {
562 pub total_gates: usize,
564 pub fused_gates: usize,
566 pub fusion_ratio: f64,
568 pub groups_analyzed: usize,
570 pub groups_fused: usize,
572}
573
574pub fn benchmark_fusion_strategies(gates: Vec<Box<dyn GateOp>>, num_qubits: usize) -> Result<()> {
576 println!("\nGate Fusion Benchmark");
577 println!("Original circuit: {} gates", gates.len());
578 println!("{:-<60}", "");
579
580 for strategy in [
581 FusionStrategy::Conservative,
582 FusionStrategy::Aggressive,
583 FusionStrategy::DepthOptimized,
584 ] {
585 let fusion = GateFusion::new(strategy);
586 let start = std::time::Instant::now();
587
588 let optimized = fusion.optimize_circuit(gates.clone(), num_qubits)?;
589 let elapsed = start.elapsed();
590
591 println!("\n{strategy:?} Strategy:");
592 println!(" Gates after fusion: {}", optimized.gate_count());
593 println!(
594 " Fusion ratio: {:.2}%",
595 optimized.stats.fusion_ratio * 100.0
596 );
597 println!(
598 " Groups fused: {}/{}",
599 optimized.stats.groups_fused, optimized.stats.groups_analyzed
600 );
601 println!(
602 " Memory usage: {:.2} MB",
603 optimized.memory_usage() as f64 / 1e6
604 );
605 println!(" Optimization time: {elapsed:?}");
606 }
607
608 Ok(())
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use quantrs2_core::gate::multi::CNOT;
615 use quantrs2_core::gate::single::{Hadamard, PauliX};
616
617 #[test]
618 fn test_gate_group_creation() {
619 let group = GateGroup {
620 gate_indices: vec![0, 1, 2],
621 qubits: vec![QubitId::new(0), QubitId::new(1)],
622 fusable: true,
623 fusion_cost: 0.5,
624 };
625
626 assert_eq!(group.gate_indices.len(), 3);
627 assert_eq!(group.qubits.len(), 2);
628 }
629
630 #[test]
631 fn test_fusion_strategy() {
632 let fusion = GateFusion::new(FusionStrategy::Conservative);
633 assert_eq!(fusion.max_fusion_qubits, 4);
634 assert_eq!(fusion.min_fusion_gates, 2);
635 }
636
637 #[test]
638 fn test_sparse_matrix_multiplication() {
639 let mut builder1 = SparseMatrixBuilder::new(2, 2);
640 builder1.set_value(0, 0, Complex64::new(1.0, 0.0));
641 builder1.set_value(1, 1, Complex64::new(1.0, 0.0));
642 let m1 = builder1.build();
643
644 let mut builder2 = SparseMatrixBuilder::new(2, 2);
645 builder2.set_value(0, 1, Complex64::new(1.0, 0.0));
646 builder2.set_value(1, 0, Complex64::new(1.0, 0.0));
647 let m2 = builder2.build();
648
649 let result = SciRS2MatrixMultiplier::multiply_sparse(&m1, &m2)
650 .expect("sparse matrix multiplication should succeed");
651 assert_eq!(result.num_rows, 2);
652 assert_eq!(result.num_cols, 2);
653 }
654
655 #[test]
656 fn test_fused_gate() {
657 let matrix = Array2::eye(4);
658 let fused = FusedGate {
659 matrix,
660 qubits: vec![QubitId::new(0), QubitId::new(1)],
661 original_gates: vec![0, 1],
662 };
663
664 assert_eq!(fused.dimension(), 4);
665 let sparse = fused
666 .to_sparse()
667 .expect("conversion to sparse should succeed");
668 assert_eq!(sparse.num_rows, 4);
669 }
670
671 #[test]
672 fn test_fusion_cost() {
673 let fusion = GateFusion::new(FusionStrategy::Conservative);
674 let group = GateGroup {
675 gate_indices: vec![0, 1],
676 qubits: vec![QubitId::new(0), QubitId::new(1)],
677 fusable: false,
678 fusion_cost: 0.0,
679 };
680
681 let gates: Vec<Box<dyn GateOp>> = vec![
682 Box::new(Hadamard {
683 target: QubitId::new(0),
684 }),
685 Box::new(CNOT {
686 control: QubitId::new(0),
687 target: QubitId::new(1),
688 }),
689 ];
690
691 let cost = fusion
692 .compute_fusion_cost(&group, &gates)
693 .expect("fusion cost computation should succeed");
694 assert!(cost > 0.0);
695 }
696}