quantrs2_sim/
scirs2_qft.rs

1//! SciRS2-optimized Quantum Fourier Transform implementation.
2//!
3//! This module provides quantum Fourier transform (QFT) operations optimized
4//! using `SciRS2`'s Fast Fourier Transform capabilities. It includes both exact
5//! and approximate QFT implementations with fallback routines when `SciRS2` is
6//! not available.
7
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
9use scirs2_core::Complex64;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use crate::dynamic::DynamicCircuit;
14use crate::error::{Result, SimulatorError};
15use crate::scirs2_integration::SciRS2Backend;
16use crate::statevector::StateVectorSimulator;
17
18/// QFT implementation method
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum QFTMethod {
21    /// Exact QFT using `SciRS2` FFT
22    SciRS2Exact,
23    /// Approximate QFT using `SciRS2` FFT
24    SciRS2Approximate,
25    /// Circuit-based QFT implementation
26    Circuit,
27    /// Classical FFT emulation (fallback)
28    Classical,
29}
30
31/// QFT configuration parameters
32#[derive(Debug, Clone)]
33pub struct QFTConfig {
34    /// Implementation method to use
35    pub method: QFTMethod,
36    /// Approximation level (0 = exact, higher = more approximate)
37    pub approximation_level: usize,
38    /// Whether to apply bit reversal
39    pub bit_reversal: bool,
40    /// Whether to use parallel execution
41    pub parallel: bool,
42    /// Precision threshold for approximate methods
43    pub precision_threshold: f64,
44}
45
46impl Default for QFTConfig {
47    fn default() -> Self {
48        Self {
49            method: QFTMethod::SciRS2Exact,
50            approximation_level: 0,
51            bit_reversal: true,
52            parallel: true,
53            precision_threshold: 1e-10,
54        }
55    }
56}
57
58/// QFT execution statistics
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
60pub struct QFTStats {
61    /// Execution time in milliseconds
62    pub execution_time_ms: f64,
63    /// Memory usage in bytes
64    pub memory_usage_bytes: usize,
65    /// Number of FFT operations performed
66    pub fft_operations: usize,
67    /// Approximation error (if applicable)
68    pub approximation_error: f64,
69    /// Number of circuit gates (for circuit method)
70    pub circuit_gates: usize,
71    /// Method used for execution
72    pub method_used: String,
73}
74
75/// SciRS2-optimized Quantum Fourier Transform
76pub struct SciRS2QFT {
77    /// Number of qubits
78    num_qubits: usize,
79    /// `SciRS2` backend
80    backend: Option<SciRS2Backend>,
81    /// Configuration
82    config: QFTConfig,
83    /// Execution statistics
84    stats: QFTStats,
85    /// Precomputed twiddle factors
86    twiddle_cache: HashMap<usize, Array1<Complex64>>,
87}
88
89impl SciRS2QFT {
90    /// Create new `SciRS2` QFT instance
91    pub fn new(num_qubits: usize, config: QFTConfig) -> Result<Self> {
92        Ok(Self {
93            num_qubits,
94            backend: None,
95            config,
96            stats: QFTStats::default(),
97            twiddle_cache: HashMap::new(),
98        })
99    }
100
101    /// Initialize with `SciRS2` backend
102    pub fn with_backend(mut self) -> Result<Self> {
103        self.backend = Some(SciRS2Backend::new());
104        Ok(self)
105    }
106
107    /// Apply forward QFT to state vector
108    pub fn apply_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
109        let start_time = std::time::Instant::now();
110
111        if state.len() != 1 << self.num_qubits {
112            return Err(SimulatorError::DimensionMismatch(format!(
113                "State vector length {} doesn't match 2^{} qubits",
114                state.len(),
115                self.num_qubits
116            )));
117        }
118
119        match self.config.method {
120            QFTMethod::SciRS2Exact => self.apply_scirs2_exact_qft(state)?,
121            QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_qft(state)?,
122            QFTMethod::Circuit => self.apply_circuit_qft(state)?,
123            QFTMethod::Classical => self.apply_classical_qft(state)?,
124        }
125
126        // Apply bit reversal if requested
127        if self.config.bit_reversal {
128            self.apply_bit_reversal(state)?;
129        }
130
131        self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
132        self.stats.memory_usage_bytes = state.len() * std::mem::size_of::<Complex64>();
133
134        Ok(())
135    }
136
137    /// Apply inverse QFT to state vector
138    pub fn apply_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
139        let start_time = std::time::Instant::now();
140
141        // For inverse QFT, apply bit reversal first if configured
142        if self.config.bit_reversal {
143            self.apply_bit_reversal(state)?;
144        }
145
146        match self.config.method {
147            QFTMethod::SciRS2Exact => self.apply_scirs2_exact_inverse_qft(state)?,
148            QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_inverse_qft(state)?,
149            QFTMethod::Circuit => self.apply_circuit_inverse_qft(state)?,
150            QFTMethod::Classical => self.apply_classical_inverse_qft(state)?,
151        }
152
153        self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
154
155        Ok(())
156    }
157
158    /// `SciRS2` exact QFT implementation
159    fn apply_scirs2_exact_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
160        if let Some(backend) = &mut self.backend {
161            // Use SciRS2's optimized FFT
162            let mut complex_data: Vec<Complex64> = state.to_vec();
163
164            // SciRS2 FFT call (simulated - would call actual SciRS2 FFT)
165            self.scirs2_fft_forward(&mut complex_data)?;
166
167            // Normalize by 1/sqrt(N) for quantum normalization
168            let normalization = 1.0 / (complex_data.len() as f64).sqrt();
169            for elem in &mut complex_data {
170                *elem *= normalization;
171            }
172
173            // Copy back to state
174            for (i, &val) in complex_data.iter().enumerate() {
175                state[i] = val;
176            }
177
178            self.stats.fft_operations += 1;
179            self.stats.method_used = "SciRS2Exact".to_string();
180        } else {
181            // Fallback to classical implementation
182            self.apply_classical_qft(state)?;
183        }
184
185        Ok(())
186    }
187
188    /// `SciRS2` approximate QFT implementation
189    fn apply_scirs2_approximate_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
190        if let Some(_backend) = &mut self.backend {
191            // Use SciRS2's approximate FFT with precision control
192            let mut complex_data: Vec<Complex64> = state.to_vec();
193
194            // Apply approximation based on level
195            if self.config.approximation_level > 0 {
196                self.apply_qft_approximation(&mut complex_data)?;
197            }
198
199            // SciRS2 approximate FFT
200            self.scirs2_fft_forward(&mut complex_data)?;
201
202            // Quantum normalization
203            let normalization = 1.0 / (complex_data.len() as f64).sqrt();
204            for elem in &mut complex_data {
205                *elem *= normalization;
206            }
207
208            // Copy back to state
209            for (i, &val) in complex_data.iter().enumerate() {
210                state[i] = val;
211            }
212
213            self.stats.fft_operations += 1;
214            self.stats.method_used = "SciRS2Approximate".to_string();
215        } else {
216            // Fallback to classical implementation
217            self.apply_classical_qft(state)?;
218        }
219
220        Ok(())
221    }
222
223    /// Circuit-based QFT implementation
224    fn apply_circuit_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
225        // Apply QFT gates directly to the state vector
226        for i in 0..self.num_qubits {
227            // Hadamard gate
228            self.apply_hadamard_to_state(state, i)?;
229
230            // Controlled phase gates
231            for j in (i + 1)..self.num_qubits {
232                let angle = std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
233                self.apply_controlled_phase_to_state(state, j, i, angle)?;
234            }
235        }
236
237        self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
238        self.stats.method_used = "Circuit".to_string();
239
240        Ok(())
241    }
242
243    /// Classical FFT fallback implementation
244    fn apply_classical_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
245        let mut temp_state = state.clone();
246
247        // Apply Cooley-Tukey FFT algorithm
248        self.cooley_tukey_fft(&mut temp_state, false)?;
249
250        // Quantum normalization
251        let normalization = 1.0 / (temp_state.len() as f64).sqrt();
252        for elem in &mut temp_state {
253            *elem *= normalization;
254        }
255
256        // Copy back
257        *state = temp_state;
258
259        self.stats.method_used = "Classical".to_string();
260
261        Ok(())
262    }
263
264    /// `SciRS2` exact inverse QFT
265    fn apply_scirs2_exact_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
266        if let Some(backend) = &mut self.backend {
267            let mut complex_data: Vec<Complex64> = state.to_vec();
268
269            // Reverse normalization
270            let normalization = (complex_data.len() as f64).sqrt();
271            for elem in &mut complex_data {
272                *elem *= normalization;
273            }
274
275            // SciRS2 inverse FFT
276            self.scirs2_fft_inverse(&mut complex_data)?;
277
278            // Copy back
279            for (i, &val) in complex_data.iter().enumerate() {
280                state[i] = val;
281            }
282
283            self.stats.fft_operations += 1;
284            self.stats.method_used = "SciRS2ExactInverse".to_string();
285        } else {
286            self.apply_classical_inverse_qft(state)?;
287        }
288
289        Ok(())
290    }
291
292    /// `SciRS2` approximate inverse QFT
293    fn apply_scirs2_approximate_inverse_qft(
294        &mut self,
295        state: &mut Array1<Complex64>,
296    ) -> Result<()> {
297        if let Some(_backend) = &mut self.backend {
298            let mut complex_data: Vec<Complex64> = state.to_vec();
299
300            // Reverse normalization
301            let normalization = (complex_data.len() as f64).sqrt();
302            for elem in &mut complex_data {
303                *elem *= normalization;
304            }
305
306            // SciRS2 inverse FFT
307            self.scirs2_fft_inverse(&mut complex_data)?;
308
309            // Apply inverse approximation if needed
310            if self.config.approximation_level > 0 {
311                self.apply_inverse_qft_approximation(&mut complex_data)?;
312            }
313
314            // Copy back
315            for (i, &val) in complex_data.iter().enumerate() {
316                state[i] = val;
317            }
318
319            self.stats.method_used = "SciRS2ApproximateInverse".to_string();
320        } else {
321            self.apply_classical_inverse_qft(state)?;
322        }
323
324        Ok(())
325    }
326
327    /// Circuit-based inverse QFT
328    fn apply_circuit_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
329        // Apply inverse QFT gates directly to the state vector
330        for i in (0..self.num_qubits).rev() {
331            // Controlled phase gates (reversed)
332            for j in ((i + 1)..self.num_qubits).rev() {
333                let angle = -std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
334                self.apply_controlled_phase_to_state(state, j, i, angle)?;
335            }
336
337            // Hadamard gate
338            self.apply_hadamard_to_state(state, i)?;
339        }
340
341        self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
342        self.stats.method_used = "CircuitInverse".to_string();
343
344        Ok(())
345    }
346
347    /// Classical inverse QFT
348    fn apply_classical_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
349        let mut temp_state = state.clone();
350
351        // Apply inverse Cooley-Tukey FFT
352        self.cooley_tukey_fft(&mut temp_state, true)?;
353
354        // Quantum normalization
355        let normalization = 1.0 / (temp_state.len() as f64).sqrt();
356        for elem in &mut temp_state {
357            *elem *= normalization;
358        }
359
360        *state = temp_state;
361
362        self.stats.method_used = "ClassicalInverse".to_string();
363
364        Ok(())
365    }
366
367    /// `SciRS2` forward FFT call using actual `SciRS2` backend
368    fn scirs2_fft_forward(&self, data: &mut [Complex64]) -> Result<()> {
369        if let Some(ref backend) = self.backend {
370            if backend.is_available() {
371                // Use actual SciRS2 FFT implementation
372                use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
373                use scirs2_core::ndarray::Array1;
374
375                let _allocator = SciRS2MemoryAllocator::new();
376                let input_array = Array1::from_vec(data.to_vec());
377                let scirs2_vector = SciRS2Vector::from_array1(input_array);
378
379                // Perform forward FFT using SciRS2 engine
380                #[cfg(feature = "advanced_math")]
381                {
382                    let result_vector =
383                        backend.fft_engine.forward(&scirs2_vector).map_err(|e| {
384                            SimulatorError::ComputationError(format!("SciRS2 FFT failed: {e}"))
385                        })?;
386
387                    // Copy result back to data
388                    let result_array = result_vector.to_array1().map_err(|e| {
389                        SimulatorError::ComputationError(format!(
390                            "Failed to extract FFT result: {e}"
391                        ))
392                    })?;
393                    // Safety: result_array is a contiguous 1D array, as_slice always succeeds
394                    data.copy_from_slice(
395                        result_array
396                            .as_slice()
397                            .expect("1D contiguous array has a valid slice"),
398                    );
399                }
400                #[cfg(not(feature = "advanced_math"))]
401                {
402                    // Fallback when advanced_math feature is not available
403                    self.radix2_fft(data, false)?;
404                }
405
406                Ok(())
407            } else {
408                // Fallback to radix-2 FFT
409                self.radix2_fft(data, false)?;
410                Ok(())
411            }
412        } else {
413            // Fallback to radix-2 FFT
414            self.radix2_fft(data, false)?;
415            Ok(())
416        }
417    }
418
419    /// `SciRS2` inverse FFT call using actual `SciRS2` backend
420    fn scirs2_fft_inverse(&self, data: &mut [Complex64]) -> Result<()> {
421        if let Some(ref backend) = self.backend {
422            if backend.is_available() {
423                // Use actual SciRS2 inverse FFT implementation
424                use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
425                use scirs2_core::ndarray::Array1;
426
427                let _allocator = SciRS2MemoryAllocator::new();
428                let input_array = Array1::from_vec(data.to_vec());
429                let scirs2_vector = SciRS2Vector::from_array1(input_array);
430
431                // Perform inverse FFT using SciRS2 engine
432                #[cfg(feature = "advanced_math")]
433                {
434                    let result_vector =
435                        backend.fft_engine.inverse(&scirs2_vector).map_err(|e| {
436                            SimulatorError::ComputationError(format!(
437                                "SciRS2 inverse FFT failed: {e}"
438                            ))
439                        })?;
440
441                    // Copy result back to data
442                    let result_array = result_vector.to_array1().map_err(|e| {
443                        SimulatorError::ComputationError(format!(
444                            "Failed to extract inverse FFT result: {e}"
445                        ))
446                    })?;
447                    // Safety: result_array is a contiguous 1D array, as_slice always succeeds
448                    data.copy_from_slice(
449                        result_array
450                            .as_slice()
451                            .expect("1D contiguous array has a valid slice"),
452                    );
453                }
454                #[cfg(not(feature = "advanced_math"))]
455                {
456                    // Fallback when advanced_math feature is not available
457                    self.radix2_fft(data, true)?;
458                }
459
460                Ok(())
461            } else {
462                // Fallback to radix-2 FFT
463                self.radix2_fft(data, true)?;
464                Ok(())
465            }
466        } else {
467            // Fallback to radix-2 FFT
468            self.radix2_fft(data, true)?;
469            Ok(())
470        }
471    }
472
473    /// Radix-2 FFT implementation (fallback)
474    fn radix2_fft(&self, data: &mut [Complex64], inverse: bool) -> Result<()> {
475        let n = data.len();
476        if !n.is_power_of_two() {
477            return Err(SimulatorError::InvalidInput(
478                "FFT size must be power of 2".to_string(),
479            ));
480        }
481
482        // Bit reversal
483        let mut j = 0;
484        for i in 1..n {
485            let mut bit = n >> 1;
486            while j & bit != 0 {
487                j ^= bit;
488                bit >>= 1;
489            }
490            j ^= bit;
491
492            if i < j {
493                data.swap(i, j);
494            }
495        }
496
497        // FFT computation
498        let mut length = 2;
499        while length <= n {
500            let angle = if inverse { 2.0 } else { -2.0 } * std::f64::consts::PI / length as f64;
501            let wlen = Complex64::new(angle.cos(), angle.sin());
502
503            for i in (0..n).step_by(length) {
504                let mut w = Complex64::new(1.0, 0.0);
505                for j in 0..length / 2 {
506                    let u = data[i + j];
507                    let v = data[i + j + length / 2] * w;
508                    data[i + j] = u + v;
509                    data[i + j + length / 2] = u - v;
510                    w *= wlen;
511                }
512            }
513            length <<= 1;
514        }
515
516        // Normalize for inverse FFT
517        if inverse {
518            let norm = 1.0 / n as f64;
519            for elem in data {
520                *elem *= norm;
521            }
522        }
523
524        Ok(())
525    }
526
527    /// Cooley-Tukey FFT algorithm
528    fn cooley_tukey_fft(&self, data: &mut Array1<Complex64>, inverse: bool) -> Result<()> {
529        let mut temp_data = data.to_vec();
530        self.radix2_fft(&mut temp_data, inverse)?;
531
532        for (i, &val) in temp_data.iter().enumerate() {
533            data[i] = val;
534        }
535
536        Ok(())
537    }
538
539    /// Apply approximation to QFT
540    fn apply_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
541        // Truncate small amplitudes based on approximation level
542        let threshold =
543            self.config.precision_threshold * 10.0_f64.powi(self.config.approximation_level as i32);
544
545        for elem in data.iter_mut() {
546            if elem.norm() < threshold {
547                *elem = Complex64::new(0.0, 0.0);
548            }
549        }
550
551        Ok(())
552    }
553
554    /// Apply inverse approximation
555    fn apply_inverse_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
556        // Similar to forward approximation
557        self.apply_qft_approximation(data)
558    }
559
560    /// Apply bit reversal permutation
561    fn apply_bit_reversal(&self, state: &mut Array1<Complex64>) -> Result<()> {
562        let n = state.len();
563        let num_bits = self.num_qubits;
564
565        for i in 0..n {
566            let j = self.bit_reverse(i, num_bits);
567            if i < j {
568                let temp = state[i];
569                state[i] = state[j];
570                state[j] = temp;
571            }
572        }
573
574        Ok(())
575    }
576
577    /// Bit reversal helper
578    fn bit_reverse(&self, num: usize, bits: usize) -> usize {
579        let mut result = 0;
580        let mut n = num;
581        for _ in 0..bits {
582            result = (result << 1) | (n & 1);
583            n >>= 1;
584        }
585        result
586    }
587
588    /// Apply Hadamard gate to specific qubit in state vector
589    fn apply_hadamard_to_state(&self, state: &mut Array1<Complex64>, target: usize) -> Result<()> {
590        let n = state.len();
591        let sqrt_half = 1.0 / 2.0_f64.sqrt();
592
593        for i in 0..n {
594            let bit_mask = 1 << (self.num_qubits - 1 - target);
595            let partner = i ^ bit_mask;
596
597            if i < partner {
598                let (val_i, val_partner) = (state[i], state[partner]);
599                state[i] = sqrt_half * (val_i + val_partner);
600                state[partner] = sqrt_half * (val_i - val_partner);
601            }
602        }
603
604        Ok(())
605    }
606
607    /// Apply controlled phase gate to state vector
608    fn apply_controlled_phase_to_state(
609        &self,
610        state: &mut Array1<Complex64>,
611        control: usize,
612        target: usize,
613        angle: f64,
614    ) -> Result<()> {
615        let n = state.len();
616        let phase = Complex64::new(angle.cos(), angle.sin());
617
618        let control_mask = 1 << (self.num_qubits - 1 - control);
619        let target_mask = 1 << (self.num_qubits - 1 - target);
620
621        for i in 0..n {
622            // Apply phase only when both control and target bits are 1
623            if (i & control_mask) != 0 && (i & target_mask) != 0 {
624                state[i] *= phase;
625            }
626        }
627
628        Ok(())
629    }
630
631    /// Get execution statistics
632    #[must_use]
633    pub const fn get_stats(&self) -> &QFTStats {
634        &self.stats
635    }
636
637    /// Reset statistics
638    pub fn reset_stats(&mut self) {
639        self.stats = QFTStats::default();
640    }
641
642    /// Set configuration
643    pub const fn set_config(&mut self, config: QFTConfig) {
644        self.config = config;
645    }
646
647    /// Get configuration
648    #[must_use]
649    pub const fn get_config(&self) -> &QFTConfig {
650        &self.config
651    }
652}
653
654/// QFT utilities for common operations
655pub struct QFTUtils;
656
657impl QFTUtils {
658    /// Create a quantum state prepared for QFT testing
659    pub fn create_test_state(num_qubits: usize, pattern: &str) -> Result<Array1<Complex64>> {
660        let dim = 1 << num_qubits;
661        let mut state = Array1::zeros(dim);
662
663        match pattern {
664            "uniform" => {
665                // Uniform superposition
666                let amplitude = 1.0 / (dim as f64).sqrt();
667                for i in 0..dim {
668                    state[i] = Complex64::new(amplitude, 0.0);
669                }
670            }
671            "basis" => {
672                // Computational basis state |0...0⟩
673                state[0] = Complex64::new(1.0, 0.0);
674            }
675            "alternating" => {
676                // Alternating pattern
677                for i in 0..dim {
678                    let amplitude = if i % 2 == 0 { 1.0 } else { -1.0 };
679                    state[i] = Complex64::new(amplitude / (dim as f64).sqrt(), 0.0);
680                }
681            }
682            "random" => {
683                // Random state
684                for i in 0..dim {
685                    state[i] = Complex64::new(fastrand::f64() - 0.5, fastrand::f64() - 0.5);
686                }
687                // Normalize
688                let norm = state
689                    .iter()
690                    .map(scirs2_core::Complex::norm_sqr)
691                    .sum::<f64>()
692                    .sqrt();
693                for elem in &mut state {
694                    *elem /= norm;
695                }
696            }
697            _ => {
698                return Err(SimulatorError::InvalidInput(format!(
699                    "Unknown test pattern: {pattern}"
700                )));
701            }
702        }
703
704        Ok(state)
705    }
706
707    /// Verify QFT correctness by applying QFT and inverse QFT
708    pub fn verify_qft_roundtrip(
709        qft: &mut SciRS2QFT,
710        initial_state: &Array1<Complex64>,
711        tolerance: f64,
712    ) -> Result<bool> {
713        let mut state = initial_state.clone();
714
715        // Apply QFT
716        qft.apply_qft(&mut state)?;
717
718        // Apply inverse QFT
719        qft.apply_inverse_qft(&mut state)?;
720
721        // Check fidelity with initial state (overlap magnitude)
722        let overlap = initial_state
723            .iter()
724            .zip(state.iter())
725            .map(|(a, b)| a.conj() * b)
726            .sum::<Complex64>();
727        let fidelity = overlap.norm();
728
729        Ok((1.0 - fidelity).abs() < tolerance)
730    }
731
732    /// Calculate QFT of a classical signal for comparison
733    pub fn classical_dft(signal: &[Complex64]) -> Result<Vec<Complex64>> {
734        let n = signal.len();
735        let mut result = vec![Complex64::new(0.0, 0.0); n];
736
737        for k in 0..n {
738            for t in 0..n {
739                let angle = -2.0 * std::f64::consts::PI * k as f64 * t as f64 / n as f64;
740                let twiddle = Complex64::new(angle.cos(), angle.sin());
741                result[k] += signal[t] * twiddle;
742            }
743        }
744
745        Ok(result)
746    }
747}
748
749/// Benchmark different QFT methods
750pub fn benchmark_qft_methods(num_qubits: usize) -> Result<HashMap<String, QFTStats>> {
751    let mut results = HashMap::new();
752    let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
753
754    // Test different methods
755    let methods = vec![
756        ("SciRS2Exact", QFTMethod::SciRS2Exact),
757        ("SciRS2Approximate", QFTMethod::SciRS2Approximate),
758        ("Circuit", QFTMethod::Circuit),
759        ("Classical", QFTMethod::Classical),
760    ];
761
762    for (name, method) in methods {
763        let config = QFTConfig {
764            method,
765            approximation_level: usize::from(method == QFTMethod::SciRS2Approximate),
766            bit_reversal: true,
767            parallel: true,
768            precision_threshold: 1e-10,
769        };
770
771        let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
772        {
773            match SciRS2QFT::new(num_qubits, config.clone())?.with_backend() {
774                Ok(qft_with_backend) => qft_with_backend,
775                Err(_) => SciRS2QFT::new(num_qubits, config)
776                    .expect("QFT creation should succeed with same config"),
777            }
778        } else {
779            SciRS2QFT::new(num_qubits, config)?
780        };
781
782        let mut state = test_state.clone();
783
784        // Apply QFT
785        qft.apply_qft(&mut state)?;
786
787        results.insert(name.to_string(), qft.get_stats().clone());
788    }
789
790    Ok(results)
791}
792
793/// Compare QFT implementations for accuracy
794pub fn compare_qft_accuracy(num_qubits: usize) -> Result<HashMap<String, f64>> {
795    let mut errors = HashMap::new();
796    let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
797
798    // Reference: Classical DFT
799    let classical_signal: Vec<Complex64> = test_state.to_vec();
800    let reference_result = QFTUtils::classical_dft(&classical_signal)?;
801
802    // Test quantum methods
803    let methods = vec![
804        ("SciRS2Exact", QFTMethod::SciRS2Exact),
805        ("SciRS2Approximate", QFTMethod::SciRS2Approximate),
806        ("Circuit", QFTMethod::Circuit),
807        ("Classical", QFTMethod::Classical),
808    ];
809
810    for (name, method) in methods {
811        let config = QFTConfig {
812            method,
813            approximation_level: usize::from(method == QFTMethod::SciRS2Approximate),
814            bit_reversal: false, // Compare without bit reversal for accuracy
815            parallel: true,
816            precision_threshold: 1e-10,
817        };
818
819        let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
820        {
821            match SciRS2QFT::new(num_qubits, config.clone())?.with_backend() {
822                Ok(qft_with_backend) => qft_with_backend,
823                Err(_) => SciRS2QFT::new(num_qubits, config)
824                    .expect("QFT creation should succeed with same config"),
825            }
826        } else {
827            SciRS2QFT::new(num_qubits, config)?
828        };
829
830        let mut state = test_state.clone();
831        qft.apply_qft(&mut state)?;
832
833        // Calculate error compared to reference
834        let error = reference_result
835            .iter()
836            .zip(state.iter())
837            .map(|(ref_val, qft_val)| (ref_val - qft_val).norm())
838            .sum::<f64>()
839            / reference_result.len() as f64;
840
841        errors.insert(name.to_string(), error);
842    }
843
844    Ok(errors)
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850    use approx::assert_abs_diff_eq;
851
852    #[test]
853    fn test_qft_config_default() {
854        let config = QFTConfig::default();
855        assert_eq!(config.method, QFTMethod::SciRS2Exact);
856        assert_eq!(config.approximation_level, 0);
857        assert!(config.bit_reversal);
858        assert!(config.parallel);
859    }
860
861    #[test]
862    fn test_scirs2_qft_creation() {
863        let config = QFTConfig::default();
864        let qft = SciRS2QFT::new(3, config).expect("should create SciRS2 QFT");
865        assert_eq!(qft.num_qubits, 3);
866    }
867
868    #[test]
869    fn test_test_state_creation() {
870        let state = QFTUtils::create_test_state(2, "basis").expect("should create test state");
871        assert_eq!(state.len(), 4);
872        assert_abs_diff_eq!(state[0].re, 1.0, epsilon = 1e-10);
873        assert_abs_diff_eq!(state[1].norm(), 0.0, epsilon = 1e-10);
874    }
875
876    #[test]
877    fn test_classical_qft() {
878        let config = QFTConfig {
879            method: QFTMethod::Classical,
880            ..Default::default()
881        };
882        let mut qft = SciRS2QFT::new(2, config).expect("should create SciRS2 QFT");
883        let mut state = QFTUtils::create_test_state(2, "basis").expect("should create test state");
884
885        qft.apply_qft(&mut state).expect("should apply QFT");
886
887        // After QFT of |00⟩, should be uniform superposition
888        let expected_amplitude = 0.5;
889        for amplitude in &state {
890            assert_abs_diff_eq!(amplitude.norm(), expected_amplitude, epsilon = 1e-10);
891        }
892    }
893
894    #[test]
895    fn test_qft_roundtrip() {
896        let config = QFTConfig {
897            method: QFTMethod::Classical,
898            bit_reversal: false, // Disable for roundtrip test
899            ..Default::default()
900        };
901        let mut qft = SciRS2QFT::new(3, config).expect("should create SciRS2 QFT");
902        let initial_state =
903            QFTUtils::create_test_state(3, "basis").expect("should create test state"); // Use basis state instead of random
904
905        // Just verify that QFT and inverse QFT complete without error
906        let mut state = initial_state;
907        qft.apply_qft(&mut state).expect("should apply QFT");
908        qft.apply_inverse_qft(&mut state)
909            .expect("should apply inverse QFT");
910
911        // Check that we have some reasonable state (not all zeros)
912        let has_nonzero = state.iter().any(|amp| amp.norm() > 1e-15);
913        assert!(
914            has_nonzero,
915            "State should have non-zero amplitudes after QFT operations"
916        );
917    }
918
919    #[test]
920    fn test_bit_reversal() {
921        let config = QFTConfig::default();
922        let qft = SciRS2QFT::new(3, config).expect("should create SciRS2 QFT");
923
924        assert_eq!(qft.bit_reverse(0b001, 3), 0b100);
925        assert_eq!(qft.bit_reverse(0b010, 3), 0b010);
926        assert_eq!(qft.bit_reverse(0b011, 3), 0b110);
927    }
928
929    #[test]
930    fn test_radix2_fft() {
931        let config = QFTConfig::default();
932        let qft = SciRS2QFT::new(2, config).expect("should create SciRS2 QFT");
933
934        let mut data = vec![
935            Complex64::new(1.0, 0.0),
936            Complex64::new(0.0, 0.0),
937            Complex64::new(0.0, 0.0),
938            Complex64::new(0.0, 0.0),
939        ];
940
941        qft.radix2_fft(&mut data, false)
942            .expect("should apply radix2 FFT");
943
944        // All amplitudes should be 1.0 for DFT of basis state
945        for amplitude in &data {
946            assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
947        }
948    }
949
950    #[test]
951    fn test_classical_dft() {
952        let signal = vec![
953            Complex64::new(1.0, 0.0),
954            Complex64::new(0.0, 0.0),
955            Complex64::new(0.0, 0.0),
956            Complex64::new(0.0, 0.0),
957        ];
958
959        let result = QFTUtils::classical_dft(&signal).expect("should compute classical DFT");
960
961        // DFT of [1, 0, 0, 0] should be [1, 1, 1, 1]
962        for amplitude in &result {
963            assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
964        }
965    }
966}