use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
use crate::error::{FFTError, FFTResult};
pub type Statevector = Vec<Complex64>;
#[derive(Debug, Clone)]
pub struct QftConfig {
pub n_qubits: usize,
pub apply_swap: bool,
pub approximate: bool,
pub approx_threshold: f64,
}
impl Default for QftConfig {
fn default() -> Self {
Self {
n_qubits: 3,
apply_swap: true,
approximate: false,
approx_threshold: 1e-10,
}
}
}
fn apply_hadamard(state: &mut Statevector, qubit: usize, n_qubits: usize) {
let n = state.len(); let step = 1_usize << (n_qubits - 1 - qubit);
let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
let mut i = 0;
while i < n {
let block_start = i;
for j in 0..step {
let idx0 = block_start + j;
let idx1 = block_start + j + step;
let a0 = state[idx0];
let a1 = state[idx1];
state[idx0] = Complex64::new(inv_sqrt2 * (a0.re + a1.re), inv_sqrt2 * (a0.im + a1.im));
state[idx1] = Complex64::new(inv_sqrt2 * (a0.re - a1.re), inv_sqrt2 * (a0.im - a1.im));
}
i += 2 * step;
}
}
fn apply_controlled_phase(
state: &mut Statevector,
control: usize,
target: usize,
k: usize,
n_qubits: usize,
) {
let n = state.len();
let theta = 2.0 * PI / (1_u64 << k) as f64;
let phase = Complex64::new(theta.cos(), theta.sin());
let control_bit = 1_usize << (n_qubits - 1 - control);
let target_bit = 1_usize << (n_qubits - 1 - target);
let both_mask = control_bit | target_bit;
for idx in 0..n {
if (idx & both_mask) == both_mask {
state[idx] *= phase;
}
}
}
fn apply_swap(state: &mut Statevector, q1: usize, q2: usize, n_qubits: usize) {
let n = state.len();
let bit1 = 1_usize << (n_qubits - 1 - q1);
let bit2 = 1_usize << (n_qubits - 1 - q2);
for idx in 0..n {
if (idx & bit1) != 0 && (idx & bit2) == 0 {
let other = (idx & !bit1) | bit2;
state.swap(idx, other);
}
}
}
pub fn qft(state: &Statevector, config: &QftConfig) -> FFTResult<Statevector> {
let n = config.n_qubits;
let expected_len = 1_usize << n;
if state.len() != expected_len {
return Err(FFTError::DimensionError(format!(
"Statevector length {} does not match 2^{} = {}",
state.len(),
n,
expected_len
)));
}
if n == 0 {
return Ok(state.clone());
}
let mut out = state.clone();
for t in 0..n {
apply_hadamard(&mut out, t, n);
for j in (t + 1)..n {
let k = j - t + 1; let theta = 2.0 * PI / (1_u64 << k) as f64;
if config.approximate && theta.abs() < config.approx_threshold {
continue;
}
apply_controlled_phase(&mut out, j, t, k, n);
}
}
if config.apply_swap {
for i in 0..(n / 2) {
apply_swap(&mut out, i, n - 1 - i, n);
}
}
Ok(out)
}
pub fn iqft(state: &Statevector, config: &QftConfig) -> FFTResult<Statevector> {
let n = config.n_qubits;
let expected_len = 1_usize << n;
if state.len() != expected_len {
return Err(FFTError::DimensionError(format!(
"Statevector length {} does not match 2^{} = {}",
state.len(),
n,
expected_len
)));
}
if n == 0 {
return Ok(state.clone());
}
let mut out = state.clone();
if config.apply_swap {
for i in 0..(n / 2) {
apply_swap(&mut out, i, n - 1 - i, n);
}
}
for t in (0..n).rev() {
for j in (t + 1..n).rev() {
let k = j - t + 1;
let theta = 2.0 * PI / (1_u64 << k) as f64;
if config.approximate && theta.abs() < config.approx_threshold {
continue;
}
apply_controlled_phase_neg(&mut out, j, t, k, n);
}
apply_hadamard(&mut out, t, n);
}
Ok(out)
}
fn apply_controlled_phase_neg(
state: &mut Statevector,
control: usize,
target: usize,
k: usize,
n_qubits: usize,
) {
let n = state.len();
let theta = -2.0 * PI / (1_u64 << k) as f64;
let phase = Complex64::new(theta.cos(), theta.sin());
let control_bit = 1_usize << (n_qubits - 1 - control);
let target_bit = 1_usize << (n_qubits - 1 - target);
let both_mask = control_bit | target_bit;
for idx in 0..n {
if (idx & both_mask) == both_mask {
state[idx] *= phase;
}
}
}
pub fn qft_matrix(n_qubits: usize) -> Array2<Complex64> {
let n = 1_usize << n_qubits;
let inv_sqrt_n = 1.0 / (n as f64).sqrt();
let mut mat = Array2::zeros((n, n));
for row in 0..n {
for col in 0..n {
let angle = 2.0 * PI * (row * col) as f64 / n as f64;
mat[[row, col]] = Complex64::new(inv_sqrt_n * angle.cos(), inv_sqrt_n * angle.sin());
}
}
mat
}
pub fn basis_state(j: usize, n_qubits: usize) -> Statevector {
let n = 1_usize << n_qubits;
let mut sv = vec![Complex64::new(0.0, 0.0); n];
if j < n {
sv[j] = Complex64::new(1.0, 0.0);
}
sv
}
pub fn measure_probs(state: &Statevector) -> Vec<f64> {
state.iter().map(|a| a.re * a.re + a.im * a.im).collect()
}
pub fn qft_dft(signal: &[f64]) -> FFTResult<Vec<Complex64>> {
let n = signal.len();
if n == 0 {
return Ok(vec![]);
}
if !n.is_power_of_two() {
return Err(FFTError::DimensionError(format!(
"Signal length {n} is not a power of two; QFT requires power-of-two sizes"
)));
}
let n_qubits = n.trailing_zeros() as usize;
let state: Statevector = signal.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let config = QftConfig {
n_qubits,
apply_swap: true,
approximate: false,
approx_threshold: 1e-10,
};
let qft_out = qft(&state, &config)?;
let sqrt_n = (n as f64).sqrt();
Ok(qft_out
.iter()
.map(|c| Complex64::new(c.re * sqrt_n, -c.im * sqrt_n))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::PI;
#[test]
fn test_qft_basis_state_0_uniform_superposition() {
for n in 1..=4 {
let cfg = QftConfig {
n_qubits: n,
..Default::default()
};
let state = basis_state(0, n);
let out = qft(&state, &cfg).expect("qft failed");
let probs = measure_probs(&out);
let expected = 1.0 / (1_usize << n) as f64;
for p in &probs {
assert!(
(*p - expected).abs() < 1e-12,
"n={n}: probability should be {expected} but got {p}"
);
}
}
}
#[test]
fn test_qft_matrix_equals_dft_matrix() {
let n = 3;
let mat = qft_matrix(n);
let dim = 1_usize << n;
for k in 0..dim {
for j in 0..dim {
let angle = 2.0 * PI * (j * k) as f64 / dim as f64;
let expected_re = angle.cos() / (dim as f64).sqrt();
let expected_im = angle.sin() / (dim as f64).sqrt();
assert_relative_eq!(mat[[k, j]].re, expected_re, epsilon = 1e-12);
assert_relative_eq!(mat[[k, j]].im, expected_im, epsilon = 1e-12);
}
}
}
#[test]
fn test_qft_iqft_roundtrip() {
let n = 3;
let cfg = QftConfig {
n_qubits: n,
..Default::default()
};
let dim = 1_usize << n;
let norm_factor = 1.0 / (dim as f64).sqrt();
let state: Statevector = (0..dim)
.map(|j| {
let angle = 2.0 * PI * j as f64 / dim as f64;
Complex64::new(norm_factor * angle.cos(), norm_factor * angle.sin())
})
.collect();
let transformed = qft(&state, &cfg).expect("qft failed");
let recovered = iqft(&transformed, &cfg).expect("iqft failed");
for (a, b) in state.iter().zip(recovered.iter()) {
assert!(
(a.re - b.re).abs() < 1e-10,
"real part mismatch in IQFT∘QFT roundtrip"
);
assert!(
(a.im - b.im).abs() < 1e-10,
"imag part mismatch in IQFT∘QFT roundtrip"
);
}
}
#[test]
fn test_qft_computational_basis_states() {
let n = 2;
let dim = 1_usize << n;
let cfg = QftConfig {
n_qubits: n,
..Default::default()
};
for j in 0..dim {
let state = basis_state(j, n);
let out = qft(&state, &cfg).expect("qft failed");
let inv_sqrt_n = 1.0 / (dim as f64).sqrt();
for k in 0..dim {
let angle = 2.0 * PI * (j * k) as f64 / dim as f64;
let expected = Complex64::new(inv_sqrt_n * angle.cos(), inv_sqrt_n * angle.sin());
assert!(
(out[k].re - expected.re).abs() < 1e-12,
"j={j} k={k}: real mismatch"
);
assert!(
(out[k].im - expected.im).abs() < 1e-12,
"j={j} k={k}: imag mismatch"
);
}
}
}
#[test]
fn test_qft_1qubit_is_hadamard() {
let cfg = QftConfig {
n_qubits: 1,
..Default::default()
};
let s0 = basis_state(0, 1);
let out0 = qft(&s0, &cfg).expect("qft failed");
assert_relative_eq!(out0[0].re, 1.0 / 2.0_f64.sqrt(), epsilon = 1e-12);
assert_relative_eq!(out0[1].re, 1.0 / 2.0_f64.sqrt(), epsilon = 1e-12);
let s1 = basis_state(1, 1);
let out1 = qft(&s1, &cfg).expect("qft failed");
assert_relative_eq!(out1[0].re, 1.0 / 2.0_f64.sqrt(), epsilon = 1e-12);
assert_relative_eq!(out1[1].re, -1.0 / 2.0_f64.sqrt(), epsilon = 1e-12);
}
#[test]
fn test_qft_vs_classical_dft() {
let n = 8;
let signal: Vec<f64> = (0..n)
.map(|i| {
(2.0 * PI * i as f64 / n as f64).cos()
+ 0.5 * (4.0 * PI * i as f64 / n as f64).sin()
})
.collect();
let qft_result = qft_dft(&signal).expect("qft_dft failed");
let mut dft_ref = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
for j in 0..n {
let angle = -2.0 * PI * (j * k) as f64 / n as f64;
dft_ref[k] += Complex64::new(signal[j] * angle.cos(), signal[j] * angle.sin());
}
}
for k in 0..n {
assert!(
(qft_result[k].re - dft_ref[k].re).abs() < 1e-8,
"k={k}: real mismatch QFT vs DFT"
);
assert!(
(qft_result[k].im - dft_ref[k].im).abs() < 1e-8,
"k={k}: imag mismatch QFT vs DFT"
);
}
}
#[test]
fn test_qft_approximate_close_to_exact() {
let n = 4;
let exact_cfg = QftConfig {
n_qubits: n,
..Default::default()
};
let approx_cfg = QftConfig {
n_qubits: n,
approximate: true,
approx_threshold: 1e-15, ..Default::default()
};
let state = basis_state(5, n);
let exact = qft(&state, &exact_cfg).expect("exact qft failed");
let approx = qft(&state, &approx_cfg).expect("approx qft failed");
for (a, b) in exact.iter().zip(approx.iter()) {
assert_relative_eq!(a.re, b.re, epsilon = 1e-12);
assert_relative_eq!(a.im, b.im, epsilon = 1e-12);
}
}
#[test]
fn test_qft_error_wrong_state_size() {
let cfg = QftConfig {
n_qubits: 3,
..Default::default()
};
let bad_state = vec![Complex64::new(1.0, 0.0); 5]; let result = qft(&bad_state, &cfg);
assert!(result.is_err());
}
}