quantrs2_core/gpu/
adaptive_simd.rs

1//! Adaptive SIMD dispatch based on CPU capabilities detection
2//!
3//! This module provides runtime detection of CPU capabilities and dispatches
4//! to the most optimized SIMD implementation available on the target hardware.
5
6use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::platform::PlatformCapabilities;
8use scirs2_core::Complex64;
9use std::sync::{Mutex, OnceLock};
10// use scirs2_core::simd_ops::SimdUnifiedOps;
11use crate::simd_ops_stubs::SimdF64;
12use scirs2_core::ndarray::ArrayView1;
13
14/// CPU feature detection results
15#[derive(Debug, Clone, Copy)]
16pub struct CpuFeatures {
17    /// AVX2 support (256-bit vectors)
18    pub has_avx2: bool,
19    /// AVX-512 support (512-bit vectors)
20    pub has_avx512: bool,
21    /// FMA (Fused Multiply-Add) support
22    pub has_fma: bool,
23    /// AVX-512 VL (Vector Length) support
24    pub has_avx512vl: bool,
25    /// AVX-512 DQ (Doubleword and Quadword) support
26    pub has_avx512dq: bool,
27    /// AVX-512 CD (Conflict Detection) support
28    pub has_avx512cd: bool,
29    /// SSE 4.1 support
30    pub has_sse41: bool,
31    /// SSE 4.2 support
32    pub has_sse42: bool,
33    /// Number of CPU cores
34    pub num_cores: usize,
35    /// L1 cache size per core (in bytes)
36    pub l1_cache_size: usize,
37    /// L2 cache size per core (in bytes)
38    pub l2_cache_size: usize,
39    /// L3 cache size (in bytes)
40    pub l3_cache_size: usize,
41}
42
43/// SIMD implementation variants
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum SimdVariant {
46    /// Scalar fallback implementation
47    Scalar,
48    /// SSE 4.1/4.2 implementation
49    Sse4,
50    /// AVX2 implementation (256-bit)
51    Avx2,
52    /// AVX-512 implementation (512-bit)
53    Avx512,
54}
55
56/// Adaptive SIMD dispatcher
57pub struct AdaptiveSimdDispatcher {
58    /// Detected CPU features
59    cpu_features: CpuFeatures,
60    /// Selected SIMD variant
61    selected_variant: SimdVariant,
62    /// Performance cache for different operation sizes
63    performance_cache: Mutex<std::collections::HashMap<String, PerformanceData>>,
64}
65
66/// Performance data for SIMD operations
67#[derive(Debug, Clone)]
68pub struct PerformanceData {
69    /// Average execution time (nanoseconds)
70    avg_time: f64,
71    /// Number of samples
72    samples: usize,
73    /// Best SIMD variant for this operation size
74    best_variant: SimdVariant,
75}
76
77/// Global dispatcher instance
78static GLOBAL_DISPATCHER: OnceLock<AdaptiveSimdDispatcher> = OnceLock::new();
79
80impl AdaptiveSimdDispatcher {
81    /// Initialize the global adaptive SIMD dispatcher
82    pub fn initialize() -> QuantRS2Result<()> {
83        let cpu_features = Self::detect_cpu_features();
84        let selected_variant = Self::select_optimal_variant(&cpu_features);
85
86        let dispatcher = Self {
87            cpu_features,
88            selected_variant,
89            performance_cache: Mutex::new(std::collections::HashMap::new()),
90        };
91
92        GLOBAL_DISPATCHER.set(dispatcher).map_err(|_| {
93            QuantRS2Error::RuntimeError("Adaptive SIMD dispatcher already initialized".to_string())
94        })?;
95
96        Ok(())
97    }
98
99    /// Get the global dispatcher instance
100    pub fn instance() -> QuantRS2Result<&'static Self> {
101        GLOBAL_DISPATCHER.get().ok_or_else(|| {
102            QuantRS2Error::RuntimeError("Adaptive SIMD dispatcher not initialized".to_string())
103        })
104    }
105
106    /// Detect CPU features at runtime
107    fn detect_cpu_features() -> CpuFeatures {
108        let platform = PlatformCapabilities::detect();
109
110        CpuFeatures {
111            has_avx2: platform.cpu.simd.avx2,
112            has_avx512: platform.cpu.simd.avx512,
113            has_fma: platform.cpu.simd.fma,
114            has_avx512vl: false, // Not detected in current platform capabilities
115            has_avx512dq: false, // Not detected in current platform capabilities
116            has_avx512cd: false, // Not detected in current platform capabilities
117            has_sse41: platform.cpu.simd.sse4_1,
118            has_sse42: platform.cpu.simd.sse4_2,
119            num_cores: platform.cpu.logical_cores,
120            l1_cache_size: platform.cpu.cache.l1_data.unwrap_or(32 * 1024),
121            l2_cache_size: platform.cpu.cache.l2.unwrap_or(256 * 1024),
122            l3_cache_size: platform.cpu.cache.l3.unwrap_or(8 * 1024 * 1024),
123        }
124    }
125
126    /// Select the optimal SIMD variant based on CPU features
127    const fn select_optimal_variant(features: &CpuFeatures) -> SimdVariant {
128        if features.has_avx512 && features.has_avx512vl && features.has_avx512dq {
129            SimdVariant::Avx512
130        } else if features.has_avx2 && features.has_fma {
131            SimdVariant::Avx2
132        } else if features.has_sse41 && features.has_sse42 {
133            SimdVariant::Sse4
134        } else {
135            SimdVariant::Scalar
136        }
137    }
138
139    /// Apply a single-qubit gate with adaptive SIMD
140    pub fn apply_single_qubit_gate_adaptive(
141        &self,
142        state: &mut [Complex64],
143        target: usize,
144        matrix: &[Complex64; 4],
145    ) -> QuantRS2Result<()> {
146        let operation_key = format!("single_qubit_{}", state.len());
147        let variant = self.select_variant_for_operation(&operation_key, state.len());
148
149        let start_time = std::time::Instant::now();
150
151        let result = match variant {
152            SimdVariant::Avx512 | SimdVariant::Avx2 | SimdVariant::Sse4 => {
153                self.apply_single_qubit_sse4(state, target, matrix) // Fallback to SSE4
154            }
155            SimdVariant::Scalar => self.apply_single_qubit_scalar(state, target, matrix),
156        };
157
158        let execution_time = start_time.elapsed().as_nanos() as f64;
159        self.update_performance_cache(&operation_key, execution_time, variant);
160
161        result
162    }
163
164    /// Apply a two-qubit gate with adaptive SIMD
165    pub fn apply_two_qubit_gate_adaptive(
166        &self,
167        state: &mut [Complex64],
168        control: usize,
169        target: usize,
170        matrix: &[Complex64; 16],
171    ) -> QuantRS2Result<()> {
172        let operation_key = format!("two_qubit_{}", state.len());
173        let variant = self.select_variant_for_operation(&operation_key, state.len());
174
175        let start_time = std::time::Instant::now();
176
177        let result = match variant {
178            SimdVariant::Avx512 => self.apply_two_qubit_avx512(state, control, target, matrix),
179            SimdVariant::Avx2 => self.apply_two_qubit_avx2(state, control, target, matrix),
180            SimdVariant::Sse4 => self.apply_two_qubit_sse4(state, control, target, matrix),
181            SimdVariant::Scalar => self.apply_two_qubit_scalar(state, control, target, matrix),
182        };
183
184        let execution_time = start_time.elapsed().as_nanos() as f64;
185        self.update_performance_cache(&operation_key, execution_time, variant);
186
187        result
188    }
189
190    /// Batch apply gates with adaptive SIMD
191    pub fn apply_batch_gates_adaptive(
192        &self,
193        states: &mut [&mut [Complex64]],
194        gates: &[Box<dyn crate::gate::GateOp>],
195    ) -> QuantRS2Result<()> {
196        let batch_size = states.len();
197        let operation_key = format!("batch_{}_{}", batch_size, gates.len());
198        let variant = self.select_variant_for_operation(&operation_key, batch_size * 1000); // Estimate
199
200        let start_time = std::time::Instant::now();
201
202        let result = match variant {
203            SimdVariant::Avx512 => self.apply_batch_gates_avx512(states, gates),
204            SimdVariant::Avx2 => self.apply_batch_gates_avx2(states, gates),
205            SimdVariant::Sse4 => self.apply_batch_gates_sse4(states, gates),
206            SimdVariant::Scalar => self.apply_batch_gates_scalar(states, gates),
207        };
208
209        let execution_time = start_time.elapsed().as_nanos() as f64;
210        self.update_performance_cache(&operation_key, execution_time, variant);
211
212        result
213    }
214
215    /// Select the best SIMD variant for a specific operation
216    fn select_variant_for_operation(&self, operation_key: &str, data_size: usize) -> SimdVariant {
217        // Check performance cache first
218        if let Ok(cache) = self.performance_cache.lock() {
219            if let Some(perf_data) = cache.get(operation_key) {
220                if perf_data.samples >= 5 {
221                    return perf_data.best_variant;
222                }
223            }
224        }
225
226        // Heuristics based on data size and CPU features
227        if data_size >= 1024 && self.cpu_features.has_avx512 {
228            SimdVariant::Avx512
229        } else if data_size >= 256 && self.cpu_features.has_avx2 {
230            SimdVariant::Avx2
231        } else if data_size >= 64 && self.cpu_features.has_sse41 {
232            SimdVariant::Sse4
233        } else {
234            SimdVariant::Scalar
235        }
236    }
237
238    /// Update performance cache with execution time
239    fn update_performance_cache(
240        &self,
241        operation_key: &str,
242        execution_time: f64,
243        variant: SimdVariant,
244    ) {
245        if let Ok(mut cache) = self.performance_cache.lock() {
246            let perf_data =
247                cache
248                    .entry(operation_key.to_string())
249                    .or_insert_with(|| PerformanceData {
250                        avg_time: execution_time,
251                        samples: 0,
252                        best_variant: variant,
253                    });
254
255            // Update running average
256            perf_data.avg_time = perf_data
257                .avg_time
258                .mul_add(perf_data.samples as f64, execution_time)
259                / (perf_data.samples + 1) as f64;
260            perf_data.samples += 1;
261
262            // Update best variant if this one is significantly faster
263            if execution_time < perf_data.avg_time * 0.9 {
264                perf_data.best_variant = variant;
265            }
266        }
267    }
268
269    /// Get performance report
270    pub fn get_performance_report(&self) -> AdaptivePerformanceReport {
271        let cache = self
272            .performance_cache
273            .lock()
274            .map(|cache| cache.clone())
275            .unwrap_or_default();
276
277        AdaptivePerformanceReport {
278            cpu_features: self.cpu_features,
279            selected_variant: self.selected_variant,
280            performance_cache: cache,
281        }
282    }
283
284    // SIMD implementation methods (simplified placeholders)
285
286    #[cfg(target_arch = "x86_64")]
287    fn apply_single_qubit_avx512(
288        &self,
289        state: &mut [Complex64],
290        target: usize,
291        matrix: &[Complex64; 4],
292    ) -> QuantRS2Result<()> {
293        // AVX-512 implementation using SciRS2 SIMD operations
294        // SciRS2 will automatically use AVX-512 if available
295        self.apply_single_qubit_simd_unified(state, target, matrix)
296    }
297
298    #[cfg(target_arch = "x86_64")]
299    fn apply_single_qubit_avx2(
300        &self,
301        state: &mut [Complex64],
302        target: usize,
303        matrix: &[Complex64; 4],
304    ) -> QuantRS2Result<()> {
305        // AVX2 implementation using SciRS2 SIMD operations
306        // SciRS2 will automatically use AVX2 if available
307        self.apply_single_qubit_simd_unified(state, target, matrix)
308    }
309
310    fn apply_single_qubit_sse4(
311        &self,
312        state: &mut [Complex64],
313        target: usize,
314        matrix: &[Complex64; 4],
315    ) -> QuantRS2Result<()> {
316        // SSE4 implementation using SciRS2 SIMD operations
317        // SciRS2 will automatically use SSE4 if available
318        self.apply_single_qubit_simd_unified(state, target, matrix)
319    }
320
321    fn apply_single_qubit_scalar(
322        &self,
323        state: &mut [Complex64],
324        target: usize,
325        matrix: &[Complex64; 4],
326    ) -> QuantRS2Result<()> {
327        // Scalar implementation
328        let n = state.len();
329        for i in 0..n {
330            if (i >> target) & 1 == 0 {
331                let j = i | (1 << target);
332                let temp0 = state[i];
333                let temp1 = state[j];
334                state[i] = matrix[0] * temp0 + matrix[1] * temp1;
335                state[j] = matrix[2] * temp0 + matrix[3] * temp1;
336            }
337        }
338        Ok(())
339    }
340
341    /// Apply single-qubit gate using SciRS2 unified SIMD operations
342    fn apply_single_qubit_simd_unified(
343        &self,
344        state: &mut [Complex64],
345        target: usize,
346        matrix: &[Complex64; 4],
347    ) -> QuantRS2Result<()> {
348        let qubit_mask = 1 << target;
349        let half_size = state.len() / 2;
350
351        // Collect pairs of indices that need to be processed
352        let mut idx0_list = Vec::new();
353        let mut idx1_list = Vec::new();
354
355        for i in 0..half_size {
356            let idx0 = (i & !(qubit_mask >> 1)) | ((i & (qubit_mask >> 1)) << 1);
357            let idx1 = idx0 | qubit_mask;
358
359            if idx1 < state.len() {
360                idx0_list.push(idx0);
361                idx1_list.push(idx1);
362            }
363        }
364
365        let pair_count = idx0_list.len();
366        if pair_count == 0 {
367            return Ok(());
368        }
369
370        // Extract amplitude pairs for SIMD processing
371        let mut a0_real = Vec::with_capacity(pair_count);
372        let mut a0_imag = Vec::with_capacity(pair_count);
373        let mut a1_real = Vec::with_capacity(pair_count);
374        let mut a1_imag = Vec::with_capacity(pair_count);
375
376        for i in 0..pair_count {
377            let a0 = state[idx0_list[i]];
378            let a1 = state[idx1_list[i]];
379            a0_real.push(a0.re);
380            a0_imag.push(a0.im);
381            a1_real.push(a1.re);
382            a1_imag.push(a1.im);
383        }
384
385        // Convert to array views for SciRS2 SIMD operations
386        let a0_real_view = ArrayView1::from(&a0_real);
387        let a0_imag_view = ArrayView1::from(&a0_imag);
388        let a1_real_view = ArrayView1::from(&a1_real);
389        let a1_imag_view = ArrayView1::from(&a1_imag);
390
391        // Extract matrix elements
392        let m00_re = matrix[0].re;
393        let m00_im = matrix[0].im;
394        let m01_re = matrix[1].re;
395        let m01_im = matrix[1].im;
396        let m10_re = matrix[2].re;
397        let m10_im = matrix[2].im;
398        let m11_re = matrix[3].re;
399        let m11_im = matrix[3].im;
400
401        // Compute new amplitudes using SciRS2 SIMD operations
402        // new_a0 = m00 * a0 + m01 * a1
403        // new_a1 = m10 * a0 + m11 * a1
404
405        // For new_a0_real: m00_re * a0_re - m00_im * a0_im + m01_re * a1_re - m01_im * a1_im
406        let term1 = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, m00_re);
407        let term2 = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, m00_im);
408        let term3 = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, m01_re);
409        let term4 = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, m01_im);
410        let sub1 = <f64 as SimdF64>::simd_sub_arrays(&term1.view(), &term2.view());
411        let sub2 = <f64 as SimdF64>::simd_sub_arrays(&term3.view(), &term4.view());
412        let new_a0_real_arr = <f64 as SimdF64>::simd_add_arrays(&sub1.view(), &sub2.view());
413
414        // For new_a0_imag: m00_re * a0_im + m00_im * a0_re + m01_re * a1_im + m01_im * a1_re
415        let term5 = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, m00_re);
416        let term6 = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, m00_im);
417        let term7 = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, m01_re);
418        let term8 = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, m01_im);
419        let add1 = <f64 as SimdF64>::simd_add_arrays(&term5.view(), &term6.view());
420        let add2 = <f64 as SimdF64>::simd_add_arrays(&term7.view(), &term8.view());
421        let new_a0_imag_arr = <f64 as SimdF64>::simd_add_arrays(&add1.view(), &add2.view());
422
423        // For new_a1_real: m10_re * a0_re - m10_im * a0_im + m11_re * a1_re - m11_im * a1_im
424        let term9 = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, m10_re);
425        let term10 = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, m10_im);
426        let term11 = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, m11_re);
427        let term12 = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, m11_im);
428        let sub3 = <f64 as SimdF64>::simd_sub_arrays(&term9.view(), &term10.view());
429        let sub4 = <f64 as SimdF64>::simd_sub_arrays(&term11.view(), &term12.view());
430        let new_a1_real_arr = <f64 as SimdF64>::simd_add_arrays(&sub3.view(), &sub4.view());
431
432        // For new_a1_imag: m10_re * a0_im + m10_im * a0_re + m11_re * a1_im + m11_im * a1_re
433        let term13 = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, m10_re);
434        let term14 = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, m10_im);
435        let term15 = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, m11_re);
436        let term16 = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, m11_im);
437        let add3 = <f64 as SimdF64>::simd_add_arrays(&term13.view(), &term14.view());
438        let add4 = <f64 as SimdF64>::simd_add_arrays(&term15.view(), &term16.view());
439        let new_a1_imag_arr = <f64 as SimdF64>::simd_add_arrays(&add3.view(), &add4.view());
440
441        // Write back results
442        for i in 0..pair_count {
443            state[idx0_list[i]] = Complex64::new(new_a0_real_arr[i], new_a0_imag_arr[i]);
444            state[idx1_list[i]] = Complex64::new(new_a1_real_arr[i], new_a1_imag_arr[i]);
445        }
446
447        Ok(())
448    }
449
450    // Similar implementations for two-qubit gates and batch operations
451
452    const fn apply_two_qubit_avx512(
453        &self,
454        _state: &mut [Complex64],
455        _control: usize,
456        _target: usize,
457        _matrix: &[Complex64; 16],
458    ) -> QuantRS2Result<()> {
459        // Placeholder
460        Ok(())
461    }
462
463    const fn apply_two_qubit_avx2(
464        &self,
465        _state: &mut [Complex64],
466        _control: usize,
467        _target: usize,
468        _matrix: &[Complex64; 16],
469    ) -> QuantRS2Result<()> {
470        // Placeholder
471        Ok(())
472    }
473
474    const fn apply_two_qubit_sse4(
475        &self,
476        _state: &mut [Complex64],
477        _control: usize,
478        _target: usize,
479        _matrix: &[Complex64; 16],
480    ) -> QuantRS2Result<()> {
481        // Placeholder
482        Ok(())
483    }
484
485    const fn apply_two_qubit_scalar(
486        &self,
487        _state: &mut [Complex64],
488        _control: usize,
489        _target: usize,
490        _matrix: &[Complex64; 16],
491    ) -> QuantRS2Result<()> {
492        // Placeholder
493        Ok(())
494    }
495
496    fn apply_batch_gates_avx512(
497        &self,
498        _states: &mut [&mut [Complex64]],
499        _gates: &[Box<dyn crate::gate::GateOp>],
500    ) -> QuantRS2Result<()> {
501        // Placeholder
502        Ok(())
503    }
504
505    fn apply_batch_gates_avx2(
506        &self,
507        _states: &mut [&mut [Complex64]],
508        _gates: &[Box<dyn crate::gate::GateOp>],
509    ) -> QuantRS2Result<()> {
510        // Placeholder
511        Ok(())
512    }
513
514    fn apply_batch_gates_sse4(
515        &self,
516        _states: &mut [&mut [Complex64]],
517        _gates: &[Box<dyn crate::gate::GateOp>],
518    ) -> QuantRS2Result<()> {
519        // Placeholder
520        Ok(())
521    }
522
523    fn apply_batch_gates_scalar(
524        &self,
525        _states: &mut [&mut [Complex64]],
526        _gates: &[Box<dyn crate::gate::GateOp>],
527    ) -> QuantRS2Result<()> {
528        // Placeholder
529        Ok(())
530    }
531}
532
533/// Performance report for adaptive SIMD
534#[derive(Debug, Clone)]
535pub struct AdaptivePerformanceReport {
536    pub cpu_features: CpuFeatures,
537    pub selected_variant: SimdVariant,
538    pub performance_cache: std::collections::HashMap<String, PerformanceData>,
539}
540
541/// Convenience functions for adaptive SIMD operations
542pub fn apply_single_qubit_adaptive(
543    state: &mut [Complex64],
544    target: usize,
545    matrix: &[Complex64; 4],
546) -> QuantRS2Result<()> {
547    AdaptiveSimdDispatcher::instance()?.apply_single_qubit_gate_adaptive(state, target, matrix)
548}
549
550pub fn apply_two_qubit_adaptive(
551    state: &mut [Complex64],
552    control: usize,
553    target: usize,
554    matrix: &[Complex64; 16],
555) -> QuantRS2Result<()> {
556    AdaptiveSimdDispatcher::instance()?
557        .apply_two_qubit_gate_adaptive(state, control, target, matrix)
558}
559
560pub fn apply_batch_gates_adaptive(
561    states: &mut [&mut [Complex64]],
562    gates: &[Box<dyn crate::gate::GateOp>],
563) -> QuantRS2Result<()> {
564    AdaptiveSimdDispatcher::instance()?.apply_batch_gates_adaptive(states, gates)
565}
566
567/// Initialize the adaptive SIMD system
568pub fn initialize_adaptive_simd() -> QuantRS2Result<()> {
569    AdaptiveSimdDispatcher::initialize()
570}
571
572/// Get the performance report
573pub fn get_adaptive_performance_report() -> QuantRS2Result<AdaptivePerformanceReport> {
574    Ok(AdaptiveSimdDispatcher::instance()?.get_performance_report())
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580    use scirs2_core::Complex64;
581
582    #[test]
583    fn test_cpu_feature_detection() {
584        let features = AdaptiveSimdDispatcher::detect_cpu_features();
585        println!("Detected CPU features: {:?}", features);
586
587        // Basic sanity checks
588        assert!(features.num_cores >= 1);
589        assert!(features.l1_cache_size > 0);
590    }
591
592    #[test]
593    fn test_simd_variant_selection() {
594        let features = CpuFeatures {
595            has_avx2: true,
596            has_avx512: false,
597            has_fma: true,
598            has_avx512vl: false,
599            has_avx512dq: false,
600            has_avx512cd: false,
601            has_sse41: true,
602            has_sse42: true,
603            num_cores: 8,
604            l1_cache_size: 32768,
605            l2_cache_size: 262144,
606            l3_cache_size: 8388608,
607        };
608
609        let variant = AdaptiveSimdDispatcher::select_optimal_variant(&features);
610        assert_eq!(variant, SimdVariant::Avx2);
611    }
612
613    #[test]
614    fn test_adaptive_single_qubit_gate() {
615        let _ = AdaptiveSimdDispatcher::initialize();
616
617        let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
618
619        let hadamard_matrix = [
620            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
621            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
622            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
623            Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
624        ];
625
626        let result = apply_single_qubit_adaptive(&mut state, 0, &hadamard_matrix);
627        assert!(result.is_ok());
628
629        // Check that the state has been modified
630        let expected_amplitude = 1.0 / 2.0_f64.sqrt();
631        assert!((state[0].re - expected_amplitude).abs() < 1e-10);
632        assert!((state[1].re - expected_amplitude).abs() < 1e-10);
633    }
634
635    #[test]
636    fn test_performance_caching() {
637        let dispatcher = AdaptiveSimdDispatcher {
638            cpu_features: AdaptiveSimdDispatcher::detect_cpu_features(),
639            selected_variant: SimdVariant::Avx2,
640            performance_cache: Mutex::new(std::collections::HashMap::new()),
641        };
642
643        dispatcher.update_performance_cache("test_op", 100.0, SimdVariant::Avx2);
644        dispatcher.update_performance_cache("test_op", 150.0, SimdVariant::Avx2);
645
646        let perf_data = dispatcher
647            .performance_cache
648            .lock()
649            .unwrap_or_else(|e| e.into_inner())
650            .get("test_op")
651            .expect("Performance data for 'test_op' should exist after updates")
652            .clone();
653        assert_eq!(perf_data.samples, 2);
654        assert!((perf_data.avg_time - 125.0).abs() < 1e-10);
655    }
656}