#![allow(dead_code)]
use super::fft::{Cplx, Fft, add_fft, adj_fft, div_fft, mul_fft};
use super::fpr::Fpr;
use super::sampler::{SamplerRng, sampler_z};
use super::zint::{Zint, ext_gcd};
use alloc::vec::Vec;
const Q: i64 = 12289;
pub(crate) type RawNtruKey = (Vec<i64>, Vec<i64>, Vec<i64>, Vec<i64>, Vec<u16>);
fn pow_mod(mut b: i64, mut e: i64, q: i64) -> i64 {
b = b.rem_euclid(q);
let mut r = 1i64;
while e > 0 {
if e & 1 == 1 {
r = r * b % q;
}
b = b * b % q;
e >>= 1;
}
r
}
fn inv_mod(a: i64, q: i64) -> i64 {
pow_mod(a.rem_euclid(q), q - 2, q)
}
fn find_psi(n: usize) -> i64 {
let twon = (2 * n) as i64;
let exp = (Q - 1) / twon;
for base in 2..Q {
let psi = pow_mod(base, exp, Q);
if pow_mod(psi, n as i64, Q) == Q - 1 {
return psi;
}
}
unreachable!("q supports a 2n-th root of unity for Falcon degrees")
}
fn negacyclic_eval(coeffs: &[i64], psi_pows: &[i64], omega_pows: &[i64], n: usize) -> Vec<i64> {
let b: Vec<i64> = (0..n)
.map(|k| coeffs[k].rem_euclid(Q) * psi_pows[k] % Q)
.collect();
let mut out = alloc::vec![0i64; n];
for (i, oi) in out.iter_mut().enumerate() {
let mut acc = 0i64;
for k in 0..n {
acc += b[k] * omega_pows[(i * k) % n] % Q;
acc %= Q;
}
*oi = acc;
}
out
}
fn negacyclic_interp(
evals: &[i64],
psi_inv_pows: &[i64],
omega_inv_pows: &[i64],
n: usize,
) -> Vec<i64> {
let ninv = inv_mod(n as i64, Q);
let mut out = alloc::vec![0i64; n];
for (k, ok) in out.iter_mut().enumerate() {
let mut acc = 0i64;
for (i, &fi) in evals.iter().enumerate() {
acc += fi * omega_inv_pows[(i * k) % n] % Q;
acc %= Q;
}
let bk = acc * ninv % Q;
*ok = bk * psi_inv_pows[k] % Q;
}
out
}
pub(crate) fn recompute_g(f: &[i64], g: &[i64], cap_f: &[i64], n: usize) -> Vec<i64> {
let fft = Fft::new(n);
let to_fpr = |p: &[i64]| -> Vec<Fpr> { p.iter().map(|&c| Fpr::of_i64(c)).collect() };
let f_fft = fft.fft(&to_fpr(f));
let g_fft = fft.fft(&to_fpr(g));
let cf_fft = fft.fft(&to_fpr(cap_f));
let qf = Fpr::of_i64(Q);
let num: Vec<Cplx> = (0..n)
.map(|i| {
let gf = g_fft[i].mul(cf_fft[i]);
Cplx::new(gf.re.add(qf), gf.im)
})
.collect();
let g_cap_fft = div_fft(&num, &f_fft);
fft.ifft(&g_cap_fft).iter().map(|x| x.rint()).collect()
}
pub(crate) fn compute_h(f: &[i64], g: &[i64], n: usize) -> Option<Vec<u16>> {
let psi = find_psi(n);
let omega = psi * psi % Q;
let psi_inv = inv_mod(psi, Q);
let omega_inv = inv_mod(omega, Q);
let pow_table = |base: i64| -> Vec<i64> {
let mut v = alloc::vec![1i64; n];
for i in 1..n {
v[i] = v[i - 1] * base % Q;
}
v
};
let psi_pows = pow_table(psi);
let omega_pows = pow_table(omega);
let psi_inv_pows = pow_table(psi_inv);
let omega_inv_pows = pow_table(omega_inv);
let fe = negacyclic_eval(f, &psi_pows, &omega_pows, n);
if fe.contains(&0) {
return None; }
let ge = negacyclic_eval(g, &psi_pows, &omega_pows, n);
let he: Vec<i64> = (0..n).map(|i| ge[i] * inv_mod(fe[i], Q) % Q).collect();
let h = negacyclic_interp(&he, &psi_inv_pows, &omega_inv_pows, n);
Some(h.iter().map(|&x| x.rem_euclid(Q) as u16).collect())
}
fn karatsuba(a: &[Zint], b: &[Zint], n: usize) -> Vec<Zint> {
if n == 1 {
return alloc::vec![a[0].mul(&b[0]), Zint::zero()];
}
let n2 = n / 2;
let (a0, a1) = (&a[..n2], &a[n2..]);
let (b0, b1) = (&b[..n2], &b[n2..]);
let ax: Vec<Zint> = (0..n2).map(|i| a0[i].add(&a1[i])).collect();
let bx: Vec<Zint> = (0..n2).map(|i| b0[i].add(&b1[i])).collect();
let a0b0 = karatsuba(a0, b0, n2);
let a1b1 = karatsuba(a1, b1, n2);
let mut axbx = karatsuba(&ax, &bx, n2);
for i in 0..n {
axbx[i] = axbx[i].sub(&a0b0[i].add(&a1b1[i]));
}
let mut ab = alloc::vec![Zint::zero(); 2 * n];
for i in 0..n {
ab[i] = ab[i].add(&a0b0[i]);
ab[i + n] = ab[i + n].add(&a1b1[i]);
ab[i + n2] = ab[i + n2].add(&axbx[i]);
}
ab
}
fn karamul(a: &[Zint], b: &[Zint]) -> Vec<Zint> {
let n = a.len();
let ab = karatsuba(a, b, n);
(0..n).map(|i| ab[i].sub(&ab[i + n])).collect()
}
fn galois_conjugate(a: &[Zint]) -> Vec<Zint> {
a.iter()
.enumerate()
.map(|(i, c)| if i & 1 == 1 { c.neg() } else { c.clone() })
.collect()
}
fn field_norm(a: &[Zint]) -> Vec<Zint> {
let n2 = a.len() / 2;
let ae: Vec<Zint> = (0..n2).map(|i| a[2 * i].clone()).collect();
let ao: Vec<Zint> = (0..n2).map(|i| a[2 * i + 1].clone()).collect();
let ae_sq = karamul(&ae, &ae);
let ao_sq = karamul(&ao, &ao);
let mut res = ae_sq;
for i in 0..n2 - 1 {
res[i + 1] = res[i + 1].sub(&ao_sq[i]);
}
res[0] = res[0].add(&ao_sq[n2 - 1]);
res
}
fn lift(a: &[Zint]) -> Vec<Zint> {
let n = a.len();
let mut res = alloc::vec![Zint::zero(); 2 * n];
for i in 0..n {
res[2 * i] = a[i].clone();
}
res
}
fn max_bitsize(f: &[Zint], g: &[Zint]) -> usize {
let mut m = 53usize;
for c in f.iter().chain(g.iter()) {
m = m.max(c.bitsize());
}
m
}
fn reduce(f: &[Zint], g: &[Zint], cap_f: &mut [Zint], cap_g: &mut [Zint]) {
let n = f.len();
let fft = Fft::new(n);
let size = max_bitsize(f, g);
let adjust = |p: &[Zint], sz: usize| -> Vec<Fpr> {
p.iter()
.map(|c| Fpr::of_i64(c.shr(sz - 53).to_i64().unwrap_or(0)))
.collect()
};
let fa = fft.fft(&adjust(f, size));
let ga = fft.fft(&adjust(g, size));
let adj_fa = adj_fft(&fa);
let adj_ga = adj_fft(&ga);
let den = add_fft(&mul_fft(&fa, &adj_fa), &mul_fft(&ga, &adj_ga));
loop {
let big = max_bitsize(cap_f, cap_g);
if big < size {
break;
}
let cap_fa = fft.fft(&adjust(cap_f, big));
let cap_ga = fft.fft(&adjust(cap_g, big));
let num = add_fft(&mul_fft(&cap_fa, &adj_fa), &mul_fft(&cap_ga, &adj_ga));
let k_real = fft.ifft(&div_fft(&num, &den));
let k: Vec<Zint> = k_real.iter().map(|x| Zint::from_i64(x.rint())).collect();
if k.iter().all(|z| z.is_zero()) {
break;
}
let fk = karamul(f, &k);
let gk = karamul(g, &k);
let sh = big - size;
for i in 0..n {
cap_f[i] = cap_f[i].sub(&fk[i].shl(sh));
cap_g[i] = cap_g[i].sub(&gk[i].shl(sh));
}
}
}
fn ntru_solve(f: &[Zint], g: &[Zint]) -> Option<(Vec<Zint>, Vec<Zint>)> {
let n = f.len();
if n == 1 {
let (d, u, v) = ext_gcd(&f[0], &g[0]);
if d != Zint::from_i64(1) {
return None;
}
let q = Zint::from_i64(Q);
return Some((alloc::vec![q.mul(&v).neg()], alloc::vec![q.mul(&u)]));
}
let fp = field_norm(f);
let gp = field_norm(g);
let (cap_fp, cap_gp) = ntru_solve(&fp, &gp)?;
let mut cap_f = karamul(&lift(&cap_fp), &galois_conjugate(g));
let mut cap_g = karamul(&lift(&cap_gp), &galois_conjugate(f));
reduce(f, g, &mut cap_f, &mut cap_g);
Some((cap_f, cap_g))
}
const SIGMA_FG: Fpr = Fpr::from_f64(1.43300980528773);
fn gen_poly<R: SamplerRng>(n: usize, rng: &mut R) -> Vec<i64> {
let sigmin = Fpr::from_f64(1.43300980528773 - 0.001);
let zero = Fpr::from_f64(0.0);
let total = 4096;
let samples: Vec<i64> = (0..total)
.map(|_| sampler_z(zero, SIGMA_FG, sigmin, rng))
.collect();
let k = total / n;
(0..n)
.map(|i| (0..k).map(|j| samples[i * k + j]).sum())
.collect()
}
fn gs_norm(f: &[i64], g: &[i64], n: usize) -> Fpr {
let fft = Fft::new(n);
let ff: Vec<Fpr> = f.iter().map(|&c| Fpr::of_i64(c)).collect();
let gf: Vec<Fpr> = g.iter().map(|&c| Fpr::of_i64(c)).collect();
let sqnorm_fg = {
let mut s = Fpr::from_f64(0.0);
for &c in f.iter().chain(g.iter()) {
let cf = Fpr::of_i64(c);
s = s.add(cf.mul(cf));
}
s
};
let f_fft = fft.fft(&ff);
let g_fft = fft.fft(&gf);
let ffgg = add_fft(
&mul_fft(&f_fft, &adj_fft(&f_fft)),
&mul_fft(&g_fft, &adj_fft(&g_fft)),
);
let ft = fft.ifft(&div_fft(&adj_fft(&g_fft), &ffgg));
let gt = fft.ifft(&div_fft(&adj_fft(&f_fft), &ffgg));
let mut s_ftgt = Fpr::from_f64(0.0);
for c in ft.iter().chain(gt.iter()) {
s_ftgt = s_ftgt.add(c.mul(*c));
}
let qsq = Fpr::of_i64(Q * Q);
let sqnorm_cap = qsq.mul(s_ftgt);
if sqnorm_fg.lt(sqnorm_cap) {
sqnorm_cap
} else {
sqnorm_fg
}
}
pub(crate) fn ntru_gen<R: SamplerRng>(n: usize, rng: &mut R) -> RawNtruKey {
let bound = Fpr::of_i64(Q).mul(Fpr::from_f64(1.17 * 1.17));
loop {
let f = gen_poly(n, rng);
let g = gen_poly(n, rng);
if bound.lt(gs_norm(&f, &g, n)) {
continue;
}
let h = match compute_h(&f, &g, n) {
Some(h) => h,
None => continue, };
let fz: Vec<Zint> = f.iter().map(|&c| Zint::from_i64(c)).collect();
let gz: Vec<Zint> = g.iter().map(|&c| Zint::from_i64(c)).collect();
let (cap_f, cap_g) = match ntru_solve(&fz, &gz) {
Some(fg) => fg,
None => continue,
};
let cap_f_i: Option<Vec<i64>> = cap_f.iter().map(|z| z.to_i64()).collect();
let cap_g_i: Option<Vec<i64>> = cap_g.iter().map(|z| z.to_i64()).collect();
match (cap_f_i, cap_g_i) {
(Some(cf), Some(cg)) => return (f, g, cf, cg, h),
_ => continue,
}
}
}
#[cfg(test)]
#[path = "keygen_tests.rs"]
mod keygen_tests;