Skip to main content

cryptography/public_key/
ntru_ees_core.rs

1//! Shared core for the IEEE Std 1363.1-2008 NTRUEncrypt parameter sets in
2//! this crate. Each per-set module
3//! ([`crate::public_key::ntru_ees401ep1`], etc.) is a thin wrapper that
4//! plugs an [`EesParams`] constant and an `N` const generic into the
5//! routines defined here.
6//!
7//! The construction is exactly the one described in the per-set module
8//! docstrings; this file just hoists the algorithm out of nine duplicates.
9//! Two structural variants are supported:
10//!
11//! - **Dense trapdoor** (`prod_flag = 0` in the IEEE tables): `t` is a
12//!   single trinary polynomial with `df` ones and `df` minus-ones; the
13//!   private key encodes `t` as 2-bit signed trinary.
14//! - **Product-form trapdoor** (`prod_flag = 1`): `t = f_1 \cdot f_2 + f_3`
15//!   with `(df_1, df_2, df_3)` ones-counts; the private key encodes the
16//!   three sparse trinary polynomials as 9-bit-per-coefficient index lists.
17//!
18//! Side channels: variable-time arithmetic. This module is only used from
19//! types under [`crate::vt`].
20//!
21//! Storage strategy: hot polynomial buffers (`Poly<N>::coeffs`) are
22//! inline `[u16; N]` arrays via the `const N: usize` parameter, so the
23//! Karatsuba multiplier inner loop avoids heap traffic on those.
24//! Several other inputs and intermediates remain heap-resident:
25//! wire-format byte buffers (`pk`, `sk`, `ct`); the IGF state's `BitStr`
26//! `Vec<u8>` and the seed `Vec<u8>` it copies; `mgf`'s working buffer
27//! of accepted hash bytes; the trial-and-error trinary samplers in
28//! `sample_trinary` / `sample_trapdoor`; and the F_2 extended-Euclidean
29//! inverter that backs `poly_inverse_mod_q_cyclic`. None of these is
30//! removable without either `generic_const_exprs` (the wire buffers'
31//! lengths are derived from `N` and `logq`) or a redesign of the
32//! inverter / sampler. The IEEE 1363.1 EES keygen and encrypt paths
33//! therefore allocate; this is intentional but worth naming so a
34//! profiler reading the heap pattern is not surprised.
35
36use crate::hash::sha1::Sha1;
37use crate::hash::sha2::Sha256;
38use crate::Csprng;
39
40// ---- parameter definitions --------------------------------------------------
41
42/// Hash function selected by the IEEE 1363.1 OID for a given parameter set.
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44pub enum HashKind {
45    Sha1,
46    Sha256,
47}
48
49impl HashKind {
50    /// Output length in bytes (`hlen` in IEEE 1363.1).
51    pub const fn output_len(self) -> usize {
52        match self {
53            HashKind::Sha1 => 20,
54            HashKind::Sha256 => 32,
55        }
56    }
57
58    fn digest_into(self, input: &[u8], out: &mut [u8]) {
59        match self {
60            HashKind::Sha1 => out.copy_from_slice(Sha1::digest(input).as_slice()),
61            HashKind::Sha256 => out.copy_from_slice(Sha256::digest(input).as_slice()),
62        }
63    }
64
65    /// Hash the concatenation of `prefix || suffix` into `out` incrementally,
66    /// without materialising the joined buffer. Used by IGF / MGF callers
67    /// where the prefix is the IGF seed `z` and the suffix is a small
68    /// counter encoding.
69    fn digest_two_into(self, prefix: &[u8], suffix: &[u8], out: &mut [u8]) {
70        match self {
71            HashKind::Sha1 => {
72                let mut h = Sha1::new();
73                h.update(prefix);
74                h.update(suffix);
75                out.copy_from_slice(h.finalize().as_slice());
76            }
77            HashKind::Sha256 => {
78                let mut h = Sha256::new();
79                h.update(prefix);
80                h.update(suffix);
81                out.copy_from_slice(h.finalize().as_slice());
82            }
83        }
84    }
85}
86
87/// Trapdoor structure for the private key.
88#[derive(Clone, Copy, Debug, Eq, PartialEq)]
89pub enum TrapdoorKind {
90    /// `t` is a single trinary polynomial with `df` ones and `df` minus-ones.
91    Dense {
92        df: usize,
93    },
94    /// `t = f_1 \cdot f_2 + f_3`. `df1`, `df2`, `df3` are the per-component
95    /// ones-counts.
96    ProductForm {
97        df1: usize,
98        df2: usize,
99        df3: usize,
100    },
101}
102
103/// All scalar parameters for an EES NTRUEncrypt set.
104#[derive(Clone, Copy, Debug, Eq, PartialEq)]
105pub struct EesParams {
106    pub n: usize,
107    /// `q = 2^logq`. Every IEEE 1363.1 set in this crate uses `logq = 11`.
108    pub logq: usize,
109    pub trapdoor: TrapdoorKind,
110    pub dg: usize,
111    pub dm0: usize,
112    /// Number of random *bits* prepended to the message.
113    pub db_bits: usize,
114    /// Bits per IGF rejection sample.
115    pub c_bits: usize,
116    pub min_calls_r: usize,
117    pub min_calls_mask: usize,
118    pub pklen_bits: usize,
119    pub oid: [u8; 3],
120    pub hash: HashKind,
121}
122
123impl EesParams {
124    pub const fn db_bytes(&self) -> usize {
125        self.db_bits / 8
126    }
127    pub const fn pklen_bytes(&self) -> usize {
128        self.pklen_bits.div_ceil(8)
129    }
130    pub const fn q(&self) -> u32 {
131        1u32 << self.logq
132    }
133    pub const fn q_mask(&self) -> u16 {
134        ((1u32 << self.logq) - 1) as u16
135    }
136    /// Wire length of a public key in bytes: `ceil(N * logq / 8)`.
137    pub const fn pk_wire_bytes(&self) -> usize {
138        (self.n * self.logq).div_ceil(8)
139    }
140    /// Trapdoor section of the private key (does not include the embedded
141    /// public key bytes).
142    pub const fn trapdoor_wire_bytes(&self) -> usize {
143        match self.trapdoor {
144            TrapdoorKind::Dense { .. } => (self.n * 2).div_ceil(8),
145            TrapdoorKind::ProductForm { df1, df2, df3 } => {
146                let indices = 2 * (df1 + df2 + df3);
147                let bits = indices * Self::index_bits(self.n);
148                bits.div_ceil(8)
149            }
150        }
151    }
152    /// Number of bits needed per index in the product-form sparse encoding.
153    /// Returns `ceil(log2(N))` rounded up to the next bit.
154    const fn index_bits(n: usize) -> usize {
155        // log2(n) ceiling, computed with const-friendly arithmetic
156        let mut bits = 0usize;
157        let mut v = n.saturating_sub(1);
158        while v > 0 {
159            bits += 1;
160            v >>= 1;
161        }
162        bits
163    }
164    pub const fn ciphertext_wire_bytes(&self) -> usize {
165        self.pk_wire_bytes()
166    }
167    pub const fn max_message_bytes(&self) -> usize {
168        self.n / 2 * 3 / 8 - 1 - self.db_bytes()
169    }
170}
171
172// ---- polynomial type --------------------------------------------------------
173
174/// Polynomial in `Z_q[x] / (x^N - 1)` with `u16` coefficients.
175#[derive(Clone, Copy)]
176pub struct Poly<const N: usize> {
177    pub coeffs: [u16; N],
178}
179
180impl<const N: usize> Poly<N> {
181    pub fn zero() -> Self {
182        Self { coeffs: [0u16; N] }
183    }
184}
185
186#[inline(always)]
187fn modq(x: u16, q_mask: u16) -> u16 {
188    x & q_mask
189}
190
191pub fn poly_mul<const N: usize>(r: &mut Poly<N>, a: &Poly<N>, b: &Poly<N>) {
192    crate::public_key::ntru_poly_mul::poly_mul_cyclic(&mut r.coeffs, &a.coeffs, &b.coeffs);
193}
194
195pub fn poly_add<const N: usize>(a: &mut Poly<N>, b: &Poly<N>) {
196    for i in 0..N {
197        a.coeffs[i] = a.coeffs[i].wrapping_add(b.coeffs[i]);
198    }
199}
200
201pub fn poly_sub<const N: usize>(a: &mut Poly<N>, b: &Poly<N>) {
202    for i in 0..N {
203        a.coeffs[i] = a.coeffs[i].wrapping_sub(b.coeffs[i]);
204    }
205}
206
207/// Reduce coefficients into `{0, 1, 2}` (canonical trinary representation
208/// after centred-residue mod `q`).
209pub fn poly_mod3<const N: usize>(a: &mut Poly<N>, params: &EesParams) {
210    let q = params.q();
211    let q_mask = params.q_mask();
212    for c in a.coeffs.iter_mut() {
213        let m = modq(*c, q_mask);
214        let centred = if (m as u32) > q / 2 {
215            m as i32 - q as i32
216        } else {
217            m as i32
218        };
219        let r = centred.rem_euclid(3);
220        *c = r as u16;
221    }
222}
223
224pub fn poly_scalar_mul<const N: usize>(a: &mut Poly<N>, k: u16, q_mask: u16) {
225    for c in a.coeffs.iter_mut() {
226        *c = c.wrapping_mul(k) & q_mask;
227    }
228}
229
230pub fn poly_mod_q<const N: usize>(a: &mut Poly<N>, q_mask: u16) {
231    for c in a.coeffs.iter_mut() {
232        *c = modq(*c, q_mask);
233    }
234}
235
236// ---- trinary and product-form polynomial representations -------------------
237
238#[derive(Clone, Eq, PartialEq)]
239pub struct TernaryPoly {
240    pub ones: Vec<u16>,
241    pub neg_ones: Vec<u16>,
242}
243
244impl TernaryPoly {
245    pub fn to_dense<const N: usize>(&self, q_mask: u16) -> Poly<N> {
246        let mut p = Poly::<N>::zero();
247        for &i in &self.ones {
248            p.coeffs[i as usize] = 1;
249        }
250        for &i in &self.neg_ones {
251            p.coeffs[i as usize] = q_mask;
252        }
253        p
254    }
255
256    pub fn mul_dense<const N: usize>(&self, b: &Poly<N>, out: &mut Poly<N>) {
257        for c in out.coeffs.iter_mut() {
258            *c = 0;
259        }
260        for &idx in &self.ones {
261            let s = idx as usize;
262            for j in 0..N {
263                let k = if s + j >= N { s + j - N } else { s + j };
264                out.coeffs[k] = out.coeffs[k].wrapping_add(b.coeffs[j]);
265            }
266        }
267        for &idx in &self.neg_ones {
268            let s = idx as usize;
269            for j in 0..N {
270                let k = if s + j >= N { s + j - N } else { s + j };
271                out.coeffs[k] = out.coeffs[k].wrapping_sub(b.coeffs[j]);
272            }
273        }
274    }
275}
276
277#[derive(Clone, Eq, PartialEq)]
278pub struct ProductPoly {
279    pub f1: TernaryPoly,
280    pub f2: TernaryPoly,
281    pub f3: TernaryPoly,
282}
283
284impl ProductPoly {
285    pub fn mul_dense<const N: usize>(&self, a: &Poly<N>, out: &mut Poly<N>) {
286        let mut t1 = Poly::<N>::zero();
287        self.f1.mul_dense::<N>(a, &mut t1);
288        self.f2.mul_dense::<N>(&t1, out);
289        let mut t3 = Poly::<N>::zero();
290        self.f3.mul_dense::<N>(a, &mut t3);
291        poly_add::<N>(out, &t3);
292    }
293
294    pub fn to_dense<const N: usize>(&self, q_mask: u16) -> Poly<N> {
295        let f2_dense = self.f2.to_dense::<N>(q_mask);
296        let mut out = Poly::<N>::zero();
297        self.f1.mul_dense::<N>(&f2_dense, &mut out);
298        let f3_dense = self.f3.to_dense::<N>(q_mask);
299        poly_add::<N>(&mut out, &f3_dense);
300        out
301    }
302}
303
304/// Trapdoor used in the private key.
305#[derive(Clone, Eq, PartialEq)]
306pub enum Trapdoor {
307    Dense(TernaryPoly),
308    Product(ProductPoly),
309}
310
311impl Trapdoor {
312    fn mul_dense<const N: usize>(&self, a: &Poly<N>, out: &mut Poly<N>) {
313        match self {
314            Trapdoor::Dense(t) => t.mul_dense::<N>(a, out),
315            Trapdoor::Product(p) => p.mul_dense::<N>(a, out),
316        }
317    }
318
319    fn to_dense<const N: usize>(&self, q_mask: u16) -> Poly<N> {
320        match self {
321            Trapdoor::Dense(t) => t.to_dense::<N>(q_mask),
322            Trapdoor::Product(p) => p.to_dense::<N>(q_mask),
323        }
324    }
325
326    /// Pack the trapdoor into the canonical IEEE 1363.1 wire bytes. The
327    /// output length must equal [`EesParams::trapdoor_wire_bytes`].
328    pub fn to_wire(&self, params: &EesParams, out: &mut [u8]) {
329        debug_assert_eq!(out.len(), params.trapdoor_wire_bytes());
330        for b in out.iter_mut() {
331            *b = 0;
332        }
333        match self {
334            Trapdoor::Dense(t) => {
335                // O(df) write at each non-zero index.
336                for &i in &t.ones {
337                    let bit_pos = 2 * (i as usize);
338                    out[bit_pos / 8] |= 1 << (bit_pos % 8);
339                }
340                for &i in &t.neg_ones {
341                    let bit_pos = 2 * (i as usize);
342                    out[bit_pos / 8] |= 3 << (bit_pos % 8);
343                }
344            }
345            Trapdoor::Product(p) => {
346                let mut bit_offset = 0usize;
347                let index_bits = EesParams::index_bits(params.n);
348                for poly in &[&p.f1, &p.f2, &p.f3] {
349                    pack_indices(&poly.ones, out, &mut bit_offset, index_bits)
350                        .expect("ones fit");
351                    pack_indices(&poly.neg_ones, out, &mut bit_offset, index_bits)
352                        .expect("neg_ones fit");
353                }
354            }
355        }
356    }
357
358    /// Inverse of [`Trapdoor::to_wire`]. Validates count and index ranges.
359    pub fn from_wire(bytes: &[u8], params: &EesParams) -> Option<Self> {
360        if bytes.len() != params.trapdoor_wire_bytes() {
361            return None;
362        }
363        match params.trapdoor {
364            TrapdoorKind::Dense { df } => {
365                let n = params.n;
366                let mut bit_pos = 0usize;
367                let mut ones = Vec::new();
368                let mut neg_ones = Vec::new();
369                for i in 0..n {
370                    let code = (bytes[bit_pos / 8] >> (bit_pos % 8)) & 0x3;
371                    bit_pos += 2;
372                    match code {
373                        0 => {}
374                        1 => ones.push(i as u16),
375                        3 => neg_ones.push(i as u16),
376                        _ => return None,
377                    }
378                }
379                if ones.len() != df || neg_ones.len() != df {
380                    return None;
381                }
382                if !padding_bits_clear(bytes, n * 2) {
383                    return None;
384                }
385                Some(Trapdoor::Dense(TernaryPoly { ones, neg_ones }))
386            }
387            TrapdoorKind::ProductForm { df1, df2, df3 } => {
388                let mut bit_offset = 0usize;
389                let index_bits = EesParams::index_bits(params.n);
390                let n = params.n;
391                let f1_ones = unpack_indices(bytes, df1, &mut bit_offset, index_bits, n)?;
392                let f1_neg = unpack_indices(bytes, df1, &mut bit_offset, index_bits, n)?;
393                let f2_ones = unpack_indices(bytes, df2, &mut bit_offset, index_bits, n)?;
394                let f2_neg = unpack_indices(bytes, df2, &mut bit_offset, index_bits, n)?;
395                let f3_ones = unpack_indices(bytes, df3, &mut bit_offset, index_bits, n)?;
396                let f3_neg = unpack_indices(bytes, df3, &mut bit_offset, index_bits, n)?;
397                if !padding_bits_clear(bytes, bit_offset) {
398                    return None;
399                }
400                Some(Trapdoor::Product(ProductPoly {
401                    f1: TernaryPoly { ones: f1_ones, neg_ones: f1_neg },
402                    f2: TernaryPoly { ones: f2_ones, neg_ones: f2_neg },
403                    f3: TernaryPoly { ones: f3_ones, neg_ones: f3_neg },
404                }))
405            }
406        }
407    }
408
409    /// Sample a trapdoor at IID-uniform via the rejection sampler in
410    /// [`sample_trinary`]; the variant is chosen by `params.trapdoor`.
411    /// This is the ordinary-keygen entry point.
412    fn sample_iid<R: Csprng>(rng: &mut R, params: &EesParams) -> Self {
413        match params.trapdoor {
414            TrapdoorKind::Dense { df } => {
415                Trapdoor::Dense(sample_trinary(rng, params.n, df, df))
416            }
417            TrapdoorKind::ProductForm { df1, df2, df3 } => Trapdoor::Product(ProductPoly {
418                f1: sample_trinary(rng, params.n, df1, df1),
419                f2: sample_trinary(rng, params.n, df2, df2),
420                f3: sample_trinary(rng, params.n, df3, df3),
421            }),
422        }
423    }
424
425    /// Sample a blinding trapdoor via the IGF state, used by SVES-3
426    /// encryption. The variant is chosen by the IGF's parameter set.
427    fn sample_via_igf(state: &mut IgfState<'_>) -> Self {
428        match state.params.trapdoor {
429            TrapdoorKind::Dense { df } => Trapdoor::Dense(igf_gen_ternary(state, df)),
430            TrapdoorKind::ProductForm { df1, df2, df3 } => Trapdoor::Product(ProductPoly {
431                f1: igf_gen_ternary(state, df1),
432                f2: igf_gen_ternary(state, df2),
433                f3: igf_gen_ternary(state, df3),
434            }),
435        }
436    }
437}
438
439// ---- inversion mod 2 in F_2[x] / (x^N - 1) ---------------------------------
440
441fn poly_trim(p: &mut Vec<u8>) {
442    while p.len() > 1 && *p.last().unwrap() == 0 {
443        p.pop();
444    }
445}
446
447fn poly_deg(p: &[u8]) -> Option<usize> {
448    for i in (0..p.len()).rev() {
449        if p[i] != 0 {
450            return Some(i);
451        }
452    }
453    None
454}
455
456fn poly_inverse_mod2_cyclic(a_coeffs: &[u8]) -> Option<Vec<u8>> {
457    let n = a_coeffs.len();
458    let mut r0 = vec![0u8; n + 1];
459    r0[0] = 1;
460    r0[n] = 1;
461    let mut r1: Vec<u8> = a_coeffs.iter().map(|&c| c & 1).collect();
462    poly_trim(&mut r1);
463    let mut t0 = vec![0u8; 1];
464    let mut t1 = vec![1u8; 1];
465
466    loop {
467        let d1 = match poly_deg(&r1) {
468            Some(d) => d,
469            None => break,
470        };
471        let d0 = match poly_deg(&r0) {
472            Some(d) => d,
473            None => {
474                std::mem::swap(&mut r0, &mut r1);
475                std::mem::swap(&mut t0, &mut t1);
476                break;
477            }
478        };
479        if d0 < d1 {
480            std::mem::swap(&mut r0, &mut r1);
481            std::mem::swap(&mut t0, &mut t1);
482            continue;
483        }
484        let shift = d0 - d1;
485        for i in 0..=d1 {
486            r0[shift + i] ^= r1[i];
487        }
488        poly_trim(&mut r0);
489        let new_t0_len = t0.len().max(t1.len() + shift);
490        if t0.len() < new_t0_len {
491            t0.resize(new_t0_len, 0);
492        }
493        for i in 0..t1.len() {
494            t0[shift + i] ^= t1[i];
495        }
496    }
497
498    if !(r0.len() == 1 && r0[0] == 1) {
499        return None;
500    }
501    let mut out = vec![0u8; n];
502    for (i, &c) in t0.iter().enumerate() {
503        if c & 1 == 1 {
504            out[i % n] ^= 1;
505        }
506    }
507    Some(out)
508}
509
510fn poly_inverse_mod_q_cyclic<const N: usize>(
511    a: &Poly<N>,
512    params: &EesParams,
513) -> Option<Poly<N>> {
514    let q = params.q();
515    let q_mask = params.q_mask();
516    let a_mod2: Vec<u8> = a.coeffs.iter().map(|&c| (c & 1) as u8).collect();
517    let inv2 = poly_inverse_mod2_cyclic(&a_mod2)?;
518
519    let mut b = Poly::<N>::zero();
520    for i in 0..N {
521        b.coeffs[i] = inv2[i] as u16;
522    }
523
524    // Newton-style Hensel lift: each pass squares `precision` (2 โ†’ 4
525    // โ†’ 16 โ†’ 256 โ†’ 65536). Four iterations suffice for every $q \le
526    // 2^{16}$, which covers every IEEE 1363.1 EES parameter set in
527    // this crate ($q = 2048$). `saturating_mul` caps the final pass at
528    // `u32::MAX` so the loop terminates cleanly even if a future
529    // parameter set raised `q` past 2^{16}.
530    let mut precision: u32 = 2;
531    while precision < q {
532        let mut ab = Poly::<N>::zero();
533        poly_mul::<N>(&mut ab, a, &b);
534        poly_mod_q::<N>(&mut ab, q_mask);
535        let mut two_minus_ab = Poly::<N>::zero();
536        two_minus_ab.coeffs[0] = 2u16.wrapping_sub(ab.coeffs[0]) & q_mask;
537        for i in 1..N {
538            two_minus_ab.coeffs[i] = 0u16.wrapping_sub(ab.coeffs[i]) & q_mask;
539        }
540        let mut new_b = Poly::<N>::zero();
541        poly_mul::<N>(&mut new_b, &b, &two_minus_ab);
542        poly_mod_q::<N>(&mut new_b, q_mask);
543        b = new_b;
544        precision = precision.saturating_mul(precision);
545    }
546    Some(b)
547}
548
549// ---- bit-string accumulator (IEEE 1363.1 ยง9 BPGM3 / IGF helpers) -----------
550//
551// Bits are packed LSB-first within each byte; byte 0 holds the oldest 8 bits
552// (the "bottom") and the partial byte at `buf.len() - 1` holds the newest
553// (the "top"). The IGF only ever appends at the top, reads the top `c` bits,
554// and (when refilling) saves the unconsumed bottom slice into a fresh
555// `BitStr` that gets new hash output appended above. There are exactly four
556// supported operations; everything else is a hidden invariant. Track the
557// total bit count explicitly and derive the byte index / partial-bit count
558// from it on the fly so the invariants are obvious.
559
560#[derive(Clone)]
561struct BitStr {
562    buf: Vec<u8>,
563    bit_len: usize,
564}
565
566impl BitStr {
567    fn new() -> Self {
568        Self { buf: Vec::new(), bit_len: 0 }
569    }
570
571    /// Append 8 bits at the top of the stack.
572    fn append_byte(&mut self, b: u8) {
573        let off = self.bit_len % 8;
574        if off == 0 {
575            self.buf.push(b);
576        } else {
577            *self
578                .buf
579                .last_mut()
580                .expect("non-empty by `bit_len > 0`") |= b << off;
581            self.buf.push(b >> (8 - off));
582        }
583        self.bit_len += 8;
584    }
585
586    fn append(&mut self, bytes: &[u8]) {
587        for &b in bytes {
588            self.append_byte(b);
589        }
590    }
591
592    /// Read the top `num_bits` (most recently appended) as a little-endian
593    /// `u32`; `num_bits` must be โ‰ค 32 and โ‰ค `bit_len`.
594    fn leading(&self, num_bits: u8) -> u32 {
595        let n = num_bits as usize;
596        debug_assert!(n <= 32 && n <= self.bit_len);
597        let start = self.bit_len - n;
598        let mut v: u32 = 0;
599        for i in 0..n {
600            let p = start + i;
601            v |= u32::from((self.buf[p / 8] >> (p % 8)) & 1) << i;
602        }
603        v
604    }
605
606    /// Drop the top `num_bits`; trims trailing bytes that are wholly above
607    /// the new bit length and clears any stale bits in the new top byte so
608    /// later appends OR cleanly into it.
609    fn truncate(&mut self, num_bits: u8) {
610        let n = num_bits as usize;
611        debug_assert!(n <= self.bit_len);
612        self.bit_len -= n;
613        let needed = self.bit_len.div_ceil(8);
614        self.buf.truncate(needed);
615        let off = self.bit_len % 8;
616        if off != 0 {
617            let last = self.buf.last_mut().expect("non-empty by needed > 0");
618            *last &= (1u8 << off) - 1;
619        }
620    }
621
622    /// Take the bottom `num_bits` (oldest, the unconsumed remainder) into
623    /// a fresh `BitStr`. Used at IGF refill time to preserve the residual
624    /// before stacking new hash output above it.
625    fn trailing(&self, num_bits: u32) -> Self {
626        let n = num_bits as usize;
627        debug_assert!(n <= self.bit_len);
628        let needed = n.div_ceil(8);
629        let mut buf = self.buf[..needed].to_vec();
630        let off = n % 8;
631        if off != 0 {
632            *buf.last_mut().expect("needed > 0") &= (1u8 << off) - 1;
633        }
634        Self { buf, bit_len: n }
635    }
636}
637
638struct IgfState<'a> {
639    z: Vec<u8>,
640    counter: u16,
641    buf: BitStr,
642    rem_bits: u32,
643    params: &'a EesParams,
644}
645
646impl<'a> IgfState<'a> {
647    fn new(seed: &[u8], params: &'a EesParams) -> Self {
648        // The IGF reads `c_bits` per index sample, accumulating into
649        // `BitStr::leading(num_bits: u8)` โ€” so `c_bits` must fit in a
650        // u8. Every IEEE 1363.1 EES set this crate ships uses
651        // `c_bits โˆˆ {9, 11, 12, 13}`; the guard catches a future
652        // parameter set that violates the assumption.
653        debug_assert!(
654            params.c_bits <= u8::MAX as usize,
655            "IGF c_bits must fit in a u8"
656        );
657        let hlen = params.hash.output_len();
658        let mut s = Self {
659            z: seed.to_vec(),
660            counter: 0,
661            buf: BitStr::new(),
662            rem_bits: (params.min_calls_r * 8 * hlen) as u32,
663            params,
664        };
665        while (s.counter as usize) < params.min_calls_r {
666            s.absorb_one();
667        }
668        s
669    }
670    fn absorb_one(&mut self) {
671        let hlen = self.params.hash.output_len();
672        let mut out = [0u8; 64];
673        self.params
674            .hash
675            .digest_two_into(&self.z, &self.counter.to_le_bytes(), &mut out[..hlen]);
676        self.buf.append(&out[..hlen]);
677        self.counter = self.counter.wrapping_add(1);
678    }
679    fn next_index(&mut self) -> u16 {
680        let n = self.params.n as u32;
681        let c = self.params.c_bits as u8;
682        let hlen = self.params.hash.output_len();
683        // Largest multiple of n that fits in c bits. `v < rnd_thresh` โ‡’ `v %
684        // n` is uniformly distributed; otherwise resample.
685        let rnd_thresh: u32 = (1u32 << c) - (1u32 << c) % n;
686        loop {
687            if self.rem_bits < c as u32 {
688                let mut tail = self.buf.trailing(self.rem_bits);
689                let need = (c as u32) - self.rem_bits;
690                let extra_calls = need.div_ceil((hlen as u32) * 8);
691                let mut out = [0u8; 64];
692                for _ in 0..extra_calls {
693                    self.params.hash.digest_two_into(
694                        &self.z,
695                        &self.counter.to_le_bytes(),
696                        &mut out[..hlen],
697                    );
698                    tail.append(&out[..hlen]);
699                    self.counter = self.counter.wrapping_add(1);
700                    self.rem_bits += 8 * hlen as u32;
701                }
702                self.buf = tail;
703            }
704            let v = self.buf.leading(c);
705            self.buf.truncate(c);
706            self.rem_bits -= c as u32;
707            if v < rnd_thresh {
708                return (v % n) as u16;
709            }
710        }
711    }
712}
713
714fn igf_gen_ternary(state: &mut IgfState<'_>, num_each: usize) -> TernaryPoly {
715    let n = state.params.n;
716    let mut occupied = vec![false; n];
717    let mut neg_ones = Vec::with_capacity(num_each);
718    let mut ones = Vec::with_capacity(num_each);
719    while neg_ones.len() < num_each {
720        let idx = state.next_index();
721        if !occupied[idx as usize] {
722            occupied[idx as usize] = true;
723            neg_ones.push(idx);
724        }
725    }
726    while ones.len() < num_each {
727        let idx = state.next_index();
728        if !occupied[idx as usize] {
729            occupied[idx as usize] = true;
730            ones.push(idx);
731        }
732    }
733    neg_ones.sort_unstable();
734    ones.sort_unstable();
735    TernaryPoly { ones, neg_ones }
736}
737
738fn igf_gen_blinding(state: &mut IgfState<'_>) -> Trapdoor {
739    Trapdoor::sample_via_igf(state)
740}
741
742// ---- MGF -------------------------------------------------------------------
743
744/// IEEE 1363.1 ยง9.2.4 Trit-decomposition table: each in-range byte
745/// (0..243 = 3^5) maps to a 5-trit base-3 expansion using {0, 1, -1}. Built at
746/// compile time so MGF stays branch-free at the hash-output -> trinary stage.
747const MGF_TRIT_TABLE: [[i8; 5]; 243] = {
748    let mut t = [[0i8; 5]; 243];
749    let map = [0i8, 1, -1];
750    let mut byte = 0usize;
751    while byte < 243 {
752        let mut v = byte;
753        let mut slot = 0usize;
754        while slot < 5 {
755            t[byte][slot] = map[v % 3];
756            v /= 3;
757            slot += 1;
758        }
759        byte += 1;
760    }
761    t
762};
763
764fn mgf<const N: usize>(seed: &[u8], params: &EesParams) -> Poly<N> {
765    let hlen = params.hash.output_len();
766    let q_mask = params.q_mask();
767    let mut z = [0u8; 64];
768    params.hash.digest_into(seed, &mut z[..hlen]);
769
770    let mut buf: Vec<u8> = Vec::with_capacity(params.min_calls_mask * hlen);
771    let mut counter: u16 = 0;
772    let mut h = [0u8; 64];
773    while (counter as usize) < params.min_calls_mask {
774        params
775            .hash
776            .digest_two_into(&z[..hlen], &counter.to_be_bytes(), &mut h[..hlen]);
777        for &b in &h[..hlen] {
778            if b < 243 {
779                buf.push(b);
780            }
781        }
782        counter = counter.wrapping_add(1);
783    }
784
785    let mut out = Poly::<N>::zero();
786    let mut cur = 0usize;
787    // The IEEE 1363.1 IGF + MGF rejection-samples bytes < 243 at a
788    // density of 243/256 โ‰ˆ 95%, so the expected number of hash
789    // calls is at most $\lceil N / (5 \cdot 0.95 \cdot \text{hlen})
790    // \rceil$ โ€” about 4 for $N = 1499, \text{hlen} = 32$. The bound
791    // below is a defensive ceiling against a pathological hash
792    // distribution; for the round-3-style SHA-1 / SHA-256 hashes
793    // shipped in this crate it is never reached.
794    let counter_ceiling = (params.min_calls_mask as u16).saturating_add(1024);
795    'outer: loop {
796        for &b in &buf {
797            for &t in &MGF_TRIT_TABLE[b as usize] {
798                out.coeffs[cur] = match t {
799                    -1 => q_mask,
800                    0 => 0,
801                    1 => 1,
802                    _ => unreachable!(),
803                };
804                cur += 1;
805                if cur >= N {
806                    break 'outer;
807                }
808            }
809        }
810        assert!(
811            counter < counter_ceiling,
812            "MGF rejection sampler exceeded counter ceiling โ€” hash output is pathologically biased"
813        );
814        params
815            .hash
816            .digest_two_into(&z[..hlen], &counter.to_be_bytes(), &mut h[..hlen]);
817        buf.clear();
818        for &b in &h[..hlen] {
819            if b < 243 {
820                buf.push(b);
821            }
822        }
823        counter = counter.wrapping_add(1);
824    }
825    out
826}
827
828// ---- SVES encoding (IEEE 1363.1 ยง9.2.2 / ยง9.2.3) ---------------------------
829
830const SVES_C1: [i8; 8] = [0, 0, 0, 1, 1, 1, -1, -1];
831const SVES_C2: [i8; 8] = [0, 1, -1, 0, 1, -1, 0, 1];
832
833fn trit_to_u16(t: i8, q_mask: u16) -> u16 {
834    match t {
835        -1 => q_mask,
836        0 => 0,
837        1 => 1,
838        _ => unreachable!(),
839    }
840}
841
842fn sves_from_bytes<const N: usize>(m: &[u8], q_mask: u16) -> Poly<N> {
843    let mut out = Poly::<N>::zero();
844    let mut coeff_idx: usize = 0;
845    let mut i = 0usize;
846    while i + 3 <= ((m.len() + 2) / 3) * 3 && coeff_idx < N - 1 {
847        let b0 = if i < m.len() { m[i] } else { 0 } as u32;
848        let b1 = if i + 1 < m.len() { m[i + 1] } else { 0 } as u32;
849        let b2 = if i + 2 < m.len() { m[i + 2] } else { 0 } as u32;
850        let mut chunk = (b2 << 16) | (b1 << 8) | b0;
851        i += 3;
852        for _ in 0..8 {
853            if coeff_idx >= N - 1 {
854                break;
855            }
856            let tbl = (chunk & 7) as usize;
857            out.coeffs[coeff_idx] = trit_to_u16(SVES_C1[tbl], q_mask);
858            out.coeffs[coeff_idx + 1] = trit_to_u16(SVES_C2[tbl], q_mask);
859            coeff_idx += 2;
860            chunk >>= 3;
861        }
862    }
863    out
864}
865
866fn sves_to_bytes<const N: usize>(p: &Poly<N>) -> Option<Vec<u8>> {
867    let num_bits = (N * 3 + 1) / 2;
868    let num_bytes = num_bits.div_ceil(8);
869    let mut out = vec![0u8; num_bytes + 3];
870    let end = N / 2 * 2;
871    let mut d_idx = 0usize;
872    let mut i = 0usize;
873    while i < end {
874        let mut acc: u32 = 0;
875        let mut bits_in_acc: u32 = 0;
876        for _ in 0..8 {
877            if i >= end {
878                break;
879            }
880            let c1 = p.coeffs[i] as i32;
881            let c2 = p.coeffs[i + 1] as i32;
882            i += 2;
883            if c1 == 2 && c2 == 2 {
884                return None;
885            }
886            let c = (c1 * 3 + c2) as u32;
887            acc |= c << bits_in_acc;
888            bits_in_acc += 3;
889            while bits_in_acc >= 8 && d_idx < out.len() {
890                out[d_idx] = (acc & 0xff) as u8;
891                d_idx += 1;
892                acc >>= 8;
893                bits_in_acc -= 8;
894            }
895        }
896        if bits_in_acc > 0 && d_idx < out.len() {
897            out[d_idx] |= acc as u8;
898        }
899    }
900    out.truncate(num_bytes);
901    Some(out)
902}
903
904// ---- byte encodings of polynomials -----------------------------------------
905
906fn poly_to_arr<const N: usize>(p: &Poly<N>, out: &mut [u8], params: &EesParams) {
907    let logq = params.logq;
908    let q_mask = params.q_mask();
909    debug_assert_eq!(out.len(), params.pk_wire_bytes());
910    for b in out.iter_mut() {
911        *b = 0;
912    }
913    let mut bit_pos = 0usize;
914    for i in 0..N {
915        let v = (p.coeffs[i] & q_mask) as u32;
916        for b in 0..logq {
917            let bit = ((v >> b) & 1) as u8;
918            out[bit_pos / 8] |= bit << (bit_pos % 8);
919            bit_pos += 1;
920        }
921    }
922}
923
924fn poly_from_arr<const N: usize>(input: &[u8], params: &EesParams) -> Poly<N> {
925    let logq = params.logq;
926    debug_assert!(input.len() >= params.pk_wire_bytes());
927    let mut p = Poly::<N>::zero();
928    let mut bit_pos = 0usize;
929    for i in 0..N {
930        let mut v: u32 = 0;
931        for b in 0..logq {
932            let bit = ((input[bit_pos / 8] >> (bit_pos % 8)) & 1) as u32;
933            v |= bit << b;
934            bit_pos += 1;
935        }
936        p.coeffs[i] = v as u16;
937    }
938    p
939}
940
941fn poly_to_arr4<const N: usize>(p: &Poly<N>, params: &EesParams) -> Vec<u8> {
942    let q = params.q();
943    let q_mask = params.q_mask();
944    let nbits = N * 2;
945    let mut out = vec![0u8; nbits.div_ceil(8)];
946    let mut bit_pos = 0usize;
947    for i in 0..N {
948        let centred = {
949            let m = p.coeffs[i] & q_mask;
950            let centred = if (m as u32) > q / 2 {
951                m as i32 - q as i32
952            } else {
953                m as i32
954            };
955            (centred & 3) as u8
956        };
957        for b in 0..2 {
958            let bit = (centred >> b) & 1;
959            out[bit_pos / 8] |= bit << (bit_pos % 8);
960            bit_pos += 1;
961        }
962    }
963    out
964}
965
966// ---- private-key wire encoding (dense vs product form) ---------------------
967
968fn pack_indices(
969    indices: &[u16],
970    out: &mut [u8],
971    bit_offset: &mut usize,
972    index_bits: usize,
973) -> Option<()> {
974    for &v in indices {
975        if (v as usize) >= (1usize << index_bits) {
976            return None;
977        }
978        for i in 0..index_bits {
979            let bit = ((v >> i) & 1) as u8;
980            out[*bit_offset / 8] |= bit << (*bit_offset % 8);
981            *bit_offset += 1;
982        }
983    }
984    Some(())
985}
986
987fn unpack_indices(
988    bytes: &[u8],
989    n: usize,
990    bit_offset: &mut usize,
991    index_bits: usize,
992    n_max: usize,
993) -> Option<Vec<u16>> {
994    let mut out = Vec::with_capacity(n);
995    for _ in 0..n {
996        let mut v: u32 = 0;
997        for i in 0..index_bits {
998            let bit = ((bytes[*bit_offset / 8] >> (*bit_offset % 8)) & 1) as u32;
999            v |= bit << i;
1000            *bit_offset += 1;
1001        }
1002        if (v as usize) >= n_max {
1003            return None;
1004        }
1005        out.push(v as u16);
1006    }
1007    Some(out)
1008}
1009
1010/// `out`'s trailing bits past `used_bits` must be zero; returns true if so.
1011/// Used both for trapdoor and pk/ct wire decoding so the malleability check
1012/// is implemented once. Caller passes the slice and the meaningful bit count.
1013#[doc(hidden)]
1014pub fn padding_bits_clear(bytes: &[u8], used_bits: usize) -> bool {
1015    debug_assert!(used_bits <= bytes.len() * 8);
1016    let total = bytes.len() * 8;
1017    if total == used_bits {
1018        return true;
1019    }
1020    let last = *bytes.last().expect("non-empty by construction");
1021    let used_in_last = used_bits - (bytes.len() - 1) * 8;
1022    (last >> used_in_last) == 0
1023}
1024
1025/// Pack the trapdoor portion of a private key into the `params`-defined
1026/// trapdoor wire bytes. Thin alias for [`Trapdoor::to_wire`].
1027pub fn trapdoor_to_wire(t: &Trapdoor, params: &EesParams, out: &mut [u8]) {
1028    t.to_wire(params, out);
1029}
1030
1031/// Inverse of [`trapdoor_to_wire`]. Thin alias for
1032/// [`Trapdoor::from_wire`].
1033pub fn trapdoor_from_wire(bytes: &[u8], params: &EesParams) -> Option<Trapdoor> {
1034    Trapdoor::from_wire(bytes, params)
1035}
1036
1037// ---- helpers ---------------------------------------------------------------
1038
1039fn next_index_below<R: Csprng>(rng: &mut R, modulus: u32) -> u32 {
1040    let threshold = u32::MAX - (u32::MAX % modulus);
1041    loop {
1042        let mut buf = [0u8; 4];
1043        rng.fill_bytes(&mut buf);
1044        let v = u32::from_le_bytes(buf);
1045        if v < threshold {
1046            return v % modulus;
1047        }
1048    }
1049}
1050
1051fn sample_trinary<R: Csprng>(
1052    rng: &mut R,
1053    n: usize,
1054    num_ones: usize,
1055    num_neg_ones: usize,
1056) -> TernaryPoly {
1057    debug_assert!(num_ones + num_neg_ones <= n);
1058    let mut idx: Vec<u16> = (0..n as u16).collect();
1059    let take = num_ones + num_neg_ones;
1060    for i in 0..take {
1061        let j = i + next_index_below(rng, (n - i) as u32) as usize;
1062        idx.swap(i, j);
1063    }
1064    let mut ones = idx[..num_ones].to_vec();
1065    let mut neg_ones = idx[num_ones..take].to_vec();
1066    ones.sort_unstable();
1067    neg_ones.sort_unstable();
1068    TernaryPoly { ones, neg_ones }
1069}
1070
1071fn sample_trapdoor<R: Csprng>(rng: &mut R, params: &EesParams) -> Trapdoor {
1072    Trapdoor::sample_iid(rng, params)
1073}
1074
1075fn check_rep_weight<const N: usize>(p: &Poly<N>, params: &EesParams) -> bool {
1076    let mut w = [0usize; 3];
1077    for i in 0..N {
1078        let v = p.coeffs[i] as usize;
1079        if v < 3 {
1080            w[v] += 1;
1081        }
1082    }
1083    w[0] >= params.dm0 && w[1] >= params.dm0 && w[2] >= params.dm0
1084}
1085
1086// ---- top-level keygen / encrypt / decrypt ---------------------------------
1087
1088#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1089pub enum NtruEesError {
1090    MessageTooLong,
1091    InvalidCiphertext,
1092}
1093
1094/// Generate a fresh key pair. The returned `pk_bytes` has length
1095/// `params.pk_wire_bytes()`; the trapdoor is returned as a [`Trapdoor`] so
1096/// the caller can pack it with [`trapdoor_to_wire`].
1097pub fn keygen<const N: usize, R: Csprng>(
1098    params: &EesParams,
1099    rng: &mut R,
1100) -> (Vec<u8>, Trapdoor) {
1101    debug_assert_eq!(params.n, N);
1102    let q_mask = params.q_mask();
1103    loop {
1104        let t = sample_trapdoor(rng, params);
1105        // f = 1 + 3 ยท t expanded into a dense polynomial.
1106        let mut f = t.to_dense::<N>(q_mask);
1107        poly_scalar_mul::<N>(&mut f, 3, q_mask);
1108        f.coeffs[0] = f.coeffs[0].wrapping_add(1) & q_mask;
1109        let f_inv = match poly_inverse_mod_q_cyclic::<N>(&f, params) {
1110            Some(inv) => inv,
1111            None => continue,
1112        };
1113
1114        let g = sample_trinary(rng, params.n, params.dg, params.dg);
1115        let mut g_dense = g.to_dense::<N>(q_mask);
1116        poly_mod_q::<N>(&mut g_dense, q_mask);
1117        let mut h = Poly::<N>::zero();
1118        poly_mul::<N>(&mut h, &g_dense, &f_inv);
1119        poly_scalar_mul::<N>(&mut h, 3, q_mask);
1120
1121        let mut pk_bytes = vec![0u8; params.pk_wire_bytes()];
1122        poly_to_arr::<N>(&h, &mut pk_bytes, params);
1123        return (pk_bytes, t);
1124    }
1125}
1126
1127pub fn encrypt<const N: usize, R: Csprng>(
1128    pk_bytes: &[u8],
1129    msg: &[u8],
1130    rng: &mut R,
1131    params: &EesParams,
1132) -> Result<Vec<u8>, NtruEesError> {
1133    debug_assert_eq!(params.n, N);
1134    if msg.len() > params.max_message_bytes() {
1135        return Err(NtruEesError::MessageTooLong);
1136    }
1137    let q_mask = params.q_mask();
1138    let mut h = poly_from_arr::<N>(pk_bytes, params);
1139    poly_mod_q::<N>(&mut h, q_mask);
1140
1141    let pklen_bytes = params.pklen_bytes();
1142    let htrunc = &pk_bytes[..pklen_bytes];
1143    let db_bytes = params.db_bytes();
1144    let max_msg = params.max_message_bytes();
1145
1146    loop {
1147        let mut b = vec![0u8; db_bytes];
1148        rng.fill_bytes(&mut b);
1149
1150        let m_len = db_bytes + 1 + max_msg + 1;
1151        let mut m = vec![0u8; m_len];
1152        m[..db_bytes].copy_from_slice(&b);
1153        m[db_bytes] = msg.len() as u8;
1154        m[db_bytes + 1..db_bytes + 1 + msg.len()].copy_from_slice(msg);
1155
1156        let mtrin = sves_from_bytes::<N>(&m, q_mask);
1157
1158        let mut sdata =
1159            Vec::with_capacity(params.oid.len() + msg.len() + b.len() + htrunc.len());
1160        sdata.extend_from_slice(&params.oid);
1161        sdata.extend_from_slice(msg);
1162        sdata.extend_from_slice(&b);
1163        sdata.extend_from_slice(htrunc);
1164
1165        let mut igf = IgfState::new(&sdata, params);
1166        let r = igf_gen_blinding(&mut igf);
1167
1168        let mut bigr = Poly::<N>::zero();
1169        r.mul_dense::<N>(&h, &mut bigr);
1170        poly_mod_q::<N>(&mut bigr, q_mask);
1171
1172        let or4 = poly_to_arr4::<N>(&bigr, params);
1173        let mask = mgf::<N>(&or4, params);
1174
1175        let mut mtrin_plus_mask = mtrin;
1176        poly_add::<N>(&mut mtrin_plus_mask, &mask);
1177        poly_mod3::<N>(&mut mtrin_plus_mask, params);
1178
1179        if !check_rep_weight::<N>(&mtrin_plus_mask, params) {
1180            continue;
1181        }
1182
1183        let mut e = bigr;
1184        for i in 0..N {
1185            let v = mtrin_plus_mask.coeffs[i];
1186            let signed: u16 = match v {
1187                0 => 0,
1188                1 => 1,
1189                2 => q_mask,
1190                _ => unreachable!(),
1191            };
1192            e.coeffs[i] = e.coeffs[i].wrapping_add(signed);
1193        }
1194        poly_mod_q::<N>(&mut e, q_mask);
1195
1196        let mut out = vec![0u8; params.ciphertext_wire_bytes()];
1197        poly_to_arr::<N>(&e, &mut out, params);
1198        return Ok(out);
1199    }
1200}
1201
1202pub fn decrypt<const N: usize>(
1203    sk_trapdoor: &Trapdoor,
1204    pk_bytes: &[u8],
1205    ct_bytes: &[u8],
1206    params: &EesParams,
1207) -> Result<Vec<u8>, NtruEesError> {
1208    debug_assert_eq!(params.n, N);
1209    let q_mask = params.q_mask();
1210    let e = poly_from_arr::<N>(ct_bytes, params);
1211
1212    let mut te = Poly::<N>::zero();
1213    sk_trapdoor.mul_dense::<N>(&e, &mut te);
1214    let mut ci = te;
1215    poly_scalar_mul::<N>(&mut ci, 3, q_mask);
1216    poly_add::<N>(&mut ci, &e);
1217    poly_mod_q::<N>(&mut ci, q_mask);
1218    poly_mod3::<N>(&mut ci, params);
1219
1220    let mut retcode_ok = check_rep_weight::<N>(&ci, params);
1221
1222    let mut c_r = e;
1223    let mut ci_modq = Poly::<N>::zero();
1224    for i in 0..N {
1225        ci_modq.coeffs[i] = match ci.coeffs[i] {
1226            0 => 0,
1227            1 => 1,
1228            2 => q_mask,
1229            _ => unreachable!(),
1230        };
1231    }
1232    poly_sub::<N>(&mut c_r, &ci_modq);
1233    poly_mod_q::<N>(&mut c_r, q_mask);
1234
1235    let or4 = poly_to_arr4::<N>(&c_r, params);
1236    let mask = mgf::<N>(&or4, params);
1237
1238    let mut cmtrin = ci;
1239    poly_sub::<N>(&mut cmtrin, &mask);
1240    poly_mod3::<N>(&mut cmtrin, params);
1241
1242    let cm = sves_to_bytes::<N>(&cmtrin).ok_or(NtruEesError::InvalidCiphertext)?;
1243
1244    let db_bytes = params.db_bytes();
1245    let max_msg = params.max_message_bytes();
1246    let cb = &cm[..db_bytes];
1247    let cl = cm[db_bytes] as usize;
1248    if cl > max_msg {
1249        return Err(NtruEesError::InvalidCiphertext);
1250    }
1251    let msg = cm[db_bytes + 1..db_bytes + 1 + cl].to_vec();
1252
1253    let pad_start = db_bytes + 1 + cl;
1254    let pad_end = (params.n * 3 + 1) / 2;
1255    let pad_end_bytes = pad_end.div_ceil(8);
1256    for &p in &cm[pad_start..pad_end_bytes.min(cm.len())] {
1257        if p != 0 {
1258            retcode_ok = false;
1259        }
1260    }
1261
1262    let pklen_bytes = params.pklen_bytes();
1263    let htrunc = &pk_bytes[..pklen_bytes];
1264    let mut sdata = Vec::with_capacity(params.oid.len() + cl + db_bytes + db_bytes);
1265    sdata.extend_from_slice(&params.oid);
1266    sdata.extend_from_slice(&msg);
1267    sdata.extend_from_slice(cb);
1268    sdata.extend_from_slice(htrunc);
1269    let mut igf = IgfState::new(&sdata, params);
1270    let cr_priv = igf_gen_blinding(&mut igf);
1271
1272    let h = poly_from_arr::<N>(pk_bytes, params);
1273    let mut bigr_prime = Poly::<N>::zero();
1274    cr_priv.mul_dense::<N>(&h, &mut bigr_prime);
1275    poly_mod_q::<N>(&mut bigr_prime, q_mask);
1276
1277    for i in 0..N {
1278        if bigr_prime.coeffs[i] != c_r.coeffs[i] {
1279            retcode_ok = false;
1280            break;
1281        }
1282    }
1283
1284    if !retcode_ok {
1285        return Err(NtruEesError::InvalidCiphertext);
1286    }
1287    Ok(msg)
1288}
1289
1290// ---- per-set wrapper macro --------------------------------------------------
1291//
1292// Each IEEE 1363.1 parameter set turns into a thin wrapper module: typed
1293// `*PublicKey`, `*PrivateKey`, `*Ciphertext` newtypes plus a `keygen` /
1294// `encrypt` / `decrypt` namespace, all delegating to the generic routines
1295// above. The macro takes the four wrapper type idents explicitly (avoiding a
1296// paste! dependency) plus the parameter values; each per-set source file is
1297// then a single macro invocation.
1298
1299macro_rules! define_ees_set {
1300    (
1301        namespace = $type_name:ident,
1302        public_key = $pk_ty:ident,
1303        private_key = $sk_ty:ident,
1304        ciphertext = $ct_ty:ident,
1305        n = $n:expr,
1306        trapdoor = $trapdoor:expr,
1307        dg = $dg:expr,
1308        dm0 = $dm0:expr,
1309        db_bits = $db_bits:expr,
1310        c_bits = $c_bits:expr,
1311        min_calls_r = $min_calls_r:expr,
1312        min_calls_mask = $min_calls_mask:expr,
1313        pklen_bits = $pklen_bits:expr,
1314        oid = $oid:expr,
1315        hash = $hash:expr,
1316        pk_bytes = $pk_bytes:expr,
1317        sk_packed_bytes = $sk_packed_bytes:expr,
1318        ct_bytes = $ct_bytes:expr,
1319        regression_digest = $regression_digest:expr $(,)?
1320    ) => {
1321        use $crate::public_key::ntru_ees_core::{
1322            decrypt as __ees_core_decrypt, encrypt as __ees_core_encrypt,
1323            keygen as __ees_core_keygen, padding_bits_clear as __ees_padding_bits_clear,
1324            trapdoor_from_wire as __ees_trapdoor_from_wire,
1325            trapdoor_to_wire as __ees_trapdoor_to_wire, EesParams, HashKind, NtruEesError,
1326            Trapdoor, TrapdoorKind,
1327        };
1328        use $crate::Csprng;
1329
1330        const PARAMS: EesParams = EesParams {
1331            n: $n,
1332            logq: 11,
1333            trapdoor: $trapdoor,
1334            dg: $dg,
1335            dm0: $dm0,
1336            db_bits: $db_bits,
1337            c_bits: $c_bits,
1338            min_calls_r: $min_calls_r,
1339            min_calls_mask: $min_calls_mask,
1340            pklen_bits: $pklen_bits,
1341            oid: $oid,
1342            hash: $hash,
1343        };
1344
1345        const N: usize = $n;
1346
1347        pub const PUBLIC_KEY_BYTES: usize = PARAMS.pk_wire_bytes();
1348        pub const PRIVATE_KEY_BYTES: usize = PARAMS.trapdoor_wire_bytes();
1349        pub const CIPHERTEXT_BYTES: usize = PARAMS.ciphertext_wire_bytes();
1350        pub const MAX_MESSAGE_BYTES: usize = PARAMS.max_message_bytes();
1351
1352        #[derive(Clone, Eq, PartialEq)]
1353        pub struct $pk_ty {
1354            bytes: Vec<u8>,
1355        }
1356
1357        #[derive(Clone, Eq, PartialEq)]
1358        pub struct $sk_ty {
1359            t: Trapdoor,
1360            pk: $pk_ty,
1361        }
1362
1363        #[derive(Clone, Eq, PartialEq)]
1364        pub struct $ct_ty {
1365            bytes: Vec<u8>,
1366        }
1367
1368        impl $pk_ty {
1369            #[must_use]
1370            pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
1371                if bytes.len() != PUBLIC_KEY_BYTES { return None; }
1372                if !__ees_padding_bits_clear(bytes, N * PARAMS.logq) {
1373                    return None;
1374                }
1375                Some(Self { bytes: bytes.to_vec() })
1376            }
1377
1378            #[must_use]
1379            pub fn to_wire_bytes(&self) -> Vec<u8> { self.bytes.clone() }
1380
1381            #[must_use]
1382            pub fn as_bytes(&self) -> &[u8] { &self.bytes }
1383        }
1384
1385        impl $sk_ty {
1386            #[must_use]
1387            pub fn to_wire_bytes(&self) -> Vec<u8> {
1388                let mut out = vec![0u8; PRIVATE_KEY_BYTES + PUBLIC_KEY_BYTES];
1389                __ees_trapdoor_to_wire(&self.t, &PARAMS, &mut out[..PRIVATE_KEY_BYTES]);
1390                out[PRIVATE_KEY_BYTES..].copy_from_slice(&self.pk.bytes);
1391                out
1392            }
1393
1394            #[must_use]
1395            pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
1396                if bytes.len() != PRIVATE_KEY_BYTES + PUBLIC_KEY_BYTES { return None; }
1397                let t = __ees_trapdoor_from_wire(&bytes[..PRIVATE_KEY_BYTES], &PARAMS)?;
1398                let pk = $pk_ty::from_wire_bytes(&bytes[PRIVATE_KEY_BYTES..])?;
1399                Some(Self { t, pk })
1400            }
1401
1402            #[must_use]
1403            pub fn public_key(&self) -> &$pk_ty { &self.pk }
1404        }
1405
1406        impl $ct_ty {
1407            #[must_use]
1408            pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
1409                if bytes.len() != CIPHERTEXT_BYTES { return None; }
1410                if !__ees_padding_bits_clear(bytes, N * PARAMS.logq) {
1411                    return None;
1412                }
1413                Some(Self { bytes: bytes.to_vec() })
1414            }
1415
1416            #[must_use]
1417            pub fn to_wire_bytes(&self) -> Vec<u8> { self.bytes.clone() }
1418
1419            #[must_use]
1420            pub fn as_bytes(&self) -> &[u8] { &self.bytes }
1421        }
1422
1423        impl ::core::fmt::Debug for $sk_ty {
1424            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1425                f.write_str(concat!(stringify!($sk_ty), "(<redacted>)"))
1426            }
1427        }
1428
1429        impl ::core::fmt::Debug for $pk_ty {
1430            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1431                f.debug_struct(stringify!($pk_ty)).finish()
1432            }
1433        }
1434
1435        impl ::core::fmt::Debug for $ct_ty {
1436            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1437                f.debug_struct(stringify!($ct_ty)).finish()
1438            }
1439        }
1440
1441        pub struct $type_name;
1442
1443        impl $type_name {
1444            /// Wire-format public-key length in bytes for this set.
1445            pub const PUBLIC_KEY_BYTES: usize = PUBLIC_KEY_BYTES;
1446            /// Wire-format private-key length in bytes for this set.
1447            pub const PRIVATE_KEY_BYTES: usize = PRIVATE_KEY_BYTES;
1448            /// Wire-format ciphertext length in bytes for this set.
1449            pub const CIPHERTEXT_BYTES: usize = CIPHERTEXT_BYTES;
1450            /// Maximum byte length of a message that
1451            /// [`Self::encrypt`] will accept; longer inputs return
1452            /// [`NtruEesError::MessageTooLong`].
1453            pub const MAX_MESSAGE_BYTES: usize = MAX_MESSAGE_BYTES;
1454
1455            pub fn keygen<R: Csprng>(rng: &mut R) -> ($pk_ty, $sk_ty) {
1456                let (pk_bytes, t) = __ees_core_keygen::<N, R>(&PARAMS, rng);
1457                let pk = $pk_ty { bytes: pk_bytes.clone() };
1458                let sk = $sk_ty { t, pk: pk.clone() };
1459                (pk, sk)
1460            }
1461
1462            pub fn encrypt<R: Csprng>(
1463                pk: &$pk_ty,
1464                msg: &[u8],
1465                rng: &mut R,
1466            ) -> Result<$ct_ty, NtruEesError> {
1467                let bytes = __ees_core_encrypt::<N, R>(&pk.bytes, msg, rng, &PARAMS)?;
1468                Ok($ct_ty { bytes })
1469            }
1470
1471            pub fn decrypt(sk: &$sk_ty, ct: &$ct_ty) -> Result<Vec<u8>, NtruEesError> {
1472                __ees_core_decrypt::<N>(&sk.t, &sk.pk.bytes, &ct.bytes, &PARAMS)
1473            }
1474        }
1475
1476        #[cfg(test)]
1477        mod tests {
1478            use super::*;
1479            use $crate::CtrDrbgAes256;
1480
1481            #[test]
1482            fn parameter_byte_lengths() {
1483                assert_eq!(PUBLIC_KEY_BYTES, $pk_bytes);
1484                assert_eq!(PRIVATE_KEY_BYTES, $sk_packed_bytes);
1485                assert_eq!(CIPHERTEXT_BYTES, $ct_bytes);
1486                assert!(MAX_MESSAGE_BYTES > 0);
1487            }
1488
1489            #[test]
1490            fn round_trip_empty_and_full_messages() {
1491                let mut drbg = CtrDrbgAes256::new(&[0x42u8; 48]);
1492                let (pk, sk) = $type_name::keygen(&mut drbg);
1493                for &len in &[0usize, 1, 16, 32, MAX_MESSAGE_BYTES] {
1494                    let mut msg = vec![0u8; len];
1495                    drbg.fill_bytes(&mut msg);
1496                    let ct = $type_name::encrypt(&pk, &msg, &mut drbg).expect("encrypt");
1497                    let dec = $type_name::decrypt(&sk, &ct).expect("decrypt");
1498                    assert_eq!(dec, msg, "round-trip at len={}", len);
1499                }
1500            }
1501
1502            #[test]
1503            fn rejects_oversize_message() {
1504                let mut drbg = CtrDrbgAes256::new(&[0x77u8; 48]);
1505                let (pk, _) = $type_name::keygen(&mut drbg);
1506                let too_big = vec![0u8; MAX_MESSAGE_BYTES + 1];
1507                let err = $type_name::encrypt(&pk, &too_big, &mut drbg).unwrap_err();
1508                assert_eq!(err, NtruEesError::MessageTooLong);
1509            }
1510
1511            #[test]
1512            fn corrupted_ciphertext_rejected() {
1513                let mut drbg = CtrDrbgAes256::new(&[0x99u8; 48]);
1514                let (pk, sk) = $type_name::keygen(&mut drbg);
1515                let msg = b"hello ntru";
1516                let ct = $type_name::encrypt(&pk, msg, &mut drbg).expect("encrypt");
1517                let mut bad_bytes = ct.to_wire_bytes();
1518                bad_bytes[10] ^= 0xff;
1519                let bad_ct = $ct_ty::from_wire_bytes(&bad_bytes).expect("structural decode");
1520                match $type_name::decrypt(&sk, &bad_ct) {
1521                    Err(NtruEesError::InvalidCiphertext) => {}
1522                    other => panic!("expected InvalidCiphertext, got {:?}", other),
1523                }
1524            }
1525
1526            /// Regression vector: locks in the byte-level encoding of pk,
1527            /// sk, and ct under a fixed DRBG seed and message. Computed
1528            /// once via `cargo run --bin ees_regression_gen`; a future
1529            /// refactor that silently changes wire-format byte order or
1530            /// padding will fail this digest check.
1531            #[test]
1532            fn byte_format_regression_digest() {
1533                use $crate::hash::sha2::Sha256;
1534                let mut drbg = CtrDrbgAes256::new(&[0xC0u8; 48]);
1535                let (pk, sk) = $type_name::keygen(&mut drbg);
1536                let ct = $type_name::encrypt(&pk, &[0xA5u8; 8], &mut drbg)
1537                    .expect("encrypt");
1538                let mut h = Sha256::new();
1539                h.update(&pk.to_wire_bytes());
1540                h.update(&sk.to_wire_bytes());
1541                h.update(&ct.to_wire_bytes());
1542                let digest = h.finalize();
1543                let mut hex = String::with_capacity(64);
1544                for b in digest.iter() {
1545                    use ::core::fmt::Write;
1546                    write!(&mut hex, "{:02x}", b).unwrap();
1547                }
1548                assert_eq!(hex, $regression_digest, "byte-format regression");
1549            }
1550
1551            #[test]
1552            fn wire_format_roundtrip_keys_and_ct() {
1553                let mut drbg = CtrDrbgAes256::new(&[0xa0u8; 48]);
1554                let (pk, sk) = $type_name::keygen(&mut drbg);
1555                let msg = b"wire-format-roundtrip";
1556                let ct = $type_name::encrypt(&pk, msg, &mut drbg).expect("encrypt");
1557
1558                let pk_round = $pk_ty::from_wire_bytes(&pk.to_wire_bytes()).expect("pk decode");
1559                let sk_round = $sk_ty::from_wire_bytes(&sk.to_wire_bytes()).expect("sk decode");
1560                let ct_round = $ct_ty::from_wire_bytes(&ct.to_wire_bytes()).expect("ct decode");
1561
1562                assert_eq!(pk_round, pk);
1563                assert_eq!(sk_round, sk);
1564                assert_eq!(ct_round, ct);
1565
1566                let dec = $type_name::decrypt(&sk_round, &ct_round).expect("decrypt");
1567                assert_eq!(dec, msg);
1568            }
1569        }
1570    };
1571}
1572
1573pub(crate) use define_ees_set;