Skip to main content

falcon_rust/
math.rs

1use std::vec::IntoIter;
2
3use falcon_profiler::profiling;
4use itertools::Itertools;
5use num::{BigInt, FromPrimitive, One, Zero};
6use num_complex::Complex64;
7use rand::Rng;
8
9use crate::{
10    cyclotomic_fourier::CyclotomicFourier,
11    falcon_field::{Felt, Q},
12    fast_fft::FastFft,
13    fixed_point::FixedPoint64,
14    inverse::Inverse,
15    polynomial::Polynomial,
16    samplerz::sampler_z,
17    u32_field::U32Field,
18};
19
20/// Reduce the vector (F,G) relative to (f,g). This method follows the python
21/// implementation [1].
22///
23/// Algorithm 7 in the spec [2, p.35]
24///
25/// [1]: https://github.com/tprest/falcon.py
26///
27/// [2]: https://falcon-sign.info/falcon.pdf
28///
29/// This function is marked pub for the purpose of benchmarking; it is not
30/// considered part of the public API.
31#[doc(hidden)]
32#[profiling]
33pub fn babai_reduce_bigint(
34    f: &Polynomial<BigInt>,
35    g: &Polynomial<BigInt>,
36    capital_f: &mut Polynomial<BigInt>,
37    capital_g: &mut Polynomial<BigInt>,
38) -> Result<(), String> {
39    let bitsize = |bi: &BigInt| bi.bits();
40    let n = f.coefficients.len();
41    let size = [
42        f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
43        g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
44        53,
45    ]
46    .into_iter()
47    .max()
48    .unwrap();
49    let shift = (size as i64) - 53;
50    let f_adjusted = f
51        .map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
52        .fft();
53    let g_adjusted = g
54        .map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
55        .fft();
56
57    let f_star_adjusted = f_adjusted.map(|c| c.conj());
58    let g_star_adjusted = g_adjusted.map(|c| c.conj());
59    let denominator_fft =
60        f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
61
62    let mut prev_capital_size = u64::MAX;
63    loop {
64        let capital_size = [
65            capital_f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
66            capital_g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
67            53,
68        ]
69        .into_iter()
70        .max()
71        .unwrap();
72
73        // Stop when we've reached the target size, or when capital_size stopped
74        // strictly decreasing (floating-point precision limit reached).
75        if capital_size < size {
76            break;
77        }
78        if capital_size >= prev_capital_size {
79            break;
80        }
81        prev_capital_size = capital_size;
82
83        // When D = capital_size - size > 53, scaling both capital_F and f to
84        // ~2^53 makes the FFT quotient ≈ 1, capturing only ~1 bit of k_true per
85        // iteration.  Instead, scale capital_F to ~2^106 (shift 53 less) so the
86        // quotient ≈ 2^53, extracting 53 bits of k_true per iteration.  The
87        // back-shift on kf decreases by 53 to compensate.
88        let d = capital_size - size;
89        let (capital_shift, back_shift) = if d > 53 {
90            ((capital_size as i64) - 106, d - 53)
91        } else {
92            ((capital_size as i64) - 53, d)
93        };
94
95        let capital_f_adjusted = capital_f
96            .map(|bi| Complex64::new(i128::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
97            .fft();
98        let capital_g_adjusted = capital_g
99            .map(|bi| Complex64::new(i128::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
100            .fft();
101
102        let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
103            + capital_g_adjusted.hadamard_mul(&g_star_adjusted);
104        let quotient = numerator.hadamard_div(&denominator_fft).ifft();
105
106        // Use i128 to avoid i64 saturation when the FFT quotient is large
107        // (can happen for small n when |f_fft[i]|^2 + |g_fft[i]|^2 is small at
108        // some frequency).
109        let k = quotient.map(|f| BigInt::from(f.re.round() as i128));
110
111        if k.is_zero() {
112            break;
113        }
114        let kf = (k.clone().karatsuba(f)).reduce_by_cyclotomic(n);
115        let kg = (k.clone().karatsuba(g)).reduce_by_cyclotomic(n);
116        let shifted_kf = kf.map(|bi| bi << back_shift);
117        let shifted_kg = kg.map(|bi| bi << back_shift);
118
119        // Tentative check: if applying shifted_kf would make capital_F grow,
120        // the step overshot (common when |f_fft|^2+|g_fft|^2 is very small at
121        // one frequency for small n).  In that case fall back to the old
122        // single-bit formula for this iteration.
123        if d > 53 {
124            let new_cs_f = capital_f
125                .coefficients
126                .iter()
127                .zip(shifted_kf.coefficients.iter())
128                .map(|(a, b)| (a - b).bits())
129                .max()
130                .unwrap_or(0);
131            let new_cs_g = capital_g
132                .coefficients
133                .iter()
134                .zip(shifted_kg.coefficients.iter())
135                .map(|(a, b)| (a - b).bits())
136                .max()
137                .unwrap_or(0);
138            if u64::max(new_cs_f, new_cs_g) >= capital_size {
139                // Recompute with old formula (capital_shift = capital_size - 53)
140                let cs_old = (capital_size as i64) - 53;
141                let cf_old = capital_f
142                    .map(|bi| Complex64::new(i64::try_from(bi >> cs_old).unwrap() as f64, 0.0))
143                    .fft();
144                let cg_old = capital_g
145                    .map(|bi| Complex64::new(i64::try_from(bi >> cs_old).unwrap() as f64, 0.0))
146                    .fft();
147                let num_old = cf_old.hadamard_mul(&f_star_adjusted)
148                    + cg_old.hadamard_mul(&g_star_adjusted);
149                let quot_old = num_old.hadamard_div(&denominator_fft).ifft();
150                let k_old = quot_old.map(|f| BigInt::from(f.re.round() as i64));
151                if k_old.is_zero() {
152                    break;
153                }
154                let kf_old = (k_old.clone().karatsuba(f)).reduce_by_cyclotomic(n);
155                let kg_old = (k_old.karatsuba(g)).reduce_by_cyclotomic(n);
156                *capital_f -= kf_old.map(|bi| bi << d);
157                *capital_g -= kg_old.map(|bi| bi << d);
158                continue;
159            }
160        }
161
162        *capital_f -= shifted_kf;
163        *capital_g -= shifted_kg;
164    }
165    Ok(())
166}
167
168/// Reduce the vector (F,G) relative to (f,g). This method follows the python
169/// implementation [1] but uses multimodular arithmetic for fast operations on
170/// big integer polynomials in a cyclotomic ring.
171///
172/// Algorithm 7 in the spec [2, p.35]
173///
174/// [1]: https://github.com/tprest/falcon.py
175///
176/// [2]: https://falcon-sign.info/falcon.pdf
177///
178///
179/// This function is marked pub for the purpose of benchmarking; it is not
180/// considered part of the public API.
181#[doc(hidden)]
182#[profiling]
183pub fn babai_reduce_i32(
184    f: &Polynomial<i32>,
185    g: &Polynomial<i32>,
186    capital_f: &mut Polynomial<i32>,
187    capital_g: &mut Polynomial<i32>,
188) -> Result<(), String> {
189    let f_ntt: Polynomial<U32Field> = f.map(|&i| U32Field::new(i)).fft();
190    let g_ntt: Polynomial<U32Field> = g.map(|&i| U32Field::new(i)).fft();
191
192    let bitsize = |itr: IntoIter<i32>| {
193        (itr.map(|i| i.abs()).max().unwrap() * 2)
194            .ilog2()
195            .next_multiple_of(8) as usize
196    };
197    let size = usize::max(
198        bitsize(
199            f.coefficients
200                .iter()
201                .chain(g.coefficients.iter())
202                .cloned()
203                .collect_vec()
204                .into_iter(),
205        ),
206        53,
207    );
208
209    let shift = (size as i64) - 53;
210    let f_adjusted = f
211        .map(|i| Complex64::new(i64::from(*i >> shift) as f64, 0.0))
212        .fft();
213    let g_adjusted = g
214        .map(|i| Complex64::new(i64::from(*i >> shift) as f64, 0.0))
215        .fft();
216
217    let f_star_adjusted = f_adjusted.map(|c| c.conj());
218    let g_star_adjusted = g_adjusted.map(|c| c.conj());
219    let denominator_fft =
220        f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
221
222    let mut prev_capital_size = usize::MAX;
223    loop {
224        let capital_size = [
225            bitsize(
226                capital_f
227                    .coefficients
228                    .iter()
229                    .chain(capital_g.coefficients.iter())
230                    .copied()
231                    .collect_vec()
232                    .into_iter(),
233            ),
234            53,
235        ]
236        .into_iter()
237        .max()
238        .unwrap();
239
240        if capital_size < size || capital_size >= prev_capital_size {
241            break;
242        }
243        prev_capital_size = capital_size;
244
245        let capital_shift = (capital_size as i64) - 53;
246        let capital_f_adjusted = capital_f
247            .map(|bi| Complex64::new(i64::from(*bi >> capital_shift) as f64, 0.0))
248            .fft();
249        let capital_g_adjusted = capital_g
250            .map(|bi| Complex64::new(i64::from(*bi >> capital_shift) as f64, 0.0))
251            .fft();
252
253        let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
254            + capital_g_adjusted.hadamard_mul(&g_star_adjusted);
255        let quotient = numerator.hadamard_div(&denominator_fft).ifft();
256
257        let k_ntt = quotient.map(|f| U32Field::new(f.re.round() as i32)).fft();
258
259        if k_ntt.is_zero() {
260            break;
261        }
262
263        let kf_ntt = k_ntt.hadamard_mul(&f_ntt).ifft();
264        let kg_ntt = k_ntt.hadamard_mul(&g_ntt).ifft();
265
266        let kf = kf_ntt.map(|p| p.balanced_value());
267        let kg = kg_ntt.map(|p| p.balanced_value());
268
269        *capital_f -= kf;
270        *capital_g -= kg;
271    }
272    Ok(())
273}
274
275/// Extended Euclidean algorithm for computing the greatest common divisor (g) and
276/// Bézout coefficients (u, v) for the relation
277///
278///  u a + v b = g .
279///
280/// Implementation adapted from Wikipedia [1].
281///
282/// [1]: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
283#[profiling]
284fn xgcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
285    let (mut old_r, mut r) = (a.clone(), b.clone());
286    let (mut old_s, mut s) = (BigInt::one(), BigInt::zero());
287    let (mut old_t, mut t) = (BigInt::zero(), BigInt::one());
288
289    while r != BigInt::zero() {
290        let quotient = old_r.clone() / r.clone();
291        (old_r, r) = (r.clone(), old_r.clone() - quotient.clone() * r);
292        (old_s, s) = (s.clone(), old_s.clone() - quotient.clone() * s);
293        (old_t, t) = (t.clone(), old_t.clone() - quotient * t);
294    }
295
296    (old_r, old_s, old_t)
297}
298
299/// Solve the NTRU equation. Given f, g in ZZ[X], find F, G in ZZ[X].
300/// such that
301///
302///    f G - g F = q  mod (X^n + 1)
303///
304/// Algorithm 6 of the specification [1, p.35].
305///
306/// [1]: https://falcon-sign.info/falcon.pdf
307#[profiling]
308fn ntru_solve(
309    f: &Polynomial<BigInt>,
310    g: &Polynomial<BigInt>,
311) -> Option<(Polynomial<BigInt>, Polynomial<BigInt>)> {
312    let n = f.coefficients.len();
313    if n == 1 {
314        let (gcd, u, v) = xgcd(&f.coefficients[0], &g.coefficients[0]);
315        if gcd != BigInt::one() {
316            return None;
317        }
318        return Some((
319            (Polynomial::new(vec![-v * BigInt::from_u32(Q as u32).unwrap()])),
320            Polynomial::new(vec![u * BigInt::from_u32(Q as u32).unwrap()]),
321        ));
322    }
323
324    let f_prime = f.field_norm();
325    let g_prime = g.field_norm();
326    let (capital_f_prime, capital_g_prime) = ntru_solve(&f_prime, &g_prime)?;
327
328    let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
329    let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
330    let f_minx = f.galois_adjoint();
331    let g_minx = g.galois_adjoint();
332
333    let mut capital_f = (capital_f_prime_xsq.karatsuba(&g_minx)).reduce_by_cyclotomic(n);
334    let mut capital_g = (capital_g_prime_xsq.karatsuba(&f_minx)).reduce_by_cyclotomic(n);
335
336    match babai_reduce_bigint(f, g, &mut capital_f, &mut capital_g) {
337        Ok(_) => Some((capital_f, capital_g)),
338        Err(_e) => {
339            #[cfg(test)]
340            {
341                panic!("{}", _e);
342            }
343            #[cfg(not(test))]
344            {
345                None
346            }
347        }
348    }
349}
350
351#[profiling]
352fn ntru_solve_entrypoint(
353    f: Polynomial<i32>,
354    g: Polynomial<i32>,
355) -> Option<(Polynomial<i32>, Polynomial<i32>)> {
356    let n = f.coefficients.len();
357
358    let g_prime = g.field_norm().map(|c| BigInt::from(*c));
359    let f_prime = f.field_norm().map(|c| BigInt::from(*c));
360    let (capital_f_prime_bi, capital_g_prime_bi) = ntru_solve(&f_prime, &g_prime)?;
361
362    let capital_f_prime_coefficients = capital_f_prime_bi
363        .coefficients
364        .into_iter()
365        .map(i32::try_from)
366        .collect_vec();
367    let capital_g_prime_coefficients = capital_g_prime_bi
368        .coefficients
369        .into_iter()
370        .map(i32::try_from)
371        .collect_vec();
372
373    if !capital_f_prime_coefficients
374        .iter()
375        .chain(capital_g_prime_coefficients.iter())
376        .all(|c| c.is_ok())
377    {
378        return None;
379    }
380    let capital_f_prime = Polynomial::new(
381        capital_f_prime_coefficients
382            .into_iter()
383            .map(|c| c.unwrap())
384            .collect_vec(),
385    );
386    let capital_g_prime = Polynomial::new(
387        capital_g_prime_coefficients
388            .into_iter()
389            .map(|c| c.unwrap())
390            .collect_vec(),
391    );
392
393    let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
394    let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
395    let f_minx = f.galois_adjoint();
396    let g_minx = g.galois_adjoint();
397
398    let psi_rev = U32Field::bitreversed_powers(n);
399    let psi_rev_inv = U32Field::bitreversed_powers_inverse(n);
400    let ninv = U32Field::new(n as i32).inverse_or_zero();
401    let mut cfp_ntt = capital_f_prime_xsq.map(|c| U32Field::new(*c));
402    let mut cgp_ntt = capital_g_prime_xsq.map(|c| U32Field::new(*c));
403    let mut gm_ntt = g_minx.map(|c| U32Field::new(*c));
404    let mut fm_ntt = f_minx.map(|c| U32Field::new(*c));
405    U32Field::fft(&mut cfp_ntt.coefficients, &psi_rev);
406    U32Field::fft(&mut cgp_ntt.coefficients, &psi_rev);
407    U32Field::fft(&mut gm_ntt.coefficients, &psi_rev);
408    U32Field::fft(&mut fm_ntt.coefficients, &psi_rev);
409    let mut cf_ntt = cfp_ntt.hadamard_mul(&gm_ntt);
410    let mut cg_ntt = cgp_ntt.hadamard_mul(&fm_ntt);
411    U32Field::ifft(&mut cf_ntt.coefficients, &psi_rev_inv, ninv);
412    U32Field::ifft(&mut cg_ntt.coefficients, &psi_rev_inv, ninv);
413
414    let mut capital_f = cf_ntt.map(|c| c.balanced_value());
415    let mut capital_g = cg_ntt.map(|c| c.balanced_value());
416
417    match babai_reduce_i32(&f, &g, &mut capital_f, &mut capital_g) {
418        Ok(_) => Some((capital_f, capital_g)),
419        Err(_e) => {
420            #[cfg(test)]
421            {
422                panic!("{}", _e);
423            }
424            #[cfg(not(test))]
425            {
426                None
427            }
428        }
429    }
430}
431
432/// Sample 4 small polynomials f, g, F, G such that f * G - g * F = q mod (X^n + 1).
433/// Algorithm 5 (NTRUgen) of the documentation [1, p.34].
434///
435/// This function is marked pub for benchmarking purposes only. Not considered part
436/// of the public API.
437///
438/// [1]: https://falcon-sign.info/falcon.pdf
439#[doc(hidden)]
440#[profiling]
441pub fn ntru_gen(
442    n: usize,
443    rng: &mut dyn Rng,
444) -> (
445    Polynomial<i16>,
446    Polynomial<i16>,
447    Polynomial<i16>,
448    Polynomial<i16>,
449) {
450    // let mut rng: StdRng = SeedableRng::from_seed(seed);
451
452    loop {
453        let f = gen_poly(n, rng);
454        let g = gen_poly(n, rng);
455
456        let f_ntt = f.map(|&i| Felt::from(i)).fft();
457        if f_ntt.coefficients.iter().any(|e| e.is_zero()) {
458            continue;
459        }
460        let gamma = gram_schmidt_norm_squared(&f, &g);
461        if gamma > 1.3689f64 * (Q as f64) {
462            continue;
463        }
464
465        if let Some((capital_f, capital_g)) =
466            ntru_solve_entrypoint(f.map(|&i| i as i32), g.map(|&i| i as i32))
467        {
468            return (
469                f,
470                g,
471                capital_f.map(|&i| i as i16),
472                capital_g.map(|&i| i as i16),
473            );
474        }
475    }
476}
477
478/// Generate a polynomial of degree at most n-1 whose coefficients are
479/// distributed according to a discrete Gaussian with mu = 0 and
480/// sigma = 1.17 * sqrt(Q / (2n)).
481// fn gen_poly(n: usize, rng: &mut dyn Rng) -> Polynomial<i16> {
482#[profiling]
483fn gen_poly(n: usize, rng: &mut dyn Rng) -> Polynomial<i16> {
484    let mu = FixedPoint64::ZERO;
485    let sigma_star = FixedPoint64::from(1.43300980528773f64);
486    const NUM_COEFFICIENTS: usize = 4096;
487    Polynomial {
488        coefficients: (0..NUM_COEFFICIENTS)
489            .map(|_| sampler_z(mu, sigma_star, sigma_star - FixedPoint64::from(0.001f64), rng))
490            .collect_vec()
491            .chunks(NUM_COEFFICIENTS / n)
492            .map(|ch| ch.iter().sum())
493            .collect_vec(),
494    }
495}
496
497/// Compute the Gram-Schmidt norm of B = [[g, -f], [G, -F]] from f and g.
498/// Corresponds to line 9 in algorithm 5 of the spec [1, p.34]
499///
500/// [1]: https://falcon-sign.info/falcon.pdf
501#[profiling]
502fn gram_schmidt_norm_squared(f: &Polynomial<i16>, g: &Polynomial<i16>) -> f64 {
503    let n = f.coefficients.len();
504    let gamma1 = f64::from(f.l2_norm_squared() + g.l2_norm_squared());
505
506    let fp = |i: &i16| Complex64::new(*i as f64, 0.0);
507    let q_fp = Q as f64;
508    let n_fp = n as f64;
509
510    let f_fft = f.map(fp).fft();
511    let g_fft = g.map(fp).fft();
512    let f_adj_fft = f_fft.map(|c| c.conj());
513    let g_adj_fft = g_fft.map(|c| c.conj());
514    let ffgg_fft = f_fft.hadamard_mul(&f_adj_fft) + g_fft.hadamard_mul(&g_adj_fft);
515    let ffgg_fft_inverse = ffgg_fft.hadamard_inv();
516    let qf_over_ffgg_fft = f_adj_fft
517        .map(|c| c * q_fp)
518        .hadamard_mul(&ffgg_fft_inverse);
519    let qg_over_ffgg_fft = g_adj_fft
520        .map(|c| c * q_fp)
521        .hadamard_mul(&ffgg_fft_inverse);
522    let norm_f_over_ffgg_squared = qf_over_ffgg_fft
523        .coefficients
524        .iter()
525        .map(|c| (c * c.conj()).re)
526        .sum::<f64>()
527        / n_fp;
528    let norm_g_over_ffgg_squared = qg_over_ffgg_fft
529        .coefficients
530        .iter()
531        .map(|c| (c * c.conj()).re)
532        .sum::<f64>()
533        / n_fp;
534
535    let gamma2 = norm_f_over_ffgg_squared + norm_g_over_ffgg_squared;
536
537    f64::max(gamma1, gamma2)
538}
539
540#[cfg(test)]
541mod test {
542
543    use std::str::FromStr;
544
545    use itertools::Itertools;
546    use num::{BigInt, FromPrimitive};
547    use proptest::collection::vec;
548    use proptest::prop_assert_eq;
549    use proptest::strategy::Just;
550    use rand::{rngs::StdRng, SeedableRng};
551    use test_strategy::proptest as strategy_proptest;
552
553    use crate::{
554        math::{babai_reduce_i32, gram_schmidt_norm_squared, ntru_gen, ntru_solve},
555        polynomial::Polynomial,
556    };
557
558    use super::babai_reduce_bigint;
559    fn babai_infinite_loop_polynomials() -> (
560        Polynomial<BigInt>,
561        Polynomial<BigInt>,
562        Polynomial<BigInt>,
563        Polynomial<BigInt>,
564    ) {
565        let f = Polynomial::new(
566            [
567                BigInt::from_str("6426042728002").unwrap(),
568                BigInt::from_str("-20675284604736").unwrap(),
569                BigInt::from_str("-12121913318466").unwrap(),
570                BigInt::from_str("-27836101162563").unwrap(),
571            ]
572            .to_vec(),
573        );
574
575        let g = Polynomial::new(
576            [
577                BigInt::from_str("-1001246212").unwrap(),
578                BigInt::from_str("-1347303037").unwrap(),
579                BigInt::from_str("987026048").unwrap(),
580                BigInt::from_str("-1001311747").unwrap(),
581            ]
582            .to_vec(),
583        );
584
585        let capital_f = Polynomial::new(
586            [
587                BigInt::from_str(
588                    "563985131491945032326798334533872091781886676547754689048287010878681928",
589                )
590                .unwrap(),
591                BigInt::from_str(
592                    "-348444005402208553421931883447687919671423051554816023996113866522386058",
593                )
594                .unwrap(),
595                BigInt::from_str(
596                    "-85657170778585026649528684432821341936755757853602491207147473952485632",
597                )
598                .unwrap(),
599                BigInt::from_str(
600                    "135623655239747178410899900677875843487151183900794566193191499131611018",
601                )
602                .unwrap(),
603            ]
604            .to_vec(),
605        );
606
607        let capital_g = Polynomial::new(
608            [
609                BigInt::from_str(
610                    "49040356584788663746447138446729467702643846166576265941049418069366",
611                )
612                .unwrap(),
613                BigInt::from_str(
614                    "-57075549200927059197269512430308877512934179841045274854176350745681",
615                )
616                .unwrap(),
617                BigInt::from_str(
618                    "18442173959410247991253446345066800376513088376845717824090327663990",
619                )
620                .unwrap(),
621                BigInt::from_str(
622                    "19528334302175388221061434098432127604592213845277598673231565264960",
623                )
624                .unwrap(),
625            ]
626            .to_vec(),
627        );
628
629        (f, g, capital_f, capital_g)
630    }
631
632    #[test]
633    fn babai_oscillation_terminates() {
634        let (f, g, mut capital_f, mut capital_g) = babai_infinite_loop_polynomials();
635        assert!(babai_reduce_bigint(&f, &g, &mut capital_f, &mut capital_g).is_ok())
636    }
637
638    // #[test]
639    // fn test_gen_poly() {
640    //     let mut rng = rng();
641    //     let n = 1024;
642    //     let mut sum_norms = 0.0;
643    //     let num_iterations = 100;
644    //     for _ in 0..num_iterations {
645    //         let f = gen_poly(n, &mut rng);
646    //         sum_norms += f.l2_norm();
647    //     }
648    //     let average = sum_norms / (num_iterations as f64);
649    //     assert!(90.0 < average);
650    //     assert!(average < 94.0);
651    // }
652
653    #[test]
654    fn test_gs_norm() {
655        let n = 512;
656        let f = (0..n).map(|i| i % 5).collect_vec();
657        let g = (0..n).map(|i| (i % 7) - 4).collect_vec();
658        let norm_squared = gram_schmidt_norm_squared(&Polynomial::new(f), &Polynomial::new(g));
659        let expected = 5992556.183229722f64;
660        let difference = (norm_squared - expected).abs();
661        assert!(
662            difference < 1.0,
663            "norm squared was {norm_squared} =/= {expected} (expected)",
664        );
665    }
666
667    #[test]
668    fn test_ntru_solve() {
669        let n = 64;
670        let f_coefficients = (0..n).map(|i| ((i % 7) as i32) - 4).collect_vec();
671        let f = Polynomial::new(f_coefficients).map(|&i| i.into());
672        let g_coefficients = (0..n).map(|i| ((i % 5) as i32) - 3).collect_vec();
673        let g = Polynomial::new(g_coefficients).map(|&i| i.into());
674        let (capital_f, capital_g) = ntru_solve(&f, &g).unwrap();
675
676        let expected_capital_f: [i16; 64] = [
677            -221, -19, 133, 81, -488, -112, 189, -75, -112, -223, 143, 241, -249, 33, 47, -16, 32,
678            -145, 183, -57, -99, 104, -44, 78, -129, 26, 77, -88, 52, -36, 69, -66, -37, 80, -45,
679            32, -67, 93, -24, -79, 87, -49, 68, -116, 60, 108, -158, 68, -52, 87, -32, -116, 233,
680            -120, -111, 65, 119, 144, -307, -98, 295, -163, -194, -325,
681        ];
682        let expected_capital_g: [i16; 64] = [
683            -861, 625, -531, 151, 80, 11, 132, 547, -308, 4, 184, -134, -74, -61, 215, -2, -188,
684            40, 104, -38, -59, 21, 51, -12, -101, 86, 12, -40, 0, -31, 86, -72, 7, 24, -32, 46,
685            -71, 53, 0, -21, 23, -49, 60, -16, -38, 30, 18, 3, -41, -42, 114, 2, -119, 80, -64, 95,
686            -37, -18, 238, -429, 87, 193, -3, -111,
687        ];
688        assert_eq!(
689            expected_capital_f
690                .map(|i| BigInt::from_i16(i).unwrap())
691                .to_vec(),
692            capital_f.coefficients
693        );
694        assert_eq!(
695            expected_capital_g
696                .map(|i| BigInt::from_i16(i).unwrap())
697                .to_vec(),
698            capital_g.coefficients
699        );
700
701        let ntru = (f * capital_g - g * capital_f).reduce_by_cyclotomic(n);
702        assert_eq!(Polynomial::constant(12289.into()), ntru);
703    }
704
705    #[strategy_proptest]
706    fn bigint_and_smallint_babai_reduce_agree(
707        #[strategy(1usize..5)] _logn: usize,
708        #[strategy(Just(1<<#_logn))] _n: usize,
709        #[strategy(vec(-5..5, #_n))] f_coefficients: Vec<i32>,
710        #[strategy(vec(-5..5, #_n))] g_coefficients: Vec<i32>,
711        #[strategy(vec(-115..115, #_n))] capital_f_coefficients: Vec<i32>,
712        #[strategy(vec(-115..115, #_n))] capital_g_coefficients: Vec<i32>,
713    ) {
714        let f_i32 = Polynomial::new(f_coefficients);
715        let g_i32 = Polynomial::new(g_coefficients);
716        let mut capital_f_i32 = Polynomial::new(capital_f_coefficients);
717        let mut capital_g_i32 = Polynomial::new(capital_g_coefficients);
718        let f_bigint = f_i32.map(|i| BigInt::from(*i));
719        let g_bigint = g_i32.map(|i| BigInt::from(*i));
720        let mut capital_f_bigint = capital_f_i32.map(|i| BigInt::from(*i));
721        let mut capital_g_bigint = capital_g_i32.map(|i| BigInt::from(*i));
722
723        let small_int_result =
724            babai_reduce_i32(&f_i32, &g_i32, &mut capital_f_i32, &mut capital_g_i32);
725        let big_int_result = babai_reduce_bigint(
726            &f_bigint,
727            &g_bigint,
728            &mut capital_f_bigint,
729            &mut capital_g_bigint,
730        );
731
732        prop_assert_eq!(small_int_result.is_err(), big_int_result.is_err());
733        prop_assert_eq!(capital_f_i32.map(|c| BigInt::from(*c)), capital_f_bigint);
734        prop_assert_eq!(capital_g_i32.map(|c| BigInt::from(*c)), capital_g_bigint);
735    }
736
737    #[test]
738    fn test_ntru_gen() {
739        let n = 512;
740        let seed: [u8; 32] =
741            hex::decode("deadbeef00000000deadbeef00000000deadbeef00000000deadbeef00000000")
742                .unwrap()
743                .try_into()
744                .unwrap();
745        let mut rng: StdRng = SeedableRng::from_seed(seed);
746        let (f, g, capital_f, capital_g) = ntru_gen(n, &mut rng);
747
748        println!("f: {}", f);
749        println!("g: {}", g);
750        println!("capital f: {}", capital_f);
751        println!("capital g: {}", capital_g);
752        let f_times_capital_g = (f * capital_g).reduce_by_cyclotomic(n);
753        let g_times_capital_f = (g * capital_f).reduce_by_cyclotomic(n);
754        let difference = f_times_capital_g - g_times_capital_f;
755        assert_eq!(Polynomial::constant(12289), difference);
756    }
757}