quantrs2_sim/
opencl_amd_backend.rs

1//! OpenCL Backend for AMD GPU Acceleration
2//!
3//! This module provides high-performance quantum circuit simulation using OpenCL
4//! to leverage AMD GPU compute capabilities. It implements parallel state vector
5//! operations, gate applications, and quantum algorithm acceleration on AMD
6//! graphics processing units.
7//!
8//! Key features:
9//! - OpenCL kernel compilation and execution
10//! - AMD GPU-optimized quantum gate operations
11//! - Parallel state vector manipulation
12//! - Memory management for large quantum states
13//! - Support for AMD ROCm and OpenCL 2.0+
14//! - Automatic device detection and selection
15//! - Performance profiling and optimization
16//! - Fallback to CPU when GPU is unavailable
17
18use crate::prelude::{SimulatorError, StateVectorSimulator};
19use scirs2_core::Complex64;
20use scirs2_core::parallel_ops::*;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24use crate::error::Result;
25
26/// OpenCL platform information
27#[derive(Debug, Clone)]
28pub struct OpenCLPlatform {
29    /// Platform ID
30    pub platform_id: usize,
31    /// Platform name
32    pub name: String,
33    /// Platform vendor
34    pub vendor: String,
35    /// Platform version
36    pub version: String,
37    /// Supported extensions
38    pub extensions: Vec<String>,
39}
40
41/// OpenCL device information
42#[derive(Debug, Clone)]
43pub struct OpenCLDevice {
44    /// Device ID
45    pub device_id: usize,
46    /// Device name
47    pub name: String,
48    /// Device vendor
49    pub vendor: String,
50    /// Device type (GPU, CPU, etc.)
51    pub device_type: OpenCLDeviceType,
52    /// Compute units
53    pub compute_units: u32,
54    /// Maximum work group size
55    pub max_work_group_size: usize,
56    /// Maximum work item dimensions
57    pub max_work_item_dimensions: u32,
58    /// Maximum work item sizes
59    pub max_work_item_sizes: Vec<usize>,
60    /// Global memory size
61    pub global_memory_size: u64,
62    /// Local memory size
63    pub local_memory_size: u64,
64    /// Maximum constant buffer size
65    pub max_constant_buffer_size: u64,
66    /// Supports double precision
67    pub supports_double: bool,
68    /// Device extensions
69    pub extensions: Vec<String>,
70}
71
72/// OpenCL device types
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum OpenCLDeviceType {
75    GPU,
76    CPU,
77    Accelerator,
78    Custom,
79    All,
80}
81
82/// OpenCL backend configuration
83#[derive(Debug, Clone)]
84pub struct OpenCLConfig {
85    /// Preferred platform vendor
86    pub preferred_vendor: Option<String>,
87    /// Preferred device type
88    pub preferred_device_type: OpenCLDeviceType,
89    /// Enable performance profiling
90    pub enable_profiling: bool,
91    /// Maximum memory allocation per buffer
92    pub max_buffer_size: usize,
93    /// Work group size for kernels
94    pub work_group_size: usize,
95    /// Enable kernel caching
96    pub enable_kernel_cache: bool,
97    /// OpenCL optimization level
98    pub optimization_level: OptimizationLevel,
99    /// Enable automatic fallback to CPU
100    pub enable_cpu_fallback: bool,
101}
102
103/// OpenCL optimization levels
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum OptimizationLevel {
106    /// No optimization (-O0)
107    None,
108    /// Basic optimization (-O1)
109    Basic,
110    /// Standard optimization (-O2)
111    Standard,
112    /// Aggressive optimization (-O3)
113    Aggressive,
114}
115
116impl Default for OpenCLConfig {
117    fn default() -> Self {
118        Self {
119            preferred_vendor: Some("Advanced Micro Devices".to_string()),
120            preferred_device_type: OpenCLDeviceType::GPU,
121            enable_profiling: true,
122            max_buffer_size: 1 << 30, // 1GB
123            work_group_size: 256,
124            enable_kernel_cache: true,
125            optimization_level: OptimizationLevel::Standard,
126            enable_cpu_fallback: true,
127        }
128    }
129}
130
131/// OpenCL kernel information
132#[derive(Debug, Clone)]
133pub struct OpenCLKernel {
134    /// Kernel name
135    pub name: String,
136    /// Kernel source code
137    pub source: String,
138    /// Compilation options
139    pub build_options: String,
140    /// Local memory usage
141    pub local_memory_usage: usize,
142    /// Work group size
143    pub work_group_size: usize,
144}
145
146/// AMD GPU-optimized quantum simulator using OpenCL
147pub struct AMDOpenCLSimulator {
148    /// Configuration
149    config: OpenCLConfig,
150    /// Selected platform
151    platform: Option<OpenCLPlatform>,
152    /// Selected device
153    device: Option<OpenCLDevice>,
154    /// OpenCL context (simulated)
155    context: Option<OpenCLContext>,
156    /// Command queue (simulated)
157    command_queue: Option<OpenCLCommandQueue>,
158    /// Compiled kernels
159    kernels: HashMap<String, OpenCLKernel>,
160    /// Memory buffers
161    buffers: HashMap<String, OpenCLBuffer>,
162    /// Performance statistics
163    stats: OpenCLStats,
164    /// Fallback CPU simulator
165    cpu_fallback: Option<StateVectorSimulator>,
166}
167
168/// Simulated OpenCL context
169#[derive(Debug, Clone)]
170pub struct OpenCLContext {
171    /// Context ID
172    pub context_id: usize,
173    /// Associated devices
174    pub devices: Vec<usize>,
175}
176
177/// Simulated OpenCL command queue
178#[derive(Debug, Clone)]
179pub struct OpenCLCommandQueue {
180    /// Queue ID
181    pub queue_id: usize,
182    /// Associated context
183    pub context_id: usize,
184    /// Associated device
185    pub device_id: usize,
186    /// Enable profiling
187    pub profiling_enabled: bool,
188}
189
190/// Simulated OpenCL buffer
191#[derive(Debug, Clone)]
192pub struct OpenCLBuffer {
193    /// Buffer ID
194    pub buffer_id: usize,
195    /// Buffer size in bytes
196    pub size: usize,
197    /// Memory flags
198    pub flags: MemoryFlags,
199    /// Host pointer (for simulation)
200    pub host_data: Option<Vec<u8>>,
201}
202
203/// OpenCL memory flags
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum MemoryFlags {
206    ReadWrite,
207    ReadOnly,
208    WriteOnly,
209    UseHostPtr,
210    AllocHostPtr,
211    CopyHostPtr,
212}
213
214/// OpenCL performance statistics
215#[derive(Debug, Clone, Default, Serialize, Deserialize)]
216pub struct OpenCLStats {
217    /// Total kernel executions
218    pub total_kernel_executions: usize,
219    /// Total execution time (ms)
220    pub total_execution_time: f64,
221    /// Average kernel execution time (ms)
222    pub avg_kernel_time: f64,
223    /// Memory transfer time (ms)
224    pub memory_transfer_time: f64,
225    /// Compilation time (ms)
226    pub compilation_time: f64,
227    /// GPU memory usage (bytes)
228    pub gpu_memory_usage: u64,
229    /// GPU utilization percentage
230    pub gpu_utilization: f64,
231    /// Number of state vector operations
232    pub state_vector_operations: usize,
233    /// Number of gate operations
234    pub gate_operations: usize,
235    /// Fallback to CPU count
236    pub cpu_fallback_count: usize,
237}
238
239impl OpenCLStats {
240    /// Update statistics after kernel execution
241    pub fn update_kernel_execution(&mut self, execution_time: f64) {
242        self.total_kernel_executions += 1;
243        self.total_execution_time += execution_time;
244        self.avg_kernel_time = self.total_execution_time / self.total_kernel_executions as f64;
245    }
246
247    /// Calculate performance metrics
248    pub fn get_performance_metrics(&self) -> HashMap<String, f64> {
249        let mut metrics = HashMap::new();
250        metrics.insert(
251            "kernel_executions_per_second".to_string(),
252            self.total_kernel_executions as f64 / (self.total_execution_time / 1000.0),
253        );
254        metrics.insert(
255            "memory_bandwidth_gb_s".to_string(),
256            self.gpu_memory_usage as f64 / (self.memory_transfer_time / 1000.0) / 1e9,
257        );
258        metrics.insert("gpu_efficiency".to_string(), self.gpu_utilization / 100.0);
259        metrics
260    }
261}
262
263impl AMDOpenCLSimulator {
264    /// Create new AMD OpenCL simulator
265    pub fn new(config: OpenCLConfig) -> Result<Self> {
266        let mut simulator = Self {
267            config,
268            platform: None,
269            device: None,
270            context: None,
271            command_queue: None,
272            kernels: HashMap::new(),
273            buffers: HashMap::new(),
274            stats: OpenCLStats::default(),
275            cpu_fallback: None,
276        };
277
278        // Initialize OpenCL environment
279        simulator.initialize_opencl()?;
280
281        // Compile kernels
282        simulator.compile_kernels()?;
283
284        // Initialize CPU fallback if enabled
285        if simulator.config.enable_cpu_fallback {
286            simulator.cpu_fallback = Some(StateVectorSimulator::new()); // Default size
287        }
288
289        Ok(simulator)
290    }
291
292    /// Initialize OpenCL platform and device
293    fn initialize_opencl(&mut self) -> Result<()> {
294        // Simulate platform discovery
295        let platforms = self.discover_platforms()?;
296
297        // Select preferred platform
298        let selected_platform = self.select_platform(&platforms)?;
299        self.platform = Some(selected_platform);
300
301        // Discover devices
302        let devices = self.discover_devices()?;
303
304        // Select preferred device
305        let selected_device = self.select_device(&devices)?;
306        self.device = Some(selected_device);
307
308        // Create context
309        self.context = Some(OpenCLContext {
310            context_id: 1,
311            devices: vec![self.device.as_ref().unwrap().device_id],
312        });
313
314        // Create command queue
315        self.command_queue = Some(OpenCLCommandQueue {
316            queue_id: 1,
317            context_id: 1,
318            device_id: self.device.as_ref().unwrap().device_id,
319            profiling_enabled: self.config.enable_profiling,
320        });
321
322        Ok(())
323    }
324
325    /// Discover available OpenCL platforms
326    fn discover_platforms(&self) -> Result<Vec<OpenCLPlatform>> {
327        // Simulate AMD platform discovery
328        let platforms = vec![
329            OpenCLPlatform {
330                platform_id: 0,
331                name: "AMD Accelerated Parallel Processing".to_string(),
332                vendor: "Advanced Micro Devices, Inc.".to_string(),
333                version: "OpenCL 2.1 AMD-APP (3444.0)".to_string(),
334                extensions: vec![
335                    "cl_khr_icd".to_string(),
336                    "cl_khr_d3d10_sharing".to_string(),
337                    "cl_khr_d3d11_sharing".to_string(),
338                    "cl_khr_dx9_media_sharing".to_string(),
339                    "cl_amd_event_callback".to_string(),
340                    "cl_amd_offline_devices".to_string(),
341                ],
342            },
343            OpenCLPlatform {
344                platform_id: 1,
345                name: "Intel(R) OpenCL".to_string(),
346                vendor: "Intel(R) Corporation".to_string(),
347                version: "OpenCL 2.1".to_string(),
348                extensions: vec!["cl_khr_icd".to_string()],
349            },
350        ];
351
352        Ok(platforms)
353    }
354
355    /// Select optimal platform
356    fn select_platform(&self, platforms: &[OpenCLPlatform]) -> Result<OpenCLPlatform> {
357        // Prefer AMD platform if available
358        if let Some(preferred_vendor) = &self.config.preferred_vendor {
359            for platform in platforms {
360                if platform.vendor.contains(preferred_vendor) {
361                    return Ok(platform.clone());
362                }
363            }
364        }
365
366        // Fallback to first available platform
367        platforms.first().cloned().ok_or_else(|| {
368            SimulatorError::InitializationError("No OpenCL platforms found".to_string())
369        })
370    }
371
372    /// Discover devices for selected platform
373    fn discover_devices(&self) -> Result<Vec<OpenCLDevice>> {
374        // Simulate AMD GPU device discovery
375        let devices = vec![
376            OpenCLDevice {
377                device_id: 0,
378                name: "Radeon RX 7900 XTX".to_string(),
379                vendor: "Advanced Micro Devices, Inc.".to_string(),
380                device_type: OpenCLDeviceType::GPU,
381                compute_units: 96,
382                max_work_group_size: 256,
383                max_work_item_dimensions: 3,
384                max_work_item_sizes: vec![256, 256, 256],
385                global_memory_size: 24 * (1 << 30), // 24GB
386                local_memory_size: 64 * 1024,       // 64KB
387                max_constant_buffer_size: 64 * 1024,
388                supports_double: true,
389                extensions: vec![
390                    "cl_khr_fp64".to_string(),
391                    "cl_amd_fp64".to_string(),
392                    "cl_khr_global_int32_base_atomics".to_string(),
393                ],
394            },
395            OpenCLDevice {
396                device_id: 1,
397                name: "Radeon RX 6800 XT".to_string(),
398                vendor: "Advanced Micro Devices, Inc.".to_string(),
399                device_type: OpenCLDeviceType::GPU,
400                compute_units: 72,
401                max_work_group_size: 256,
402                max_work_item_dimensions: 3,
403                max_work_item_sizes: vec![256, 256, 256],
404                global_memory_size: 16 * (1 << 30), // 16GB
405                local_memory_size: 64 * 1024,
406                max_constant_buffer_size: 64 * 1024,
407                supports_double: true,
408                extensions: vec!["cl_khr_fp64".to_string(), "cl_amd_fp64".to_string()],
409            },
410        ];
411
412        Ok(devices)
413    }
414
415    /// Select optimal device
416    fn select_device(&self, devices: &[OpenCLDevice]) -> Result<OpenCLDevice> {
417        // Filter by device type
418        let filtered_devices: Vec<&OpenCLDevice> = devices
419            .iter()
420            .filter(|device| device.device_type == self.config.preferred_device_type)
421            .collect();
422
423        if filtered_devices.is_empty() {
424            return Err(SimulatorError::InitializationError(
425                "No suitable devices found".to_string(),
426            ));
427        }
428
429        // Select device with most compute units
430        let best_device = filtered_devices
431            .iter()
432            .max_by_key(|device| device.compute_units)
433            .unwrap();
434
435        Ok((*best_device).clone())
436    }
437
438    /// Compile OpenCL kernels
439    fn compile_kernels(&mut self) -> Result<()> {
440        let start_time = std::time::Instant::now();
441
442        // Single qubit gate kernel
443        let single_qubit_kernel = self.create_single_qubit_kernel();
444        self.kernels
445            .insert("single_qubit_gate".to_string(), single_qubit_kernel);
446
447        // Two qubit gate kernel
448        let two_qubit_kernel = self.create_two_qubit_kernel();
449        self.kernels
450            .insert("two_qubit_gate".to_string(), two_qubit_kernel);
451
452        // State vector operations kernel
453        let state_vector_kernel = self.create_state_vector_kernel();
454        self.kernels
455            .insert("state_vector_ops".to_string(), state_vector_kernel);
456
457        // Measurement kernel
458        let measurement_kernel = self.create_measurement_kernel();
459        self.kernels
460            .insert("measurement".to_string(), measurement_kernel);
461
462        // Expectation value kernel
463        let expectation_kernel = self.create_expectation_kernel();
464        self.kernels
465            .insert("expectation_value".to_string(), expectation_kernel);
466
467        self.stats.compilation_time = start_time.elapsed().as_secs_f64() * 1000.0;
468
469        Ok(())
470    }
471
472    /// Create single qubit gate kernel
473    fn create_single_qubit_kernel(&self) -> OpenCLKernel {
474        let source = r#"
475            #pragma OPENCL EXTENSION cl_khr_fp64 : enable
476
477            typedef double2 complex_t;
478
479            complex_t complex_mul(complex_t a, complex_t b) {
480                return (complex_t)(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
481            }
482
483            complex_t complex_add(complex_t a, complex_t b) {
484                return (complex_t)(a.x + b.x, a.y + b.y);
485            }
486
487            __kernel void single_qubit_gate(
488                __global complex_t* state,
489                __global const double* gate_matrix,
490                const int target_qubit,
491                const int num_qubits
492            ) {
493                const int global_id = get_global_id(0);
494                const int total_states = 1 << num_qubits;
495
496                if (global_id >= total_states / 2) return;
497
498                const int target_mask = 1 << target_qubit;
499                const int i = global_id;
500                const int j = i | target_mask;
501
502                if ((i & target_mask) == 0) {
503                    // Extract gate matrix elements
504                    complex_t gate_00 = (complex_t)(gate_matrix[0], gate_matrix[1]);
505                    complex_t gate_01 = (complex_t)(gate_matrix[2], gate_matrix[3]);
506                    complex_t gate_10 = (complex_t)(gate_matrix[4], gate_matrix[5]);
507                    complex_t gate_11 = (complex_t)(gate_matrix[6], gate_matrix[7]);
508
509                    complex_t state_i = state[i];
510                    complex_t state_j = state[j];
511
512                    state[i] = complex_add(complex_mul(gate_00, state_i), complex_mul(gate_01, state_j));
513                    state[j] = complex_add(complex_mul(gate_10, state_i), complex_mul(gate_11, state_j));
514                }
515            }
516        "#;
517
518        OpenCLKernel {
519            name: "single_qubit_gate".to_string(),
520            source: source.to_string(),
521            build_options: self.get_build_options(),
522            local_memory_usage: 0,
523            work_group_size: self.config.work_group_size,
524        }
525    }
526
527    /// Create two qubit gate kernel
528    fn create_two_qubit_kernel(&self) -> OpenCLKernel {
529        let source = r#"
530            #pragma OPENCL EXTENSION cl_khr_fp64 : enable
531
532            typedef double2 complex_t;
533
534            complex_t complex_mul(complex_t a, complex_t b) {
535                return (complex_t)(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
536            }
537
538            complex_t complex_add(complex_t a, complex_t b) {
539                return (complex_t)(a.x + b.x, a.y + b.y);
540            }
541
542            __kernel void two_qubit_gate(
543                __global complex_t* state,
544                __global const double* gate_matrix,
545                const int control_qubit,
546                const int target_qubit,
547                const int num_qubits
548            ) {
549                const int global_id = get_global_id(0);
550                const int total_states = 1 << num_qubits;
551
552                if (global_id >= total_states / 4) return;
553
554                const int control_mask = 1 << control_qubit;
555                const int target_mask = 1 << target_qubit;
556                const int both_mask = control_mask | target_mask;
557
558                int base = global_id;
559                // Remove bits at control and target positions
560                if (global_id & (target_mask - 1)) base = (base & ~(target_mask - 1)) << 1 | (base & (target_mask - 1));
561                if (base & (control_mask - 1)) base = (base & ~(control_mask - 1)) << 1 | (base & (control_mask - 1));
562
563                int state_00 = base;
564                int state_01 = base | target_mask;
565                int state_10 = base | control_mask;
566                int state_11 = base | both_mask;
567
568                // Load gate matrix (16 elements for 4x4 matrix)
569                complex_t gate[4][4];
570                for (int i = 0; i < 4; i++) {
571                    for (int j = 0; j < 4; j++) {
572                        gate[i][j] = (complex_t)(gate_matrix[(i*4+j)*2], gate_matrix[(i*4+j)*2+1]);
573                    }
574                }
575
576                complex_t old_states[4];
577                old_states[0] = state[state_00];
578                old_states[1] = state[state_01];
579                old_states[2] = state[state_10];
580                old_states[3] = state[state_11];
581
582                // Apply gate matrix
583                complex_t new_states[4] = {0};
584                for (int i = 0; i < 4; i++) {
585                    for (int j = 0; j < 4; j++) {
586                        new_states[i] = complex_add(new_states[i], complex_mul(gate[i][j], old_states[j]));
587                    }
588                }
589
590                state[state_00] = new_states[0];
591                state[state_01] = new_states[1];
592                state[state_10] = new_states[2];
593                state[state_11] = new_states[3];
594            }
595        "#;
596
597        OpenCLKernel {
598            name: "two_qubit_gate".to_string(),
599            source: source.to_string(),
600            build_options: self.get_build_options(),
601            local_memory_usage: 128, // Local memory for gate matrix
602            work_group_size: self.config.work_group_size,
603        }
604    }
605
606    /// Create state vector operations kernel
607    fn create_state_vector_kernel(&self) -> OpenCLKernel {
608        let source = r#"
609            #pragma OPENCL EXTENSION cl_khr_fp64 : enable
610
611            typedef double2 complex_t;
612
613            __kernel void normalize_state(
614                __global complex_t* state,
615                const int num_states,
616                const double norm_factor
617            ) {
618                const int global_id = get_global_id(0);
619
620                if (global_id >= num_states) return;
621
622                state[global_id].x *= norm_factor;
623                state[global_id].y *= norm_factor;
624            }
625
626            __kernel void compute_probabilities(
627                __global const complex_t* state,
628                __global double* probabilities,
629                const int num_states
630            ) {
631                const int global_id = get_global_id(0);
632
633                if (global_id >= num_states) return;
634
635                complex_t amplitude = state[global_id];
636                probabilities[global_id] = amplitude.x * amplitude.x + amplitude.y * amplitude.y;
637            }
638
639            __kernel void inner_product(
640                __global const complex_t* state1,
641                __global const complex_t* state2,
642                __global complex_t* partial_results,
643                __local complex_t* local_data,
644                const int num_states
645            ) {
646                const int global_id = get_global_id(0);
647                const int local_id = get_local_id(0);
648                const int local_size = get_local_size(0);
649                const int group_id = get_group_id(0);
650
651                // Initialize local memory
652                if (global_id < num_states) {
653                    complex_t a = state1[global_id];
654                    complex_t b = state2[global_id];
655                    // Conjugate of a times b
656                    local_data[local_id] = (complex_t)(a.x * b.x + a.y * b.y, a.x * b.y - a.y * b.x);
657                } else {
658                    local_data[local_id] = (complex_t)(0.0, 0.0);
659                }
660
661                barrier(CLK_LOCAL_MEM_FENCE);
662
663                // Reduction
664                for (int stride = local_size / 2; stride > 0; stride /= 2) {
665                    if (local_id < stride) {
666                        local_data[local_id].x += local_data[local_id + stride].x;
667                        local_data[local_id].y += local_data[local_id + stride].y;
668                    }
669                    barrier(CLK_LOCAL_MEM_FENCE);
670                }
671
672                if (local_id == 0) {
673                    partial_results[group_id] = local_data[0];
674                }
675            }
676        "#;
677
678        OpenCLKernel {
679            name: "state_vector_ops".to_string(),
680            source: source.to_string(),
681            build_options: self.get_build_options(),
682            local_memory_usage: self.config.work_group_size * 16, // Complex doubles
683            work_group_size: self.config.work_group_size,
684        }
685    }
686
687    /// Create measurement kernel
688    fn create_measurement_kernel(&self) -> OpenCLKernel {
689        let source = r#"
690            #pragma OPENCL EXTENSION cl_khr_fp64 : enable
691
692            typedef double2 complex_t;
693
694            __kernel void measure_qubit(
695                __global complex_t* state,
696                __global double* probabilities,
697                const int target_qubit,
698                const int num_qubits,
699                const int measurement_result
700            ) {
701                const int global_id = get_global_id(0);
702                const int total_states = 1 << num_qubits;
703
704                if (global_id >= total_states) return;
705
706                const int target_mask = 1 << target_qubit;
707                const int qubit_value = (global_id & target_mask) ? 1 : 0;
708
709                if (qubit_value != measurement_result) {
710                    // Set amplitude to zero for inconsistent measurement
711                    state[global_id] = (complex_t)(0.0, 0.0);
712                }
713            }
714
715            __kernel void compute_measurement_probabilities(
716                __global const complex_t* state,
717                __global double* prob_0,
718                __global double* prob_1,
719                __local double* local_data,
720                const int target_qubit,
721                const int num_qubits
722            ) {
723                const int global_id = get_global_id(0);
724                const int local_id = get_local_id(0);
725                const int local_size = get_local_size(0);
726                const int group_id = get_group_id(0);
727                const int total_states = 1 << num_qubits;
728
729                double local_prob_0 = 0.0;
730                double local_prob_1 = 0.0;
731
732                if (global_id < total_states) {
733                    const int target_mask = 1 << target_qubit;
734                    complex_t amplitude = state[global_id];
735                    double prob = amplitude.x * amplitude.x + amplitude.y * amplitude.y;
736
737                    if (global_id & target_mask) {
738                        local_prob_1 = prob;
739                    } else {
740                        local_prob_0 = prob;
741                    }
742                }
743
744                local_data[local_id * 2] = local_prob_0;
745                local_data[local_id * 2 + 1] = local_prob_1;
746
747                barrier(CLK_LOCAL_MEM_FENCE);
748
749                // Reduction
750                for (int stride = local_size / 2; stride > 0; stride /= 2) {
751                    if (local_id < stride) {
752                        local_data[local_id * 2] += local_data[(local_id + stride) * 2];
753                        local_data[local_id * 2 + 1] += local_data[(local_id + stride) * 2 + 1];
754                    }
755                    barrier(CLK_LOCAL_MEM_FENCE);
756                }
757
758                if (local_id == 0) {
759                    prob_0[group_id] = local_data[0];
760                    prob_1[group_id] = local_data[1];
761                }
762            }
763        "#;
764
765        OpenCLKernel {
766            name: "measurement".to_string(),
767            source: source.to_string(),
768            build_options: self.get_build_options(),
769            local_memory_usage: self.config.work_group_size * 16, // 2 doubles per work item
770            work_group_size: self.config.work_group_size,
771        }
772    }
773
774    /// Create expectation value kernel
775    fn create_expectation_kernel(&self) -> OpenCLKernel {
776        let source = r#"
777            #pragma OPENCL EXTENSION cl_khr_fp64 : enable
778
779            typedef double2 complex_t;
780
781            complex_t complex_mul(complex_t a, complex_t b) {
782                return (complex_t)(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
783            }
784
785            __kernel void expectation_value_pauli(
786                __global const complex_t* state,
787                __global double* partial_results,
788                __local double* local_data,
789                const int pauli_string,
790                const int num_qubits
791            ) {
792                const int global_id = get_global_id(0);
793                const int local_id = get_local_id(0);
794                const int local_size = get_local_size(0);
795                const int group_id = get_group_id(0);
796                const int total_states = 1 << num_qubits;
797
798                double local_expectation = 0.0;
799
800                if (global_id < total_states) {
801                    complex_t amplitude = state[global_id];
802
803                    // Apply Pauli operators
804                    int target_state = global_id;
805                    complex_t result_amplitude = amplitude;
806                    double sign = 1.0;
807
808                    // Process each Pauli operator in the string
809                    for (int qubit = 0; qubit < num_qubits; qubit++) {
810                        int pauli_op = (pauli_string >> (2 * qubit)) & 3;
811                        int qubit_mask = 1 << qubit;
812
813                        switch (pauli_op) {
814                            case 0: // I (identity)
815                                break;
816                            case 1: // X (bit flip)
817                                target_state ^= qubit_mask;
818                                break;
819                            case 2: // Y (bit and phase flip)
820                                target_state ^= qubit_mask;
821                                if (global_id & qubit_mask) sign *= -1.0;
822                                else result_amplitude = (complex_t)(-result_amplitude.y, result_amplitude.x);
823                                break;
824                            case 3: // Z (phase flip)
825                                if (global_id & qubit_mask) sign *= -1.0;
826                                break;
827                        }
828                    }
829
830                    if (target_state == global_id) {
831                        // Diagonal element
832                        local_expectation = sign * (amplitude.x * amplitude.x + amplitude.y * amplitude.y);
833                    }
834                }
835
836                local_data[local_id] = local_expectation;
837                barrier(CLK_LOCAL_MEM_FENCE);
838
839                // Reduction
840                for (int stride = local_size / 2; stride > 0; stride /= 2) {
841                    if (local_id < stride) {
842                        local_data[local_id] += local_data[local_id + stride];
843                    }
844                    barrier(CLK_LOCAL_MEM_FENCE);
845                }
846
847                if (local_id == 0) {
848                    partial_results[group_id] = local_data[0];
849                }
850            }
851        "#;
852
853        OpenCLKernel {
854            name: "expectation_value".to_string(),
855            source: source.to_string(),
856            build_options: self.get_build_options(),
857            local_memory_usage: self.config.work_group_size * 8, // Double per work item
858            work_group_size: self.config.work_group_size,
859        }
860    }
861
862    /// Get build options for kernel compilation
863    fn get_build_options(&self) -> String {
864        let mut options = Vec::new();
865
866        match self.config.optimization_level {
867            OptimizationLevel::None => options.push("-O0"),
868            OptimizationLevel::Basic => options.push("-O1"),
869            OptimizationLevel::Standard => options.push("-O2"),
870            OptimizationLevel::Aggressive => options.push("-O3"),
871        }
872
873        // Add AMD-specific optimizations
874        options.push("-cl-mad-enable");
875        options.push("-cl-fast-relaxed-math");
876
877        // Double precision support
878        if let Some(device) = &self.device {
879            if device.supports_double {
880                options.push("-cl-fp64");
881            }
882        }
883
884        options.join(" ")
885    }
886
887    /// Create memory buffer
888    pub fn create_buffer(&mut self, name: &str, size: usize, flags: MemoryFlags) -> Result<()> {
889        if size > self.config.max_buffer_size {
890            return Err(SimulatorError::MemoryError(format!(
891                "Buffer size {} exceeds maximum {}",
892                size, self.config.max_buffer_size
893            )));
894        }
895
896        let buffer = OpenCLBuffer {
897            buffer_id: self.buffers.len(),
898            size,
899            flags,
900            host_data: Some(vec![0u8; size]),
901        };
902
903        self.buffers.insert(name.to_string(), buffer);
904        self.stats.gpu_memory_usage += size as u64;
905
906        Ok(())
907    }
908
909    /// Execute kernel
910    pub fn execute_kernel(
911        &mut self,
912        kernel_name: &str,
913        global_work_size: &[usize],
914        local_work_size: Option<&[usize]>,
915        args: &[KernelArg],
916    ) -> Result<f64> {
917        let start_time = std::time::Instant::now();
918
919        if !self.kernels.contains_key(kernel_name) {
920            return Err(SimulatorError::InvalidInput(format!(
921                "Kernel {} not found",
922                kernel_name
923            )));
924        }
925
926        // Simulate kernel execution
927        let execution_time = self.simulate_kernel_execution(kernel_name, global_work_size, args)?;
928
929        let total_time = start_time.elapsed().as_secs_f64() * 1000.0;
930        self.stats.update_kernel_execution(total_time);
931
932        match kernel_name {
933            "single_qubit_gate" | "two_qubit_gate" => {
934                self.stats.gate_operations += 1;
935            }
936            "state_vector_ops" | "normalize_state" | "compute_probabilities" => {
937                self.stats.state_vector_operations += 1;
938            }
939            _ => {}
940        }
941
942        Ok(execution_time)
943    }
944
945    /// Simulate kernel execution (for demonstration)
946    fn simulate_kernel_execution(
947        &mut self,
948        kernel_name: &str,
949        global_work_size: &[usize],
950        _args: &[KernelArg],
951    ) -> Result<f64> {
952        let total_work_items: usize = global_work_size.iter().product();
953
954        // Simulate execution time based on work items and device capabilities
955        let device = self.device.as_ref().unwrap();
956        let work_groups =
957            (total_work_items + self.config.work_group_size - 1) / self.config.work_group_size;
958        let parallel_work_groups = device.compute_units as usize;
959
960        let execution_cycles = (work_groups + parallel_work_groups - 1) / parallel_work_groups;
961
962        // Base execution time per cycle (microseconds)
963        let base_time_per_cycle = match kernel_name {
964            "single_qubit_gate" => 1.0,
965            "two_qubit_gate" => 2.5,
966            "state_vector_ops" => 0.5,
967            "measurement" => 1.5,
968            "expectation_value" => 2.0,
969            _ => 1.0,
970        };
971
972        let execution_time = execution_cycles as f64 * base_time_per_cycle;
973
974        // Add random variation
975        let variation = fastrand::f64() * 0.2 + 0.9; // 90-110% of base time
976        Ok(execution_time * variation)
977    }
978
979    /// Apply single qubit gate using OpenCL
980    pub fn apply_single_qubit_gate_opencl(
981        &mut self,
982        gate_matrix: &[Complex64; 4],
983        target_qubit: usize,
984        num_qubits: usize,
985    ) -> Result<f64> {
986        // Convert gate matrix to real array for OpenCL
987        let mut gate_real = [0.0; 8];
988        for (i, &complex_val) in gate_matrix.iter().enumerate() {
989            gate_real[i * 2] = complex_val.re;
990            gate_real[i * 2 + 1] = complex_val.im;
991        }
992
993        let total_states = 1 << num_qubits;
994        let global_work_size = vec![total_states / 2];
995
996        let args = vec![
997            KernelArg::Buffer("state".to_string()),
998            KernelArg::ConstantBuffer("gate_matrix".to_string()),
999            KernelArg::Int(target_qubit as i32),
1000            KernelArg::Int(num_qubits as i32),
1001        ];
1002
1003        self.execute_kernel("single_qubit_gate", &global_work_size, None, &args)
1004    }
1005
1006    /// Apply two qubit gate using OpenCL
1007    pub fn apply_two_qubit_gate_opencl(
1008        &mut self,
1009        gate_matrix: &[Complex64; 16],
1010        control_qubit: usize,
1011        target_qubit: usize,
1012        num_qubits: usize,
1013    ) -> Result<f64> {
1014        // Convert gate matrix to real array for OpenCL
1015        let mut gate_real = [0.0; 32];
1016        for (i, &complex_val) in gate_matrix.iter().enumerate() {
1017            gate_real[i * 2] = complex_val.re;
1018            gate_real[i * 2 + 1] = complex_val.im;
1019        }
1020
1021        let total_states = 1 << num_qubits;
1022        let global_work_size = vec![total_states / 4];
1023
1024        let args = vec![
1025            KernelArg::Buffer("state".to_string()),
1026            KernelArg::ConstantBuffer("gate_matrix".to_string()),
1027            KernelArg::Int(control_qubit as i32),
1028            KernelArg::Int(target_qubit as i32),
1029            KernelArg::Int(num_qubits as i32),
1030        ];
1031
1032        self.execute_kernel("two_qubit_gate", &global_work_size, None, &args)
1033    }
1034
1035    /// Compute expectation value using OpenCL
1036    pub fn compute_expectation_value_opencl(
1037        &mut self,
1038        pauli_string: u32,
1039        num_qubits: usize,
1040    ) -> Result<(f64, f64)> {
1041        let total_states = 1 << num_qubits;
1042        let global_work_size = vec![total_states];
1043
1044        let args = vec![
1045            KernelArg::Buffer("state".to_string()),
1046            KernelArg::Buffer("partial_results".to_string()),
1047            KernelArg::LocalMemory(self.config.work_group_size * 8),
1048            KernelArg::Int(pauli_string as i32),
1049            KernelArg::Int(num_qubits as i32),
1050        ];
1051
1052        let execution_time = self.execute_kernel(
1053            "expectation_value",
1054            &global_work_size,
1055            Some(&[self.config.work_group_size]),
1056            &args,
1057        )?;
1058
1059        // Simulate expectation value result
1060        let expectation_value = fastrand::f64() * 2.0 - 1.0; // Random value between -1 and 1
1061
1062        Ok((expectation_value, execution_time))
1063    }
1064
1065    /// Get device information
1066    pub fn get_device_info(&self) -> Option<&OpenCLDevice> {
1067        self.device.as_ref()
1068    }
1069
1070    /// Get performance statistics
1071    pub fn get_stats(&self) -> &OpenCLStats {
1072        &self.stats
1073    }
1074
1075    /// Reset performance statistics
1076    pub fn reset_stats(&mut self) {
1077        self.stats = OpenCLStats::default();
1078    }
1079
1080    /// Check if OpenCL is available
1081    pub fn is_opencl_available(&self) -> bool {
1082        self.context.is_some() && self.device.is_some()
1083    }
1084
1085    /// Fallback to CPU simulation
1086    pub fn fallback_to_cpu(&mut self, num_qubits: usize) -> Result<()> {
1087        if self.config.enable_cpu_fallback {
1088            self.cpu_fallback = Some(StateVectorSimulator::new());
1089            self.stats.cpu_fallback_count += 1;
1090            Ok(())
1091        } else {
1092            Err(SimulatorError::OperationNotSupported(
1093                "CPU fallback disabled".to_string(),
1094            ))
1095        }
1096    }
1097}
1098
1099/// Kernel argument types
1100#[derive(Debug, Clone)]
1101pub enum KernelArg {
1102    Buffer(String),
1103    ConstantBuffer(String),
1104    Int(i32),
1105    Float(f32),
1106    Double(f64),
1107    LocalMemory(usize),
1108}
1109
1110/// Benchmark AMD OpenCL backend performance
1111pub fn benchmark_amd_opencl_backend() -> Result<HashMap<String, f64>> {
1112    let mut results = HashMap::new();
1113
1114    // Test different configurations
1115    let configs = vec![
1116        OpenCLConfig {
1117            work_group_size: 64,
1118            optimization_level: OptimizationLevel::Standard,
1119            ..Default::default()
1120        },
1121        OpenCLConfig {
1122            work_group_size: 128,
1123            optimization_level: OptimizationLevel::Aggressive,
1124            ..Default::default()
1125        },
1126        OpenCLConfig {
1127            work_group_size: 256,
1128            optimization_level: OptimizationLevel::Aggressive,
1129            ..Default::default()
1130        },
1131    ];
1132
1133    for (i, config) in configs.into_iter().enumerate() {
1134        let start = std::time::Instant::now();
1135
1136        let mut simulator = AMDOpenCLSimulator::new(config)?;
1137
1138        // Benchmark single qubit gates
1139        let single_qubit_matrix = [
1140            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1141            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1142            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1143            Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
1144        ];
1145
1146        for num_qubits in [10, 15, 20] {
1147            simulator.create_buffer("state", (1 << num_qubits) * 16, MemoryFlags::ReadWrite)?;
1148
1149            for qubit in 0..num_qubits.min(5) {
1150                let _time = simulator.apply_single_qubit_gate_opencl(
1151                    &single_qubit_matrix,
1152                    qubit,
1153                    num_qubits,
1154                )?;
1155            }
1156        }
1157
1158        // Benchmark two qubit gates
1159        let cnot_matrix = [
1160            Complex64::new(1.0, 0.0),
1161            Complex64::new(0.0, 0.0),
1162            Complex64::new(0.0, 0.0),
1163            Complex64::new(0.0, 0.0),
1164            Complex64::new(0.0, 0.0),
1165            Complex64::new(1.0, 0.0),
1166            Complex64::new(0.0, 0.0),
1167            Complex64::new(0.0, 0.0),
1168            Complex64::new(0.0, 0.0),
1169            Complex64::new(0.0, 0.0),
1170            Complex64::new(0.0, 0.0),
1171            Complex64::new(1.0, 0.0),
1172            Complex64::new(0.0, 0.0),
1173            Complex64::new(0.0, 0.0),
1174            Complex64::new(1.0, 0.0),
1175            Complex64::new(0.0, 0.0),
1176        ];
1177
1178        for num_qubits in [10usize, 15, 20] {
1179            for pair in 0..num_qubits.saturating_sub(1).min(3) {
1180                let _time = simulator.apply_two_qubit_gate_opencl(
1181                    &cnot_matrix,
1182                    pair,
1183                    pair + 1,
1184                    num_qubits,
1185                )?;
1186            }
1187        }
1188
1189        // Benchmark expectation values
1190        for num_qubits in [10, 15, 20] {
1191            let _result = simulator.compute_expectation_value_opencl(0b1010, num_qubits)?;
1192        }
1193
1194        let time = start.elapsed().as_secs_f64() * 1000.0;
1195        results.insert(format!("config_{}", i), time);
1196
1197        // Add performance metrics
1198        let stats = simulator.get_stats();
1199        results.insert(
1200            format!("config_{}_gate_ops", i),
1201            stats.gate_operations as f64,
1202        );
1203        results.insert(
1204            format!("config_{}_avg_kernel_time", i),
1205            stats.avg_kernel_time,
1206        );
1207        results.insert(
1208            format!("config_{}_gpu_utilization", i),
1209            stats.gpu_utilization,
1210        );
1211    }
1212
1213    Ok(results)
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218    use super::*;
1219    use approx::assert_abs_diff_eq;
1220
1221    #[test]
1222    fn test_opencl_simulator_creation() {
1223        let config = OpenCLConfig::default();
1224        let simulator = AMDOpenCLSimulator::new(config);
1225        assert!(simulator.is_ok());
1226    }
1227
1228    #[test]
1229    fn test_platform_discovery() {
1230        let config = OpenCLConfig::default();
1231        let simulator = AMDOpenCLSimulator::new(config).unwrap();
1232        let platforms = simulator.discover_platforms().unwrap();
1233
1234        assert!(!platforms.is_empty());
1235        assert!(platforms
1236            .iter()
1237            .any(|p| p.vendor.contains("Advanced Micro Devices")));
1238    }
1239
1240    #[test]
1241    fn test_device_discovery() {
1242        let config = OpenCLConfig::default();
1243        let simulator = AMDOpenCLSimulator::new(config).unwrap();
1244        let devices = simulator.discover_devices().unwrap();
1245
1246        assert!(!devices.is_empty());
1247        assert!(devices
1248            .iter()
1249            .any(|d| d.device_type == OpenCLDeviceType::GPU));
1250    }
1251
1252    #[test]
1253    fn test_kernel_creation() {
1254        let config = OpenCLConfig::default();
1255        let simulator = AMDOpenCLSimulator::new(config).unwrap();
1256
1257        assert!(simulator.kernels.contains_key("single_qubit_gate"));
1258        assert!(simulator.kernels.contains_key("two_qubit_gate"));
1259        assert!(simulator.kernels.contains_key("state_vector_ops"));
1260        assert!(simulator.kernels.contains_key("measurement"));
1261        assert!(simulator.kernels.contains_key("expectation_value"));
1262    }
1263
1264    #[test]
1265    fn test_buffer_creation() {
1266        let config = OpenCLConfig::default();
1267        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1268
1269        let result = simulator.create_buffer("test_buffer", 1024, MemoryFlags::ReadWrite);
1270        assert!(result.is_ok());
1271        assert!(simulator.buffers.contains_key("test_buffer"));
1272        assert_eq!(simulator.stats.gpu_memory_usage, 1024);
1273    }
1274
1275    #[test]
1276    fn test_buffer_size_limit() {
1277        let config = OpenCLConfig {
1278            max_buffer_size: 512,
1279            ..Default::default()
1280        };
1281        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1282
1283        let result = simulator.create_buffer("large_buffer", 1024, MemoryFlags::ReadWrite);
1284        assert!(result.is_err());
1285    }
1286
1287    #[test]
1288    fn test_kernel_execution() {
1289        let config = OpenCLConfig::default();
1290        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1291
1292        let global_work_size = vec![256];
1293        let args = vec![
1294            KernelArg::Buffer("state".to_string()),
1295            KernelArg::Int(0),
1296            KernelArg::Int(8),
1297        ];
1298
1299        let result = simulator.execute_kernel("single_qubit_gate", &global_work_size, None, &args);
1300        assert!(result.is_ok());
1301
1302        let execution_time = result.unwrap();
1303        assert!(execution_time > 0.0);
1304    }
1305
1306    #[test]
1307    fn test_single_qubit_gate_application() {
1308        let config = OpenCLConfig::default();
1309        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1310
1311        let hadamard_matrix = [
1312            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1313            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1314            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1315            Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
1316        ];
1317
1318        simulator
1319            .create_buffer("state", 1024 * 16, MemoryFlags::ReadWrite)
1320            .unwrap();
1321
1322        let result = simulator.apply_single_qubit_gate_opencl(&hadamard_matrix, 0, 8);
1323        assert!(result.is_ok());
1324
1325        let execution_time = result.unwrap();
1326        assert!(execution_time > 0.0);
1327    }
1328
1329    #[test]
1330    fn test_two_qubit_gate_application() {
1331        let config = OpenCLConfig::default();
1332        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1333
1334        let cnot_matrix = [
1335            Complex64::new(1.0, 0.0),
1336            Complex64::new(0.0, 0.0),
1337            Complex64::new(0.0, 0.0),
1338            Complex64::new(0.0, 0.0),
1339            Complex64::new(0.0, 0.0),
1340            Complex64::new(1.0, 0.0),
1341            Complex64::new(0.0, 0.0),
1342            Complex64::new(0.0, 0.0),
1343            Complex64::new(0.0, 0.0),
1344            Complex64::new(0.0, 0.0),
1345            Complex64::new(0.0, 0.0),
1346            Complex64::new(1.0, 0.0),
1347            Complex64::new(0.0, 0.0),
1348            Complex64::new(0.0, 0.0),
1349            Complex64::new(1.0, 0.0),
1350            Complex64::new(0.0, 0.0),
1351        ];
1352
1353        simulator
1354            .create_buffer("state", 1024 * 16, MemoryFlags::ReadWrite)
1355            .unwrap();
1356
1357        let result = simulator.apply_two_qubit_gate_opencl(&cnot_matrix, 0, 1, 8);
1358        assert!(result.is_ok());
1359
1360        let execution_time = result.unwrap();
1361        assert!(execution_time > 0.0);
1362    }
1363
1364    #[test]
1365    fn test_expectation_value_computation() {
1366        let config = OpenCLConfig::default();
1367        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1368
1369        simulator
1370            .create_buffer("state", 1024 * 16, MemoryFlags::ReadWrite)
1371            .unwrap();
1372        simulator
1373            .create_buffer("partial_results", 64 * 8, MemoryFlags::ReadWrite)
1374            .unwrap();
1375
1376        let result = simulator.compute_expectation_value_opencl(0b1010, 8);
1377        assert!(result.is_ok());
1378
1379        let (expectation, execution_time) = result.unwrap();
1380        assert!(expectation >= -1.0 && expectation <= 1.0);
1381        assert!(execution_time > 0.0);
1382    }
1383
1384    #[test]
1385    fn test_build_options() {
1386        let config = OpenCLConfig {
1387            optimization_level: OptimizationLevel::Aggressive,
1388            ..Default::default()
1389        };
1390        let simulator = AMDOpenCLSimulator::new(config).unwrap();
1391
1392        let build_options = simulator.get_build_options();
1393        assert!(build_options.contains("-O3"));
1394        assert!(build_options.contains("-cl-mad-enable"));
1395        assert!(build_options.contains("-cl-fast-relaxed-math"));
1396    }
1397
1398    #[test]
1399    fn test_stats_update() {
1400        let config = OpenCLConfig::default();
1401        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1402
1403        simulator.stats.update_kernel_execution(10.0);
1404        simulator.stats.update_kernel_execution(20.0);
1405
1406        assert_eq!(simulator.stats.total_kernel_executions, 2);
1407        assert_abs_diff_eq!(simulator.stats.total_execution_time, 30.0, epsilon = 1e-10);
1408        assert_abs_diff_eq!(simulator.stats.avg_kernel_time, 15.0, epsilon = 1e-10);
1409    }
1410
1411    #[test]
1412    fn test_performance_metrics() {
1413        let config = OpenCLConfig::default();
1414        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1415
1416        simulator.stats.total_kernel_executions = 100;
1417        simulator.stats.total_execution_time = 1000.0; // 1 second
1418        simulator.stats.gpu_memory_usage = 1000000000; // 1GB
1419        simulator.stats.memory_transfer_time = 100.0; // 0.1 second
1420        simulator.stats.gpu_utilization = 85.0;
1421
1422        let metrics = simulator.stats.get_performance_metrics();
1423
1424        assert!(metrics.contains_key("kernel_executions_per_second"));
1425        assert!(metrics.contains_key("memory_bandwidth_gb_s"));
1426        assert!(metrics.contains_key("gpu_efficiency"));
1427
1428        assert_abs_diff_eq!(
1429            metrics["kernel_executions_per_second"],
1430            100.0,
1431            epsilon = 1e-10
1432        );
1433        assert_abs_diff_eq!(metrics["gpu_efficiency"], 0.85, epsilon = 1e-10);
1434    }
1435
1436    #[test]
1437    fn test_cpu_fallback() {
1438        let config = OpenCLConfig {
1439            enable_cpu_fallback: true,
1440            ..Default::default()
1441        };
1442        let mut simulator = AMDOpenCLSimulator::new(config).unwrap();
1443
1444        let result = simulator.fallback_to_cpu(10);
1445        assert!(result.is_ok());
1446        assert_eq!(simulator.stats.cpu_fallback_count, 1);
1447        assert!(simulator.cpu_fallback.is_some());
1448    }
1449
1450    #[test]
1451    fn test_device_selection() {
1452        let config = OpenCLConfig {
1453            preferred_device_type: OpenCLDeviceType::GPU,
1454            ..Default::default()
1455        };
1456        let simulator = AMDOpenCLSimulator::new(config).unwrap();
1457
1458        let device_info = simulator.get_device_info().unwrap();
1459        assert_eq!(device_info.device_type, OpenCLDeviceType::GPU);
1460        assert!(device_info.name.contains("Radeon"));
1461        assert_eq!(device_info.vendor, "Advanced Micro Devices, Inc.");
1462    }
1463}