use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
use scirs2_core::Complex64;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::dynamic::DynamicCircuit;
use crate::error::{Result, SimulatorError};
use crate::scirs2_integration::SciRS2Backend;
use crate::statevector::StateVectorSimulator;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QFTMethod {
SciRS2Exact,
SciRS2Approximate,
Circuit,
Classical,
}
#[derive(Debug, Clone)]
pub struct QFTConfig {
pub method: QFTMethod,
pub approximation_level: usize,
pub bit_reversal: bool,
pub parallel: bool,
pub precision_threshold: f64,
}
impl Default for QFTConfig {
fn default() -> Self {
Self {
method: QFTMethod::SciRS2Exact,
approximation_level: 0,
bit_reversal: true,
parallel: true,
precision_threshold: 1e-10,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct QFTStats {
pub execution_time_ms: f64,
pub memory_usage_bytes: usize,
pub fft_operations: usize,
pub approximation_error: f64,
pub circuit_gates: usize,
pub method_used: String,
}
pub struct SciRS2QFT {
num_qubits: usize,
backend: Option<SciRS2Backend>,
config: QFTConfig,
stats: QFTStats,
twiddle_cache: HashMap<usize, Array1<Complex64>>,
}
impl SciRS2QFT {
pub fn new(num_qubits: usize, config: QFTConfig) -> Result<Self> {
Ok(Self {
num_qubits,
backend: None,
config,
stats: QFTStats::default(),
twiddle_cache: HashMap::new(),
})
}
pub fn with_backend(mut self) -> Result<Self> {
self.backend = Some(SciRS2Backend::new());
Ok(self)
}
pub fn apply_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
let start_time = std::time::Instant::now();
if state.len() != 1 << self.num_qubits {
return Err(SimulatorError::DimensionMismatch(format!(
"State vector length {} doesn't match 2^{} qubits",
state.len(),
self.num_qubits
)));
}
match self.config.method {
QFTMethod::SciRS2Exact => self.apply_scirs2_exact_qft(state)?,
QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_qft(state)?,
QFTMethod::Circuit => self.apply_circuit_qft(state)?,
QFTMethod::Classical => self.apply_classical_qft(state)?,
}
if self.config.bit_reversal {
self.apply_bit_reversal(state)?;
}
self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
self.stats.memory_usage_bytes = state.len() * std::mem::size_of::<Complex64>();
Ok(())
}
pub fn apply_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
let start_time = std::time::Instant::now();
if self.config.bit_reversal {
self.apply_bit_reversal(state)?;
}
match self.config.method {
QFTMethod::SciRS2Exact => self.apply_scirs2_exact_inverse_qft(state)?,
QFTMethod::SciRS2Approximate => self.apply_scirs2_approximate_inverse_qft(state)?,
QFTMethod::Circuit => self.apply_circuit_inverse_qft(state)?,
QFTMethod::Classical => self.apply_classical_inverse_qft(state)?,
}
self.stats.execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
Ok(())
}
fn apply_scirs2_exact_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
if let Some(backend) = &mut self.backend {
let mut complex_data: Vec<Complex64> = state.to_vec();
self.scirs2_fft_forward(&mut complex_data)?;
let normalization = 1.0 / (complex_data.len() as f64).sqrt();
for elem in &mut complex_data {
*elem *= normalization;
}
for (i, &val) in complex_data.iter().enumerate() {
state[i] = val;
}
self.stats.fft_operations += 1;
self.stats.method_used = "SciRS2Exact".to_string();
} else {
self.apply_classical_qft(state)?;
}
Ok(())
}
fn apply_scirs2_approximate_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
if let Some(_backend) = &mut self.backend {
let mut complex_data: Vec<Complex64> = state.to_vec();
if self.config.approximation_level > 0 {
self.apply_qft_approximation(&mut complex_data)?;
}
self.scirs2_fft_forward(&mut complex_data)?;
let normalization = 1.0 / (complex_data.len() as f64).sqrt();
for elem in &mut complex_data {
*elem *= normalization;
}
for (i, &val) in complex_data.iter().enumerate() {
state[i] = val;
}
self.stats.fft_operations += 1;
self.stats.method_used = "SciRS2Approximate".to_string();
} else {
self.apply_classical_qft(state)?;
}
Ok(())
}
fn apply_circuit_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
for i in 0..self.num_qubits {
self.apply_hadamard_to_state(state, i)?;
for j in (i + 1)..self.num_qubits {
let angle = std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
self.apply_controlled_phase_to_state(state, j, i, angle)?;
}
}
self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
self.stats.method_used = "Circuit".to_string();
Ok(())
}
fn apply_classical_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
let mut temp_state = state.clone();
self.cooley_tukey_fft(&mut temp_state, false)?;
let normalization = 1.0 / (temp_state.len() as f64).sqrt();
for elem in &mut temp_state {
*elem *= normalization;
}
*state = temp_state;
self.stats.method_used = "Classical".to_string();
Ok(())
}
fn apply_scirs2_exact_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
if let Some(backend) = &mut self.backend {
let mut complex_data: Vec<Complex64> = state.to_vec();
let normalization = (complex_data.len() as f64).sqrt();
for elem in &mut complex_data {
*elem *= normalization;
}
self.scirs2_fft_inverse(&mut complex_data)?;
for (i, &val) in complex_data.iter().enumerate() {
state[i] = val;
}
self.stats.fft_operations += 1;
self.stats.method_used = "SciRS2ExactInverse".to_string();
} else {
self.apply_classical_inverse_qft(state)?;
}
Ok(())
}
fn apply_scirs2_approximate_inverse_qft(
&mut self,
state: &mut Array1<Complex64>,
) -> Result<()> {
if let Some(_backend) = &mut self.backend {
let mut complex_data: Vec<Complex64> = state.to_vec();
let normalization = (complex_data.len() as f64).sqrt();
for elem in &mut complex_data {
*elem *= normalization;
}
self.scirs2_fft_inverse(&mut complex_data)?;
if self.config.approximation_level > 0 {
self.apply_inverse_qft_approximation(&mut complex_data)?;
}
for (i, &val) in complex_data.iter().enumerate() {
state[i] = val;
}
self.stats.method_used = "SciRS2ApproximateInverse".to_string();
} else {
self.apply_classical_inverse_qft(state)?;
}
Ok(())
}
fn apply_circuit_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
for i in (0..self.num_qubits).rev() {
for j in ((i + 1)..self.num_qubits).rev() {
let angle = -std::f64::consts::PI / 2.0_f64.powi((j - i) as i32);
self.apply_controlled_phase_to_state(state, j, i, angle)?;
}
self.apply_hadamard_to_state(state, i)?;
}
self.stats.circuit_gates = self.num_qubits * (self.num_qubits + 1) / 2;
self.stats.method_used = "CircuitInverse".to_string();
Ok(())
}
fn apply_classical_inverse_qft(&mut self, state: &mut Array1<Complex64>) -> Result<()> {
let mut temp_state = state.clone();
self.cooley_tukey_fft(&mut temp_state, true)?;
let normalization = 1.0 / (temp_state.len() as f64).sqrt();
for elem in &mut temp_state {
*elem *= normalization;
}
*state = temp_state;
self.stats.method_used = "ClassicalInverse".to_string();
Ok(())
}
fn scirs2_fft_forward(&self, data: &mut [Complex64]) -> Result<()> {
if let Some(ref backend) = self.backend {
if backend.is_available() {
use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
use scirs2_core::ndarray::Array1;
let _allocator = SciRS2MemoryAllocator::new();
let input_array = Array1::from_vec(data.to_vec());
let scirs2_vector = SciRS2Vector::from_array1(input_array);
#[cfg(feature = "advanced_math")]
{
let result_vector =
backend.fft_engine.forward(&scirs2_vector).map_err(|e| {
SimulatorError::ComputationError(format!("SciRS2 FFT failed: {e}"))
})?;
let result_array = result_vector.to_array1().map_err(|e| {
SimulatorError::ComputationError(format!(
"Failed to extract FFT result: {e}"
))
})?;
data.copy_from_slice(
result_array
.as_slice()
.expect("1D contiguous array has a valid slice"),
);
}
#[cfg(not(feature = "advanced_math"))]
{
self.radix2_fft(data, false)?;
}
Ok(())
} else {
self.radix2_fft(data, false)?;
Ok(())
}
} else {
self.radix2_fft(data, false)?;
Ok(())
}
}
fn scirs2_fft_inverse(&self, data: &mut [Complex64]) -> Result<()> {
if let Some(ref backend) = self.backend {
if backend.is_available() {
use crate::scirs2_integration::{SciRS2MemoryAllocator, SciRS2Vector};
use scirs2_core::ndarray::Array1;
let _allocator = SciRS2MemoryAllocator::new();
let input_array = Array1::from_vec(data.to_vec());
let scirs2_vector = SciRS2Vector::from_array1(input_array);
#[cfg(feature = "advanced_math")]
{
let result_vector =
backend.fft_engine.inverse(&scirs2_vector).map_err(|e| {
SimulatorError::ComputationError(format!(
"SciRS2 inverse FFT failed: {e}"
))
})?;
let result_array = result_vector.to_array1().map_err(|e| {
SimulatorError::ComputationError(format!(
"Failed to extract inverse FFT result: {e}"
))
})?;
data.copy_from_slice(
result_array
.as_slice()
.expect("1D contiguous array has a valid slice"),
);
}
#[cfg(not(feature = "advanced_math"))]
{
self.radix2_fft(data, true)?;
}
Ok(())
} else {
self.radix2_fft(data, true)?;
Ok(())
}
} else {
self.radix2_fft(data, true)?;
Ok(())
}
}
fn radix2_fft(&self, data: &mut [Complex64], inverse: bool) -> Result<()> {
let n = data.len();
if !n.is_power_of_two() {
return Err(SimulatorError::InvalidInput(
"FFT size must be power of 2".to_string(),
));
}
let mut j = 0;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
data.swap(i, j);
}
}
let mut length = 2;
while length <= n {
let angle = if inverse { 2.0 } else { -2.0 } * std::f64::consts::PI / length as f64;
let wlen = Complex64::new(angle.cos(), angle.sin());
for i in (0..n).step_by(length) {
let mut w = Complex64::new(1.0, 0.0);
for j in 0..length / 2 {
let u = data[i + j];
let v = data[i + j + length / 2] * w;
data[i + j] = u + v;
data[i + j + length / 2] = u - v;
w *= wlen;
}
}
length <<= 1;
}
if inverse {
let norm = 1.0 / n as f64;
for elem in data {
*elem *= norm;
}
}
Ok(())
}
fn cooley_tukey_fft(&self, data: &mut Array1<Complex64>, inverse: bool) -> Result<()> {
let mut temp_data = data.to_vec();
self.radix2_fft(&mut temp_data, inverse)?;
for (i, &val) in temp_data.iter().enumerate() {
data[i] = val;
}
Ok(())
}
fn apply_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
let threshold =
self.config.precision_threshold * 10.0_f64.powi(self.config.approximation_level as i32);
for elem in data.iter_mut() {
if elem.norm() < threshold {
*elem = Complex64::new(0.0, 0.0);
}
}
Ok(())
}
fn apply_inverse_qft_approximation(&self, data: &mut [Complex64]) -> Result<()> {
self.apply_qft_approximation(data)
}
fn apply_bit_reversal(&self, state: &mut Array1<Complex64>) -> Result<()> {
let n = state.len();
let num_bits = self.num_qubits;
for i in 0..n {
let j = self.bit_reverse(i, num_bits);
if i < j {
let temp = state[i];
state[i] = state[j];
state[j] = temp;
}
}
Ok(())
}
fn bit_reverse(&self, num: usize, bits: usize) -> usize {
let mut result = 0;
let mut n = num;
for _ in 0..bits {
result = (result << 1) | (n & 1);
n >>= 1;
}
result
}
fn apply_hadamard_to_state(&self, state: &mut Array1<Complex64>, target: usize) -> Result<()> {
let n = state.len();
let sqrt_half = 1.0 / 2.0_f64.sqrt();
for i in 0..n {
let bit_mask = 1 << (self.num_qubits - 1 - target);
let partner = i ^ bit_mask;
if i < partner {
let (val_i, val_partner) = (state[i], state[partner]);
state[i] = sqrt_half * (val_i + val_partner);
state[partner] = sqrt_half * (val_i - val_partner);
}
}
Ok(())
}
fn apply_controlled_phase_to_state(
&self,
state: &mut Array1<Complex64>,
control: usize,
target: usize,
angle: f64,
) -> Result<()> {
let n = state.len();
let phase = Complex64::new(angle.cos(), angle.sin());
let control_mask = 1 << (self.num_qubits - 1 - control);
let target_mask = 1 << (self.num_qubits - 1 - target);
for i in 0..n {
if (i & control_mask) != 0 && (i & target_mask) != 0 {
state[i] *= phase;
}
}
Ok(())
}
#[must_use]
pub const fn get_stats(&self) -> &QFTStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = QFTStats::default();
}
pub const fn set_config(&mut self, config: QFTConfig) {
self.config = config;
}
#[must_use]
pub const fn get_config(&self) -> &QFTConfig {
&self.config
}
}
pub struct QFTUtils;
impl QFTUtils {
pub fn create_test_state(num_qubits: usize, pattern: &str) -> Result<Array1<Complex64>> {
let dim = 1 << num_qubits;
let mut state = Array1::zeros(dim);
match pattern {
"uniform" => {
let amplitude = 1.0 / (dim as f64).sqrt();
for i in 0..dim {
state[i] = Complex64::new(amplitude, 0.0);
}
}
"basis" => {
state[0] = Complex64::new(1.0, 0.0);
}
"alternating" => {
for i in 0..dim {
let amplitude = if i % 2 == 0 { 1.0 } else { -1.0 };
state[i] = Complex64::new(amplitude / (dim as f64).sqrt(), 0.0);
}
}
"random" => {
for i in 0..dim {
state[i] = Complex64::new(fastrand::f64() - 0.5, fastrand::f64() - 0.5);
}
let norm = state
.iter()
.map(scirs2_core::Complex::norm_sqr)
.sum::<f64>()
.sqrt();
for elem in &mut state {
*elem /= norm;
}
}
_ => {
return Err(SimulatorError::InvalidInput(format!(
"Unknown test pattern: {pattern}"
)));
}
}
Ok(state)
}
pub fn verify_qft_roundtrip(
qft: &mut SciRS2QFT,
initial_state: &Array1<Complex64>,
tolerance: f64,
) -> Result<bool> {
let mut state = initial_state.clone();
qft.apply_qft(&mut state)?;
qft.apply_inverse_qft(&mut state)?;
let overlap = initial_state
.iter()
.zip(state.iter())
.map(|(a, b)| a.conj() * b)
.sum::<Complex64>();
let fidelity = overlap.norm();
Ok((1.0 - fidelity).abs() < tolerance)
}
pub fn classical_dft(signal: &[Complex64]) -> Result<Vec<Complex64>> {
let n = signal.len();
let mut result = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
for t in 0..n {
let angle = -2.0 * std::f64::consts::PI * k as f64 * t as f64 / n as f64;
let twiddle = Complex64::new(angle.cos(), angle.sin());
result[k] += signal[t] * twiddle;
}
}
Ok(result)
}
}
pub fn benchmark_qft_methods(num_qubits: usize) -> Result<HashMap<String, QFTStats>> {
let mut results = HashMap::new();
let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
let methods = vec![
("SciRS2Exact", QFTMethod::SciRS2Exact),
("SciRS2Approximate", QFTMethod::SciRS2Approximate),
("Circuit", QFTMethod::Circuit),
("Classical", QFTMethod::Classical),
];
for (name, method) in methods {
let config = QFTConfig {
method,
approximation_level: usize::from(method == QFTMethod::SciRS2Approximate),
bit_reversal: true,
parallel: true,
precision_threshold: 1e-10,
};
let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
{
match SciRS2QFT::new(num_qubits, config.clone())?.with_backend() {
Ok(qft_with_backend) => qft_with_backend,
Err(_) => SciRS2QFT::new(num_qubits, config)
.expect("QFT creation should succeed with same config"),
}
} else {
SciRS2QFT::new(num_qubits, config)?
};
let mut state = test_state.clone();
qft.apply_qft(&mut state)?;
results.insert(name.to_string(), qft.get_stats().clone());
}
Ok(results)
}
pub fn compare_qft_accuracy(num_qubits: usize) -> Result<HashMap<String, f64>> {
let mut errors = HashMap::new();
let test_state = QFTUtils::create_test_state(num_qubits, "random")?;
let classical_signal: Vec<Complex64> = test_state.to_vec();
let reference_result = QFTUtils::classical_dft(&classical_signal)?;
let methods = vec![
("SciRS2Exact", QFTMethod::SciRS2Exact),
("SciRS2Approximate", QFTMethod::SciRS2Approximate),
("Circuit", QFTMethod::Circuit),
("Classical", QFTMethod::Classical),
];
for (name, method) in methods {
let config = QFTConfig {
method,
approximation_level: usize::from(method == QFTMethod::SciRS2Approximate),
bit_reversal: false, parallel: true,
precision_threshold: 1e-10,
};
let mut qft = if method == QFTMethod::SciRS2Exact || method == QFTMethod::SciRS2Approximate
{
match SciRS2QFT::new(num_qubits, config.clone())?.with_backend() {
Ok(qft_with_backend) => qft_with_backend,
Err(_) => SciRS2QFT::new(num_qubits, config)
.expect("QFT creation should succeed with same config"),
}
} else {
SciRS2QFT::new(num_qubits, config)?
};
let mut state = test_state.clone();
qft.apply_qft(&mut state)?;
let error = reference_result
.iter()
.zip(state.iter())
.map(|(ref_val, qft_val)| (ref_val - qft_val).norm())
.sum::<f64>()
/ reference_result.len() as f64;
errors.insert(name.to_string(), error);
}
Ok(errors)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_qft_config_default() {
let config = QFTConfig::default();
assert_eq!(config.method, QFTMethod::SciRS2Exact);
assert_eq!(config.approximation_level, 0);
assert!(config.bit_reversal);
assert!(config.parallel);
}
#[test]
fn test_scirs2_qft_creation() {
let config = QFTConfig::default();
let qft = SciRS2QFT::new(3, config).expect("should create SciRS2 QFT");
assert_eq!(qft.num_qubits, 3);
}
#[test]
fn test_test_state_creation() {
let state = QFTUtils::create_test_state(2, "basis").expect("should create test state");
assert_eq!(state.len(), 4);
assert_abs_diff_eq!(state[0].re, 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(state[1].norm(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_classical_qft() {
let config = QFTConfig {
method: QFTMethod::Classical,
..Default::default()
};
let mut qft = SciRS2QFT::new(2, config).expect("should create SciRS2 QFT");
let mut state = QFTUtils::create_test_state(2, "basis").expect("should create test state");
qft.apply_qft(&mut state).expect("should apply QFT");
let expected_amplitude = 0.5;
for amplitude in &state {
assert_abs_diff_eq!(amplitude.norm(), expected_amplitude, epsilon = 1e-10);
}
}
#[test]
fn test_qft_roundtrip() {
let config = QFTConfig {
method: QFTMethod::Classical,
bit_reversal: false, ..Default::default()
};
let mut qft = SciRS2QFT::new(3, config).expect("should create SciRS2 QFT");
let initial_state =
QFTUtils::create_test_state(3, "basis").expect("should create test state");
let mut state = initial_state;
qft.apply_qft(&mut state).expect("should apply QFT");
qft.apply_inverse_qft(&mut state)
.expect("should apply inverse QFT");
let has_nonzero = state.iter().any(|amp| amp.norm() > 1e-15);
assert!(
has_nonzero,
"State should have non-zero amplitudes after QFT operations"
);
}
#[test]
fn test_bit_reversal() {
let config = QFTConfig::default();
let qft = SciRS2QFT::new(3, config).expect("should create SciRS2 QFT");
assert_eq!(qft.bit_reverse(0b001, 3), 0b100);
assert_eq!(qft.bit_reverse(0b010, 3), 0b010);
assert_eq!(qft.bit_reverse(0b011, 3), 0b110);
}
#[test]
fn test_radix2_fft() {
let config = QFTConfig::default();
let qft = SciRS2QFT::new(2, config).expect("should create SciRS2 QFT");
let mut data = vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
];
qft.radix2_fft(&mut data, false)
.expect("should apply radix2 FFT");
for amplitude in &data {
assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_classical_dft() {
let signal = vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
];
let result = QFTUtils::classical_dft(&signal).expect("should compute classical DFT");
for amplitude in &result {
assert_abs_diff_eq!(amplitude.norm(), 1.0, epsilon = 1e-10);
}
}
}