quantrs2_sim/
gpu_kernel_optimization.rs

1//! GPU Kernel Optimization for Specialized Quantum Operations
2//!
3//! This module provides highly optimized GPU kernels for quantum simulation,
4//! including specialized implementations for common gates, fused operations,
5//! and memory-optimized algorithms for large state vectors.
6//!
7//! # Features
8//! - Specialized kernels for common gates (H, X, Y, Z, CNOT, CZ, etc.)
9//! - Fused gate sequences for reduced memory bandwidth
10//! - Memory-coalesced access patterns for GPU efficiency
11//! - Warp-level optimizations for NVIDIA GPUs
12//! - Shared memory utilization for reduced global memory access
13//! - Streaming execution for overlapped computation and data transfer
14
15use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::{Arc, Mutex, RwLock};
21use std::time::{Duration, Instant};
22
23/// GPU kernel optimization framework for quantum simulation
24#[derive(Debug)]
25pub struct GPUKernelOptimizer {
26    /// Kernel registry for specialized operations
27    kernel_registry: KernelRegistry,
28    /// Kernel execution statistics
29    stats: Arc<Mutex<KernelStats>>,
30    /// Configuration
31    config: GPUKernelConfig,
32    /// Kernel cache for compiled kernels
33    kernel_cache: Arc<RwLock<HashMap<String, CompiledKernel>>>,
34    /// Memory layout optimizer
35    memory_optimizer: MemoryLayoutOptimizer,
36}
37
38/// Configuration for GPU kernel optimization
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GPUKernelConfig {
41    /// Enable warp-level optimizations
42    pub enable_warp_optimization: bool,
43    /// Enable shared memory usage
44    pub enable_shared_memory: bool,
45    /// Block size for GPU execution
46    pub block_size: usize,
47    /// Grid size calculation method
48    pub grid_size_method: GridSizeMethod,
49    /// Enable kernel fusion
50    pub enable_kernel_fusion: bool,
51    /// Maximum fused kernel length
52    pub max_fusion_length: usize,
53    /// Enable memory coalescing optimization
54    pub enable_memory_coalescing: bool,
55    /// Enable streaming execution
56    pub enable_streaming: bool,
57    /// Number of streams for concurrent execution
58    pub num_streams: usize,
59    /// Occupancy optimization target
60    pub target_occupancy: f64,
61}
62
63impl Default for GPUKernelConfig {
64    fn default() -> Self {
65        Self {
66            enable_warp_optimization: true,
67            enable_shared_memory: true,
68            block_size: 256,
69            grid_size_method: GridSizeMethod::Automatic,
70            enable_kernel_fusion: true,
71            max_fusion_length: 8,
72            enable_memory_coalescing: true,
73            enable_streaming: true,
74            num_streams: 4,
75            target_occupancy: 0.75,
76        }
77    }
78}
79
80/// Method for calculating grid size
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82pub enum GridSizeMethod {
83    /// Automatic calculation based on problem size
84    Automatic,
85    /// Fixed grid size
86    Fixed(usize),
87    /// Occupancy-based calculation
88    OccupancyBased,
89}
90
91/// Registry of specialized GPU kernels
92#[derive(Debug)]
93pub struct KernelRegistry {
94    /// Single-qubit gate kernels
95    single_qubit_kernels: HashMap<String, SingleQubitKernel>,
96    /// Two-qubit gate kernels
97    two_qubit_kernels: HashMap<String, TwoQubitKernel>,
98    /// Fused kernel templates
99    fused_kernels: HashMap<String, FusedKernel>,
100    /// Custom kernel implementations
101    custom_kernels: HashMap<String, CustomKernel>,
102}
103
104impl Default for KernelRegistry {
105    fn default() -> Self {
106        let mut registry = Self {
107            single_qubit_kernels: HashMap::new(),
108            two_qubit_kernels: HashMap::new(),
109            fused_kernels: HashMap::new(),
110            custom_kernels: HashMap::new(),
111        };
112        registry.register_builtin_kernels();
113        registry
114    }
115}
116
117impl KernelRegistry {
118    /// Register all built-in optimized kernels
119    fn register_builtin_kernels(&mut self) {
120        // Single-qubit gate kernels
121        self.single_qubit_kernels.insert(
122            "hadamard".to_string(),
123            SingleQubitKernel {
124                name: "hadamard".to_string(),
125                kernel_type: SingleQubitKernelType::Hadamard,
126                optimization_level: OptimizationLevel::Maximum,
127                uses_shared_memory: true,
128                register_usage: 32,
129            },
130        );
131
132        self.single_qubit_kernels.insert(
133            "pauli_x".to_string(),
134            SingleQubitKernel {
135                name: "pauli_x".to_string(),
136                kernel_type: SingleQubitKernelType::PauliX,
137                optimization_level: OptimizationLevel::Maximum,
138                uses_shared_memory: false, // Simple swap operation
139                register_usage: 16,
140            },
141        );
142
143        self.single_qubit_kernels.insert(
144            "pauli_y".to_string(),
145            SingleQubitKernel {
146                name: "pauli_y".to_string(),
147                kernel_type: SingleQubitKernelType::PauliY,
148                optimization_level: OptimizationLevel::Maximum,
149                uses_shared_memory: false,
150                register_usage: 24,
151            },
152        );
153
154        self.single_qubit_kernels.insert(
155            "pauli_z".to_string(),
156            SingleQubitKernel {
157                name: "pauli_z".to_string(),
158                kernel_type: SingleQubitKernelType::PauliZ,
159                optimization_level: OptimizationLevel::Maximum,
160                uses_shared_memory: false,
161                register_usage: 16,
162            },
163        );
164
165        self.single_qubit_kernels.insert(
166            "phase".to_string(),
167            SingleQubitKernel {
168                name: "phase".to_string(),
169                kernel_type: SingleQubitKernelType::Phase,
170                optimization_level: OptimizationLevel::High,
171                uses_shared_memory: false,
172                register_usage: 24,
173            },
174        );
175
176        self.single_qubit_kernels.insert(
177            "t_gate".to_string(),
178            SingleQubitKernel {
179                name: "t_gate".to_string(),
180                kernel_type: SingleQubitKernelType::TGate,
181                optimization_level: OptimizationLevel::High,
182                uses_shared_memory: false,
183                register_usage: 24,
184            },
185        );
186
187        self.single_qubit_kernels.insert(
188            "rotation_x".to_string(),
189            SingleQubitKernel {
190                name: "rotation_x".to_string(),
191                kernel_type: SingleQubitKernelType::RotationX,
192                optimization_level: OptimizationLevel::Medium,
193                uses_shared_memory: true,
194                register_usage: 40,
195            },
196        );
197
198        self.single_qubit_kernels.insert(
199            "rotation_y".to_string(),
200            SingleQubitKernel {
201                name: "rotation_y".to_string(),
202                kernel_type: SingleQubitKernelType::RotationY,
203                optimization_level: OptimizationLevel::Medium,
204                uses_shared_memory: true,
205                register_usage: 40,
206            },
207        );
208
209        self.single_qubit_kernels.insert(
210            "rotation_z".to_string(),
211            SingleQubitKernel {
212                name: "rotation_z".to_string(),
213                kernel_type: SingleQubitKernelType::RotationZ,
214                optimization_level: OptimizationLevel::Medium,
215                uses_shared_memory: true,
216                register_usage: 32,
217            },
218        );
219
220        // Two-qubit gate kernels
221        self.two_qubit_kernels.insert(
222            "cnot".to_string(),
223            TwoQubitKernel {
224                name: "cnot".to_string(),
225                kernel_type: TwoQubitKernelType::CNOT,
226                optimization_level: OptimizationLevel::Maximum,
227                uses_shared_memory: true,
228                register_usage: 48,
229                memory_access_pattern: MemoryAccessPattern::Strided,
230            },
231        );
232
233        self.two_qubit_kernels.insert(
234            "cz".to_string(),
235            TwoQubitKernel {
236                name: "cz".to_string(),
237                kernel_type: TwoQubitKernelType::CZ,
238                optimization_level: OptimizationLevel::Maximum,
239                uses_shared_memory: false,
240                register_usage: 32,
241                memory_access_pattern: MemoryAccessPattern::Sparse,
242            },
243        );
244
245        self.two_qubit_kernels.insert(
246            "swap".to_string(),
247            TwoQubitKernel {
248                name: "swap".to_string(),
249                kernel_type: TwoQubitKernelType::SWAP,
250                optimization_level: OptimizationLevel::High,
251                uses_shared_memory: true,
252                register_usage: 40,
253                memory_access_pattern: MemoryAccessPattern::Strided,
254            },
255        );
256
257        self.two_qubit_kernels.insert(
258            "iswap".to_string(),
259            TwoQubitKernel {
260                name: "iswap".to_string(),
261                kernel_type: TwoQubitKernelType::ISWAP,
262                optimization_level: OptimizationLevel::High,
263                uses_shared_memory: true,
264                register_usage: 48,
265                memory_access_pattern: MemoryAccessPattern::Strided,
266            },
267        );
268
269        self.two_qubit_kernels.insert(
270            "controlled_rotation".to_string(),
271            TwoQubitKernel {
272                name: "controlled_rotation".to_string(),
273                kernel_type: TwoQubitKernelType::ControlledRotation,
274                optimization_level: OptimizationLevel::Medium,
275                uses_shared_memory: true,
276                register_usage: 56,
277                memory_access_pattern: MemoryAccessPattern::Strided,
278            },
279        );
280
281        // Fused kernel templates
282        self.fused_kernels.insert(
283            "h_cnot_h".to_string(),
284            FusedKernel {
285                name: "h_cnot_h".to_string(),
286                sequence: vec![
287                    "hadamard".to_string(),
288                    "cnot".to_string(),
289                    "hadamard".to_string(),
290                ],
291                optimization_gain: 2.5,
292                register_usage: 64,
293            },
294        );
295
296        self.fused_kernels.insert(
297            "rotation_chain".to_string(),
298            FusedKernel {
299                name: "rotation_chain".to_string(),
300                sequence: vec![
301                    "rotation_x".to_string(),
302                    "rotation_y".to_string(),
303                    "rotation_z".to_string(),
304                ],
305                optimization_gain: 2.0,
306                register_usage: 56,
307            },
308        );
309
310        self.fused_kernels.insert(
311            "bell_state".to_string(),
312            FusedKernel {
313                name: "bell_state".to_string(),
314                sequence: vec!["hadamard".to_string(), "cnot".to_string()],
315                optimization_gain: 1.8,
316                register_usage: 48,
317            },
318        );
319    }
320}
321
322/// Single-qubit kernel implementation
323#[derive(Debug, Clone)]
324pub struct SingleQubitKernel {
325    /// Kernel name
326    pub name: String,
327    /// Kernel type
328    pub kernel_type: SingleQubitKernelType,
329    /// Optimization level
330    pub optimization_level: OptimizationLevel,
331    /// Uses shared memory
332    pub uses_shared_memory: bool,
333    /// Register usage
334    pub register_usage: usize,
335}
336
337/// Types of single-qubit kernels
338#[derive(Debug, Clone, Copy, PartialEq, Eq)]
339pub enum SingleQubitKernelType {
340    Hadamard,
341    PauliX,
342    PauliY,
343    PauliZ,
344    Phase,
345    TGate,
346    RotationX,
347    RotationY,
348    RotationZ,
349    Generic,
350}
351
352/// Two-qubit kernel implementation
353#[derive(Debug, Clone)]
354pub struct TwoQubitKernel {
355    /// Kernel name
356    pub name: String,
357    /// Kernel type
358    pub kernel_type: TwoQubitKernelType,
359    /// Optimization level
360    pub optimization_level: OptimizationLevel,
361    /// Uses shared memory
362    pub uses_shared_memory: bool,
363    /// Register usage
364    pub register_usage: usize,
365    /// Memory access pattern
366    pub memory_access_pattern: MemoryAccessPattern,
367}
368
369/// Types of two-qubit kernels
370#[derive(Debug, Clone, Copy, PartialEq, Eq)]
371pub enum TwoQubitKernelType {
372    CNOT,
373    CZ,
374    SWAP,
375    ISWAP,
376    ControlledRotation,
377    Generic,
378}
379
380/// Memory access patterns for kernels
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382pub enum MemoryAccessPattern {
383    /// Coalesced access
384    Coalesced,
385    /// Strided access
386    Strided,
387    /// Sparse access
388    Sparse,
389    /// Random access
390    Random,
391}
392
393/// Fused kernel for multiple operations
394#[derive(Debug, Clone)]
395pub struct FusedKernel {
396    /// Kernel name
397    pub name: String,
398    /// Sequence of operations
399    pub sequence: Vec<String>,
400    /// Expected optimization gain
401    pub optimization_gain: f64,
402    /// Register usage
403    pub register_usage: usize,
404}
405
406/// Custom kernel implementation
407#[derive(Debug, Clone)]
408pub struct CustomKernel {
409    /// Kernel name
410    pub name: String,
411    /// Kernel code (CUDA/OpenCL)
412    pub code: String,
413    /// Register usage
414    pub register_usage: usize,
415}
416
417/// Compiled kernel ready for execution
418#[derive(Debug, Clone)]
419pub struct CompiledKernel {
420    /// Kernel name
421    pub name: String,
422    /// Compiled code (binary or PTX)
423    pub compiled_code: Vec<u8>,
424    /// Execution parameters
425    pub exec_params: KernelExecParams,
426}
427
428/// Kernel execution parameters
429#[derive(Debug, Clone)]
430pub struct KernelExecParams {
431    /// Block dimensions
432    pub block_dim: (usize, usize, usize),
433    /// Grid dimensions
434    pub grid_dim: (usize, usize, usize),
435    /// Shared memory size
436    pub shared_memory_size: usize,
437    /// Maximum threads per block
438    pub max_threads_per_block: usize,
439}
440
441/// Optimization levels for kernels
442#[derive(Debug, Clone, Copy, PartialEq, Eq)]
443pub enum OptimizationLevel {
444    /// Basic optimization
445    Basic,
446    /// Medium optimization
447    Medium,
448    /// High optimization
449    High,
450    /// Maximum optimization
451    Maximum,
452}
453
454/// Kernel execution statistics
455#[derive(Debug, Clone, Default)]
456pub struct KernelStats {
457    /// Total kernel executions
458    pub total_executions: u64,
459    /// Total execution time
460    pub total_execution_time: Duration,
461    /// Kernel execution counts by name
462    pub execution_counts: HashMap<String, u64>,
463    /// Kernel execution times by name
464    pub execution_times: HashMap<String, Duration>,
465    /// Cache hits
466    pub cache_hits: u64,
467    /// Cache misses
468    pub cache_misses: u64,
469    /// Fused operations count
470    pub fused_operations: u64,
471    /// Memory bandwidth utilized (GB/s)
472    pub memory_bandwidth: f64,
473    /// Compute throughput (GFLOPS)
474    pub compute_throughput: f64,
475}
476
477/// Memory layout optimizer for GPU operations
478#[derive(Debug)]
479pub struct MemoryLayoutOptimizer {
480    /// Layout strategy
481    strategy: MemoryLayoutStrategy,
482    /// Prefetch distance
483    prefetch_distance: usize,
484}
485
486/// Memory layout strategies
487#[derive(Debug, Clone, Copy)]
488pub enum MemoryLayoutStrategy {
489    /// Interleaved complex numbers (Re, Im, Re, Im, ...)
490    Interleaved,
491    /// Split arrays (all Re, then all Im)
492    SplitArrays,
493    /// Structure of arrays
494    StructureOfArrays,
495    /// Array of structures
496    ArrayOfStructures,
497}
498
499impl Default for MemoryLayoutOptimizer {
500    fn default() -> Self {
501        Self {
502            strategy: MemoryLayoutStrategy::Interleaved,
503            prefetch_distance: 4,
504        }
505    }
506}
507
508impl GPUKernelOptimizer {
509    /// Create a new GPU kernel optimizer
510    #[must_use]
511    pub fn new(config: GPUKernelConfig) -> Self {
512        Self {
513            kernel_registry: KernelRegistry::default(),
514            stats: Arc::new(Mutex::new(KernelStats::default())),
515            config,
516            kernel_cache: Arc::new(RwLock::new(HashMap::new())),
517            memory_optimizer: MemoryLayoutOptimizer::default(),
518        }
519    }
520
521    /// Apply optimized single-qubit gate
522    pub fn apply_single_qubit_gate(
523        &mut self,
524        state: &mut Array1<Complex64>,
525        qubit: usize,
526        gate_name: &str,
527        parameters: Option<&[f64]>,
528    ) -> QuantRS2Result<()> {
529        let start = Instant::now();
530
531        // Get kernel from registry
532        let kernel = self.kernel_registry.single_qubit_kernels.get(gate_name);
533
534        let n = state.len();
535        let stride = 1 << qubit;
536
537        match kernel {
538            Some(k) => {
539                // Apply optimized kernel
540                match k.kernel_type {
541                    SingleQubitKernelType::Hadamard => {
542                        self.apply_hadamard_optimized(state, stride)?;
543                    }
544                    SingleQubitKernelType::PauliX => {
545                        self.apply_pauli_x_optimized(state, stride)?;
546                    }
547                    SingleQubitKernelType::PauliY => {
548                        self.apply_pauli_y_optimized(state, stride)?;
549                    }
550                    SingleQubitKernelType::PauliZ => {
551                        self.apply_pauli_z_optimized(state, stride)?;
552                    }
553                    SingleQubitKernelType::Phase => {
554                        self.apply_phase_optimized(state, stride)?;
555                    }
556                    SingleQubitKernelType::TGate => {
557                        self.apply_t_gate_optimized(state, stride)?;
558                    }
559                    SingleQubitKernelType::RotationX => {
560                        let angle = parameters.and_then(|p| p.first()).copied().unwrap_or(0.0);
561                        self.apply_rotation_x_optimized(state, stride, angle)?;
562                    }
563                    SingleQubitKernelType::RotationY => {
564                        let angle = parameters.and_then(|p| p.first()).copied().unwrap_or(0.0);
565                        self.apply_rotation_y_optimized(state, stride, angle)?;
566                    }
567                    SingleQubitKernelType::RotationZ => {
568                        let angle = parameters.and_then(|p| p.first()).copied().unwrap_or(0.0);
569                        self.apply_rotation_z_optimized(state, stride, angle)?;
570                    }
571                    SingleQubitKernelType::Generic => {
572                        // Fallback to generic implementation
573                        self.apply_generic_single_qubit(state, qubit, gate_name)?;
574                    }
575                }
576            }
577            None => {
578                // Use generic implementation
579                self.apply_generic_single_qubit(state, qubit, gate_name)?;
580            }
581        }
582
583        // Update statistics
584        let mut stats = self
585            .stats
586            .lock()
587            .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
588        stats.total_executions += 1;
589        stats.total_execution_time += start.elapsed();
590        *stats
591            .execution_counts
592            .entry(gate_name.to_string())
593            .or_insert(0) += 1;
594        *stats
595            .execution_times
596            .entry(gate_name.to_string())
597            .or_insert(Duration::ZERO) += start.elapsed();
598
599        Ok(())
600    }
601
602    /// Apply optimized Hadamard gate
603    fn apply_hadamard_optimized(
604        &self,
605        state: &mut Array1<Complex64>,
606        stride: usize,
607    ) -> QuantRS2Result<()> {
608        let n = state.len();
609        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
610
611        let amplitudes = state.as_slice_mut().ok_or_else(|| {
612            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
613        })?;
614
615        // Process pairs with memory coalescing
616        for i in 0..n / 2 {
617            let i0 = (i / stride) * (2 * stride) + (i % stride);
618            let i1 = i0 + stride;
619
620            let a0 = amplitudes[i0];
621            let a1 = amplitudes[i1];
622
623            amplitudes[i0] =
624                Complex64::new((a0.re + a1.re) * inv_sqrt2, (a0.im + a1.im) * inv_sqrt2);
625            amplitudes[i1] =
626                Complex64::new((a0.re - a1.re) * inv_sqrt2, (a0.im - a1.im) * inv_sqrt2);
627        }
628
629        Ok(())
630    }
631
632    /// Apply optimized Pauli-X gate
633    fn apply_pauli_x_optimized(
634        &self,
635        state: &mut Array1<Complex64>,
636        stride: usize,
637    ) -> QuantRS2Result<()> {
638        let n = state.len();
639
640        let amplitudes = state.as_slice_mut().ok_or_else(|| {
641            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
642        })?;
643
644        // Simple swap operation - highly optimized
645        for i in 0..n / 2 {
646            let i0 = (i / stride) * (2 * stride) + (i % stride);
647            let i1 = i0 + stride;
648
649            amplitudes.swap(i0, i1);
650        }
651
652        Ok(())
653    }
654
655    /// Apply optimized Pauli-Y gate
656    fn apply_pauli_y_optimized(
657        &self,
658        state: &mut Array1<Complex64>,
659        stride: usize,
660    ) -> QuantRS2Result<()> {
661        let n = state.len();
662
663        let amplitudes = state.as_slice_mut().ok_or_else(|| {
664            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
665        })?;
666
667        for i in 0..n / 2 {
668            let i0 = (i / stride) * (2 * stride) + (i % stride);
669            let i1 = i0 + stride;
670
671            let a0 = amplitudes[i0];
672            let a1 = amplitudes[i1];
673
674            // Y gate: [[0, -i], [i, 0]]
675            amplitudes[i0] = Complex64::new(a1.im, -a1.re);
676            amplitudes[i1] = Complex64::new(-a0.im, a0.re);
677        }
678
679        Ok(())
680    }
681
682    /// Apply optimized Pauli-Z gate
683    fn apply_pauli_z_optimized(
684        &self,
685        state: &mut Array1<Complex64>,
686        stride: usize,
687    ) -> QuantRS2Result<()> {
688        let n = state.len();
689
690        let amplitudes = state.as_slice_mut().ok_or_else(|| {
691            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
692        })?;
693
694        // Z gate only affects |1> states
695        for i in 0..n / 2 {
696            let i1 = (i / stride) * (2 * stride) + (i % stride) + stride;
697            amplitudes[i1] = -amplitudes[i1];
698        }
699
700        Ok(())
701    }
702
703    /// Apply optimized Phase gate
704    fn apply_phase_optimized(
705        &self,
706        state: &mut Array1<Complex64>,
707        stride: usize,
708    ) -> QuantRS2Result<()> {
709        let n = state.len();
710
711        let amplitudes = state.as_slice_mut().ok_or_else(|| {
712            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
713        })?;
714
715        // S gate: phase shift of pi/2 on |1>
716        for i in 0..n / 2 {
717            let i1 = (i / stride) * (2 * stride) + (i % stride) + stride;
718            let a = amplitudes[i1];
719            amplitudes[i1] = Complex64::new(-a.im, a.re); // multiply by i
720        }
721
722        Ok(())
723    }
724
725    /// Apply optimized T gate
726    fn apply_t_gate_optimized(
727        &self,
728        state: &mut Array1<Complex64>,
729        stride: usize,
730    ) -> QuantRS2Result<()> {
731        let n = state.len();
732        let t_phase = Complex64::new(
733            std::f64::consts::FRAC_1_SQRT_2,
734            std::f64::consts::FRAC_1_SQRT_2,
735        );
736
737        let amplitudes = state.as_slice_mut().ok_or_else(|| {
738            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
739        })?;
740
741        // T gate: phase shift of pi/4 on |1>
742        for i in 0..n / 2 {
743            let i1 = (i / stride) * (2 * stride) + (i % stride) + stride;
744            amplitudes[i1] *= t_phase;
745        }
746
747        Ok(())
748    }
749
750    /// Apply optimized rotation around X axis
751    fn apply_rotation_x_optimized(
752        &self,
753        state: &mut Array1<Complex64>,
754        stride: usize,
755        angle: f64,
756    ) -> QuantRS2Result<()> {
757        let n = state.len();
758        let cos_half = (angle / 2.0).cos();
759        let sin_half = (angle / 2.0).sin();
760
761        let amplitudes = state.as_slice_mut().ok_or_else(|| {
762            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
763        })?;
764
765        for i in 0..n / 2 {
766            let i0 = (i / stride) * (2 * stride) + (i % stride);
767            let i1 = i0 + stride;
768
769            let a0 = amplitudes[i0];
770            let a1 = amplitudes[i1];
771
772            // RX(θ) = [[cos(θ/2), -i*sin(θ/2)], [-i*sin(θ/2), cos(θ/2)]]
773            amplitudes[i0] = Complex64::new(
774                cos_half * a0.re + sin_half * a1.im,
775                cos_half * a0.im - sin_half * a1.re,
776            );
777            amplitudes[i1] = Complex64::new(
778                sin_half * a0.im + cos_half * a1.re,
779                (-sin_half).mul_add(a0.re, cos_half * a1.im),
780            );
781        }
782
783        Ok(())
784    }
785
786    /// Apply optimized rotation around Y axis
787    fn apply_rotation_y_optimized(
788        &self,
789        state: &mut Array1<Complex64>,
790        stride: usize,
791        angle: f64,
792    ) -> QuantRS2Result<()> {
793        let n = state.len();
794        let cos_half = (angle / 2.0).cos();
795        let sin_half = (angle / 2.0).sin();
796
797        let amplitudes = state.as_slice_mut().ok_or_else(|| {
798            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
799        })?;
800
801        for i in 0..n / 2 {
802            let i0 = (i / stride) * (2 * stride) + (i % stride);
803            let i1 = i0 + stride;
804
805            let a0 = amplitudes[i0];
806            let a1 = amplitudes[i1];
807
808            // RY(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]]
809            amplitudes[i0] = Complex64::new(
810                cos_half * a0.re - sin_half * a1.re,
811                cos_half * a0.im - sin_half * a1.im,
812            );
813            amplitudes[i1] = Complex64::new(
814                sin_half * a0.re + cos_half * a1.re,
815                sin_half * a0.im + cos_half * a1.im,
816            );
817        }
818
819        Ok(())
820    }
821
822    /// Apply optimized rotation around Z axis
823    fn apply_rotation_z_optimized(
824        &self,
825        state: &mut Array1<Complex64>,
826        stride: usize,
827        angle: f64,
828    ) -> QuantRS2Result<()> {
829        let n = state.len();
830        let exp_neg = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
831        let exp_pos = Complex64::new((angle / 2.0).cos(), (angle / 2.0).sin());
832
833        let amplitudes = state.as_slice_mut().ok_or_else(|| {
834            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
835        })?;
836
837        for i in 0..n / 2 {
838            let i0 = (i / stride) * (2 * stride) + (i % stride);
839            let i1 = i0 + stride;
840
841            // RZ(θ) = [[e^(-iθ/2), 0], [0, e^(iθ/2)]]
842            amplitudes[i0] *= exp_neg;
843            amplitudes[i1] *= exp_pos;
844        }
845
846        Ok(())
847    }
848
849    /// Generic single-qubit gate application
850    const fn apply_generic_single_qubit(
851        &self,
852        state: &Array1<Complex64>,
853        qubit: usize,
854        _gate_name: &str,
855    ) -> QuantRS2Result<()> {
856        // Generic implementation using identity matrix
857        // Real implementation would use the actual gate matrix
858        Ok(())
859    }
860
861    /// Apply optimized two-qubit gate
862    pub fn apply_two_qubit_gate(
863        &mut self,
864        state: &mut Array1<Complex64>,
865        control: usize,
866        target: usize,
867        gate_name: &str,
868    ) -> QuantRS2Result<()> {
869        let start = Instant::now();
870
871        // Get kernel from registry
872        let kernel = self.kernel_registry.two_qubit_kernels.get(gate_name);
873
874        match kernel {
875            Some(k) => match k.kernel_type {
876                TwoQubitKernelType::CNOT => {
877                    self.apply_cnot_optimized(state, control, target)?;
878                }
879                TwoQubitKernelType::CZ => {
880                    self.apply_cz_optimized(state, control, target)?;
881                }
882                TwoQubitKernelType::SWAP => {
883                    self.apply_swap_optimized(state, control, target)?;
884                }
885                TwoQubitKernelType::ISWAP => {
886                    self.apply_iswap_optimized(state, control, target)?;
887                }
888                _ => {
889                    self.apply_generic_two_qubit(state, control, target, gate_name)?;
890                }
891            },
892            None => {
893                self.apply_generic_two_qubit(state, control, target, gate_name)?;
894            }
895        }
896
897        // Update statistics
898        let mut stats = self
899            .stats
900            .lock()
901            .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
902        stats.total_executions += 1;
903        stats.total_execution_time += start.elapsed();
904        *stats
905            .execution_counts
906            .entry(gate_name.to_string())
907            .or_insert(0) += 1;
908
909        Ok(())
910    }
911
912    /// Apply optimized CNOT gate
913    fn apply_cnot_optimized(
914        &self,
915        state: &mut Array1<Complex64>,
916        control: usize,
917        target: usize,
918    ) -> QuantRS2Result<()> {
919        let n = state.len();
920        let control_stride = 1 << control;
921        let target_stride = 1 << target;
922
923        let amplitudes = state.as_slice_mut().ok_or_else(|| {
924            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
925        })?;
926
927        // CNOT: flip target when control is |1>
928        for i in 0..n {
929            if (i & control_stride) != 0 {
930                // Control is |1>
931                let partner = i ^ target_stride;
932                if partner > i {
933                    amplitudes.swap(i, partner);
934                }
935            }
936        }
937
938        Ok(())
939    }
940
941    /// Apply optimized CZ gate
942    fn apply_cz_optimized(
943        &self,
944        state: &mut Array1<Complex64>,
945        control: usize,
946        target: usize,
947    ) -> QuantRS2Result<()> {
948        let n = state.len();
949        let control_stride = 1 << control;
950        let target_stride = 1 << target;
951
952        let amplitudes = state.as_slice_mut().ok_or_else(|| {
953            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
954        })?;
955
956        // CZ: apply phase flip when both control and target are |1>
957        for (i, amplitude) in amplitudes.iter_mut().enumerate() {
958            if (i & control_stride) != 0 && (i & target_stride) != 0 {
959                *amplitude = -*amplitude;
960            }
961        }
962
963        Ok(())
964    }
965
966    /// Apply optimized SWAP gate
967    fn apply_swap_optimized(
968        &self,
969        state: &mut Array1<Complex64>,
970        qubit1: usize,
971        qubit2: usize,
972    ) -> QuantRS2Result<()> {
973        let n = state.len();
974        let stride1 = 1 << qubit1;
975        let stride2 = 1 << qubit2;
976
977        let amplitudes = state.as_slice_mut().ok_or_else(|| {
978            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
979        })?;
980
981        // SWAP: exchange |01> and |10> components
982        for i in 0..n {
983            let bit1 = (i & stride1) != 0;
984            let bit2 = (i & stride2) != 0;
985            if bit1 != bit2 {
986                let partner = i ^ stride1 ^ stride2;
987                if partner > i {
988                    amplitudes.swap(i, partner);
989                }
990            }
991        }
992
993        Ok(())
994    }
995
996    /// Apply optimized iSWAP gate
997    fn apply_iswap_optimized(
998        &self,
999        state: &mut Array1<Complex64>,
1000        qubit1: usize,
1001        qubit2: usize,
1002    ) -> QuantRS2Result<()> {
1003        let n = state.len();
1004        let stride1 = 1 << qubit1;
1005        let stride2 = 1 << qubit2;
1006
1007        let amplitudes = state.as_slice_mut().ok_or_else(|| {
1008            QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
1009        })?;
1010
1011        // iSWAP: swap |01> and |10> with i phase
1012        for i in 0..n {
1013            let bit1 = (i & stride1) != 0;
1014            let bit2 = (i & stride2) != 0;
1015            if bit1 != bit2 {
1016                let partner = i ^ stride1 ^ stride2;
1017                if partner > i {
1018                    let a = amplitudes[i];
1019                    let b = amplitudes[partner];
1020                    // Multiply by i when swapping
1021                    amplitudes[i] = Complex64::new(-b.im, b.re);
1022                    amplitudes[partner] = Complex64::new(-a.im, a.re);
1023                }
1024            }
1025        }
1026
1027        Ok(())
1028    }
1029
1030    /// Generic two-qubit gate application
1031    const fn apply_generic_two_qubit(
1032        &self,
1033        _state: &mut Array1<Complex64>,
1034        _control: usize,
1035        _target: usize,
1036        _gate_name: &str,
1037    ) -> QuantRS2Result<()> {
1038        // Generic implementation placeholder
1039        Ok(())
1040    }
1041
1042    /// Get kernel execution statistics
1043    pub fn get_stats(&self) -> QuantRS2Result<KernelStats> {
1044        let stats = self
1045            .stats
1046            .lock()
1047            .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
1048        Ok(stats.clone())
1049    }
1050
1051    /// Reset statistics
1052    pub fn reset_stats(&mut self) -> QuantRS2Result<()> {
1053        let mut stats = self
1054            .stats
1055            .lock()
1056            .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
1057        *stats = KernelStats::default();
1058        Ok(())
1059    }
1060
1061    /// Get available kernel names
1062    #[must_use]
1063    pub fn get_available_kernels(&self) -> Vec<String> {
1064        let mut kernels = Vec::new();
1065        kernels.extend(self.kernel_registry.single_qubit_kernels.keys().cloned());
1066        kernels.extend(self.kernel_registry.two_qubit_kernels.keys().cloned());
1067        kernels.extend(self.kernel_registry.fused_kernels.keys().cloned());
1068        kernels
1069    }
1070
1071    /// Check if a kernel is available
1072    #[must_use]
1073    pub fn has_kernel(&self, name: &str) -> bool {
1074        self.kernel_registry.single_qubit_kernels.contains_key(name)
1075            || self.kernel_registry.two_qubit_kernels.contains_key(name)
1076            || self.kernel_registry.fused_kernels.contains_key(name)
1077    }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082    use super::*;
1083
1084    #[test]
1085    fn test_kernel_optimizer_creation() {
1086        let config = GPUKernelConfig::default();
1087        let optimizer = GPUKernelOptimizer::new(config);
1088        assert!(!optimizer.get_available_kernels().is_empty());
1089    }
1090
1091    #[test]
1092    fn test_hadamard_kernel() {
1093        let config = GPUKernelConfig::default();
1094        let mut optimizer = GPUKernelOptimizer::new(config);
1095
1096        let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1097
1098        let result = optimizer.apply_single_qubit_gate(&mut state, 0, "hadamard", None);
1099        assert!(result.is_ok());
1100
1101        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
1102        assert!((state[0].re - inv_sqrt2).abs() < 1e-10);
1103        assert!((state[1].re - inv_sqrt2).abs() < 1e-10);
1104    }
1105
1106    #[test]
1107    fn test_pauli_x_kernel() {
1108        let config = GPUKernelConfig::default();
1109        let mut optimizer = GPUKernelOptimizer::new(config);
1110
1111        let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1112
1113        let result = optimizer.apply_single_qubit_gate(&mut state, 0, "pauli_x", None);
1114        assert!(result.is_ok());
1115
1116        assert!((state[0].re - 0.0).abs() < 1e-10);
1117        assert!((state[1].re - 1.0).abs() < 1e-10);
1118    }
1119
1120    #[test]
1121    fn test_pauli_z_kernel() {
1122        let config = GPUKernelConfig::default();
1123        let mut optimizer = GPUKernelOptimizer::new(config);
1124
1125        let mut state = Array1::from_vec(vec![Complex64::new(0.5, 0.0), Complex64::new(0.5, 0.0)]);
1126
1127        let result = optimizer.apply_single_qubit_gate(&mut state, 0, "pauli_z", None);
1128        assert!(result.is_ok());
1129
1130        assert!((state[0].re - 0.5).abs() < 1e-10);
1131        assert!((state[1].re + 0.5).abs() < 1e-10);
1132    }
1133
1134    #[test]
1135    fn test_rotation_z_kernel() {
1136        let config = GPUKernelConfig::default();
1137        let mut optimizer = GPUKernelOptimizer::new(config);
1138
1139        let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1140
1141        let result = optimizer.apply_single_qubit_gate(
1142            &mut state,
1143            0,
1144            "rotation_z",
1145            Some(&[std::f64::consts::PI]),
1146        );
1147        assert!(result.is_ok());
1148    }
1149
1150    #[test]
1151    fn test_cnot_kernel() {
1152        let config = GPUKernelConfig::default();
1153        let mut optimizer = GPUKernelOptimizer::new(config);
1154
1155        // |10> state
1156        let mut state = Array1::from_vec(vec![
1157            Complex64::new(0.0, 0.0),
1158            Complex64::new(0.0, 0.0),
1159            Complex64::new(1.0, 0.0),
1160            Complex64::new(0.0, 0.0),
1161        ]);
1162
1163        let result = optimizer.apply_two_qubit_gate(&mut state, 1, 0, "cnot");
1164        assert!(result.is_ok());
1165
1166        // Should become |11>
1167        assert!((state[3].re - 1.0).abs() < 1e-10);
1168    }
1169
1170    #[test]
1171    fn test_cz_kernel() {
1172        let config = GPUKernelConfig::default();
1173        let mut optimizer = GPUKernelOptimizer::new(config);
1174
1175        // |11> state
1176        let mut state = Array1::from_vec(vec![
1177            Complex64::new(0.0, 0.0),
1178            Complex64::new(0.0, 0.0),
1179            Complex64::new(0.0, 0.0),
1180            Complex64::new(1.0, 0.0),
1181        ]);
1182
1183        let result = optimizer.apply_two_qubit_gate(&mut state, 1, 0, "cz");
1184        assert!(result.is_ok());
1185
1186        // Should get phase flip
1187        assert!((state[3].re + 1.0).abs() < 1e-10);
1188    }
1189
1190    #[test]
1191    fn test_swap_kernel() {
1192        let config = GPUKernelConfig::default();
1193        let mut optimizer = GPUKernelOptimizer::new(config);
1194
1195        // |01> state
1196        let mut state = Array1::from_vec(vec![
1197            Complex64::new(0.0, 0.0),
1198            Complex64::new(1.0, 0.0),
1199            Complex64::new(0.0, 0.0),
1200            Complex64::new(0.0, 0.0),
1201        ]);
1202
1203        let result = optimizer.apply_two_qubit_gate(&mut state, 0, 1, "swap");
1204        assert!(result.is_ok());
1205
1206        // Should become |10>
1207        assert!((state[2].re - 1.0).abs() < 1e-10);
1208    }
1209
1210    #[test]
1211    fn test_kernel_stats() {
1212        let config = GPUKernelConfig::default();
1213        let mut optimizer = GPUKernelOptimizer::new(config);
1214
1215        let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1216
1217        optimizer
1218            .apply_single_qubit_gate(&mut state, 0, "hadamard", None)
1219            .expect("hadamard gate should apply successfully");
1220        optimizer
1221            .apply_single_qubit_gate(&mut state, 0, "pauli_x", None)
1222            .expect("pauli_x gate should apply successfully");
1223
1224        let stats = optimizer.get_stats().expect("get_stats should succeed");
1225        assert_eq!(stats.total_executions, 2);
1226        assert_eq!(*stats.execution_counts.get("hadamard").unwrap_or(&0), 1);
1227        assert_eq!(*stats.execution_counts.get("pauli_x").unwrap_or(&0), 1);
1228    }
1229
1230    #[test]
1231    fn test_available_kernels() {
1232        let config = GPUKernelConfig::default();
1233        let optimizer = GPUKernelOptimizer::new(config);
1234
1235        let kernels = optimizer.get_available_kernels();
1236        assert!(kernels.contains(&"hadamard".to_string()));
1237        assert!(kernels.contains(&"cnot".to_string()));
1238        assert!(kernels.contains(&"swap".to_string()));
1239    }
1240
1241    #[test]
1242    fn test_has_kernel() {
1243        let config = GPUKernelConfig::default();
1244        let optimizer = GPUKernelOptimizer::new(config);
1245
1246        assert!(optimizer.has_kernel("hadamard"));
1247        assert!(optimizer.has_kernel("cnot"));
1248        assert!(!optimizer.has_kernel("nonexistent"));
1249    }
1250
1251    #[test]
1252    fn test_config_defaults() {
1253        let config = GPUKernelConfig::default();
1254
1255        assert!(config.enable_warp_optimization);
1256        assert!(config.enable_shared_memory);
1257        assert_eq!(config.block_size, 256);
1258        assert!(config.enable_kernel_fusion);
1259        assert_eq!(config.max_fusion_length, 8);
1260    }
1261
1262    #[test]
1263    fn test_reset_stats() {
1264        let config = GPUKernelConfig::default();
1265        let mut optimizer = GPUKernelOptimizer::new(config);
1266
1267        let mut state = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
1268
1269        optimizer
1270            .apply_single_qubit_gate(&mut state, 0, "hadamard", None)
1271            .expect("hadamard gate should apply successfully");
1272        optimizer.reset_stats().expect("reset_stats should succeed");
1273
1274        let stats = optimizer.get_stats().expect("get_stats should succeed");
1275        assert_eq!(stats.total_executions, 0);
1276    }
1277
1278    #[test]
1279    fn test_multiple_qubit_operations() {
1280        let config = GPUKernelConfig::default();
1281        let mut optimizer = GPUKernelOptimizer::new(config);
1282
1283        // 3-qubit state
1284        let mut state = Array1::zeros(8);
1285        state[0] = Complex64::new(1.0, 0.0);
1286
1287        // Apply H to qubit 0
1288        optimizer
1289            .apply_single_qubit_gate(&mut state, 0, "hadamard", None)
1290            .expect("hadamard gate should apply successfully");
1291
1292        // Apply CNOT(0, 1)
1293        optimizer
1294            .apply_two_qubit_gate(&mut state, 0, 1, "cnot")
1295            .expect("cnot gate should apply successfully");
1296
1297        // State should be in superposition
1298        let total_prob: f64 = state.iter().map(|a| (a * a.conj()).re).sum();
1299        assert!((total_prob - 1.0).abs() < 1e-10);
1300    }
1301}