1use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::{Arc, Mutex, RwLock};
21use std::time::{Duration, Instant};
22
23#[derive(Debug)]
25pub struct GPUKernelOptimizer {
26 kernel_registry: KernelRegistry,
28 stats: Arc<Mutex<KernelStats>>,
30 config: GPUKernelConfig,
32 kernel_cache: Arc<RwLock<HashMap<String, CompiledKernel>>>,
34 memory_optimizer: MemoryLayoutOptimizer,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GPUKernelConfig {
41 pub enable_warp_optimization: bool,
43 pub enable_shared_memory: bool,
45 pub block_size: usize,
47 pub grid_size_method: GridSizeMethod,
49 pub enable_kernel_fusion: bool,
51 pub max_fusion_length: usize,
53 pub enable_memory_coalescing: bool,
55 pub enable_streaming: bool,
57 pub num_streams: usize,
59 pub target_occupancy: f64,
61}
62
63impl Default for GPUKernelConfig {
64 fn default() -> Self {
65 Self {
66 enable_warp_optimization: true,
67 enable_shared_memory: true,
68 block_size: 256,
69 grid_size_method: GridSizeMethod::Automatic,
70 enable_kernel_fusion: true,
71 max_fusion_length: 8,
72 enable_memory_coalescing: true,
73 enable_streaming: true,
74 num_streams: 4,
75 target_occupancy: 0.75,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82pub enum GridSizeMethod {
83 Automatic,
85 Fixed(usize),
87 OccupancyBased,
89}
90
91#[derive(Debug)]
93pub struct KernelRegistry {
94 single_qubit_kernels: HashMap<String, SingleQubitKernel>,
96 two_qubit_kernels: HashMap<String, TwoQubitKernel>,
98 fused_kernels: HashMap<String, FusedKernel>,
100 custom_kernels: HashMap<String, CustomKernel>,
102}
103
104impl Default for KernelRegistry {
105 fn default() -> Self {
106 let mut registry = Self {
107 single_qubit_kernels: HashMap::new(),
108 two_qubit_kernels: HashMap::new(),
109 fused_kernels: HashMap::new(),
110 custom_kernels: HashMap::new(),
111 };
112 registry.register_builtin_kernels();
113 registry
114 }
115}
116
117impl KernelRegistry {
118 fn register_builtin_kernels(&mut self) {
120 self.single_qubit_kernels.insert(
122 "hadamard".to_string(),
123 SingleQubitKernel {
124 name: "hadamard".to_string(),
125 kernel_type: SingleQubitKernelType::Hadamard,
126 optimization_level: OptimizationLevel::Maximum,
127 uses_shared_memory: true,
128 register_usage: 32,
129 },
130 );
131
132 self.single_qubit_kernels.insert(
133 "pauli_x".to_string(),
134 SingleQubitKernel {
135 name: "pauli_x".to_string(),
136 kernel_type: SingleQubitKernelType::PauliX,
137 optimization_level: OptimizationLevel::Maximum,
138 uses_shared_memory: false, register_usage: 16,
140 },
141 );
142
143 self.single_qubit_kernels.insert(
144 "pauli_y".to_string(),
145 SingleQubitKernel {
146 name: "pauli_y".to_string(),
147 kernel_type: SingleQubitKernelType::PauliY,
148 optimization_level: OptimizationLevel::Maximum,
149 uses_shared_memory: false,
150 register_usage: 24,
151 },
152 );
153
154 self.single_qubit_kernels.insert(
155 "pauli_z".to_string(),
156 SingleQubitKernel {
157 name: "pauli_z".to_string(),
158 kernel_type: SingleQubitKernelType::PauliZ,
159 optimization_level: OptimizationLevel::Maximum,
160 uses_shared_memory: false,
161 register_usage: 16,
162 },
163 );
164
165 self.single_qubit_kernels.insert(
166 "phase".to_string(),
167 SingleQubitKernel {
168 name: "phase".to_string(),
169 kernel_type: SingleQubitKernelType::Phase,
170 optimization_level: OptimizationLevel::High,
171 uses_shared_memory: false,
172 register_usage: 24,
173 },
174 );
175
176 self.single_qubit_kernels.insert(
177 "t_gate".to_string(),
178 SingleQubitKernel {
179 name: "t_gate".to_string(),
180 kernel_type: SingleQubitKernelType::TGate,
181 optimization_level: OptimizationLevel::High,
182 uses_shared_memory: false,
183 register_usage: 24,
184 },
185 );
186
187 self.single_qubit_kernels.insert(
188 "rotation_x".to_string(),
189 SingleQubitKernel {
190 name: "rotation_x".to_string(),
191 kernel_type: SingleQubitKernelType::RotationX,
192 optimization_level: OptimizationLevel::Medium,
193 uses_shared_memory: true,
194 register_usage: 40,
195 },
196 );
197
198 self.single_qubit_kernels.insert(
199 "rotation_y".to_string(),
200 SingleQubitKernel {
201 name: "rotation_y".to_string(),
202 kernel_type: SingleQubitKernelType::RotationY,
203 optimization_level: OptimizationLevel::Medium,
204 uses_shared_memory: true,
205 register_usage: 40,
206 },
207 );
208
209 self.single_qubit_kernels.insert(
210 "rotation_z".to_string(),
211 SingleQubitKernel {
212 name: "rotation_z".to_string(),
213 kernel_type: SingleQubitKernelType::RotationZ,
214 optimization_level: OptimizationLevel::Medium,
215 uses_shared_memory: true,
216 register_usage: 32,
217 },
218 );
219
220 self.two_qubit_kernels.insert(
222 "cnot".to_string(),
223 TwoQubitKernel {
224 name: "cnot".to_string(),
225 kernel_type: TwoQubitKernelType::CNOT,
226 optimization_level: OptimizationLevel::Maximum,
227 uses_shared_memory: true,
228 register_usage: 48,
229 memory_access_pattern: MemoryAccessPattern::Strided,
230 },
231 );
232
233 self.two_qubit_kernels.insert(
234 "cz".to_string(),
235 TwoQubitKernel {
236 name: "cz".to_string(),
237 kernel_type: TwoQubitKernelType::CZ,
238 optimization_level: OptimizationLevel::Maximum,
239 uses_shared_memory: false,
240 register_usage: 32,
241 memory_access_pattern: MemoryAccessPattern::Sparse,
242 },
243 );
244
245 self.two_qubit_kernels.insert(
246 "swap".to_string(),
247 TwoQubitKernel {
248 name: "swap".to_string(),
249 kernel_type: TwoQubitKernelType::SWAP,
250 optimization_level: OptimizationLevel::High,
251 uses_shared_memory: true,
252 register_usage: 40,
253 memory_access_pattern: MemoryAccessPattern::Strided,
254 },
255 );
256
257 self.two_qubit_kernels.insert(
258 "iswap".to_string(),
259 TwoQubitKernel {
260 name: "iswap".to_string(),
261 kernel_type: TwoQubitKernelType::ISWAP,
262 optimization_level: OptimizationLevel::High,
263 uses_shared_memory: true,
264 register_usage: 48,
265 memory_access_pattern: MemoryAccessPattern::Strided,
266 },
267 );
268
269 self.two_qubit_kernels.insert(
270 "controlled_rotation".to_string(),
271 TwoQubitKernel {
272 name: "controlled_rotation".to_string(),
273 kernel_type: TwoQubitKernelType::ControlledRotation,
274 optimization_level: OptimizationLevel::Medium,
275 uses_shared_memory: true,
276 register_usage: 56,
277 memory_access_pattern: MemoryAccessPattern::Strided,
278 },
279 );
280
281 self.fused_kernels.insert(
283 "h_cnot_h".to_string(),
284 FusedKernel {
285 name: "h_cnot_h".to_string(),
286 sequence: vec![
287 "hadamard".to_string(),
288 "cnot".to_string(),
289 "hadamard".to_string(),
290 ],
291 optimization_gain: 2.5,
292 register_usage: 64,
293 },
294 );
295
296 self.fused_kernels.insert(
297 "rotation_chain".to_string(),
298 FusedKernel {
299 name: "rotation_chain".to_string(),
300 sequence: vec![
301 "rotation_x".to_string(),
302 "rotation_y".to_string(),
303 "rotation_z".to_string(),
304 ],
305 optimization_gain: 2.0,
306 register_usage: 56,
307 },
308 );
309
310 self.fused_kernels.insert(
311 "bell_state".to_string(),
312 FusedKernel {
313 name: "bell_state".to_string(),
314 sequence: vec!["hadamard".to_string(), "cnot".to_string()],
315 optimization_gain: 1.8,
316 register_usage: 48,
317 },
318 );
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct SingleQubitKernel {
325 pub name: String,
327 pub kernel_type: SingleQubitKernelType,
329 pub optimization_level: OptimizationLevel,
331 pub uses_shared_memory: bool,
333 pub register_usage: usize,
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq)]
339pub enum SingleQubitKernelType {
340 Hadamard,
341 PauliX,
342 PauliY,
343 PauliZ,
344 Phase,
345 TGate,
346 RotationX,
347 RotationY,
348 RotationZ,
349 Generic,
350}
351
352#[derive(Debug, Clone)]
354pub struct TwoQubitKernel {
355 pub name: String,
357 pub kernel_type: TwoQubitKernelType,
359 pub optimization_level: OptimizationLevel,
361 pub uses_shared_memory: bool,
363 pub register_usage: usize,
365 pub memory_access_pattern: MemoryAccessPattern,
367}
368
369#[derive(Debug, Clone, Copy, PartialEq, Eq)]
371pub enum TwoQubitKernelType {
372 CNOT,
373 CZ,
374 SWAP,
375 ISWAP,
376 ControlledRotation,
377 Generic,
378}
379
380#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382pub enum MemoryAccessPattern {
383 Coalesced,
385 Strided,
387 Sparse,
389 Random,
391}
392
393#[derive(Debug, Clone)]
395pub struct FusedKernel {
396 pub name: String,
398 pub sequence: Vec<String>,
400 pub optimization_gain: f64,
402 pub register_usage: usize,
404}
405
406#[derive(Debug, Clone)]
408pub struct CustomKernel {
409 pub name: String,
411 pub code: String,
413 pub register_usage: usize,
415}
416
417#[derive(Debug, Clone)]
419pub struct CompiledKernel {
420 pub name: String,
422 pub compiled_code: Vec<u8>,
424 pub exec_params: KernelExecParams,
426}
427
428#[derive(Debug, Clone)]
430pub struct KernelExecParams {
431 pub block_dim: (usize, usize, usize),
433 pub grid_dim: (usize, usize, usize),
435 pub shared_memory_size: usize,
437 pub max_threads_per_block: usize,
439}
440
441#[derive(Debug, Clone, Copy, PartialEq, Eq)]
443pub enum OptimizationLevel {
444 Basic,
446 Medium,
448 High,
450 Maximum,
452}
453
454#[derive(Debug, Clone, Default)]
456pub struct KernelStats {
457 pub total_executions: u64,
459 pub total_execution_time: Duration,
461 pub execution_counts: HashMap<String, u64>,
463 pub execution_times: HashMap<String, Duration>,
465 pub cache_hits: u64,
467 pub cache_misses: u64,
469 pub fused_operations: u64,
471 pub memory_bandwidth: f64,
473 pub compute_throughput: f64,
475}
476
477#[derive(Debug)]
479pub struct MemoryLayoutOptimizer {
480 strategy: MemoryLayoutStrategy,
482 prefetch_distance: usize,
484}
485
486#[derive(Debug, Clone, Copy)]
488pub enum MemoryLayoutStrategy {
489 Interleaved,
491 SplitArrays,
493 StructureOfArrays,
495 ArrayOfStructures,
497}
498
499impl Default for MemoryLayoutOptimizer {
500 fn default() -> Self {
501 Self {
502 strategy: MemoryLayoutStrategy::Interleaved,
503 prefetch_distance: 4,
504 }
505 }
506}
507
508impl GPUKernelOptimizer {
509 pub fn new(config: GPUKernelConfig) -> Self {
511 Self {
512 kernel_registry: KernelRegistry::default(),
513 stats: Arc::new(Mutex::new(KernelStats::default())),
514 config,
515 kernel_cache: Arc::new(RwLock::new(HashMap::new())),
516 memory_optimizer: MemoryLayoutOptimizer::default(),
517 }
518 }
519
520 pub fn apply_single_qubit_gate(
522 &mut self,
523 state: &mut Array1<Complex64>,
524 qubit: usize,
525 gate_name: &str,
526 parameters: Option<&[f64]>,
527 ) -> QuantRS2Result<()> {
528 let start = Instant::now();
529
530 let kernel = self.kernel_registry.single_qubit_kernels.get(gate_name);
532
533 let n = state.len();
534 let stride = 1 << qubit;
535
536 match kernel {
537 Some(k) => {
538 match k.kernel_type {
540 SingleQubitKernelType::Hadamard => {
541 self.apply_hadamard_optimized(state, stride)?;
542 }
543 SingleQubitKernelType::PauliX => {
544 self.apply_pauli_x_optimized(state, stride)?;
545 }
546 SingleQubitKernelType::PauliY => {
547 self.apply_pauli_y_optimized(state, stride)?;
548 }
549 SingleQubitKernelType::PauliZ => {
550 self.apply_pauli_z_optimized(state, stride)?;
551 }
552 SingleQubitKernelType::Phase => {
553 self.apply_phase_optimized(state, stride)?;
554 }
555 SingleQubitKernelType::TGate => {
556 self.apply_t_gate_optimized(state, stride)?;
557 }
558 SingleQubitKernelType::RotationX => {
559 let angle = parameters.and_then(|p| p.first()).copied().unwrap_or(0.0);
560 self.apply_rotation_x_optimized(state, stride, angle)?;
561 }
562 SingleQubitKernelType::RotationY => {
563 let angle = parameters.and_then(|p| p.first()).copied().unwrap_or(0.0);
564 self.apply_rotation_y_optimized(state, stride, angle)?;
565 }
566 SingleQubitKernelType::RotationZ => {
567 let angle = parameters.and_then(|p| p.first()).copied().unwrap_or(0.0);
568 self.apply_rotation_z_optimized(state, stride, angle)?;
569 }
570 SingleQubitKernelType::Generic => {
571 self.apply_generic_single_qubit(state, qubit, gate_name)?;
573 }
574 }
575 }
576 None => {
577 self.apply_generic_single_qubit(state, qubit, gate_name)?;
579 }
580 }
581
582 let mut stats = self
584 .stats
585 .lock()
586 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
587 stats.total_executions += 1;
588 stats.total_execution_time += start.elapsed();
589 *stats
590 .execution_counts
591 .entry(gate_name.to_string())
592 .or_insert(0) += 1;
593 *stats
594 .execution_times
595 .entry(gate_name.to_string())
596 .or_insert(Duration::ZERO) += start.elapsed();
597
598 Ok(())
599 }
600
601 fn apply_hadamard_optimized(
603 &self,
604 state: &mut Array1<Complex64>,
605 stride: usize,
606 ) -> QuantRS2Result<()> {
607 let n = state.len();
608 let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
609
610 let amplitudes = state.as_slice_mut().ok_or_else(|| {
611 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
612 })?;
613
614 for i in 0..n / 2 {
616 let i0 = (i / stride) * (2 * stride) + (i % stride);
617 let i1 = i0 + stride;
618
619 let a0 = amplitudes[i0];
620 let a1 = amplitudes[i1];
621
622 amplitudes[i0] =
623 Complex64::new((a0.re + a1.re) * inv_sqrt2, (a0.im + a1.im) * inv_sqrt2);
624 amplitudes[i1] =
625 Complex64::new((a0.re - a1.re) * inv_sqrt2, (a0.im - a1.im) * inv_sqrt2);
626 }
627
628 Ok(())
629 }
630
631 fn apply_pauli_x_optimized(
633 &self,
634 state: &mut Array1<Complex64>,
635 stride: usize,
636 ) -> QuantRS2Result<()> {
637 let n = state.len();
638
639 let amplitudes = state.as_slice_mut().ok_or_else(|| {
640 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
641 })?;
642
643 for i in 0..n / 2 {
645 let i0 = (i / stride) * (2 * stride) + (i % stride);
646 let i1 = i0 + stride;
647
648 amplitudes.swap(i0, i1);
649 }
650
651 Ok(())
652 }
653
654 fn apply_pauli_y_optimized(
656 &self,
657 state: &mut Array1<Complex64>,
658 stride: usize,
659 ) -> QuantRS2Result<()> {
660 let n = state.len();
661
662 let amplitudes = state.as_slice_mut().ok_or_else(|| {
663 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
664 })?;
665
666 for i in 0..n / 2 {
667 let i0 = (i / stride) * (2 * stride) + (i % stride);
668 let i1 = i0 + stride;
669
670 let a0 = amplitudes[i0];
671 let a1 = amplitudes[i1];
672
673 amplitudes[i0] = Complex64::new(a1.im, -a1.re);
675 amplitudes[i1] = Complex64::new(-a0.im, a0.re);
676 }
677
678 Ok(())
679 }
680
681 fn apply_pauli_z_optimized(
683 &self,
684 state: &mut Array1<Complex64>,
685 stride: usize,
686 ) -> QuantRS2Result<()> {
687 let n = state.len();
688
689 let amplitudes = state.as_slice_mut().ok_or_else(|| {
690 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
691 })?;
692
693 for i in 0..n / 2 {
695 let i1 = (i / stride) * (2 * stride) + (i % stride) + stride;
696 amplitudes[i1] = -amplitudes[i1];
697 }
698
699 Ok(())
700 }
701
702 fn apply_phase_optimized(
704 &self,
705 state: &mut Array1<Complex64>,
706 stride: usize,
707 ) -> QuantRS2Result<()> {
708 let n = state.len();
709
710 let amplitudes = state.as_slice_mut().ok_or_else(|| {
711 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
712 })?;
713
714 for i in 0..n / 2 {
716 let i1 = (i / stride) * (2 * stride) + (i % stride) + stride;
717 let a = amplitudes[i1];
718 amplitudes[i1] = Complex64::new(-a.im, a.re); }
720
721 Ok(())
722 }
723
724 fn apply_t_gate_optimized(
726 &self,
727 state: &mut Array1<Complex64>,
728 stride: usize,
729 ) -> QuantRS2Result<()> {
730 let n = state.len();
731 let t_phase = Complex64::new(
732 std::f64::consts::FRAC_1_SQRT_2,
733 std::f64::consts::FRAC_1_SQRT_2,
734 );
735
736 let amplitudes = state.as_slice_mut().ok_or_else(|| {
737 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
738 })?;
739
740 for i in 0..n / 2 {
742 let i1 = (i / stride) * (2 * stride) + (i % stride) + stride;
743 amplitudes[i1] *= t_phase;
744 }
745
746 Ok(())
747 }
748
749 fn apply_rotation_x_optimized(
751 &self,
752 state: &mut Array1<Complex64>,
753 stride: usize,
754 angle: f64,
755 ) -> QuantRS2Result<()> {
756 let n = state.len();
757 let cos_half = (angle / 2.0).cos();
758 let sin_half = (angle / 2.0).sin();
759
760 let amplitudes = state.as_slice_mut().ok_or_else(|| {
761 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
762 })?;
763
764 for i in 0..n / 2 {
765 let i0 = (i / stride) * (2 * stride) + (i % stride);
766 let i1 = i0 + stride;
767
768 let a0 = amplitudes[i0];
769 let a1 = amplitudes[i1];
770
771 amplitudes[i0] = Complex64::new(
773 cos_half * a0.re + sin_half * a1.im,
774 cos_half * a0.im - sin_half * a1.re,
775 );
776 amplitudes[i1] = Complex64::new(
777 sin_half * a0.im + cos_half * a1.re,
778 (-sin_half).mul_add(a0.re, cos_half * a1.im),
779 );
780 }
781
782 Ok(())
783 }
784
785 fn apply_rotation_y_optimized(
787 &self,
788 state: &mut Array1<Complex64>,
789 stride: usize,
790 angle: f64,
791 ) -> QuantRS2Result<()> {
792 let n = state.len();
793 let cos_half = (angle / 2.0).cos();
794 let sin_half = (angle / 2.0).sin();
795
796 let amplitudes = state.as_slice_mut().ok_or_else(|| {
797 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
798 })?;
799
800 for i in 0..n / 2 {
801 let i0 = (i / stride) * (2 * stride) + (i % stride);
802 let i1 = i0 + stride;
803
804 let a0 = amplitudes[i0];
805 let a1 = amplitudes[i1];
806
807 amplitudes[i0] = Complex64::new(
809 cos_half * a0.re - sin_half * a1.re,
810 cos_half * a0.im - sin_half * a1.im,
811 );
812 amplitudes[i1] = Complex64::new(
813 sin_half * a0.re + cos_half * a1.re,
814 sin_half * a0.im + cos_half * a1.im,
815 );
816 }
817
818 Ok(())
819 }
820
821 fn apply_rotation_z_optimized(
823 &self,
824 state: &mut Array1<Complex64>,
825 stride: usize,
826 angle: f64,
827 ) -> QuantRS2Result<()> {
828 let n = state.len();
829 let exp_neg = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
830 let exp_pos = Complex64::new((angle / 2.0).cos(), (angle / 2.0).sin());
831
832 let amplitudes = state.as_slice_mut().ok_or_else(|| {
833 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
834 })?;
835
836 for i in 0..n / 2 {
837 let i0 = (i / stride) * (2 * stride) + (i % stride);
838 let i1 = i0 + stride;
839
840 amplitudes[i0] *= exp_neg;
842 amplitudes[i1] *= exp_pos;
843 }
844
845 Ok(())
846 }
847
848 const fn apply_generic_single_qubit(
850 &self,
851 state: &mut Array1<Complex64>,
852 qubit: usize,
853 _gate_name: &str,
854 ) -> QuantRS2Result<()> {
855 Ok(())
858 }
859
860 pub fn apply_two_qubit_gate(
862 &mut self,
863 state: &mut Array1<Complex64>,
864 control: usize,
865 target: usize,
866 gate_name: &str,
867 ) -> QuantRS2Result<()> {
868 let start = Instant::now();
869
870 let kernel = self.kernel_registry.two_qubit_kernels.get(gate_name);
872
873 match kernel {
874 Some(k) => match k.kernel_type {
875 TwoQubitKernelType::CNOT => {
876 self.apply_cnot_optimized(state, control, target)?;
877 }
878 TwoQubitKernelType::CZ => {
879 self.apply_cz_optimized(state, control, target)?;
880 }
881 TwoQubitKernelType::SWAP => {
882 self.apply_swap_optimized(state, control, target)?;
883 }
884 TwoQubitKernelType::ISWAP => {
885 self.apply_iswap_optimized(state, control, target)?;
886 }
887 _ => {
888 self.apply_generic_two_qubit(state, control, target, gate_name)?;
889 }
890 },
891 None => {
892 self.apply_generic_two_qubit(state, control, target, gate_name)?;
893 }
894 }
895
896 let mut stats = self
898 .stats
899 .lock()
900 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
901 stats.total_executions += 1;
902 stats.total_execution_time += start.elapsed();
903 *stats
904 .execution_counts
905 .entry(gate_name.to_string())
906 .or_insert(0) += 1;
907
908 Ok(())
909 }
910
911 fn apply_cnot_optimized(
913 &self,
914 state: &mut Array1<Complex64>,
915 control: usize,
916 target: usize,
917 ) -> QuantRS2Result<()> {
918 let n = state.len();
919 let control_stride = 1 << control;
920 let target_stride = 1 << target;
921
922 let amplitudes = state.as_slice_mut().ok_or_else(|| {
923 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
924 })?;
925
926 for i in 0..n {
928 if (i & control_stride) != 0 {
929 let partner = i ^ target_stride;
931 if partner > i {
932 amplitudes.swap(i, partner);
933 }
934 }
935 }
936
937 Ok(())
938 }
939
940 fn apply_cz_optimized(
942 &self,
943 state: &mut Array1<Complex64>,
944 control: usize,
945 target: usize,
946 ) -> QuantRS2Result<()> {
947 let n = state.len();
948 let control_stride = 1 << control;
949 let target_stride = 1 << target;
950
951 let amplitudes = state.as_slice_mut().ok_or_else(|| {
952 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
953 })?;
954
955 for i in 0..n {
957 if (i & control_stride) != 0 && (i & target_stride) != 0 {
958 amplitudes[i] = -amplitudes[i];
959 }
960 }
961
962 Ok(())
963 }
964
965 fn apply_swap_optimized(
967 &self,
968 state: &mut Array1<Complex64>,
969 qubit1: usize,
970 qubit2: usize,
971 ) -> QuantRS2Result<()> {
972 let n = state.len();
973 let stride1 = 1 << qubit1;
974 let stride2 = 1 << qubit2;
975
976 let amplitudes = state.as_slice_mut().ok_or_else(|| {
977 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
978 })?;
979
980 for i in 0..n {
982 let bit1 = (i & stride1) != 0;
983 let bit2 = (i & stride2) != 0;
984 if bit1 != bit2 {
985 let partner = i ^ stride1 ^ stride2;
986 if partner > i {
987 amplitudes.swap(i, partner);
988 }
989 }
990 }
991
992 Ok(())
993 }
994
995 fn apply_iswap_optimized(
997 &self,
998 state: &mut Array1<Complex64>,
999 qubit1: usize,
1000 qubit2: usize,
1001 ) -> QuantRS2Result<()> {
1002 let n = state.len();
1003 let stride1 = 1 << qubit1;
1004 let stride2 = 1 << qubit2;
1005
1006 let amplitudes = state.as_slice_mut().ok_or_else(|| {
1007 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
1008 })?;
1009
1010 for i in 0..n {
1012 let bit1 = (i & stride1) != 0;
1013 let bit2 = (i & stride2) != 0;
1014 if bit1 != bit2 {
1015 let partner = i ^ stride1 ^ stride2;
1016 if partner > i {
1017 let a = amplitudes[i];
1018 let b = amplitudes[partner];
1019 amplitudes[i] = Complex64::new(-b.im, b.re);
1021 amplitudes[partner] = Complex64::new(-a.im, a.re);
1022 }
1023 }
1024 }
1025
1026 Ok(())
1027 }
1028
1029 const fn apply_generic_two_qubit(
1031 &self,
1032 _state: &mut Array1<Complex64>,
1033 _control: usize,
1034 _target: usize,
1035 _gate_name: &str,
1036 ) -> QuantRS2Result<()> {
1037 Ok(())
1039 }
1040
1041 pub fn get_stats(&self) -> QuantRS2Result<KernelStats> {
1043 let stats = self
1044 .stats
1045 .lock()
1046 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
1047 Ok(stats.clone())
1048 }
1049
1050 pub fn reset_stats(&mut self) -> QuantRS2Result<()> {
1052 let mut stats = self
1053 .stats
1054 .lock()
1055 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
1056 *stats = KernelStats::default();
1057 Ok(())
1058 }
1059
1060 pub fn get_available_kernels(&self) -> Vec<String> {
1062 let mut kernels = Vec::new();
1063 kernels.extend(self.kernel_registry.single_qubit_kernels.keys().cloned());
1064 kernels.extend(self.kernel_registry.two_qubit_kernels.keys().cloned());
1065 kernels.extend(self.kernel_registry.fused_kernels.keys().cloned());
1066 kernels
1067 }
1068
1069 pub fn has_kernel(&self, name: &str) -> bool {
1071 self.kernel_registry.single_qubit_kernels.contains_key(name)
1072 || self.kernel_registry.two_qubit_kernels.contains_key(name)
1073 || self.kernel_registry.fused_kernels.contains_key(name)
1074 }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use super::*;
1080
1081 #[test]
1082 fn test_kernel_optimizer_creation() {
1083 let config = GPUKernelConfig::default();
1084 let optimizer = GPUKernelOptimizer::new(config);
1085 assert!(!optimizer.get_available_kernels().is_empty());
1086 }
1087
1088 #[test]
1089 fn test_hadamard_kernel() {
1090 let config = GPUKernelConfig::default();
1091 let mut optimizer = GPUKernelOptimizer::new(config);
1092
1093 let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1094
1095 let result = optimizer.apply_single_qubit_gate(&mut state, 0, "hadamard", None);
1096 assert!(result.is_ok());
1097
1098 let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
1099 assert!((state[0].re - inv_sqrt2).abs() < 1e-10);
1100 assert!((state[1].re - inv_sqrt2).abs() < 1e-10);
1101 }
1102
1103 #[test]
1104 fn test_pauli_x_kernel() {
1105 let config = GPUKernelConfig::default();
1106 let mut optimizer = GPUKernelOptimizer::new(config);
1107
1108 let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1109
1110 let result = optimizer.apply_single_qubit_gate(&mut state, 0, "pauli_x", None);
1111 assert!(result.is_ok());
1112
1113 assert!((state[0].re - 0.0).abs() < 1e-10);
1114 assert!((state[1].re - 1.0).abs() < 1e-10);
1115 }
1116
1117 #[test]
1118 fn test_pauli_z_kernel() {
1119 let config = GPUKernelConfig::default();
1120 let mut optimizer = GPUKernelOptimizer::new(config);
1121
1122 let mut state = Array1::from_vec(vec![Complex64::new(0.5, 0.0), Complex64::new(0.5, 0.0)]);
1123
1124 let result = optimizer.apply_single_qubit_gate(&mut state, 0, "pauli_z", None);
1125 assert!(result.is_ok());
1126
1127 assert!((state[0].re - 0.5).abs() < 1e-10);
1128 assert!((state[1].re + 0.5).abs() < 1e-10);
1129 }
1130
1131 #[test]
1132 fn test_rotation_z_kernel() {
1133 let config = GPUKernelConfig::default();
1134 let mut optimizer = GPUKernelOptimizer::new(config);
1135
1136 let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1137
1138 let result = optimizer.apply_single_qubit_gate(
1139 &mut state,
1140 0,
1141 "rotation_z",
1142 Some(&[std::f64::consts::PI]),
1143 );
1144 assert!(result.is_ok());
1145 }
1146
1147 #[test]
1148 fn test_cnot_kernel() {
1149 let config = GPUKernelConfig::default();
1150 let mut optimizer = GPUKernelOptimizer::new(config);
1151
1152 let mut state = Array1::from_vec(vec![
1154 Complex64::new(0.0, 0.0),
1155 Complex64::new(0.0, 0.0),
1156 Complex64::new(1.0, 0.0),
1157 Complex64::new(0.0, 0.0),
1158 ]);
1159
1160 let result = optimizer.apply_two_qubit_gate(&mut state, 1, 0, "cnot");
1161 assert!(result.is_ok());
1162
1163 assert!((state[3].re - 1.0).abs() < 1e-10);
1165 }
1166
1167 #[test]
1168 fn test_cz_kernel() {
1169 let config = GPUKernelConfig::default();
1170 let mut optimizer = GPUKernelOptimizer::new(config);
1171
1172 let mut state = Array1::from_vec(vec![
1174 Complex64::new(0.0, 0.0),
1175 Complex64::new(0.0, 0.0),
1176 Complex64::new(0.0, 0.0),
1177 Complex64::new(1.0, 0.0),
1178 ]);
1179
1180 let result = optimizer.apply_two_qubit_gate(&mut state, 1, 0, "cz");
1181 assert!(result.is_ok());
1182
1183 assert!((state[3].re + 1.0).abs() < 1e-10);
1185 }
1186
1187 #[test]
1188 fn test_swap_kernel() {
1189 let config = GPUKernelConfig::default();
1190 let mut optimizer = GPUKernelOptimizer::new(config);
1191
1192 let mut state = Array1::from_vec(vec![
1194 Complex64::new(0.0, 0.0),
1195 Complex64::new(1.0, 0.0),
1196 Complex64::new(0.0, 0.0),
1197 Complex64::new(0.0, 0.0),
1198 ]);
1199
1200 let result = optimizer.apply_two_qubit_gate(&mut state, 0, 1, "swap");
1201 assert!(result.is_ok());
1202
1203 assert!((state[2].re - 1.0).abs() < 1e-10);
1205 }
1206
1207 #[test]
1208 fn test_kernel_stats() {
1209 let config = GPUKernelConfig::default();
1210 let mut optimizer = GPUKernelOptimizer::new(config);
1211
1212 let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1213
1214 optimizer
1215 .apply_single_qubit_gate(&mut state, 0, "hadamard", None)
1216 .unwrap();
1217 optimizer
1218 .apply_single_qubit_gate(&mut state, 0, "pauli_x", None)
1219 .unwrap();
1220
1221 let stats = optimizer.get_stats().unwrap();
1222 assert_eq!(stats.total_executions, 2);
1223 assert_eq!(*stats.execution_counts.get("hadamard").unwrap(), 1);
1224 assert_eq!(*stats.execution_counts.get("pauli_x").unwrap(), 1);
1225 }
1226
1227 #[test]
1228 fn test_available_kernels() {
1229 let config = GPUKernelConfig::default();
1230 let optimizer = GPUKernelOptimizer::new(config);
1231
1232 let kernels = optimizer.get_available_kernels();
1233 assert!(kernels.contains(&"hadamard".to_string()));
1234 assert!(kernels.contains(&"cnot".to_string()));
1235 assert!(kernels.contains(&"swap".to_string()));
1236 }
1237
1238 #[test]
1239 fn test_has_kernel() {
1240 let config = GPUKernelConfig::default();
1241 let optimizer = GPUKernelOptimizer::new(config);
1242
1243 assert!(optimizer.has_kernel("hadamard"));
1244 assert!(optimizer.has_kernel("cnot"));
1245 assert!(!optimizer.has_kernel("nonexistent"));
1246 }
1247
1248 #[test]
1249 fn test_config_defaults() {
1250 let config = GPUKernelConfig::default();
1251
1252 assert!(config.enable_warp_optimization);
1253 assert!(config.enable_shared_memory);
1254 assert_eq!(config.block_size, 256);
1255 assert!(config.enable_kernel_fusion);
1256 assert_eq!(config.max_fusion_length, 8);
1257 }
1258
1259 #[test]
1260 fn test_reset_stats() {
1261 let config = GPUKernelConfig::default();
1262 let mut optimizer = GPUKernelOptimizer::new(config);
1263
1264 let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1265
1266 optimizer
1267 .apply_single_qubit_gate(&mut state, 0, "hadamard", None)
1268 .unwrap();
1269 optimizer.reset_stats().unwrap();
1270
1271 let stats = optimizer.get_stats().unwrap();
1272 assert_eq!(stats.total_executions, 0);
1273 }
1274
1275 #[test]
1276 fn test_multiple_qubit_operations() {
1277 let config = GPUKernelConfig::default();
1278 let mut optimizer = GPUKernelOptimizer::new(config);
1279
1280 let mut state = Array1::zeros(8);
1282 state[0] = Complex64::new(1.0, 0.0);
1283
1284 optimizer
1286 .apply_single_qubit_gate(&mut state, 0, "hadamard", None)
1287 .unwrap();
1288
1289 optimizer
1291 .apply_two_qubit_gate(&mut state, 0, 1, "cnot")
1292 .unwrap();
1293
1294 let total_prob: f64 = state.iter().map(|a| (a * a.conj()).re).sum();
1296 assert!((total_prob - 1.0).abs() < 1e-10);
1297 }
1298}