1use crate::hash::sha1::Sha1;
37use crate::hash::sha2::Sha256;
38use crate::Csprng;
39
40#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44pub enum HashKind {
45 Sha1,
46 Sha256,
47}
48
49impl HashKind {
50 pub const fn output_len(self) -> usize {
52 match self {
53 HashKind::Sha1 => 20,
54 HashKind::Sha256 => 32,
55 }
56 }
57
58 fn digest_into(self, input: &[u8], out: &mut [u8]) {
59 match self {
60 HashKind::Sha1 => out.copy_from_slice(Sha1::digest(input).as_slice()),
61 HashKind::Sha256 => out.copy_from_slice(Sha256::digest(input).as_slice()),
62 }
63 }
64
65 fn digest_two_into(self, prefix: &[u8], suffix: &[u8], out: &mut [u8]) {
70 match self {
71 HashKind::Sha1 => {
72 let mut h = Sha1::new();
73 h.update(prefix);
74 h.update(suffix);
75 out.copy_from_slice(h.finalize().as_slice());
76 }
77 HashKind::Sha256 => {
78 let mut h = Sha256::new();
79 h.update(prefix);
80 h.update(suffix);
81 out.copy_from_slice(h.finalize().as_slice());
82 }
83 }
84 }
85}
86
87#[derive(Clone, Copy, Debug, Eq, PartialEq)]
89pub enum TrapdoorKind {
90 Dense {
92 df: usize,
93 },
94 ProductForm {
97 df1: usize,
98 df2: usize,
99 df3: usize,
100 },
101}
102
103#[derive(Clone, Copy, Debug, Eq, PartialEq)]
105pub struct EesParams {
106 pub n: usize,
107 pub logq: usize,
109 pub trapdoor: TrapdoorKind,
110 pub dg: usize,
111 pub dm0: usize,
112 pub db_bits: usize,
114 pub c_bits: usize,
116 pub min_calls_r: usize,
117 pub min_calls_mask: usize,
118 pub pklen_bits: usize,
119 pub oid: [u8; 3],
120 pub hash: HashKind,
121}
122
123impl EesParams {
124 pub const fn db_bytes(&self) -> usize {
125 self.db_bits / 8
126 }
127 pub const fn pklen_bytes(&self) -> usize {
128 self.pklen_bits.div_ceil(8)
129 }
130 pub const fn q(&self) -> u32 {
131 1u32 << self.logq
132 }
133 pub const fn q_mask(&self) -> u16 {
134 ((1u32 << self.logq) - 1) as u16
135 }
136 pub const fn pk_wire_bytes(&self) -> usize {
138 (self.n * self.logq).div_ceil(8)
139 }
140 pub const fn trapdoor_wire_bytes(&self) -> usize {
143 match self.trapdoor {
144 TrapdoorKind::Dense { .. } => (self.n * 2).div_ceil(8),
145 TrapdoorKind::ProductForm { df1, df2, df3 } => {
146 let indices = 2 * (df1 + df2 + df3);
147 let bits = indices * Self::index_bits(self.n);
148 bits.div_ceil(8)
149 }
150 }
151 }
152 const fn index_bits(n: usize) -> usize {
155 let mut bits = 0usize;
157 let mut v = n.saturating_sub(1);
158 while v > 0 {
159 bits += 1;
160 v >>= 1;
161 }
162 bits
163 }
164 pub const fn ciphertext_wire_bytes(&self) -> usize {
165 self.pk_wire_bytes()
166 }
167 pub const fn max_message_bytes(&self) -> usize {
168 self.n / 2 * 3 / 8 - 1 - self.db_bytes()
169 }
170}
171
172#[derive(Clone, Copy)]
176pub struct Poly<const N: usize> {
177 pub coeffs: [u16; N],
178}
179
180impl<const N: usize> Poly<N> {
181 pub fn zero() -> Self {
182 Self { coeffs: [0u16; N] }
183 }
184}
185
186#[inline(always)]
187fn modq(x: u16, q_mask: u16) -> u16 {
188 x & q_mask
189}
190
191pub fn poly_mul<const N: usize>(r: &mut Poly<N>, a: &Poly<N>, b: &Poly<N>) {
192 crate::public_key::ntru_poly_mul::poly_mul_cyclic(&mut r.coeffs, &a.coeffs, &b.coeffs);
193}
194
195pub fn poly_add<const N: usize>(a: &mut Poly<N>, b: &Poly<N>) {
196 for i in 0..N {
197 a.coeffs[i] = a.coeffs[i].wrapping_add(b.coeffs[i]);
198 }
199}
200
201pub fn poly_sub<const N: usize>(a: &mut Poly<N>, b: &Poly<N>) {
202 for i in 0..N {
203 a.coeffs[i] = a.coeffs[i].wrapping_sub(b.coeffs[i]);
204 }
205}
206
207pub fn poly_mod3<const N: usize>(a: &mut Poly<N>, params: &EesParams) {
210 let q = params.q();
211 let q_mask = params.q_mask();
212 for c in a.coeffs.iter_mut() {
213 let m = modq(*c, q_mask);
214 let centred = if (m as u32) > q / 2 {
215 m as i32 - q as i32
216 } else {
217 m as i32
218 };
219 let r = centred.rem_euclid(3);
220 *c = r as u16;
221 }
222}
223
224pub fn poly_scalar_mul<const N: usize>(a: &mut Poly<N>, k: u16, q_mask: u16) {
225 for c in a.coeffs.iter_mut() {
226 *c = c.wrapping_mul(k) & q_mask;
227 }
228}
229
230pub fn poly_mod_q<const N: usize>(a: &mut Poly<N>, q_mask: u16) {
231 for c in a.coeffs.iter_mut() {
232 *c = modq(*c, q_mask);
233 }
234}
235
236#[derive(Clone, Eq, PartialEq)]
239pub struct TernaryPoly {
240 pub ones: Vec<u16>,
241 pub neg_ones: Vec<u16>,
242}
243
244impl TernaryPoly {
245 pub fn to_dense<const N: usize>(&self, q_mask: u16) -> Poly<N> {
246 let mut p = Poly::<N>::zero();
247 for &i in &self.ones {
248 p.coeffs[i as usize] = 1;
249 }
250 for &i in &self.neg_ones {
251 p.coeffs[i as usize] = q_mask;
252 }
253 p
254 }
255
256 pub fn mul_dense<const N: usize>(&self, b: &Poly<N>, out: &mut Poly<N>) {
257 for c in out.coeffs.iter_mut() {
258 *c = 0;
259 }
260 for &idx in &self.ones {
261 let s = idx as usize;
262 for j in 0..N {
263 let k = if s + j >= N { s + j - N } else { s + j };
264 out.coeffs[k] = out.coeffs[k].wrapping_add(b.coeffs[j]);
265 }
266 }
267 for &idx in &self.neg_ones {
268 let s = idx as usize;
269 for j in 0..N {
270 let k = if s + j >= N { s + j - N } else { s + j };
271 out.coeffs[k] = out.coeffs[k].wrapping_sub(b.coeffs[j]);
272 }
273 }
274 }
275}
276
277#[derive(Clone, Eq, PartialEq)]
278pub struct ProductPoly {
279 pub f1: TernaryPoly,
280 pub f2: TernaryPoly,
281 pub f3: TernaryPoly,
282}
283
284impl ProductPoly {
285 pub fn mul_dense<const N: usize>(&self, a: &Poly<N>, out: &mut Poly<N>) {
286 let mut t1 = Poly::<N>::zero();
287 self.f1.mul_dense::<N>(a, &mut t1);
288 self.f2.mul_dense::<N>(&t1, out);
289 let mut t3 = Poly::<N>::zero();
290 self.f3.mul_dense::<N>(a, &mut t3);
291 poly_add::<N>(out, &t3);
292 }
293
294 pub fn to_dense<const N: usize>(&self, q_mask: u16) -> Poly<N> {
295 let f2_dense = self.f2.to_dense::<N>(q_mask);
296 let mut out = Poly::<N>::zero();
297 self.f1.mul_dense::<N>(&f2_dense, &mut out);
298 let f3_dense = self.f3.to_dense::<N>(q_mask);
299 poly_add::<N>(&mut out, &f3_dense);
300 out
301 }
302}
303
304#[derive(Clone, Eq, PartialEq)]
306pub enum Trapdoor {
307 Dense(TernaryPoly),
308 Product(ProductPoly),
309}
310
311impl Trapdoor {
312 fn mul_dense<const N: usize>(&self, a: &Poly<N>, out: &mut Poly<N>) {
313 match self {
314 Trapdoor::Dense(t) => t.mul_dense::<N>(a, out),
315 Trapdoor::Product(p) => p.mul_dense::<N>(a, out),
316 }
317 }
318
319 fn to_dense<const N: usize>(&self, q_mask: u16) -> Poly<N> {
320 match self {
321 Trapdoor::Dense(t) => t.to_dense::<N>(q_mask),
322 Trapdoor::Product(p) => p.to_dense::<N>(q_mask),
323 }
324 }
325
326 pub fn to_wire(&self, params: &EesParams, out: &mut [u8]) {
329 debug_assert_eq!(out.len(), params.trapdoor_wire_bytes());
330 for b in out.iter_mut() {
331 *b = 0;
332 }
333 match self {
334 Trapdoor::Dense(t) => {
335 for &i in &t.ones {
337 let bit_pos = 2 * (i as usize);
338 out[bit_pos / 8] |= 1 << (bit_pos % 8);
339 }
340 for &i in &t.neg_ones {
341 let bit_pos = 2 * (i as usize);
342 out[bit_pos / 8] |= 3 << (bit_pos % 8);
343 }
344 }
345 Trapdoor::Product(p) => {
346 let mut bit_offset = 0usize;
347 let index_bits = EesParams::index_bits(params.n);
348 for poly in &[&p.f1, &p.f2, &p.f3] {
349 pack_indices(&poly.ones, out, &mut bit_offset, index_bits)
350 .expect("ones fit");
351 pack_indices(&poly.neg_ones, out, &mut bit_offset, index_bits)
352 .expect("neg_ones fit");
353 }
354 }
355 }
356 }
357
358 pub fn from_wire(bytes: &[u8], params: &EesParams) -> Option<Self> {
360 if bytes.len() != params.trapdoor_wire_bytes() {
361 return None;
362 }
363 match params.trapdoor {
364 TrapdoorKind::Dense { df } => {
365 let n = params.n;
366 let mut bit_pos = 0usize;
367 let mut ones = Vec::new();
368 let mut neg_ones = Vec::new();
369 for i in 0..n {
370 let code = (bytes[bit_pos / 8] >> (bit_pos % 8)) & 0x3;
371 bit_pos += 2;
372 match code {
373 0 => {}
374 1 => ones.push(i as u16),
375 3 => neg_ones.push(i as u16),
376 _ => return None,
377 }
378 }
379 if ones.len() != df || neg_ones.len() != df {
380 return None;
381 }
382 if !padding_bits_clear(bytes, n * 2) {
383 return None;
384 }
385 Some(Trapdoor::Dense(TernaryPoly { ones, neg_ones }))
386 }
387 TrapdoorKind::ProductForm { df1, df2, df3 } => {
388 let mut bit_offset = 0usize;
389 let index_bits = EesParams::index_bits(params.n);
390 let n = params.n;
391 let f1_ones = unpack_indices(bytes, df1, &mut bit_offset, index_bits, n)?;
392 let f1_neg = unpack_indices(bytes, df1, &mut bit_offset, index_bits, n)?;
393 let f2_ones = unpack_indices(bytes, df2, &mut bit_offset, index_bits, n)?;
394 let f2_neg = unpack_indices(bytes, df2, &mut bit_offset, index_bits, n)?;
395 let f3_ones = unpack_indices(bytes, df3, &mut bit_offset, index_bits, n)?;
396 let f3_neg = unpack_indices(bytes, df3, &mut bit_offset, index_bits, n)?;
397 if !padding_bits_clear(bytes, bit_offset) {
398 return None;
399 }
400 Some(Trapdoor::Product(ProductPoly {
401 f1: TernaryPoly { ones: f1_ones, neg_ones: f1_neg },
402 f2: TernaryPoly { ones: f2_ones, neg_ones: f2_neg },
403 f3: TernaryPoly { ones: f3_ones, neg_ones: f3_neg },
404 }))
405 }
406 }
407 }
408
409 fn sample_iid<R: Csprng>(rng: &mut R, params: &EesParams) -> Self {
413 match params.trapdoor {
414 TrapdoorKind::Dense { df } => {
415 Trapdoor::Dense(sample_trinary(rng, params.n, df, df))
416 }
417 TrapdoorKind::ProductForm { df1, df2, df3 } => Trapdoor::Product(ProductPoly {
418 f1: sample_trinary(rng, params.n, df1, df1),
419 f2: sample_trinary(rng, params.n, df2, df2),
420 f3: sample_trinary(rng, params.n, df3, df3),
421 }),
422 }
423 }
424
425 fn sample_via_igf(state: &mut IgfState<'_>) -> Self {
428 match state.params.trapdoor {
429 TrapdoorKind::Dense { df } => Trapdoor::Dense(igf_gen_ternary(state, df)),
430 TrapdoorKind::ProductForm { df1, df2, df3 } => Trapdoor::Product(ProductPoly {
431 f1: igf_gen_ternary(state, df1),
432 f2: igf_gen_ternary(state, df2),
433 f3: igf_gen_ternary(state, df3),
434 }),
435 }
436 }
437}
438
439fn poly_trim(p: &mut Vec<u8>) {
442 while p.len() > 1 && *p.last().unwrap() == 0 {
443 p.pop();
444 }
445}
446
447fn poly_deg(p: &[u8]) -> Option<usize> {
448 for i in (0..p.len()).rev() {
449 if p[i] != 0 {
450 return Some(i);
451 }
452 }
453 None
454}
455
456fn poly_inverse_mod2_cyclic(a_coeffs: &[u8]) -> Option<Vec<u8>> {
457 let n = a_coeffs.len();
458 let mut r0 = vec![0u8; n + 1];
459 r0[0] = 1;
460 r0[n] = 1;
461 let mut r1: Vec<u8> = a_coeffs.iter().map(|&c| c & 1).collect();
462 poly_trim(&mut r1);
463 let mut t0 = vec![0u8; 1];
464 let mut t1 = vec![1u8; 1];
465
466 loop {
467 let d1 = match poly_deg(&r1) {
468 Some(d) => d,
469 None => break,
470 };
471 let d0 = match poly_deg(&r0) {
472 Some(d) => d,
473 None => {
474 std::mem::swap(&mut r0, &mut r1);
475 std::mem::swap(&mut t0, &mut t1);
476 break;
477 }
478 };
479 if d0 < d1 {
480 std::mem::swap(&mut r0, &mut r1);
481 std::mem::swap(&mut t0, &mut t1);
482 continue;
483 }
484 let shift = d0 - d1;
485 for i in 0..=d1 {
486 r0[shift + i] ^= r1[i];
487 }
488 poly_trim(&mut r0);
489 let new_t0_len = t0.len().max(t1.len() + shift);
490 if t0.len() < new_t0_len {
491 t0.resize(new_t0_len, 0);
492 }
493 for i in 0..t1.len() {
494 t0[shift + i] ^= t1[i];
495 }
496 }
497
498 if !(r0.len() == 1 && r0[0] == 1) {
499 return None;
500 }
501 let mut out = vec![0u8; n];
502 for (i, &c) in t0.iter().enumerate() {
503 if c & 1 == 1 {
504 out[i % n] ^= 1;
505 }
506 }
507 Some(out)
508}
509
510fn poly_inverse_mod_q_cyclic<const N: usize>(
511 a: &Poly<N>,
512 params: &EesParams,
513) -> Option<Poly<N>> {
514 let q = params.q();
515 let q_mask = params.q_mask();
516 let a_mod2: Vec<u8> = a.coeffs.iter().map(|&c| (c & 1) as u8).collect();
517 let inv2 = poly_inverse_mod2_cyclic(&a_mod2)?;
518
519 let mut b = Poly::<N>::zero();
520 for i in 0..N {
521 b.coeffs[i] = inv2[i] as u16;
522 }
523
524 let mut precision: u32 = 2;
531 while precision < q {
532 let mut ab = Poly::<N>::zero();
533 poly_mul::<N>(&mut ab, a, &b);
534 poly_mod_q::<N>(&mut ab, q_mask);
535 let mut two_minus_ab = Poly::<N>::zero();
536 two_minus_ab.coeffs[0] = 2u16.wrapping_sub(ab.coeffs[0]) & q_mask;
537 for i in 1..N {
538 two_minus_ab.coeffs[i] = 0u16.wrapping_sub(ab.coeffs[i]) & q_mask;
539 }
540 let mut new_b = Poly::<N>::zero();
541 poly_mul::<N>(&mut new_b, &b, &two_minus_ab);
542 poly_mod_q::<N>(&mut new_b, q_mask);
543 b = new_b;
544 precision = precision.saturating_mul(precision);
545 }
546 Some(b)
547}
548
549#[derive(Clone)]
561struct BitStr {
562 buf: Vec<u8>,
563 bit_len: usize,
564}
565
566impl BitStr {
567 fn new() -> Self {
568 Self { buf: Vec::new(), bit_len: 0 }
569 }
570
571 fn append_byte(&mut self, b: u8) {
573 let off = self.bit_len % 8;
574 if off == 0 {
575 self.buf.push(b);
576 } else {
577 *self
578 .buf
579 .last_mut()
580 .expect("non-empty by `bit_len > 0`") |= b << off;
581 self.buf.push(b >> (8 - off));
582 }
583 self.bit_len += 8;
584 }
585
586 fn append(&mut self, bytes: &[u8]) {
587 for &b in bytes {
588 self.append_byte(b);
589 }
590 }
591
592 fn leading(&self, num_bits: u8) -> u32 {
595 let n = num_bits as usize;
596 debug_assert!(n <= 32 && n <= self.bit_len);
597 let start = self.bit_len - n;
598 let mut v: u32 = 0;
599 for i in 0..n {
600 let p = start + i;
601 v |= u32::from((self.buf[p / 8] >> (p % 8)) & 1) << i;
602 }
603 v
604 }
605
606 fn truncate(&mut self, num_bits: u8) {
610 let n = num_bits as usize;
611 debug_assert!(n <= self.bit_len);
612 self.bit_len -= n;
613 let needed = self.bit_len.div_ceil(8);
614 self.buf.truncate(needed);
615 let off = self.bit_len % 8;
616 if off != 0 {
617 let last = self.buf.last_mut().expect("non-empty by needed > 0");
618 *last &= (1u8 << off) - 1;
619 }
620 }
621
622 fn trailing(&self, num_bits: u32) -> Self {
626 let n = num_bits as usize;
627 debug_assert!(n <= self.bit_len);
628 let needed = n.div_ceil(8);
629 let mut buf = self.buf[..needed].to_vec();
630 let off = n % 8;
631 if off != 0 {
632 *buf.last_mut().expect("needed > 0") &= (1u8 << off) - 1;
633 }
634 Self { buf, bit_len: n }
635 }
636}
637
638struct IgfState<'a> {
639 z: Vec<u8>,
640 counter: u16,
641 buf: BitStr,
642 rem_bits: u32,
643 params: &'a EesParams,
644}
645
646impl<'a> IgfState<'a> {
647 fn new(seed: &[u8], params: &'a EesParams) -> Self {
648 debug_assert!(
654 params.c_bits <= u8::MAX as usize,
655 "IGF c_bits must fit in a u8"
656 );
657 let hlen = params.hash.output_len();
658 let mut s = Self {
659 z: seed.to_vec(),
660 counter: 0,
661 buf: BitStr::new(),
662 rem_bits: (params.min_calls_r * 8 * hlen) as u32,
663 params,
664 };
665 while (s.counter as usize) < params.min_calls_r {
666 s.absorb_one();
667 }
668 s
669 }
670 fn absorb_one(&mut self) {
671 let hlen = self.params.hash.output_len();
672 let mut out = [0u8; 64];
673 self.params
674 .hash
675 .digest_two_into(&self.z, &self.counter.to_le_bytes(), &mut out[..hlen]);
676 self.buf.append(&out[..hlen]);
677 self.counter = self.counter.wrapping_add(1);
678 }
679 fn next_index(&mut self) -> u16 {
680 let n = self.params.n as u32;
681 let c = self.params.c_bits as u8;
682 let hlen = self.params.hash.output_len();
683 let rnd_thresh: u32 = (1u32 << c) - (1u32 << c) % n;
686 loop {
687 if self.rem_bits < c as u32 {
688 let mut tail = self.buf.trailing(self.rem_bits);
689 let need = (c as u32) - self.rem_bits;
690 let extra_calls = need.div_ceil((hlen as u32) * 8);
691 let mut out = [0u8; 64];
692 for _ in 0..extra_calls {
693 self.params.hash.digest_two_into(
694 &self.z,
695 &self.counter.to_le_bytes(),
696 &mut out[..hlen],
697 );
698 tail.append(&out[..hlen]);
699 self.counter = self.counter.wrapping_add(1);
700 self.rem_bits += 8 * hlen as u32;
701 }
702 self.buf = tail;
703 }
704 let v = self.buf.leading(c);
705 self.buf.truncate(c);
706 self.rem_bits -= c as u32;
707 if v < rnd_thresh {
708 return (v % n) as u16;
709 }
710 }
711 }
712}
713
714fn igf_gen_ternary(state: &mut IgfState<'_>, num_each: usize) -> TernaryPoly {
715 let n = state.params.n;
716 let mut occupied = vec![false; n];
717 let mut neg_ones = Vec::with_capacity(num_each);
718 let mut ones = Vec::with_capacity(num_each);
719 while neg_ones.len() < num_each {
720 let idx = state.next_index();
721 if !occupied[idx as usize] {
722 occupied[idx as usize] = true;
723 neg_ones.push(idx);
724 }
725 }
726 while ones.len() < num_each {
727 let idx = state.next_index();
728 if !occupied[idx as usize] {
729 occupied[idx as usize] = true;
730 ones.push(idx);
731 }
732 }
733 neg_ones.sort_unstable();
734 ones.sort_unstable();
735 TernaryPoly { ones, neg_ones }
736}
737
738fn igf_gen_blinding(state: &mut IgfState<'_>) -> Trapdoor {
739 Trapdoor::sample_via_igf(state)
740}
741
742const MGF_TRIT_TABLE: [[i8; 5]; 243] = {
748 let mut t = [[0i8; 5]; 243];
749 let map = [0i8, 1, -1];
750 let mut byte = 0usize;
751 while byte < 243 {
752 let mut v = byte;
753 let mut slot = 0usize;
754 while slot < 5 {
755 t[byte][slot] = map[v % 3];
756 v /= 3;
757 slot += 1;
758 }
759 byte += 1;
760 }
761 t
762};
763
764fn mgf<const N: usize>(seed: &[u8], params: &EesParams) -> Poly<N> {
765 let hlen = params.hash.output_len();
766 let q_mask = params.q_mask();
767 let mut z = [0u8; 64];
768 params.hash.digest_into(seed, &mut z[..hlen]);
769
770 let mut buf: Vec<u8> = Vec::with_capacity(params.min_calls_mask * hlen);
771 let mut counter: u16 = 0;
772 let mut h = [0u8; 64];
773 while (counter as usize) < params.min_calls_mask {
774 params
775 .hash
776 .digest_two_into(&z[..hlen], &counter.to_be_bytes(), &mut h[..hlen]);
777 for &b in &h[..hlen] {
778 if b < 243 {
779 buf.push(b);
780 }
781 }
782 counter = counter.wrapping_add(1);
783 }
784
785 let mut out = Poly::<N>::zero();
786 let mut cur = 0usize;
787 let counter_ceiling = (params.min_calls_mask as u16).saturating_add(1024);
795 'outer: loop {
796 for &b in &buf {
797 for &t in &MGF_TRIT_TABLE[b as usize] {
798 out.coeffs[cur] = match t {
799 -1 => q_mask,
800 0 => 0,
801 1 => 1,
802 _ => unreachable!(),
803 };
804 cur += 1;
805 if cur >= N {
806 break 'outer;
807 }
808 }
809 }
810 assert!(
811 counter < counter_ceiling,
812 "MGF rejection sampler exceeded counter ceiling โ hash output is pathologically biased"
813 );
814 params
815 .hash
816 .digest_two_into(&z[..hlen], &counter.to_be_bytes(), &mut h[..hlen]);
817 buf.clear();
818 for &b in &h[..hlen] {
819 if b < 243 {
820 buf.push(b);
821 }
822 }
823 counter = counter.wrapping_add(1);
824 }
825 out
826}
827
828const SVES_C1: [i8; 8] = [0, 0, 0, 1, 1, 1, -1, -1];
831const SVES_C2: [i8; 8] = [0, 1, -1, 0, 1, -1, 0, 1];
832
833fn trit_to_u16(t: i8, q_mask: u16) -> u16 {
834 match t {
835 -1 => q_mask,
836 0 => 0,
837 1 => 1,
838 _ => unreachable!(),
839 }
840}
841
842fn sves_from_bytes<const N: usize>(m: &[u8], q_mask: u16) -> Poly<N> {
843 let mut out = Poly::<N>::zero();
844 let mut coeff_idx: usize = 0;
845 let mut i = 0usize;
846 while i + 3 <= ((m.len() + 2) / 3) * 3 && coeff_idx < N - 1 {
847 let b0 = if i < m.len() { m[i] } else { 0 } as u32;
848 let b1 = if i + 1 < m.len() { m[i + 1] } else { 0 } as u32;
849 let b2 = if i + 2 < m.len() { m[i + 2] } else { 0 } as u32;
850 let mut chunk = (b2 << 16) | (b1 << 8) | b0;
851 i += 3;
852 for _ in 0..8 {
853 if coeff_idx >= N - 1 {
854 break;
855 }
856 let tbl = (chunk & 7) as usize;
857 out.coeffs[coeff_idx] = trit_to_u16(SVES_C1[tbl], q_mask);
858 out.coeffs[coeff_idx + 1] = trit_to_u16(SVES_C2[tbl], q_mask);
859 coeff_idx += 2;
860 chunk >>= 3;
861 }
862 }
863 out
864}
865
866fn sves_to_bytes<const N: usize>(p: &Poly<N>) -> Option<Vec<u8>> {
867 let num_bits = (N * 3 + 1) / 2;
868 let num_bytes = num_bits.div_ceil(8);
869 let mut out = vec![0u8; num_bytes + 3];
870 let end = N / 2 * 2;
871 let mut d_idx = 0usize;
872 let mut i = 0usize;
873 while i < end {
874 let mut acc: u32 = 0;
875 let mut bits_in_acc: u32 = 0;
876 for _ in 0..8 {
877 if i >= end {
878 break;
879 }
880 let c1 = p.coeffs[i] as i32;
881 let c2 = p.coeffs[i + 1] as i32;
882 i += 2;
883 if c1 == 2 && c2 == 2 {
884 return None;
885 }
886 let c = (c1 * 3 + c2) as u32;
887 acc |= c << bits_in_acc;
888 bits_in_acc += 3;
889 while bits_in_acc >= 8 && d_idx < out.len() {
890 out[d_idx] = (acc & 0xff) as u8;
891 d_idx += 1;
892 acc >>= 8;
893 bits_in_acc -= 8;
894 }
895 }
896 if bits_in_acc > 0 && d_idx < out.len() {
897 out[d_idx] |= acc as u8;
898 }
899 }
900 out.truncate(num_bytes);
901 Some(out)
902}
903
904fn poly_to_arr<const N: usize>(p: &Poly<N>, out: &mut [u8], params: &EesParams) {
907 let logq = params.logq;
908 let q_mask = params.q_mask();
909 debug_assert_eq!(out.len(), params.pk_wire_bytes());
910 for b in out.iter_mut() {
911 *b = 0;
912 }
913 let mut bit_pos = 0usize;
914 for i in 0..N {
915 let v = (p.coeffs[i] & q_mask) as u32;
916 for b in 0..logq {
917 let bit = ((v >> b) & 1) as u8;
918 out[bit_pos / 8] |= bit << (bit_pos % 8);
919 bit_pos += 1;
920 }
921 }
922}
923
924fn poly_from_arr<const N: usize>(input: &[u8], params: &EesParams) -> Poly<N> {
925 let logq = params.logq;
926 debug_assert!(input.len() >= params.pk_wire_bytes());
927 let mut p = Poly::<N>::zero();
928 let mut bit_pos = 0usize;
929 for i in 0..N {
930 let mut v: u32 = 0;
931 for b in 0..logq {
932 let bit = ((input[bit_pos / 8] >> (bit_pos % 8)) & 1) as u32;
933 v |= bit << b;
934 bit_pos += 1;
935 }
936 p.coeffs[i] = v as u16;
937 }
938 p
939}
940
941fn poly_to_arr4<const N: usize>(p: &Poly<N>, params: &EesParams) -> Vec<u8> {
942 let q = params.q();
943 let q_mask = params.q_mask();
944 let nbits = N * 2;
945 let mut out = vec![0u8; nbits.div_ceil(8)];
946 let mut bit_pos = 0usize;
947 for i in 0..N {
948 let centred = {
949 let m = p.coeffs[i] & q_mask;
950 let centred = if (m as u32) > q / 2 {
951 m as i32 - q as i32
952 } else {
953 m as i32
954 };
955 (centred & 3) as u8
956 };
957 for b in 0..2 {
958 let bit = (centred >> b) & 1;
959 out[bit_pos / 8] |= bit << (bit_pos % 8);
960 bit_pos += 1;
961 }
962 }
963 out
964}
965
966fn pack_indices(
969 indices: &[u16],
970 out: &mut [u8],
971 bit_offset: &mut usize,
972 index_bits: usize,
973) -> Option<()> {
974 for &v in indices {
975 if (v as usize) >= (1usize << index_bits) {
976 return None;
977 }
978 for i in 0..index_bits {
979 let bit = ((v >> i) & 1) as u8;
980 out[*bit_offset / 8] |= bit << (*bit_offset % 8);
981 *bit_offset += 1;
982 }
983 }
984 Some(())
985}
986
987fn unpack_indices(
988 bytes: &[u8],
989 n: usize,
990 bit_offset: &mut usize,
991 index_bits: usize,
992 n_max: usize,
993) -> Option<Vec<u16>> {
994 let mut out = Vec::with_capacity(n);
995 for _ in 0..n {
996 let mut v: u32 = 0;
997 for i in 0..index_bits {
998 let bit = ((bytes[*bit_offset / 8] >> (*bit_offset % 8)) & 1) as u32;
999 v |= bit << i;
1000 *bit_offset += 1;
1001 }
1002 if (v as usize) >= n_max {
1003 return None;
1004 }
1005 out.push(v as u16);
1006 }
1007 Some(out)
1008}
1009
1010#[doc(hidden)]
1014pub fn padding_bits_clear(bytes: &[u8], used_bits: usize) -> bool {
1015 debug_assert!(used_bits <= bytes.len() * 8);
1016 let total = bytes.len() * 8;
1017 if total == used_bits {
1018 return true;
1019 }
1020 let last = *bytes.last().expect("non-empty by construction");
1021 let used_in_last = used_bits - (bytes.len() - 1) * 8;
1022 (last >> used_in_last) == 0
1023}
1024
1025pub fn trapdoor_to_wire(t: &Trapdoor, params: &EesParams, out: &mut [u8]) {
1028 t.to_wire(params, out);
1029}
1030
1031pub fn trapdoor_from_wire(bytes: &[u8], params: &EesParams) -> Option<Trapdoor> {
1034 Trapdoor::from_wire(bytes, params)
1035}
1036
1037fn next_index_below<R: Csprng>(rng: &mut R, modulus: u32) -> u32 {
1040 let threshold = u32::MAX - (u32::MAX % modulus);
1041 loop {
1042 let mut buf = [0u8; 4];
1043 rng.fill_bytes(&mut buf);
1044 let v = u32::from_le_bytes(buf);
1045 if v < threshold {
1046 return v % modulus;
1047 }
1048 }
1049}
1050
1051fn sample_trinary<R: Csprng>(
1052 rng: &mut R,
1053 n: usize,
1054 num_ones: usize,
1055 num_neg_ones: usize,
1056) -> TernaryPoly {
1057 debug_assert!(num_ones + num_neg_ones <= n);
1058 let mut idx: Vec<u16> = (0..n as u16).collect();
1059 let take = num_ones + num_neg_ones;
1060 for i in 0..take {
1061 let j = i + next_index_below(rng, (n - i) as u32) as usize;
1062 idx.swap(i, j);
1063 }
1064 let mut ones = idx[..num_ones].to_vec();
1065 let mut neg_ones = idx[num_ones..take].to_vec();
1066 ones.sort_unstable();
1067 neg_ones.sort_unstable();
1068 TernaryPoly { ones, neg_ones }
1069}
1070
1071fn sample_trapdoor<R: Csprng>(rng: &mut R, params: &EesParams) -> Trapdoor {
1072 Trapdoor::sample_iid(rng, params)
1073}
1074
1075fn check_rep_weight<const N: usize>(p: &Poly<N>, params: &EesParams) -> bool {
1076 let mut w = [0usize; 3];
1077 for i in 0..N {
1078 let v = p.coeffs[i] as usize;
1079 if v < 3 {
1080 w[v] += 1;
1081 }
1082 }
1083 w[0] >= params.dm0 && w[1] >= params.dm0 && w[2] >= params.dm0
1084}
1085
1086#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1089pub enum NtruEesError {
1090 MessageTooLong,
1091 InvalidCiphertext,
1092}
1093
1094pub fn keygen<const N: usize, R: Csprng>(
1098 params: &EesParams,
1099 rng: &mut R,
1100) -> (Vec<u8>, Trapdoor) {
1101 debug_assert_eq!(params.n, N);
1102 let q_mask = params.q_mask();
1103 loop {
1104 let t = sample_trapdoor(rng, params);
1105 let mut f = t.to_dense::<N>(q_mask);
1107 poly_scalar_mul::<N>(&mut f, 3, q_mask);
1108 f.coeffs[0] = f.coeffs[0].wrapping_add(1) & q_mask;
1109 let f_inv = match poly_inverse_mod_q_cyclic::<N>(&f, params) {
1110 Some(inv) => inv,
1111 None => continue,
1112 };
1113
1114 let g = sample_trinary(rng, params.n, params.dg, params.dg);
1115 let mut g_dense = g.to_dense::<N>(q_mask);
1116 poly_mod_q::<N>(&mut g_dense, q_mask);
1117 let mut h = Poly::<N>::zero();
1118 poly_mul::<N>(&mut h, &g_dense, &f_inv);
1119 poly_scalar_mul::<N>(&mut h, 3, q_mask);
1120
1121 let mut pk_bytes = vec![0u8; params.pk_wire_bytes()];
1122 poly_to_arr::<N>(&h, &mut pk_bytes, params);
1123 return (pk_bytes, t);
1124 }
1125}
1126
1127pub fn encrypt<const N: usize, R: Csprng>(
1128 pk_bytes: &[u8],
1129 msg: &[u8],
1130 rng: &mut R,
1131 params: &EesParams,
1132) -> Result<Vec<u8>, NtruEesError> {
1133 debug_assert_eq!(params.n, N);
1134 if msg.len() > params.max_message_bytes() {
1135 return Err(NtruEesError::MessageTooLong);
1136 }
1137 let q_mask = params.q_mask();
1138 let mut h = poly_from_arr::<N>(pk_bytes, params);
1139 poly_mod_q::<N>(&mut h, q_mask);
1140
1141 let pklen_bytes = params.pklen_bytes();
1142 let htrunc = &pk_bytes[..pklen_bytes];
1143 let db_bytes = params.db_bytes();
1144 let max_msg = params.max_message_bytes();
1145
1146 loop {
1147 let mut b = vec![0u8; db_bytes];
1148 rng.fill_bytes(&mut b);
1149
1150 let m_len = db_bytes + 1 + max_msg + 1;
1151 let mut m = vec![0u8; m_len];
1152 m[..db_bytes].copy_from_slice(&b);
1153 m[db_bytes] = msg.len() as u8;
1154 m[db_bytes + 1..db_bytes + 1 + msg.len()].copy_from_slice(msg);
1155
1156 let mtrin = sves_from_bytes::<N>(&m, q_mask);
1157
1158 let mut sdata =
1159 Vec::with_capacity(params.oid.len() + msg.len() + b.len() + htrunc.len());
1160 sdata.extend_from_slice(¶ms.oid);
1161 sdata.extend_from_slice(msg);
1162 sdata.extend_from_slice(&b);
1163 sdata.extend_from_slice(htrunc);
1164
1165 let mut igf = IgfState::new(&sdata, params);
1166 let r = igf_gen_blinding(&mut igf);
1167
1168 let mut bigr = Poly::<N>::zero();
1169 r.mul_dense::<N>(&h, &mut bigr);
1170 poly_mod_q::<N>(&mut bigr, q_mask);
1171
1172 let or4 = poly_to_arr4::<N>(&bigr, params);
1173 let mask = mgf::<N>(&or4, params);
1174
1175 let mut mtrin_plus_mask = mtrin;
1176 poly_add::<N>(&mut mtrin_plus_mask, &mask);
1177 poly_mod3::<N>(&mut mtrin_plus_mask, params);
1178
1179 if !check_rep_weight::<N>(&mtrin_plus_mask, params) {
1180 continue;
1181 }
1182
1183 let mut e = bigr;
1184 for i in 0..N {
1185 let v = mtrin_plus_mask.coeffs[i];
1186 let signed: u16 = match v {
1187 0 => 0,
1188 1 => 1,
1189 2 => q_mask,
1190 _ => unreachable!(),
1191 };
1192 e.coeffs[i] = e.coeffs[i].wrapping_add(signed);
1193 }
1194 poly_mod_q::<N>(&mut e, q_mask);
1195
1196 let mut out = vec![0u8; params.ciphertext_wire_bytes()];
1197 poly_to_arr::<N>(&e, &mut out, params);
1198 return Ok(out);
1199 }
1200}
1201
1202pub fn decrypt<const N: usize>(
1203 sk_trapdoor: &Trapdoor,
1204 pk_bytes: &[u8],
1205 ct_bytes: &[u8],
1206 params: &EesParams,
1207) -> Result<Vec<u8>, NtruEesError> {
1208 debug_assert_eq!(params.n, N);
1209 let q_mask = params.q_mask();
1210 let e = poly_from_arr::<N>(ct_bytes, params);
1211
1212 let mut te = Poly::<N>::zero();
1213 sk_trapdoor.mul_dense::<N>(&e, &mut te);
1214 let mut ci = te;
1215 poly_scalar_mul::<N>(&mut ci, 3, q_mask);
1216 poly_add::<N>(&mut ci, &e);
1217 poly_mod_q::<N>(&mut ci, q_mask);
1218 poly_mod3::<N>(&mut ci, params);
1219
1220 let mut retcode_ok = check_rep_weight::<N>(&ci, params);
1221
1222 let mut c_r = e;
1223 let mut ci_modq = Poly::<N>::zero();
1224 for i in 0..N {
1225 ci_modq.coeffs[i] = match ci.coeffs[i] {
1226 0 => 0,
1227 1 => 1,
1228 2 => q_mask,
1229 _ => unreachable!(),
1230 };
1231 }
1232 poly_sub::<N>(&mut c_r, &ci_modq);
1233 poly_mod_q::<N>(&mut c_r, q_mask);
1234
1235 let or4 = poly_to_arr4::<N>(&c_r, params);
1236 let mask = mgf::<N>(&or4, params);
1237
1238 let mut cmtrin = ci;
1239 poly_sub::<N>(&mut cmtrin, &mask);
1240 poly_mod3::<N>(&mut cmtrin, params);
1241
1242 let cm = sves_to_bytes::<N>(&cmtrin).ok_or(NtruEesError::InvalidCiphertext)?;
1243
1244 let db_bytes = params.db_bytes();
1245 let max_msg = params.max_message_bytes();
1246 let cb = &cm[..db_bytes];
1247 let cl = cm[db_bytes] as usize;
1248 if cl > max_msg {
1249 return Err(NtruEesError::InvalidCiphertext);
1250 }
1251 let msg = cm[db_bytes + 1..db_bytes + 1 + cl].to_vec();
1252
1253 let pad_start = db_bytes + 1 + cl;
1254 let pad_end = (params.n * 3 + 1) / 2;
1255 let pad_end_bytes = pad_end.div_ceil(8);
1256 for &p in &cm[pad_start..pad_end_bytes.min(cm.len())] {
1257 if p != 0 {
1258 retcode_ok = false;
1259 }
1260 }
1261
1262 let pklen_bytes = params.pklen_bytes();
1263 let htrunc = &pk_bytes[..pklen_bytes];
1264 let mut sdata = Vec::with_capacity(params.oid.len() + cl + db_bytes + db_bytes);
1265 sdata.extend_from_slice(¶ms.oid);
1266 sdata.extend_from_slice(&msg);
1267 sdata.extend_from_slice(cb);
1268 sdata.extend_from_slice(htrunc);
1269 let mut igf = IgfState::new(&sdata, params);
1270 let cr_priv = igf_gen_blinding(&mut igf);
1271
1272 let h = poly_from_arr::<N>(pk_bytes, params);
1273 let mut bigr_prime = Poly::<N>::zero();
1274 cr_priv.mul_dense::<N>(&h, &mut bigr_prime);
1275 poly_mod_q::<N>(&mut bigr_prime, q_mask);
1276
1277 for i in 0..N {
1278 if bigr_prime.coeffs[i] != c_r.coeffs[i] {
1279 retcode_ok = false;
1280 break;
1281 }
1282 }
1283
1284 if !retcode_ok {
1285 return Err(NtruEesError::InvalidCiphertext);
1286 }
1287 Ok(msg)
1288}
1289
1290macro_rules! define_ees_set {
1300 (
1301 namespace = $type_name:ident,
1302 public_key = $pk_ty:ident,
1303 private_key = $sk_ty:ident,
1304 ciphertext = $ct_ty:ident,
1305 n = $n:expr,
1306 trapdoor = $trapdoor:expr,
1307 dg = $dg:expr,
1308 dm0 = $dm0:expr,
1309 db_bits = $db_bits:expr,
1310 c_bits = $c_bits:expr,
1311 min_calls_r = $min_calls_r:expr,
1312 min_calls_mask = $min_calls_mask:expr,
1313 pklen_bits = $pklen_bits:expr,
1314 oid = $oid:expr,
1315 hash = $hash:expr,
1316 pk_bytes = $pk_bytes:expr,
1317 sk_packed_bytes = $sk_packed_bytes:expr,
1318 ct_bytes = $ct_bytes:expr,
1319 regression_digest = $regression_digest:expr $(,)?
1320 ) => {
1321 use $crate::public_key::ntru_ees_core::{
1322 decrypt as __ees_core_decrypt, encrypt as __ees_core_encrypt,
1323 keygen as __ees_core_keygen, padding_bits_clear as __ees_padding_bits_clear,
1324 trapdoor_from_wire as __ees_trapdoor_from_wire,
1325 trapdoor_to_wire as __ees_trapdoor_to_wire, EesParams, HashKind, NtruEesError,
1326 Trapdoor, TrapdoorKind,
1327 };
1328 use $crate::Csprng;
1329
1330 const PARAMS: EesParams = EesParams {
1331 n: $n,
1332 logq: 11,
1333 trapdoor: $trapdoor,
1334 dg: $dg,
1335 dm0: $dm0,
1336 db_bits: $db_bits,
1337 c_bits: $c_bits,
1338 min_calls_r: $min_calls_r,
1339 min_calls_mask: $min_calls_mask,
1340 pklen_bits: $pklen_bits,
1341 oid: $oid,
1342 hash: $hash,
1343 };
1344
1345 const N: usize = $n;
1346
1347 pub const PUBLIC_KEY_BYTES: usize = PARAMS.pk_wire_bytes();
1348 pub const PRIVATE_KEY_BYTES: usize = PARAMS.trapdoor_wire_bytes();
1349 pub const CIPHERTEXT_BYTES: usize = PARAMS.ciphertext_wire_bytes();
1350 pub const MAX_MESSAGE_BYTES: usize = PARAMS.max_message_bytes();
1351
1352 #[derive(Clone, Eq, PartialEq)]
1353 pub struct $pk_ty {
1354 bytes: Vec<u8>,
1355 }
1356
1357 #[derive(Clone, Eq, PartialEq)]
1358 pub struct $sk_ty {
1359 t: Trapdoor,
1360 pk: $pk_ty,
1361 }
1362
1363 #[derive(Clone, Eq, PartialEq)]
1364 pub struct $ct_ty {
1365 bytes: Vec<u8>,
1366 }
1367
1368 impl $pk_ty {
1369 #[must_use]
1370 pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
1371 if bytes.len() != PUBLIC_KEY_BYTES { return None; }
1372 if !__ees_padding_bits_clear(bytes, N * PARAMS.logq) {
1373 return None;
1374 }
1375 Some(Self { bytes: bytes.to_vec() })
1376 }
1377
1378 #[must_use]
1379 pub fn to_wire_bytes(&self) -> Vec<u8> { self.bytes.clone() }
1380
1381 #[must_use]
1382 pub fn as_bytes(&self) -> &[u8] { &self.bytes }
1383 }
1384
1385 impl $sk_ty {
1386 #[must_use]
1387 pub fn to_wire_bytes(&self) -> Vec<u8> {
1388 let mut out = vec![0u8; PRIVATE_KEY_BYTES + PUBLIC_KEY_BYTES];
1389 __ees_trapdoor_to_wire(&self.t, &PARAMS, &mut out[..PRIVATE_KEY_BYTES]);
1390 out[PRIVATE_KEY_BYTES..].copy_from_slice(&self.pk.bytes);
1391 out
1392 }
1393
1394 #[must_use]
1395 pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
1396 if bytes.len() != PRIVATE_KEY_BYTES + PUBLIC_KEY_BYTES { return None; }
1397 let t = __ees_trapdoor_from_wire(&bytes[..PRIVATE_KEY_BYTES], &PARAMS)?;
1398 let pk = $pk_ty::from_wire_bytes(&bytes[PRIVATE_KEY_BYTES..])?;
1399 Some(Self { t, pk })
1400 }
1401
1402 #[must_use]
1403 pub fn public_key(&self) -> &$pk_ty { &self.pk }
1404 }
1405
1406 impl $ct_ty {
1407 #[must_use]
1408 pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
1409 if bytes.len() != CIPHERTEXT_BYTES { return None; }
1410 if !__ees_padding_bits_clear(bytes, N * PARAMS.logq) {
1411 return None;
1412 }
1413 Some(Self { bytes: bytes.to_vec() })
1414 }
1415
1416 #[must_use]
1417 pub fn to_wire_bytes(&self) -> Vec<u8> { self.bytes.clone() }
1418
1419 #[must_use]
1420 pub fn as_bytes(&self) -> &[u8] { &self.bytes }
1421 }
1422
1423 impl ::core::fmt::Debug for $sk_ty {
1424 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1425 f.write_str(concat!(stringify!($sk_ty), "(<redacted>)"))
1426 }
1427 }
1428
1429 impl ::core::fmt::Debug for $pk_ty {
1430 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1431 f.debug_struct(stringify!($pk_ty)).finish()
1432 }
1433 }
1434
1435 impl ::core::fmt::Debug for $ct_ty {
1436 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1437 f.debug_struct(stringify!($ct_ty)).finish()
1438 }
1439 }
1440
1441 pub struct $type_name;
1442
1443 impl $type_name {
1444 pub const PUBLIC_KEY_BYTES: usize = PUBLIC_KEY_BYTES;
1446 pub const PRIVATE_KEY_BYTES: usize = PRIVATE_KEY_BYTES;
1448 pub const CIPHERTEXT_BYTES: usize = CIPHERTEXT_BYTES;
1450 pub const MAX_MESSAGE_BYTES: usize = MAX_MESSAGE_BYTES;
1454
1455 pub fn keygen<R: Csprng>(rng: &mut R) -> ($pk_ty, $sk_ty) {
1456 let (pk_bytes, t) = __ees_core_keygen::<N, R>(&PARAMS, rng);
1457 let pk = $pk_ty { bytes: pk_bytes.clone() };
1458 let sk = $sk_ty { t, pk: pk.clone() };
1459 (pk, sk)
1460 }
1461
1462 pub fn encrypt<R: Csprng>(
1463 pk: &$pk_ty,
1464 msg: &[u8],
1465 rng: &mut R,
1466 ) -> Result<$ct_ty, NtruEesError> {
1467 let bytes = __ees_core_encrypt::<N, R>(&pk.bytes, msg, rng, &PARAMS)?;
1468 Ok($ct_ty { bytes })
1469 }
1470
1471 pub fn decrypt(sk: &$sk_ty, ct: &$ct_ty) -> Result<Vec<u8>, NtruEesError> {
1472 __ees_core_decrypt::<N>(&sk.t, &sk.pk.bytes, &ct.bytes, &PARAMS)
1473 }
1474 }
1475
1476 #[cfg(test)]
1477 mod tests {
1478 use super::*;
1479 use $crate::CtrDrbgAes256;
1480
1481 #[test]
1482 fn parameter_byte_lengths() {
1483 assert_eq!(PUBLIC_KEY_BYTES, $pk_bytes);
1484 assert_eq!(PRIVATE_KEY_BYTES, $sk_packed_bytes);
1485 assert_eq!(CIPHERTEXT_BYTES, $ct_bytes);
1486 assert!(MAX_MESSAGE_BYTES > 0);
1487 }
1488
1489 #[test]
1490 fn round_trip_empty_and_full_messages() {
1491 let mut drbg = CtrDrbgAes256::new(&[0x42u8; 48]);
1492 let (pk, sk) = $type_name::keygen(&mut drbg);
1493 for &len in &[0usize, 1, 16, 32, MAX_MESSAGE_BYTES] {
1494 let mut msg = vec![0u8; len];
1495 drbg.fill_bytes(&mut msg);
1496 let ct = $type_name::encrypt(&pk, &msg, &mut drbg).expect("encrypt");
1497 let dec = $type_name::decrypt(&sk, &ct).expect("decrypt");
1498 assert_eq!(dec, msg, "round-trip at len={}", len);
1499 }
1500 }
1501
1502 #[test]
1503 fn rejects_oversize_message() {
1504 let mut drbg = CtrDrbgAes256::new(&[0x77u8; 48]);
1505 let (pk, _) = $type_name::keygen(&mut drbg);
1506 let too_big = vec![0u8; MAX_MESSAGE_BYTES + 1];
1507 let err = $type_name::encrypt(&pk, &too_big, &mut drbg).unwrap_err();
1508 assert_eq!(err, NtruEesError::MessageTooLong);
1509 }
1510
1511 #[test]
1512 fn corrupted_ciphertext_rejected() {
1513 let mut drbg = CtrDrbgAes256::new(&[0x99u8; 48]);
1514 let (pk, sk) = $type_name::keygen(&mut drbg);
1515 let msg = b"hello ntru";
1516 let ct = $type_name::encrypt(&pk, msg, &mut drbg).expect("encrypt");
1517 let mut bad_bytes = ct.to_wire_bytes();
1518 bad_bytes[10] ^= 0xff;
1519 let bad_ct = $ct_ty::from_wire_bytes(&bad_bytes).expect("structural decode");
1520 match $type_name::decrypt(&sk, &bad_ct) {
1521 Err(NtruEesError::InvalidCiphertext) => {}
1522 other => panic!("expected InvalidCiphertext, got {:?}", other),
1523 }
1524 }
1525
1526 #[test]
1532 fn byte_format_regression_digest() {
1533 use $crate::hash::sha2::Sha256;
1534 let mut drbg = CtrDrbgAes256::new(&[0xC0u8; 48]);
1535 let (pk, sk) = $type_name::keygen(&mut drbg);
1536 let ct = $type_name::encrypt(&pk, &[0xA5u8; 8], &mut drbg)
1537 .expect("encrypt");
1538 let mut h = Sha256::new();
1539 h.update(&pk.to_wire_bytes());
1540 h.update(&sk.to_wire_bytes());
1541 h.update(&ct.to_wire_bytes());
1542 let digest = h.finalize();
1543 let mut hex = String::with_capacity(64);
1544 for b in digest.iter() {
1545 use ::core::fmt::Write;
1546 write!(&mut hex, "{:02x}", b).unwrap();
1547 }
1548 assert_eq!(hex, $regression_digest, "byte-format regression");
1549 }
1550
1551 #[test]
1552 fn wire_format_roundtrip_keys_and_ct() {
1553 let mut drbg = CtrDrbgAes256::new(&[0xa0u8; 48]);
1554 let (pk, sk) = $type_name::keygen(&mut drbg);
1555 let msg = b"wire-format-roundtrip";
1556 let ct = $type_name::encrypt(&pk, msg, &mut drbg).expect("encrypt");
1557
1558 let pk_round = $pk_ty::from_wire_bytes(&pk.to_wire_bytes()).expect("pk decode");
1559 let sk_round = $sk_ty::from_wire_bytes(&sk.to_wire_bytes()).expect("sk decode");
1560 let ct_round = $ct_ty::from_wire_bytes(&ct.to_wire_bytes()).expect("ct decode");
1561
1562 assert_eq!(pk_round, pk);
1563 assert_eq!(sk_round, sk);
1564 assert_eq!(ct_round, ct);
1565
1566 let dec = $type_name::decrypt(&sk_round, &ct_round).expect("decrypt");
1567 assert_eq!(dec, msg);
1568 }
1569 }
1570 };
1571}
1572
1573pub(crate) use define_ees_set;