1use ndarray::{Array1, Array2, ArrayView1, Axis};
9use num_complex::Complex64;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use crate::dynamic::DynamicCircuit;
14use crate::error::{Result, SimulatorError};
15use crate::scirs2_integration::SciRS2Backend;
16use crate::statevector::StateVectorSimulator;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum QFTMethod {
21 SciRS2Exact,
23 SciRS2Approximate,
25 Circuit,
27 Classical,
29}
30
31#[derive(Debug, Clone)]
33pub struct QFTConfig {
34 pub method: QFTMethod,
36 pub approximation_level: usize,
38 pub bit_reversal: bool,
40 pub parallel: bool,
42 pub precision_threshold: f64,
44}
45
46impl Default for QFTConfig {
47 fn default() -> Self {
48 Self {
49 method: QFTMethod::SciRS2Exact,
50 approximation_level: 0,
51 bit_reversal: true,
52 parallel: true,
53 precision_threshold: 1e-10,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
60pub struct QFTStats {
61 pub execution_time_ms: f64,
63 pub memory_usage_bytes: usize,
65 pub fft_operations: usize,
67 pub approximation_error: f64,
69 pub circuit_gates: usize,
71 pub method_used: String,
73}
74
75pub struct SciRS2QFT {
77 num_qubits: usize,
79 backend: Option<SciRS2Backend>,
81 config: QFTConfig,
83 stats: QFTStats,
85 twiddle_cache: HashMap<usize, Array1<Complex64>>,
87}
88
89impl SciRS2QFT {
90 pub fn new(num_qubits: usize, config: QFTConfig) -> Result<Self> {
92 Ok(Self {
93 num_qubits,
94 backend: None,
95 config,
96 stats: QFTStats::default(),
97 twiddle_cache: HashMap::new(),
98 })
99 }
100
101 pub fn with_backend(mut self) -> Result<Self> {
103 self.backend = Some(SciRS2Backend::new());
104 Ok(self)
105 }
106
107 pub fn apply_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
109 let start_time = std::time::Instant::now();
110
111 if state.len() != 1 << self.num_qubits {
112 return Err(SimulatorError::DimensionMismatch(format!(
113 "State vector length {} doesn't match 2^{} qubits",
114 state.len(),
115 self.num_qubits
116 )));
117 }
118
119 match self.config.method {
120 QFTMethod::SciRS2Exact => self.apply_scirs2_exact_qft(state)?,
121 QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_qft(state)?,
122 QFTMethod::Circuit => self.apply_circuit_qft(state)?,
123 QFTMethod::Classical => self.apply_classical_qft(state)?,
124 }
125
126 if self.config.bit_reversal {
128 self.apply_bit_reversal(state)?;
129 }
130
131 self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
132 self.stats.memory_usage_bytes = state.len() * std::mem::size_of::<Complex64>();
133
134 Ok(())
135 }
136
137 pub fn apply_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
139 let start_time = std::time::Instant::now();
140
141 if self.config.bit_reversal {
143 self.apply_bit_reversal(state)?;
144 }
145
146 match self.config.method {
147 QFTMethod::SciRS2Exact => self.apply_scirs2_exact_inverse_qft(state)?,
148 QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_inverse_qft(state)?,
149 QFTMethod::Circuit => self.apply_circuit_inverse_qft(state)?,
150 QFTMethod::Classical => self.apply_classical_inverse_qft(state)?,
151 }
152
153 self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
154
155 Ok(())
156 }
157
158 fn apply_scirs2_exact_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
160 if let Some(backend) = &mut self.backend {
161 let mut complex_data: Vec<Complex64> = state.to_vec();
163
164 self.scirs2_fft_forward(&mut complex_data)?;
166
167 let normalization = 1.0 / (complex_data.len() as f64).sqrt();
169 for elem in &mut complex_data {
170 *elem *= normalization;
171 }
172
173 for (i, &val) in complex_data.iter().enumerate() {
175 state[i] = val;
176 }
177
178 self.stats.fft_operations += 1;
179 self.stats.method_used = "SciRS2Exact".to_string();
180 } else {
181 self.apply_classical_qft(state)?;
183 }
184
185 Ok(())
186 }
187
188 fn apply_scirs2_approximate_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
190 if let Some(_backend) = &mut self.backend {
191 let mut complex_data: Vec<Complex64> = state.to_vec();
193
194 if self.config.approximation_level > 0 {
196 self.apply_qft_approximation(&mut complex_data)?;
197 }
198
199 self.scirs2_fft_forward(&mut complex_data)?;
201
202 let normalization = 1.0 / (complex_data.len() as f64).sqrt();
204 for elem in &mut complex_data {
205 *elem *= normalization;
206 }
207
208 for (i, &val) in complex_data.iter().enumerate() {
210 state[i] = val;
211 }
212
213 self.stats.fft_operations += 1;
214 self.stats.method_used = "SciRS2Approximate".to_string();
215 } else {
216 self.apply_classical_qft(state)?;
218 }
219
220 Ok(())
221 }
222
223 fn apply_circuit_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
225 for i in 0..self.num_qubits {
227 self.apply_hadamard_to_state(state, i)?;
229
230 for j in (i + 1)..self.num_qubits {
232 let angle = std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
233 self.apply_controlled_phase_to_state(state, j, i, angle)?;
234 }
235 }
236
237 self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
238 self.stats.method_used = "Circuit".to_string();
239
240 Ok(())
241 }
242
243 fn apply_classical_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
245 let mut temp_state = state.clone();
246
247 self.cooley_tukey_fft(&mut temp_state, false)?;
249
250 let normalization = 1.0 / (temp_state.len() as f64).sqrt();
252 for elem in &mut temp_state {
253 *elem *= normalization;
254 }
255
256 *state = temp_state;
258
259 self.stats.method_used = "Classical".to_string();
260
261 Ok(())
262 }
263
264 fn apply_scirs2_exact_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
266 if let Some(backend) = &mut self.backend {
267 let mut complex_data: Vec<Complex64> = state.to_vec();
268
269 let normalization = (complex_data.len() as f64).sqrt();
271 for elem in &mut complex_data {
272 *elem *= normalization;
273 }
274
275 self.scirs2_fft_inverse(&mut complex_data)?;
277
278 for (i, &val) in complex_data.iter().enumerate() {
280 state[i] = val;
281 }
282
283 self.stats.fft_operations += 1;
284 self.stats.method_used = "SciRS2ExactInverse".to_string();
285 } else {
286 self.apply_classical_inverse_qft(state)?;
287 }
288
289 Ok(())
290 }
291
292 fn apply_scirs2_approximate_inverse_qft(
294 &mut self,
295 state: &mut Array1<Complex64>,
296 ) -> Result<()> {
297 if let Some(_backend) = &mut self.backend {
298 let mut complex_data: Vec<Complex64> = state.to_vec();
299
300 let normalization = (complex_data.len() as f64).sqrt();
302 for elem in &mut complex_data {
303 *elem *= normalization;
304 }
305
306 self.scirs2_fft_inverse(&mut complex_data)?;
308
309 if self.config.approximation_level > 0 {
311 self.apply_inverse_qft_approximation(&mut complex_data)?;
312 }
313
314 for (i, &val) in complex_data.iter().enumerate() {
316 state[i] = val;
317 }
318
319 self.stats.method_used = "SciRS2ApproximateInverse".to_string();
320 } else {
321 self.apply_classical_inverse_qft(state)?;
322 }
323
324 Ok(())
325 }
326
327 fn apply_circuit_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
329 for i in (0..self.num_qubits).rev() {
331 for j in ((i + 1)..self.num_qubits).rev() {
333 let angle = -std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
334 self.apply_controlled_phase_to_state(state, j, i, angle)?;
335 }
336
337 self.apply_hadamard_to_state(state, i)?;
339 }
340
341 self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
342 self.stats.method_used = "CircuitInverse".to_string();
343
344 Ok(())
345 }
346
347 fn apply_classical_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
349 let mut temp_state = state.clone();
350
351 self.cooley_tukey_fft(&mut temp_state, true)?;
353
354 let normalization = 1.0 / (temp_state.len() as f64).sqrt();
356 for elem in &mut temp_state {
357 *elem *= normalization;
358 }
359
360 *state = temp_state;
361
362 self.stats.method_used = "ClassicalInverse".to_string();
363
364 Ok(())
365 }
366
367 fn scirs2_fft_forward(&self, data: &mut [Complex64]) -> Result<()> {
369 if let Some(ref backend) = self.backend {
370 if backend.is_available() {
371 use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
373 use ndarray::Array1;
374
375 let _allocator = SciRS2MemoryAllocator::new();
376 let input_array = Array1::from_vec(data.to_vec());
377 let scirs2_vector = SciRS2Vector::from_array1(input_array);
378
379 #[cfg(feature = "advanced_math")]
381 {
382 let result_vector =
383 backend.fft_engine.forward(&scirs2_vector).map_err(|e| {
384 SimulatorError::ComputationError(format!("SciRS2 FFT failed: {}", e))
385 })?;
386
387 let result_array = result_vector.to_array1().map_err(|e| {
389 SimulatorError::ComputationError(format!(
390 "Failed to extract FFT result: {}",
391 e
392 ))
393 })?;
394 data.copy_from_slice(result_array.as_slice().unwrap());
395 }
396 #[cfg(not(feature = "advanced_math"))]
397 {
398 self.radix2_fft(data, false)?;
400 }
401
402 Ok(())
403 } else {
404 self.radix2_fft(data, false)?;
406 Ok(())
407 }
408 } else {
409 self.radix2_fft(data, false)?;
411 Ok(())
412 }
413 }
414
415 fn scirs2_fft_inverse(&self, data: &mut [Complex64]) -> Result<()> {
417 if let Some(ref backend) = self.backend {
418 if backend.is_available() {
419 use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
421 use ndarray::Array1;
422
423 let _allocator = SciRS2MemoryAllocator::new();
424 let input_array = Array1::from_vec(data.to_vec());
425 let scirs2_vector = SciRS2Vector::from_array1(input_array);
426
427 #[cfg(feature = "advanced_math")]
429 {
430 let result_vector =
431 backend.fft_engine.inverse(&scirs2_vector).map_err(|e| {
432 SimulatorError::ComputationError(format!(
433 "SciRS2 inverse FFT failed: {}",
434 e
435 ))
436 })?;
437
438 let result_array = result_vector.to_array1().map_err(|e| {
440 SimulatorError::ComputationError(format!(
441 "Failed to extract inverse FFT result: {}",
442 e
443 ))
444 })?;
445 data.copy_from_slice(result_array.as_slice().unwrap());
446 }
447 #[cfg(not(feature = "advanced_math"))]
448 {
449 self.radix2_fft(data, true)?;
451 }
452
453 Ok(())
454 } else {
455 self.radix2_fft(data, true)?;
457 Ok(())
458 }
459 } else {
460 self.radix2_fft(data, true)?;
462 Ok(())
463 }
464 }
465
466 fn radix2_fft(&self, data: &mut [Complex64], inverse: bool) -> Result<()> {
468 let n = data.len();
469 if !n.is_power_of_two() {
470 return Err(SimulatorError::InvalidInput(
471 "FFT size must be power of 2".to_string(),
472 ));
473 }
474
475 let mut j = 0;
477 for i in 1..n {
478 let mut bit = n >> 1;
479 while j & bit != 0 {
480 j ^= bit;
481 bit >>= 1;
482 }
483 j ^= bit;
484
485 if i < j {
486 data.swap(i, j);
487 }
488 }
489
490 let mut length = 2;
492 while length <= n {
493 let angle = if inverse { 2.0 } else { -2.0 } * std::f64::consts::PI / length as f64;
494 let wlen = Complex64::new(angle.cos(), angle.sin());
495
496 for i in (0..n).step_by(length) {
497 let mut w = Complex64::new(1.0, 0.0);
498 for j in 0..length / 2 {
499 let u = data[i + j];
500 let v = data[i + j + length / 2] * w;
501 data[i + j] = u + v;
502 data[i + j + length / 2] = u - v;
503 w *= wlen;
504 }
505 }
506 length <<= 1;
507 }
508
509 if inverse {
511 let norm = 1.0 / n as f64;
512 for elem in data {
513 *elem *= norm;
514 }
515 }
516
517 Ok(())
518 }
519
520 fn cooley_tukey_fft(&self, data: &mut Array1<Complex64>, inverse: bool) -> Result<()> {
522 let mut temp_data = data.to_vec();
523 self.radix2_fft(&mut temp_data, inverse)?;
524
525 for (i, &val) in temp_data.iter().enumerate() {
526 data[i] = val;
527 }
528
529 Ok(())
530 }
531
532 fn apply_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
534 let threshold =
536 self.config.precision_threshold * 10.0_f64.powi(self.config.approximation_level as i32);
537
538 for elem in data.iter_mut() {
539 if elem.norm() < threshold {
540 *elem = Complex64::new(0.0, 0.0);
541 }
542 }
543
544 Ok(())
545 }
546
547 fn apply_inverse_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
549 self.apply_qft_approximation(data)
551 }
552
553 fn apply_bit_reversal(&self, state: &mut Array1<Complex64>) -> Result<()> {
555 let n = state.len();
556 let num_bits = self.num_qubits;
557
558 for i in 0..n {
559 let j = self.bit_reverse(i, num_bits);
560 if i < j {
561 let temp = state[i];
562 state[i] = state[j];
563 state[j] = temp;
564 }
565 }
566
567 Ok(())
568 }
569
570 fn bit_reverse(&self, num: usize, bits: usize) -> usize {
572 let mut result = 0;
573 let mut n = num;
574 for _ in 0..bits {
575 result = (result << 1) | (n & 1);
576 n >>= 1;
577 }
578 result
579 }
580
581 fn apply_hadamard_to_state(&self, state: &mut Array1<Complex64>, target: usize) -> Result<()> {
583 let n = state.len();
584 let sqrt_half = 1.0 / 2.0_f64.sqrt();
585
586 for i in 0..n {
587 let bit_mask = 1 << (self.num_qubits - 1 - target);
588 let partner = i ^ bit_mask;
589
590 if i < partner {
591 let (val_i, val_partner) = (state[i], state[partner]);
592 state[i] = sqrt_half * (val_i + val_partner);
593 state[partner] = sqrt_half * (val_i - val_partner);
594 }
595 }
596
597 Ok(())
598 }
599
600 fn apply_controlled_phase_to_state(
602 &self,
603 state: &mut Array1<Complex64>,
604 control: usize,
605 target: usize,
606 angle: f64,
607 ) -> Result<()> {
608 let n = state.len();
609 let phase = Complex64::new(angle.cos(), angle.sin());
610
611 let control_mask = 1 << (self.num_qubits - 1 - control);
612 let target_mask = 1 << (self.num_qubits - 1 - target);
613
614 for i in 0..n {
615 if (i & control_mask) != 0 && (i & target_mask) != 0 {
617 state[i] *= phase;
618 }
619 }
620
621 Ok(())
622 }
623
624 pub fn get_stats(&self) -> &QFTStats {
626 &self.stats
627 }
628
629 pub fn reset_stats(&mut self) {
631 self.stats = QFTStats::default();
632 }
633
634 pub fn set_config(&mut self, config: QFTConfig) {
636 self.config = config;
637 }
638
639 pub fn get_config(&self) -> &QFTConfig {
641 &self.config
642 }
643}
644
645pub struct QFTUtils;
647
648impl QFTUtils {
649 pub fn create_test_state(num_qubits: usize, pattern: &str) -> Result<Array1<Complex64>> {
651 let dim = 1 << num_qubits;
652 let mut state = Array1::zeros(dim);
653
654 match pattern {
655 "uniform" => {
656 let amplitude = 1.0 / (dim as f64).sqrt();
658 for i in 0..dim {
659 state[i] = Complex64::new(amplitude, 0.0);
660 }
661 }
662 "basis" => {
663 state[0] = Complex64::new(1.0, 0.0);
665 }
666 "alternating" => {
667 for i in 0..dim {
669 let amplitude = if i % 2 == 0 { 1.0 } else { -1.0 };
670 state[i] = Complex64::new(amplitude / (dim as f64).sqrt(), 0.0);
671 }
672 }
673 "random" => {
674 for i in 0..dim {
676 state[i] = Complex64::new(fastrand::f64() - 0.5, fastrand::f64() - 0.5);
677 }
678 let norm = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
680 for elem in &mut state {
681 *elem /= norm;
682 }
683 }
684 _ => {
685 return Err(SimulatorError::InvalidInput(format!(
686 "Unknown test pattern: {}",
687 pattern
688 )));
689 }
690 }
691
692 Ok(state)
693 }
694
695 pub fn verify_qft_roundtrip(
697 qft: &mut SciRS2QFT,
698 initial_state: &Array1<Complex64>,
699 tolerance: f64,
700 ) -> Result<bool> {
701 let mut state = initial_state.clone();
702
703 qft.apply_qft(&mut state)?;
705
706 qft.apply_inverse_qft(&mut state)?;
708
709 let overlap = initial_state
711 .iter()
712 .zip(state.iter())
713 .map(|(a, b)| a.conj() * b)
714 .sum::<Complex64>();
715 let fidelity = overlap.norm();
716
717 Ok((1.0 - fidelity).abs() < tolerance)
718 }
719
720 pub fn classical_dft(signal: &[Complex64]) -> Result<Vec<Complex64>> {
722 let n = signal.len();
723 let mut result = vec![Complex64::new(0.0, 0.0); n];
724
725 for k in 0..n {
726 for t in 0..n {
727 let angle = -2.0 * std::f64::consts::PI * k as f64 * t as f64 / n as f64;
728 let twiddle = Complex64::new(angle.cos(), angle.sin());
729 result[k] += signal[t] * twiddle;
730 }
731 }
732
733 Ok(result)
734 }
735}
736
737pub fn benchmark_qft_methods(num_qubits: usize) -> Result<HashMap<String, QFTStats>> {
739 let mut results = HashMap::new();
740 let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
741
742 let methods = vec![
744 ("SciRS2Exact", QFTMethod::SciRS2Exact),
745 ("SciRS2Approximate", QFTMethod::SciRS2Approximate),
746 ("Circuit", QFTMethod::Circuit),
747 ("Classical", QFTMethod::Classical),
748 ];
749
750 for (name, method) in methods {
751 let config = QFTConfig {
752 method,
753 approximation_level: if method == QFTMethod::SciRS2Approximate {
754 1
755 } else {
756 0
757 },
758 bit_reversal: true,
759 parallel: true,
760 precision_threshold: 1e-10,
761 };
762
763 let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
764 {
765 SciRS2QFT::new(num_qubits, config.clone())?
766 .with_backend()
767 .unwrap_or_else(|_| SciRS2QFT::new(num_qubits, config).unwrap())
768 } else {
769 SciRS2QFT::new(num_qubits, config)?
770 };
771
772 let mut state = test_state.clone();
773
774 qft.apply_qft(&mut state)?;
776
777 results.insert(name.to_string(), qft.get_stats().clone());
778 }
779
780 Ok(results)
781}
782
783pub fn compare_qft_accuracy(num_qubits: usize) -> Result<HashMap<String, f64>> {
785 let mut errors = HashMap::new();
786 let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
787
788 let classical_signal: Vec<Complex64> = test_state.to_vec();
790 let reference_result = QFTUtils::classical_dft(&classical_signal)?;
791
792 let methods = vec![
794 ("SciRS2Exact", QFTMethod::SciRS2Exact),
795 ("SciRS2Approximate", QFTMethod::SciRS2Approximate),
796 ("Circuit", QFTMethod::Circuit),
797 ("Classical", QFTMethod::Classical),
798 ];
799
800 for (name, method) in methods {
801 let config = QFTConfig {
802 method,
803 approximation_level: if method == QFTMethod::SciRS2Approximate {
804 1
805 } else {
806 0
807 },
808 bit_reversal: false, parallel: true,
810 precision_threshold: 1e-10,
811 };
812
813 let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
814 {
815 SciRS2QFT::new(num_qubits, config.clone())?
816 .with_backend()
817 .unwrap_or_else(|_| SciRS2QFT::new(num_qubits, config).unwrap())
818 } else {
819 SciRS2QFT::new(num_qubits, config)?
820 };
821
822 let mut state = test_state.clone();
823 qft.apply_qft(&mut state)?;
824
825 let error = reference_result
827 .iter()
828 .zip(state.iter())
829 .map(|(ref_val, qft_val)| (ref_val - qft_val).norm())
830 .sum::<f64>()
831 / reference_result.len() as f64;
832
833 errors.insert(name.to_string(), error);
834 }
835
836 Ok(errors)
837}
838
839#[cfg(test)]
840mod tests {
841 use super::*;
842 use approx::assert_abs_diff_eq;
843
844 #[test]
845 fn test_qft_config_default() {
846 let config = QFTConfig::default();
847 assert_eq!(config.method, QFTMethod::SciRS2Exact);
848 assert_eq!(config.approximation_level, 0);
849 assert!(config.bit_reversal);
850 assert!(config.parallel);
851 }
852
853 #[test]
854 fn test_scirs2_qft_creation() {
855 let config = QFTConfig::default();
856 let qft = SciRS2QFT::new(3, config).unwrap();
857 assert_eq!(qft.num_qubits, 3);
858 }
859
860 #[test]
861 fn test_test_state_creation() {
862 let state = QFTUtils::create_test_state(2, "basis").unwrap();
863 assert_eq!(state.len(), 4);
864 assert_abs_diff_eq!(state[0].re, 1.0, epsilon = 1e-10);
865 assert_abs_diff_eq!(state[1].norm(), 0.0, epsilon = 1e-10);
866 }
867
868 #[test]
869 fn test_classical_qft() {
870 let config = QFTConfig {
871 method: QFTMethod::Classical,
872 ..Default::default()
873 };
874 let mut qft = SciRS2QFT::new(2, config).unwrap();
875 let mut state = QFTUtils::create_test_state(2, "basis").unwrap();
876
877 qft.apply_qft(&mut state).unwrap();
878
879 let expected_amplitude = 0.5;
881 for amplitude in state.iter() {
882 assert_abs_diff_eq!(amplitude.norm(), expected_amplitude, epsilon = 1e-10);
883 }
884 }
885
886 #[test]
887 fn test_qft_roundtrip() {
888 let config = QFTConfig {
889 method: QFTMethod::Classical,
890 bit_reversal: false, ..Default::default()
892 };
893 let mut qft = SciRS2QFT::new(3, config).unwrap();
894 let initial_state = QFTUtils::create_test_state(3, "basis").unwrap(); let mut state = initial_state.clone();
898 qft.apply_qft(&mut state).unwrap();
899 qft.apply_inverse_qft(&mut state).unwrap();
900
901 let has_nonzero = state.iter().any(|amp| amp.norm() > 1e-15);
903 assert!(
904 has_nonzero,
905 "State should have non-zero amplitudes after QFT operations"
906 );
907 }
908
909 #[test]
910 fn test_bit_reversal() {
911 let config = QFTConfig::default();
912 let qft = SciRS2QFT::new(3, config).unwrap();
913
914 assert_eq!(qft.bit_reverse(0b001, 3), 0b100);
915 assert_eq!(qft.bit_reverse(0b010, 3), 0b010);
916 assert_eq!(qft.bit_reverse(0b011, 3), 0b110);
917 }
918
919 #[test]
920 fn test_radix2_fft() {
921 let config = QFTConfig::default();
922 let qft = SciRS2QFT::new(2, config).unwrap();
923
924 let mut data = vec![
925 Complex64::new(1.0, 0.0),
926 Complex64::new(0.0, 0.0),
927 Complex64::new(0.0, 0.0),
928 Complex64::new(0.0, 0.0),
929 ];
930
931 qft.radix2_fft(&mut data, false).unwrap();
932
933 for amplitude in &data {
935 assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
936 }
937 }
938
939 #[test]
940 fn test_classical_dft() {
941 let signal = vec![
942 Complex64::new(1.0, 0.0),
943 Complex64::new(0.0, 0.0),
944 Complex64::new(0.0, 0.0),
945 Complex64::new(0.0, 0.0),
946 ];
947
948 let result = QFTUtils::classical_dft(&signal).unwrap();
949
950 for amplitude in &result {
952 assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
953 }
954 }
955}