use std::f64::consts::PI;
use crate::error::{FFTError, FFTResult};
pub fn modular_exp(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
if modulus == 1 {
return 0;
}
let mut result: u64 = 1;
base %= modulus;
while exp > 0 {
if exp & 1 == 1 {
result = result
.checked_mul(base)
.map(|v| v % modulus)
.unwrap_or_else(|| {
((result as u128 * base as u128) % modulus as u128) as u64
});
}
exp >>= 1;
base = base
.checked_mul(base)
.map(|v| v % modulus)
.unwrap_or_else(|| ((base as u128 * base as u128) % modulus as u128) as u64);
}
result
}
pub fn gcd(mut a: u64, mut b: u64) -> u64 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
pub fn continued_fraction_convergents(
x: f64,
max_denominator: u64,
max_candidates: usize,
) -> Vec<(u64, u64)> {
let mut convergents = Vec::new();
if max_candidates == 0 || x.is_nan() {
return convergents;
}
let mut xi = x;
let mut hm2: i64 = 0; let mut hm1: i64 = 1; let mut km2: i64 = 1; let mut km1: i64 = 0;
for _ in 0..64 {
if xi.is_nan() || xi.is_infinite() {
break;
}
let a = xi.floor() as i64;
let hn = a * hm1 + hm2;
let kn = a * km1 + km2;
if kn <= 0 || kn as u64 > max_denominator {
break;
}
convergents.push((hn.unsigned_abs(), kn as u64));
if convergents.len() >= max_candidates {
break;
}
let frac = xi - a as f64;
if frac.abs() < 1e-12 {
break;
}
xi = 1.0 / frac;
hm2 = hm1;
hm1 = hn;
km2 = km1;
km1 = kn;
}
convergents
}
fn fft_inplace(data: &mut [(f64, f64)]) -> FFTResult<()> {
let n = data.len();
if n == 0 || n & (n - 1) != 0 {
return Err(FFTError::ValueError(format!(
"fft_inplace requires power-of-2 length, got {n}"
)));
}
let mut j = 0usize;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
data.swap(i, j);
}
}
let mut len = 2;
while len <= n {
let half = len / 2;
let angle = -2.0 * PI / len as f64;
let wlen = (angle.cos(), angle.sin());
for i in (0..n).step_by(len) {
let mut w = (1.0_f64, 0.0_f64);
for k in 0..half {
let u = data[i + k];
let v_re = data[i + k + half].0 * w.0 - data[i + k + half].1 * w.1;
let v_im = data[i + k + half].0 * w.1 + data[i + k + half].1 * w.0;
data[i + k] = (u.0 + v_re, u.1 + v_im);
data[i + k + half] = (u.0 - v_re, u.1 - v_im);
let new_w_re = w.0 * wlen.0 - w.1 * wlen.1;
let new_w_im = w.0 * wlen.1 + w.1 * wlen.0;
w = (new_w_re, new_w_im);
}
}
len <<= 1;
}
Ok(())
}
pub fn find_period_qft(
a: u64,
n: u64,
n_qubits: usize,
max_cands: usize,
) -> FFTResult<Option<u64>> {
if a <= 1 || a >= n {
return Err(FFTError::ValueError(format!(
"require 1 < a < n, got a={a}, n={n}"
)));
}
if n_qubits == 0 || n_qubits > 24 {
return Err(FFTError::ValueError(format!(
"n_qubits must be in 1..=24, got {n_qubits}"
)));
}
let m: usize = 1 << n_qubits;
let f_vals: Vec<u64> = (0..m as u64).map(|x| modular_exp(a, x, n)).collect();
let mut distinct: Vec<u64> = f_vals.clone();
distinct.sort_unstable();
distinct.dedup();
let mut prob: Vec<f64> = vec![0.0; m];
for &v in &distinct {
let mut buf: Vec<(f64, f64)> = f_vals
.iter()
.map(|&fv| if fv == v { (1.0, 0.0) } else { (0.0, 0.0) })
.collect();
fft_inplace(&mut buf)?;
for (s, amp) in buf.iter().enumerate() {
prob[s] += amp.0 * amp.0 + amp.1 * amp.1;
}
}
let total: f64 = prob.iter().sum();
if total < 1e-15 {
return Ok(None);
}
let inv_total = 1.0 / total;
for p in prob.iter_mut() {
*p *= inv_total;
}
let mut indexed: Vec<(f64, usize)> = prob
.iter()
.copied()
.enumerate()
.map(|(i, p)| (p, i))
.collect();
indexed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let max_denominator = n; for &(_prob, s) in indexed.iter().take(max_cands * 4) {
if s == 0 {
continue;
}
let phase = s as f64 / m as f64;
let convergents = continued_fraction_convergents(phase, max_denominator, max_cands);
for &(_p, q) in &convergents {
if q < 2 {
continue;
}
if modular_exp(a, q, n) == 1 {
return Ok(Some(q));
}
for k in 2..=4u64 {
let rk = q * k;
if rk < n && modular_exp(a, rk, n) == 1 {
return Ok(Some(rk));
}
}
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_modular_exp_basic() {
assert_eq!(modular_exp(2, 10, 1000), 1024 % 1000);
assert_eq!(modular_exp(7, 0, 13), 1);
assert_eq!(modular_exp(5, 5, 1), 0);
}
#[test]
fn test_modular_exp_fermat() {
let p = 17u64;
for a in 1..p {
assert_eq!(modular_exp(a, p - 1, p), 1, "a={a}");
}
}
#[test]
fn test_modular_exp_large() {
let result = modular_exp(3, 100, 1_000_000_007);
assert!(result < 1_000_000_007);
}
#[test]
fn test_gcd_basic() {
assert_eq!(gcd(12, 8), 4);
assert_eq!(gcd(7, 13), 1);
assert_eq!(gcd(0, 5), 5);
assert_eq!(gcd(100, 75), 25);
}
#[test]
fn test_continued_fraction_convergents_half() {
let convs = continued_fraction_convergents(0.5, 100, 10);
assert!(convs.iter().any(|&(p, q)| p == 1 && q == 2));
}
#[test]
fn test_continued_fraction_convergents_third() {
let convs = continued_fraction_convergents(1.0 / 3.0, 100, 10);
assert!(
convs.iter().any(|&(p, q)| p == 1 && q == 3),
"convergents: {convs:?}"
);
}
#[test]
fn test_continued_fraction_golden_ratio() {
let phi = (5.0_f64.sqrt() - 1.0) / 2.0;
let convs = continued_fraction_convergents(phi, 1000, 15);
assert!(
convs.iter().any(|&(p, q)| p == 3 && q == 5),
"convergents: {convs:?}"
);
assert!(
convs.iter().any(|&(p, q)| p == 5 && q == 8),
"convergents: {convs:?}"
);
}
#[test]
fn test_fft_inplace_length_check() {
let mut v = vec![(1.0, 0.0), (0.0, 0.0), (0.0, 0.0)]; assert!(fft_inplace(&mut v).is_err());
}
#[test]
fn test_fft_inplace_trivial() {
let n = 8;
let mut v: Vec<(f64, f64)> = vec![(1.0, 0.0); n];
fft_inplace(&mut v).expect("fft ok");
let (re0, im0) = v[0];
assert!((re0 - n as f64).abs() < 1e-10, "DC bin={re0}");
assert!(im0.abs() < 1e-10);
for i in 1..n {
let mag = (v[i].0 * v[i].0 + v[i].1 * v[i].1).sqrt();
assert!(mag < 1e-9, "bin {i} mag={mag}");
}
}
}