pub mod period_finding;
pub use period_finding::{continued_fraction_convergents, find_period_qft, gcd, modular_exp};
use crate::error::{FFTError, FFTResult};
#[derive(Debug, Clone)]
pub struct ShorConfig {
pub max_period_candidates: usize,
pub max_qft_qubits: usize,
pub max_base_attempts: usize,
}
impl Default for ShorConfig {
fn default() -> Self {
Self {
max_period_candidates: 16,
max_qft_qubits: 12,
max_base_attempts: 20,
}
}
}
#[derive(Debug, Clone)]
pub struct ShorResult {
pub factors: Option<(u64, u64)>,
pub period: Option<u64>,
pub iterations: usize,
}
#[derive(Debug, Clone)]
pub struct ShorSimulator {
pub config: ShorConfig,
}
impl ShorSimulator {
pub fn new(config: ShorConfig) -> Self {
Self { config }
}
pub fn default_new() -> Self {
Self::new(ShorConfig::default())
}
fn is_prime(n: u64) -> bool {
if n < 2 {
return false;
}
if n == 2 || n == 3 {
return true;
}
if n % 2 == 0 || n % 3 == 0 {
return false;
}
let mut i = 5u64;
while i * i <= n {
if n % i == 0 || n % (i + 2) == 0 {
return false;
}
i += 6;
}
true
}
fn is_perfect_power(n: u64) -> Option<(u64, u32)> {
if n < 4 {
return None;
}
let max_k = (n as f64).log2().floor() as u32;
for k in 2..=max_k {
let b = (n as f64).powf(1.0 / k as f64).round() as u64;
for candidate in b.saturating_sub(2)..=b + 2 {
if candidate >= 2 {
let pw = (candidate as u128).pow(k);
if pw == n as u128 {
return Some((candidate, k));
}
}
}
}
None
}
fn period_direct(a: u64, n: u64) -> Option<u64> {
let limit = n.min(1 << 20);
let mut x = a % n;
for r in 1..=limit {
if x == 1 {
return Some(r);
}
x = ((x as u128 * a as u128) % n as u128) as u64;
}
None
}
fn factors_from_period(a: u64, r: u64, n: u64) -> Option<(u64, u64)> {
if r % 2 != 0 {
return None; }
let half_r = r / 2;
let x = modular_exp(a, half_r, n);
if x == n - 1 {
return None; }
let f1 = gcd(x + 1, n);
let f2 = gcd(x.wrapping_sub(1).min(n - 1) + 1, n);
for &f in &[f1, f2] {
if f > 1 && f < n && n % f == 0 {
return Some((f, n / f));
}
}
None
}
pub fn find_period(&self, a: u64, n: u64) -> FFTResult<Option<u64>> {
if a <= 1 || a >= n {
return Err(FFTError::ValueError(format!(
"require 1 < a < n, got a={a} n={n}"
)));
}
if gcd(a, n) != 1 {
return Err(FFTError::ValueError(format!(
"a={a} and n={n} are not coprime"
)));
}
let qft_result = find_period_qft(
a,
n,
self.config.max_qft_qubits,
self.config.max_period_candidates,
)?;
if qft_result.is_some() {
return Ok(qft_result);
}
if n <= 1 << 20 {
return Ok(Self::period_direct(a, n));
}
Ok(None)
}
pub fn factor(&self, n: u64) -> FFTResult<ShorResult> {
if n <= 1 {
return Ok(ShorResult {
factors: None,
period: None,
iterations: 0,
});
}
if n % 2 == 0 {
return Ok(ShorResult {
factors: Some((2, n / 2)),
period: None,
iterations: 0,
});
}
if Self::is_prime(n) {
return Ok(ShorResult {
factors: None,
period: None,
iterations: 0,
});
}
if let Some((base, _exp)) = Self::is_perfect_power(n) {
if base > 1 && n % base == 0 {
return Ok(ShorResult {
factors: Some((base, n / base)),
period: None,
iterations: 0,
});
}
}
let mut iterations = 0;
let mut base_candidates: Vec<u64> = (2..n)
.filter(|&a| gcd(a, n) == 1)
.take(self.config.max_base_attempts)
.collect();
base_candidates.sort_unstable();
for a in base_candidates {
iterations += 1;
let g = gcd(a, n);
if g > 1 && g < n {
return Ok(ShorResult {
factors: Some((g, n / g)),
period: None,
iterations,
});
}
let period_opt = self.find_period(a, n)?;
if let Some(r) = period_opt {
if let Some(factors) = Self::factors_from_period(a, r, n) {
return Ok(ShorResult {
factors: Some(factors),
period: Some(r),
iterations,
});
}
}
}
Ok(ShorResult {
factors: None,
period: None,
iterations,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_prime() {
assert!(ShorSimulator::is_prime(2));
assert!(ShorSimulator::is_prime(3));
assert!(ShorSimulator::is_prime(17));
assert!(!ShorSimulator::is_prime(1));
assert!(!ShorSimulator::is_prime(15));
assert!(!ShorSimulator::is_prime(21));
}
#[test]
fn test_is_perfect_power() {
assert_eq!(ShorSimulator::is_perfect_power(8), Some((2, 3)));
assert_eq!(ShorSimulator::is_perfect_power(9), Some((3, 2)));
assert_eq!(ShorSimulator::is_perfect_power(25), Some((5, 2)));
assert!(ShorSimulator::is_perfect_power(15).is_none());
}
#[test]
fn test_factor_even_number() {
let sim = ShorSimulator::default_new();
let result = sim.factor(10).expect("factor 10");
assert!(result.factors.is_some());
let (a, b) = result.factors.unwrap();
assert_eq!(a * b, 10);
}
#[test]
fn test_factor_prime_returns_none() {
let sim = ShorSimulator::default_new();
let result = sim.factor(17).expect("factor 17");
assert!(result.factors.is_none(), "17 is prime");
let result2 = sim.factor(7).expect("factor 7");
assert!(result2.factors.is_none(), "7 is prime");
}
#[test]
fn test_factor_15() {
let sim = ShorSimulator::default_new();
let result = sim.factor(15).expect("factor 15");
assert!(result.factors.is_some(), "should find factors of 15");
let (a, b) = result.factors.unwrap();
assert_eq!(a * b, 15, "product must equal 15");
assert!(a > 1 && b > 1, "non-trivial factors");
}
#[test]
fn test_factor_21() {
let sim = ShorSimulator::default_new();
let result = sim.factor(21).expect("factor 21");
assert!(result.factors.is_some(), "should find factors of 21");
let (a, b) = result.factors.unwrap();
assert_eq!(a * b, 21);
assert!(a > 1 && b > 1);
}
#[test]
fn test_factor_35() {
let sim = ShorSimulator::default_new();
let result = sim.factor(35).expect("factor 35");
assert!(result.factors.is_some(), "should find factors of 35");
let (a, b) = result.factors.unwrap();
assert_eq!(a * b, 35);
}
#[test]
fn test_find_period_known() {
let sim = ShorSimulator::default_new();
let period = sim.find_period(2, 15).expect("find_period 2 mod 15");
assert!(period.is_some());
let r = period.unwrap();
assert_eq!(modular_exp(2, r, 15), 1, "2^r mod 15 must equal 1");
}
#[test]
fn test_find_period_bad_input() {
let sim = ShorSimulator::default_new();
assert!(sim.find_period(15, 15).is_err());
assert!(sim.find_period(1, 15).is_err());
}
#[test]
fn test_factor_1_and_2() {
let sim = ShorSimulator::default_new();
let r1 = sim.factor(1).expect("factor 1");
assert!(r1.factors.is_none());
let r2 = sim.factor(2).expect("factor 2");
assert!(r2.factors.is_some());
}
#[test]
fn test_factors_from_period_odd_period() {
assert!(ShorSimulator::factors_from_period(2, 3, 15).is_none());
}
#[test]
fn test_shor_config_default() {
let cfg = ShorConfig::default();
assert_eq!(cfg.max_period_candidates, 16);
assert_eq!(cfg.max_qft_qubits, 12);
assert_eq!(cfg.max_base_attempts, 20);
}
}