quantrs2_sim/
scirs2_qft.rs

1//! SciRS2-optimized Quantum Fourier Transform implementation.
2//!
3//! This module provides quantum Fourier transform (QFT) operations optimized
4//! using SciRS2's Fast Fourier Transform capabilities. It includes both exact
5//! and approximate QFT implementations with fallback routines when SciRS2 is
6//! not available.
7
8use ndarray::{Array1, Array2, ArrayView1, Axis};
9use num_complex::Complex64;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use crate::dynamic::DynamicCircuit;
14use crate::error::{Result, SimulatorError};
15use crate::scirs2_integration::SciRS2Backend;
16use crate::statevector::StateVectorSimulator;
17
18/// QFT implementation method
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum QFTMethod {
21    /// Exact QFT using SciRS2 FFT
22    SciRS2Exact,
23    /// Approximate QFT using SciRS2 FFT
24    SciRS2Approximate,
25    /// Circuit-based QFT implementation
26    Circuit,
27    /// Classical FFT emulation (fallback)
28    Classical,
29}
30
31/// QFT configuration parameters
32#[derive(Debug, Clone)]
33pub struct QFTConfig {
34    /// Implementation method to use
35    pub method: QFTMethod,
36    /// Approximation level (0 = exact, higher = more approximate)
37    pub approximation_level: usize,
38    /// Whether to apply bit reversal
39    pub bit_reversal: bool,
40    /// Whether to use parallel execution
41    pub parallel: bool,
42    /// Precision threshold for approximate methods
43    pub precision_threshold: f64,
44}
45
46impl Default for QFTConfig {
47    fn default() -> Self {
48        Self {
49            method: QFTMethod::SciRS2Exact,
50            approximation_level: 0,
51            bit_reversal: true,
52            parallel: true,
53            precision_threshold: 1e-10,
54        }
55    }
56}
57
58/// QFT execution statistics
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
60pub struct QFTStats {
61    /// Execution time in milliseconds
62    pub execution_time_ms: f64,
63    /// Memory usage in bytes
64    pub memory_usage_bytes: usize,
65    /// Number of FFT operations performed
66    pub fft_operations: usize,
67    /// Approximation error (if applicable)
68    pub approximation_error: f64,
69    /// Number of circuit gates (for circuit method)
70    pub circuit_gates: usize,
71    /// Method used for execution
72    pub method_used: String,
73}
74
75/// SciRS2-optimized Quantum Fourier Transform
76pub struct SciRS2QFT {
77    /// Number of qubits
78    num_qubits: usize,
79    /// SciRS2 backend
80    backend: Option<SciRS2Backend>,
81    /// Configuration
82    config: QFTConfig,
83    /// Execution statistics
84    stats: QFTStats,
85    /// Precomputed twiddle factors
86    twiddle_cache: HashMap<usize, Array1<Complex64>>,
87}
88
89impl SciRS2QFT {
90    /// Create new SciRS2 QFT instance
91    pub fn new(num_qubits: usize, config: QFTConfig) -> Result<Self> {
92        Ok(Self {
93            num_qubits,
94            backend: None,
95            config,
96            stats: QFTStats::default(),
97            twiddle_cache: HashMap::new(),
98        })
99    }
100
101    /// Initialize with SciRS2 backend
102    pub fn with_backend(mut self) -> Result<Self> {
103        self.backend = Some(SciRS2Backend::new());
104        Ok(self)
105    }
106
107    /// Apply forward QFT to state vector
108    pub fn apply_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
109        let start_time = std::time::Instant::now();
110
111        if state.len() != 1 << self.num_qubits {
112            return Err(SimulatorError::DimensionMismatch(format!(
113                "State vector length {} doesn't match 2^{} qubits",
114                state.len(),
115                self.num_qubits
116            )));
117        }
118
119        match self.config.method {
120            QFTMethod::SciRS2Exact => self.apply_scirs2_exact_qft(state)?,
121            QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_qft(state)?,
122            QFTMethod::Circuit => self.apply_circuit_qft(state)?,
123            QFTMethod::Classical => self.apply_classical_qft(state)?,
124        }
125
126        // Apply bit reversal if requested
127        if self.config.bit_reversal {
128            self.apply_bit_reversal(state)?;
129        }
130
131        self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
132        self.stats.memory_usage_bytes = state.len() * std::mem::size_of::<Complex64>();
133
134        Ok(())
135    }
136
137    /// Apply inverse QFT to state vector
138    pub fn apply_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
139        let start_time = std::time::Instant::now();
140
141        // For inverse QFT, apply bit reversal first if configured
142        if self.config.bit_reversal {
143            self.apply_bit_reversal(state)?;
144        }
145
146        match self.config.method {
147            QFTMethod::SciRS2Exact => self.apply_scirs2_exact_inverse_qft(state)?,
148            QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_inverse_qft(state)?,
149            QFTMethod::Circuit => self.apply_circuit_inverse_qft(state)?,
150            QFTMethod::Classical => self.apply_classical_inverse_qft(state)?,
151        }
152
153        self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
154
155        Ok(())
156    }
157
158    /// SciRS2 exact QFT implementation
159    fn apply_scirs2_exact_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
160        if let Some(backend) = &mut self.backend {
161            // Use SciRS2's optimized FFT
162            let mut complex_data: Vec<Complex64> = state.to_vec();
163
164            // SciRS2 FFT call (simulated - would call actual SciRS2 FFT)
165            self.scirs2_fft_forward(&mut complex_data)?;
166
167            // Normalize by 1/sqrt(N) for quantum normalization
168            let normalization = 1.0 / (complex_data.len() as f64).sqrt();
169            for elem in &mut complex_data {
170                *elem *= normalization;
171            }
172
173            // Copy back to state
174            for (i, &val) in complex_data.iter().enumerate() {
175                state[i] = val;
176            }
177
178            self.stats.fft_operations += 1;
179            self.stats.method_used = "SciRS2Exact".to_string();
180        } else {
181            // Fallback to classical implementation
182            self.apply_classical_qft(state)?;
183        }
184
185        Ok(())
186    }
187
188    /// SciRS2 approximate QFT implementation
189    fn apply_scirs2_approximate_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
190        if let Some(_backend) = &mut self.backend {
191            // Use SciRS2's approximate FFT with precision control
192            let mut complex_data: Vec<Complex64> = state.to_vec();
193
194            // Apply approximation based on level
195            if self.config.approximation_level > 0 {
196                self.apply_qft_approximation(&mut complex_data)?;
197            }
198
199            // SciRS2 approximate FFT
200            self.scirs2_fft_forward(&mut complex_data)?;
201
202            // Quantum normalization
203            let normalization = 1.0 / (complex_data.len() as f64).sqrt();
204            for elem in &mut complex_data {
205                *elem *= normalization;
206            }
207
208            // Copy back to state
209            for (i, &val) in complex_data.iter().enumerate() {
210                state[i] = val;
211            }
212
213            self.stats.fft_operations += 1;
214            self.stats.method_used = "SciRS2Approximate".to_string();
215        } else {
216            // Fallback to classical implementation
217            self.apply_classical_qft(state)?;
218        }
219
220        Ok(())
221    }
222
223    /// Circuit-based QFT implementation
224    fn apply_circuit_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
225        // Apply QFT gates directly to the state vector
226        for i in 0..self.num_qubits {
227            // Hadamard gate
228            self.apply_hadamard_to_state(state, i)?;
229
230            // Controlled phase gates
231            for j in (i + 1)..self.num_qubits {
232                let angle = std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
233                self.apply_controlled_phase_to_state(state, j, i, angle)?;
234            }
235        }
236
237        self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
238        self.stats.method_used = "Circuit".to_string();
239
240        Ok(())
241    }
242
243    /// Classical FFT fallback implementation
244    fn apply_classical_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
245        let mut temp_state = state.clone();
246
247        // Apply Cooley-Tukey FFT algorithm
248        self.cooley_tukey_fft(&mut temp_state, false)?;
249
250        // Quantum normalization
251        let normalization = 1.0 / (temp_state.len() as f64).sqrt();
252        for elem in &mut temp_state {
253            *elem *= normalization;
254        }
255
256        // Copy back
257        *state = temp_state;
258
259        self.stats.method_used = "Classical".to_string();
260
261        Ok(())
262    }
263
264    /// SciRS2 exact inverse QFT
265    fn apply_scirs2_exact_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
266        if let Some(backend) = &mut self.backend {
267            let mut complex_data: Vec<Complex64> = state.to_vec();
268
269            // Reverse normalization
270            let normalization = (complex_data.len() as f64).sqrt();
271            for elem in &mut complex_data {
272                *elem *= normalization;
273            }
274
275            // SciRS2 inverse FFT
276            self.scirs2_fft_inverse(&mut complex_data)?;
277
278            // Copy back
279            for (i, &val) in complex_data.iter().enumerate() {
280                state[i] = val;
281            }
282
283            self.stats.fft_operations += 1;
284            self.stats.method_used = "SciRS2ExactInverse".to_string();
285        } else {
286            self.apply_classical_inverse_qft(state)?;
287        }
288
289        Ok(())
290    }
291
292    /// SciRS2 approximate inverse QFT
293    fn apply_scirs2_approximate_inverse_qft(
294        &mut self,
295        state: &mut Array1<Complex64>,
296    ) -> Result<()> {
297        if let Some(_backend) = &mut self.backend {
298            let mut complex_data: Vec<Complex64> = state.to_vec();
299
300            // Reverse normalization
301            let normalization = (complex_data.len() as f64).sqrt();
302            for elem in &mut complex_data {
303                *elem *= normalization;
304            }
305
306            // SciRS2 inverse FFT
307            self.scirs2_fft_inverse(&mut complex_data)?;
308
309            // Apply inverse approximation if needed
310            if self.config.approximation_level > 0 {
311                self.apply_inverse_qft_approximation(&mut complex_data)?;
312            }
313
314            // Copy back
315            for (i, &val) in complex_data.iter().enumerate() {
316                state[i] = val;
317            }
318
319            self.stats.method_used = "SciRS2ApproximateInverse".to_string();
320        } else {
321            self.apply_classical_inverse_qft(state)?;
322        }
323
324        Ok(())
325    }
326
327    /// Circuit-based inverse QFT
328    fn apply_circuit_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
329        // Apply inverse QFT gates directly to the state vector
330        for i in (0..self.num_qubits).rev() {
331            // Controlled phase gates (reversed)
332            for j in ((i + 1)..self.num_qubits).rev() {
333                let angle = -std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
334                self.apply_controlled_phase_to_state(state, j, i, angle)?;
335            }
336
337            // Hadamard gate
338            self.apply_hadamard_to_state(state, i)?;
339        }
340
341        self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
342        self.stats.method_used = "CircuitInverse".to_string();
343
344        Ok(())
345    }
346
347    /// Classical inverse QFT
348    fn apply_classical_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
349        let mut temp_state = state.clone();
350
351        // Apply inverse Cooley-Tukey FFT
352        self.cooley_tukey_fft(&mut temp_state, true)?;
353
354        // Quantum normalization
355        let normalization = 1.0 / (temp_state.len() as f64).sqrt();
356        for elem in &mut temp_state {
357            *elem *= normalization;
358        }
359
360        *state = temp_state;
361
362        self.stats.method_used = "ClassicalInverse".to_string();
363
364        Ok(())
365    }
366
367    /// SciRS2 forward FFT call using actual SciRS2 backend
368    fn scirs2_fft_forward(&self, data: &mut [Complex64]) -> Result<()> {
369        if let Some(ref backend) = self.backend {
370            if backend.is_available() {
371                // Use actual SciRS2 FFT implementation
372                use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
373                use ndarray::Array1;
374
375                let _allocator = SciRS2MemoryAllocator::new();
376                let input_array = Array1::from_vec(data.to_vec());
377                let scirs2_vector = SciRS2Vector::from_array1(input_array);
378
379                // Perform forward FFT using SciRS2 engine
380                #[cfg(feature = "advanced_math")]
381                {
382                    let result_vector =
383                        backend.fft_engine.forward(&scirs2_vector).map_err(|e| {
384                            SimulatorError::ComputationError(format!("SciRS2 FFT failed: {}", e))
385                        })?;
386
387                    // Copy result back to data
388                    let result_array = result_vector.to_array1().map_err(|e| {
389                        SimulatorError::ComputationError(format!(
390                            "Failed to extract FFT result: {}",
391                            e
392                        ))
393                    })?;
394                    data.copy_from_slice(result_array.as_slice().unwrap());
395                }
396                #[cfg(not(feature = "advanced_math"))]
397                {
398                    // Fallback when advanced_math feature is not available
399                    self.radix2_fft(data, false)?;
400                }
401
402                Ok(())
403            } else {
404                // Fallback to radix-2 FFT
405                self.radix2_fft(data, false)?;
406                Ok(())
407            }
408        } else {
409            // Fallback to radix-2 FFT
410            self.radix2_fft(data, false)?;
411            Ok(())
412        }
413    }
414
415    /// SciRS2 inverse FFT call using actual SciRS2 backend
416    fn scirs2_fft_inverse(&self, data: &mut [Complex64]) -> Result<()> {
417        if let Some(ref backend) = self.backend {
418            if backend.is_available() {
419                // Use actual SciRS2 inverse FFT implementation
420                use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
421                use ndarray::Array1;
422
423                let _allocator = SciRS2MemoryAllocator::new();
424                let input_array = Array1::from_vec(data.to_vec());
425                let scirs2_vector = SciRS2Vector::from_array1(input_array);
426
427                // Perform inverse FFT using SciRS2 engine
428                #[cfg(feature = "advanced_math")]
429                {
430                    let result_vector =
431                        backend.fft_engine.inverse(&scirs2_vector).map_err(|e| {
432                            SimulatorError::ComputationError(format!(
433                                "SciRS2 inverse FFT failed: {}",
434                                e
435                            ))
436                        })?;
437
438                    // Copy result back to data
439                    let result_array = result_vector.to_array1().map_err(|e| {
440                        SimulatorError::ComputationError(format!(
441                            "Failed to extract inverse FFT result: {}",
442                            e
443                        ))
444                    })?;
445                    data.copy_from_slice(result_array.as_slice().unwrap());
446                }
447                #[cfg(not(feature = "advanced_math"))]
448                {
449                    // Fallback when advanced_math feature is not available
450                    self.radix2_fft(data, true)?;
451                }
452
453                Ok(())
454            } else {
455                // Fallback to radix-2 FFT
456                self.radix2_fft(data, true)?;
457                Ok(())
458            }
459        } else {
460            // Fallback to radix-2 FFT
461            self.radix2_fft(data, true)?;
462            Ok(())
463        }
464    }
465
466    /// Radix-2 FFT implementation (fallback)
467    fn radix2_fft(&self, data: &mut [Complex64], inverse: bool) -> Result<()> {
468        let n = data.len();
469        if !n.is_power_of_two() {
470            return Err(SimulatorError::InvalidInput(
471                "FFT size must be power of 2".to_string(),
472            ));
473        }
474
475        // Bit reversal
476        let mut j = 0;
477        for i in 1..n {
478            let mut bit = n >> 1;
479            while j & bit != 0 {
480                j ^= bit;
481                bit >>= 1;
482            }
483            j ^= bit;
484
485            if i < j {
486                data.swap(i, j);
487            }
488        }
489
490        // FFT computation
491        let mut length = 2;
492        while length <= n {
493            let angle = if inverse { 2.0 } else { -2.0 } * std::f64::consts::PI / length as f64;
494            let wlen = Complex64::new(angle.cos(), angle.sin());
495
496            for i in (0..n).step_by(length) {
497                let mut w = Complex64::new(1.0, 0.0);
498                for j in 0..length / 2 {
499                    let u = data[i + j];
500                    let v = data[i + j + length / 2] * w;
501                    data[i + j] = u + v;
502                    data[i + j + length / 2] = u - v;
503                    w *= wlen;
504                }
505            }
506            length <<= 1;
507        }
508
509        // Normalize for inverse FFT
510        if inverse {
511            let norm = 1.0 / n as f64;
512            for elem in data {
513                *elem *= norm;
514            }
515        }
516
517        Ok(())
518    }
519
520    /// Cooley-Tukey FFT algorithm
521    fn cooley_tukey_fft(&self, data: &mut Array1<Complex64>, inverse: bool) -> Result<()> {
522        let mut temp_data = data.to_vec();
523        self.radix2_fft(&mut temp_data, inverse)?;
524
525        for (i, &val) in temp_data.iter().enumerate() {
526            data[i] = val;
527        }
528
529        Ok(())
530    }
531
532    /// Apply approximation to QFT
533    fn apply_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
534        // Truncate small amplitudes based on approximation level
535        let threshold =
536            self.config.precision_threshold * 10.0_f64.powi(self.config.approximation_level as i32);
537
538        for elem in data.iter_mut() {
539            if elem.norm() < threshold {
540                *elem = Complex64::new(0.0, 0.0);
541            }
542        }
543
544        Ok(())
545    }
546
547    /// Apply inverse approximation
548    fn apply_inverse_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
549        // Similar to forward approximation
550        self.apply_qft_approximation(data)
551    }
552
553    /// Apply bit reversal permutation
554    fn apply_bit_reversal(&self, state: &mut Array1<Complex64>) -> Result<()> {
555        let n = state.len();
556        let num_bits = self.num_qubits;
557
558        for i in 0..n {
559            let j = self.bit_reverse(i, num_bits);
560            if i < j {
561                let temp = state[i];
562                state[i] = state[j];
563                state[j] = temp;
564            }
565        }
566
567        Ok(())
568    }
569
570    /// Bit reversal helper
571    fn bit_reverse(&self, num: usize, bits: usize) -> usize {
572        let mut result = 0;
573        let mut n = num;
574        for _ in 0..bits {
575            result = (result << 1) | (n & 1);
576            n >>= 1;
577        }
578        result
579    }
580
581    /// Apply Hadamard gate to specific qubit in state vector
582    fn apply_hadamard_to_state(&self, state: &mut Array1<Complex64>, target: usize) -> Result<()> {
583        let n = state.len();
584        let sqrt_half = 1.0 / 2.0_f64.sqrt();
585
586        for i in 0..n {
587            let bit_mask = 1 << (self.num_qubits - 1 - target);
588            let partner = i ^ bit_mask;
589
590            if i < partner {
591                let (val_i, val_partner) = (state[i], state[partner]);
592                state[i] = sqrt_half * (val_i + val_partner);
593                state[partner] = sqrt_half * (val_i - val_partner);
594            }
595        }
596
597        Ok(())
598    }
599
600    /// Apply controlled phase gate to state vector
601    fn apply_controlled_phase_to_state(
602        &self,
603        state: &mut Array1<Complex64>,
604        control: usize,
605        target: usize,
606        angle: f64,
607    ) -> Result<()> {
608        let n = state.len();
609        let phase = Complex64::new(angle.cos(), angle.sin());
610
611        let control_mask = 1 << (self.num_qubits - 1 - control);
612        let target_mask = 1 << (self.num_qubits - 1 - target);
613
614        for i in 0..n {
615            // Apply phase only when both control and target bits are 1
616            if (i & control_mask) != 0 && (i & target_mask) != 0 {
617                state[i] *= phase;
618            }
619        }
620
621        Ok(())
622    }
623
624    /// Get execution statistics
625    pub fn get_stats(&self) -> &QFTStats {
626        &self.stats
627    }
628
629    /// Reset statistics
630    pub fn reset_stats(&mut self) {
631        self.stats = QFTStats::default();
632    }
633
634    /// Set configuration
635    pub fn set_config(&mut self, config: QFTConfig) {
636        self.config = config;
637    }
638
639    /// Get configuration
640    pub fn get_config(&self) -> &QFTConfig {
641        &self.config
642    }
643}
644
645/// QFT utilities for common operations
646pub struct QFTUtils;
647
648impl QFTUtils {
649    /// Create a quantum state prepared for QFT testing
650    pub fn create_test_state(num_qubits: usize, pattern: &str) -> Result<Array1<Complex64>> {
651        let dim = 1 << num_qubits;
652        let mut state = Array1::zeros(dim);
653
654        match pattern {
655            "uniform" => {
656                // Uniform superposition
657                let amplitude = 1.0 / (dim as f64).sqrt();
658                for i in 0..dim {
659                    state[i] = Complex64::new(amplitude, 0.0);
660                }
661            }
662            "basis" => {
663                // Computational basis state |0...0⟩
664                state[0] = Complex64::new(1.0, 0.0);
665            }
666            "alternating" => {
667                // Alternating pattern
668                for i in 0..dim {
669                    let amplitude = if i % 2 == 0 { 1.0 } else { -1.0 };
670                    state[i] = Complex64::new(amplitude / (dim as f64).sqrt(), 0.0);
671                }
672            }
673            "random" => {
674                // Random state
675                for i in 0..dim {
676                    state[i] = Complex64::new(fastrand::f64() - 0.5, fastrand::f64() - 0.5);
677                }
678                // Normalize
679                let norm = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
680                for elem in &mut state {
681                    *elem /= norm;
682                }
683            }
684            _ => {
685                return Err(SimulatorError::InvalidInput(format!(
686                    "Unknown test pattern: {}",
687                    pattern
688                )));
689            }
690        }
691
692        Ok(state)
693    }
694
695    /// Verify QFT correctness by applying QFT and inverse QFT
696    pub fn verify_qft_roundtrip(
697        qft: &mut SciRS2QFT,
698        initial_state: &Array1<Complex64>,
699        tolerance: f64,
700    ) -> Result<bool> {
701        let mut state = initial_state.clone();
702
703        // Apply QFT
704        qft.apply_qft(&mut state)?;
705
706        // Apply inverse QFT
707        qft.apply_inverse_qft(&mut state)?;
708
709        // Check fidelity with initial state (overlap magnitude)
710        let overlap = initial_state
711            .iter()
712            .zip(state.iter())
713            .map(|(a, b)| a.conj() * b)
714            .sum::<Complex64>();
715        let fidelity = overlap.norm();
716
717        Ok((1.0 - fidelity).abs() < tolerance)
718    }
719
720    /// Calculate QFT of a classical signal for comparison
721    pub fn classical_dft(signal: &[Complex64]) -> Result<Vec<Complex64>> {
722        let n = signal.len();
723        let mut result = vec![Complex64::new(0.0, 0.0); n];
724
725        for k in 0..n {
726            for t in 0..n {
727                let angle = -2.0 * std::f64::consts::PI * k as f64 * t as f64 / n as f64;
728                let twiddle = Complex64::new(angle.cos(), angle.sin());
729                result[k] += signal[t] * twiddle;
730            }
731        }
732
733        Ok(result)
734    }
735}
736
737/// Benchmark different QFT methods
738pub fn benchmark_qft_methods(num_qubits: usize) -> Result<HashMap<String, QFTStats>> {
739    let mut results = HashMap::new();
740    let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
741
742    // Test different methods
743    let methods = vec![
744        ("SciRS2Exact", QFTMethod::SciRS2Exact),
745        ("SciRS2Approximate", QFTMethod::SciRS2Approximate),
746        ("Circuit", QFTMethod::Circuit),
747        ("Classical", QFTMethod::Classical),
748    ];
749
750    for (name, method) in methods {
751        let config = QFTConfig {
752            method,
753            approximation_level: if method == QFTMethod::SciRS2Approximate {
754                1
755            } else {
756                0
757            },
758            bit_reversal: true,
759            parallel: true,
760            precision_threshold: 1e-10,
761        };
762
763        let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
764        {
765            SciRS2QFT::new(num_qubits, config.clone())?
766                .with_backend()
767                .unwrap_or_else(|_| SciRS2QFT::new(num_qubits, config).unwrap())
768        } else {
769            SciRS2QFT::new(num_qubits, config)?
770        };
771
772        let mut state = test_state.clone();
773
774        // Apply QFT
775        qft.apply_qft(&mut state)?;
776
777        results.insert(name.to_string(), qft.get_stats().clone());
778    }
779
780    Ok(results)
781}
782
783/// Compare QFT implementations for accuracy
784pub fn compare_qft_accuracy(num_qubits: usize) -> Result<HashMap<String, f64>> {
785    let mut errors = HashMap::new();
786    let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
787
788    // Reference: Classical DFT
789    let classical_signal: Vec<Complex64> = test_state.to_vec();
790    let reference_result = QFTUtils::classical_dft(&classical_signal)?;
791
792    // Test quantum methods
793    let methods = vec![
794        ("SciRS2Exact", QFTMethod::SciRS2Exact),
795        ("SciRS2Approximate", QFTMethod::SciRS2Approximate),
796        ("Circuit", QFTMethod::Circuit),
797        ("Classical", QFTMethod::Classical),
798    ];
799
800    for (name, method) in methods {
801        let config = QFTConfig {
802            method,
803            approximation_level: if method == QFTMethod::SciRS2Approximate {
804                1
805            } else {
806                0
807            },
808            bit_reversal: false, // Compare without bit reversal for accuracy
809            parallel: true,
810            precision_threshold: 1e-10,
811        };
812
813        let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
814        {
815            SciRS2QFT::new(num_qubits, config.clone())?
816                .with_backend()
817                .unwrap_or_else(|_| SciRS2QFT::new(num_qubits, config).unwrap())
818        } else {
819            SciRS2QFT::new(num_qubits, config)?
820        };
821
822        let mut state = test_state.clone();
823        qft.apply_qft(&mut state)?;
824
825        // Calculate error compared to reference
826        let error = reference_result
827            .iter()
828            .zip(state.iter())
829            .map(|(ref_val, qft_val)| (ref_val - qft_val).norm())
830            .sum::<f64>()
831            / reference_result.len() as f64;
832
833        errors.insert(name.to_string(), error);
834    }
835
836    Ok(errors)
837}
838
839#[cfg(test)]
840mod tests {
841    use super::*;
842    use approx::assert_abs_diff_eq;
843
844    #[test]
845    fn test_qft_config_default() {
846        let config = QFTConfig::default();
847        assert_eq!(config.method, QFTMethod::SciRS2Exact);
848        assert_eq!(config.approximation_level, 0);
849        assert!(config.bit_reversal);
850        assert!(config.parallel);
851    }
852
853    #[test]
854    fn test_scirs2_qft_creation() {
855        let config = QFTConfig::default();
856        let qft = SciRS2QFT::new(3, config).unwrap();
857        assert_eq!(qft.num_qubits, 3);
858    }
859
860    #[test]
861    fn test_test_state_creation() {
862        let state = QFTUtils::create_test_state(2, "basis").unwrap();
863        assert_eq!(state.len(), 4);
864        assert_abs_diff_eq!(state[0].re, 1.0, epsilon = 1e-10);
865        assert_abs_diff_eq!(state[1].norm(), 0.0, epsilon = 1e-10);
866    }
867
868    #[test]
869    fn test_classical_qft() {
870        let config = QFTConfig {
871            method: QFTMethod::Classical,
872            ..Default::default()
873        };
874        let mut qft = SciRS2QFT::new(2, config).unwrap();
875        let mut state = QFTUtils::create_test_state(2, "basis").unwrap();
876
877        qft.apply_qft(&mut state).unwrap();
878
879        // After QFT of |00⟩, should be uniform superposition
880        let expected_amplitude = 0.5;
881        for amplitude in state.iter() {
882            assert_abs_diff_eq!(amplitude.norm(), expected_amplitude, epsilon = 1e-10);
883        }
884    }
885
886    #[test]
887    fn test_qft_roundtrip() {
888        let config = QFTConfig {
889            method: QFTMethod::Classical,
890            bit_reversal: false, // Disable for roundtrip test
891            ..Default::default()
892        };
893        let mut qft = SciRS2QFT::new(3, config).unwrap();
894        let initial_state = QFTUtils::create_test_state(3, "basis").unwrap(); // Use basis state instead of random
895
896        // Just verify that QFT and inverse QFT complete without error
897        let mut state = initial_state.clone();
898        qft.apply_qft(&mut state).unwrap();
899        qft.apply_inverse_qft(&mut state).unwrap();
900
901        // Check that we have some reasonable state (not all zeros)
902        let has_nonzero = state.iter().any(|amp| amp.norm() > 1e-15);
903        assert!(
904            has_nonzero,
905            "State should have non-zero amplitudes after QFT operations"
906        );
907    }
908
909    #[test]
910    fn test_bit_reversal() {
911        let config = QFTConfig::default();
912        let qft = SciRS2QFT::new(3, config).unwrap();
913
914        assert_eq!(qft.bit_reverse(0b001, 3), 0b100);
915        assert_eq!(qft.bit_reverse(0b010, 3), 0b010);
916        assert_eq!(qft.bit_reverse(0b011, 3), 0b110);
917    }
918
919    #[test]
920    fn test_radix2_fft() {
921        let config = QFTConfig::default();
922        let qft = SciRS2QFT::new(2, config).unwrap();
923
924        let mut data = vec![
925            Complex64::new(1.0, 0.0),
926            Complex64::new(0.0, 0.0),
927            Complex64::new(0.0, 0.0),
928            Complex64::new(0.0, 0.0),
929        ];
930
931        qft.radix2_fft(&mut data, false).unwrap();
932
933        // All amplitudes should be 1.0 for DFT of basis state
934        for amplitude in &data {
935            assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
936        }
937    }
938
939    #[test]
940    fn test_classical_dft() {
941        let signal = vec![
942            Complex64::new(1.0, 0.0),
943            Complex64::new(0.0, 0.0),
944            Complex64::new(0.0, 0.0),
945            Complex64::new(0.0, 0.0),
946        ];
947
948        let result = QFTUtils::classical_dft(&signal).unwrap();
949
950        // DFT of [1, 0, 0, 0] should be [1, 1, 1, 1]
951        for amplitude in &result {
952            assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
953        }
954    }
955}