purecrypto 0.5.0

A pure-Rust cryptography toolkit with no foreign-code dependencies, from constant-time primitives up to keys, X.509 and TLS.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
//! UMAC (RFC 4418): a fast, AES-based message authentication code.
//!
//! Two output sizes are provided:
//!
//! - [`Umac64`] — 8-byte tag (`UMAC-AES-128`, 2 internal iterations).
//! - [`Umac128`] — 16-byte tag (`UMAC-AES-128`, 4 internal iterations).
//!
//! Both share the same construction (RFC 4418 §3):
//!
//! 1. **KDF** — derive internal keys by running AES-128 in counter mode over
//!    the master key.
//! 2. **L1-HASH (NH)** — break the message into 1024-byte chunks and reduce
//!    each to 8 bytes via the constant-time *non-cryptographic* hash NH.
//! 3. **L2-HASH (POLY)** — evaluate a polynomial over GF(p) on the chunked
//!    L1 output, with p = 2⁶⁴−59 (or, for messages of L1 output larger than
//!    2¹⁷ bytes, transitioning to p = 2¹²⁸−159).
//! 4. **L3-HASH** — fold the 128-bit polynomial value into a 32-bit tag word
//!    via an inner product modulo 2³⁶−5.
//! 5. **PDF** — XOR the tag with an AES-derived per-nonce pad.
//!
//! The whole pipeline runs without table lookups (NH uses arithmetic only,
//! AES uses [`Aes128`]'s constant-time GF(2⁸) S-box), so the only data-
//! dependent timing arises from the POLY marker branch (RFC 4418 §6, noted
//! by the RFC itself as a narrow timing concern; the leak is constrained to
//! detection of a 32-bit boundary in the L1 output).
//!
//! # Example
//!
//! ```
//! use purecrypto::mac::Umac64;
//!
//! let key   = [0u8; 16];
//! let nonce = b"01234567";
//! let tag   = Umac64::compute(&key, b"hello", nonce);
//! assert_eq!(tag.len(), 8);
//! ```

use crate::cipher::{Aes128, BlockCipher};

// ---------------------------------------------------------------------------
//  Constants
// ---------------------------------------------------------------------------

/// 2³⁶ − 5, the prime used by L3-HASH (RFC 4418 §5.4).
const P36: u64 = 0xf_fffffffb;

/// 2⁶⁴ − 59, the prime used by POLY-64 (RFC 4418 §5.3).
const P64: u64 = 0xffff_ffff_ffff_ffc5;

/// 2¹²⁸ − 159, the prime used by POLY-128 (RFC 4418 §5.3).
const P128: u128 = 0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ff61;

/// Threshold below which POLY-64 inputs are processed directly: 2⁶⁴ − 2³².
const MAXWR_64: u64 = 0xffff_ffff_0000_0000;

/// Threshold below which POLY-128 inputs are processed directly: 2¹²⁸ − 2⁹⁶.
const MAXWR_128: u128 = 0xffff_ffff_0000_0000_0000_0000_0000_0000;

/// `2^wordbits − p` for each prime: the additive correction used when an
/// input is larger than `maxwordrange`.
const OFFSET_64: u64 = 59;
const OFFSET_128: u128 = 159;

/// Per-chunk key/data window length (bytes).
const L1_KEY_LEN: usize = 1024;

/// L1-output bytes after which POLY-64 transitions to POLY-128: 2¹⁷ bytes
/// = 16384 NH outputs.
const POLY64_OUTPUTS: u32 = 1 << 14;

/// Per-32-bit-word mask applied to the L2 key (RFC 4418 §5.3): the high 7
/// bits of every 32-bit half are forced to zero so that intermediate POLY
/// products fit within 64-/128-bit accumulators.
const MASK64: u64 = 0x01ff_ffff_01ff_ffff;
const MASK128: u128 = 0x01ff_ffff_01ff_ffff_01ff_ffff_01ff_ffff;

// ---------------------------------------------------------------------------
//  Key derivation
// ---------------------------------------------------------------------------

/// RFC 4418 §3.2.1 KDF: AES-128 in counter mode with `(index, counter)` as
/// the input block. `index` selects the keystream (0 = PDF, 1 = L1, 2 = L2,
/// 3 = L3-K1, 4 = L3-K2); `out` receives `out.len()` bytes of keystream.
fn kdf(aes: &Aes128, index: u64, out: &mut [u8]) {
    let mut block = [0u8; 16];
    block[0..8].copy_from_slice(&index.to_be_bytes());
    let n = out.len().div_ceil(16);
    let mut written = 0;
    for i in 1..=n as u64 {
        block[8..16].copy_from_slice(&i.to_be_bytes());
        let mut ct = block;
        aes.encrypt_block(&mut ct);
        let take = (out.len() - written).min(16);
        out[written..written + take].copy_from_slice(&ct[..take]);
        written += take;
    }
}

// ---------------------------------------------------------------------------
//  NH and L1-HASH
// ---------------------------------------------------------------------------

/// NH (RFC 4418 §5.2.2). `data` must be a multiple of 32 bytes; `key` is the
/// same length or longer. Returns the 64-bit NH value (before the L1 length
/// adjustment).
///
/// Message words are read little-endian (per the ENDIAN-SWAP step of L1-HASH);
/// key words are read big-endian (`str2uint` semantics, no ENDIAN-SWAP).
fn nh(key: &[u8], data: &[u8]) -> u64 {
    let mut y: u64 = 0;
    let mut i = 0;
    while i < data.len() {
        // Each 32-byte step processes 8 32-bit words; pairs are (j, j+4).
        let mut sums = [0u32; 8];
        for (j, slot) in sums.iter_mut().enumerate() {
            let off = i + j * 4;
            let m = u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]);
            let k = u32::from_be_bytes([key[off], key[off + 1], key[off + 2], key[off + 3]]);
            *slot = m.wrapping_add(k);
        }
        for j in 0..4 {
            y = y.wrapping_add((sums[j] as u64).wrapping_mul(sums[j + 4] as u64));
        }
        i += 32;
    }
    y
}

/// Computes NH on a chunk (zero-padded to the next 32-byte boundary) and adds
/// the original bit length, as RFC 4418 §5.2.1 prescribes.
fn l1_chunk(key: &[u8; L1_KEY_LEN], data: &[u8], bit_length: u64) -> u64 {
    debug_assert!(data.len() <= L1_KEY_LEN);
    let pad_len = data.len().div_ceil(32) * 32;
    // For an empty chunk (used when the entire message is empty) we still
    // need one 32-byte zero block.
    let pad_len = pad_len.max(32);
    let mut buf = [0u8; L1_KEY_LEN];
    buf[..data.len()].copy_from_slice(data);
    let nh_out = nh(&key[..pad_len], &buf[..pad_len]);
    nh_out.wrapping_add(bit_length)
}

// ---------------------------------------------------------------------------
//  Modular arithmetic for POLY
// ---------------------------------------------------------------------------

#[inline]
fn add_mod_p64(a: u64, b: u64) -> u64 {
    let (s, carry) = a.overflowing_add(b);
    let s = if carry { s.wrapping_add(OFFSET_64) } else { s };
    if s >= P64 { s - P64 } else { s }
}

#[inline]
fn add_mod_p128(a: u128, b: u128) -> u128 {
    let (s, carry) = a.overflowing_add(b);
    let s = if carry { s.wrapping_add(OFFSET_128) } else { s };
    if s >= P128 { s - P128 } else { s }
}

/// `(a * b) mod (2⁶⁴ − 59)`. Uses the identity `2⁶⁴ ≡ 59 (mod p)` to fold
/// the high half of the 128-bit product back into 64 bits.
fn mul_mod_p64(a: u64, b: u64) -> u64 {
    let prod = (a as u128) * (b as u128);
    let lo = prod as u64;
    let hi = (prod >> 64) as u64;
    // First reduction: prod ≡ lo + 59·hi (mod p64).
    let lift1 = (lo as u128) + (OFFSET_64 as u128) * (hi as u128); // < 60·2⁶⁴
    let lift1_lo = lift1 as u64;
    let lift1_hi = (lift1 >> 64) as u64; // < 60
    // Second reduction: lift1 ≡ lift1_lo + 59·lift1_hi (mod p64).
    let (r, carry) = lift1_lo.overflowing_add(OFFSET_64.wrapping_mul(lift1_hi));
    let r = if carry { r.wrapping_add(OFFSET_64) } else { r };
    if r >= P64 { r - P64 } else { r }
}

/// `(a * b) mod (2¹²⁸ − 159)`. Schoolbook-multiplies into four 64-bit limbs,
/// then collapses with `2¹²⁸ ≡ 159 (mod p)`.
fn mul_mod_p128(a: u128, b: u128) -> u128 {
    let a_lo = a as u64;
    let a_hi = (a >> 64) as u64;
    let b_lo = b as u64;
    let b_hi = (b >> 64) as u64;

    let ll = (a_lo as u128) * (b_lo as u128);
    let lh = (a_lo as u128) * (b_hi as u128);
    let hl = (a_hi as u128) * (b_lo as u128);
    let hh = (a_hi as u128) * (b_hi as u128);

    let ll_lo = ll as u64;
    let ll_hi = (ll >> 64) as u64;
    let lh_lo = lh as u64;
    let lh_hi = (lh >> 64) as u64;
    let hl_lo = hl as u64;
    let hl_hi = (hl >> 64) as u64;
    let hh_lo = hh as u64;
    let hh_hi = (hh >> 64) as u64;

    // a·b = w3·2¹⁹² + w2·2¹²⁸ + w1·2⁶⁴ + w0
    let w0 = ll_lo;
    let s1 = (ll_hi as u128) + (lh_lo as u128) + (hl_lo as u128);
    let w1 = s1 as u64;
    let c1 = (s1 >> 64) as u64;
    let s2 = (lh_hi as u128) + (hl_hi as u128) + (hh_lo as u128) + (c1 as u128);
    let w2 = s2 as u64;
    let c2 = (s2 >> 64) as u64;
    let w3 = hh_hi.wrapping_add(c2);

    // a·b ≡ (w0 + 159·w2) + (w1 + 159·w3)·2⁶⁴  (mod p128).
    let r0 = (w0 as u128) + OFFSET_128 * (w2 as u128); // < 2⁷²
    let r1 = (w1 as u128) + OFFSET_128 * (w3 as u128); // < 2⁷²

    let r0_lo = r0 as u64;
    let r0_hi = (r0 >> 64) as u64; // < 2⁸
    let r1_lo = r1 as u64;
    let r1_hi = (r1 >> 64) as u64; // < 2⁸

    let mid = (r0_hi as u128) + (r1_lo as u128); // < 2⁶⁵
    let mid_lo = mid as u64;
    let mid_hi = (mid >> 64) as u64; // 0 or 1

    let low128: u128 = (r0_lo as u128) | ((mid_lo as u128) << 64);
    let high = (mid_hi as u128) + (r1_hi as u128); // < 2⁹
    let extra = high * OFFSET_128;
    let (sum, overflow) = low128.overflowing_add(extra);
    let sum = if overflow {
        sum.wrapping_add(OFFSET_128)
    } else {
        sum
    };
    if sum >= P128 { sum - P128 } else { sum }
}

/// POLY-64 step (RFC 4418 §5.3.2). Inputs at or above `MAXWR_64` are split
/// into a marker followed by `m − OFFSET_64` so the input always reduces
/// modulo `p64`.
#[inline]
fn poly64_step(k64: u64, acc: &mut u64, m: u64) {
    if m >= MAXWR_64 {
        *acc = add_mod_p64(mul_mod_p64(k64, *acc), P64 - 1);
        *acc = add_mod_p64(mul_mod_p64(k64, *acc), m - OFFSET_64);
    } else {
        *acc = add_mod_p64(mul_mod_p64(k64, *acc), m);
    }
}

/// POLY-128 step (RFC 4418 §5.3.2). Same shape as [`poly64_step`].
#[inline]
fn poly128_step(k128: u128, acc: &mut u128, m: u128) {
    if m >= MAXWR_128 {
        *acc = add_mod_p128(mul_mod_p128(k128, *acc), P128 - 1);
        *acc = add_mod_p128(mul_mod_p128(k128, *acc), m - OFFSET_128);
    } else {
        *acc = add_mod_p128(mul_mod_p128(k128, *acc), m);
    }
}

// ---------------------------------------------------------------------------
//  L3-HASH
// ---------------------------------------------------------------------------

/// L3-HASH (RFC 4418 §5.4.1). `k1_reduced` is the 8-word K1 already reduced
/// modulo `p36`, `k2` is the 4-byte K2 (raw, XORed at the end), and `b` is
/// the 16-byte input from L2-HASH.
fn l3_hash(k1_reduced: &[u64; 8], k2: &[u8; 4], b: &[u8; 16]) -> [u8; 4] {
    // m_i < 2¹⁶, k_i < 2³⁶ → m·k < 2⁵². Sum of 8 < 2⁵⁵, fits in u64.
    let mut sum: u64 = 0;
    for i in 0..8 {
        let m_i = u16::from_be_bytes([b[i * 2], b[i * 2 + 1]]) as u64;
        sum = sum.wrapping_add(m_i.wrapping_mul(k1_reduced[i]));
    }
    let y = (sum % P36) as u32; // y mod 2³² takes the low 32 bits
    let k2_u32 = u32::from_be_bytes(*k2);
    (y ^ k2_u32).to_be_bytes()
}

// ---------------------------------------------------------------------------
//  PDF
// ---------------------------------------------------------------------------

/// PDF for an 8-byte tag (RFC 4418 §3.3.1). `nonce.len()` must be 1..=16.
///
/// The spec specifies right-padding (`Nonce || zeroes(...)`), so the nonce
/// occupies the LOW bytes of the AES input block and trailing bytes are
/// zero; the index XOR therefore touches the last byte of the nonce, not
/// of the full block.
fn pdf_8(aes_pdf: &Aes128, nonce: &[u8]) -> [u8; 8] {
    debug_assert!(!nonce.is_empty() && nonce.len() <= 16);
    let mut t = [0u8; 16];
    // index = str2uint(nonce) mod 2 → low bit of last nonce byte
    let last = nonce.len() - 1;
    let index = nonce[last] & 1;
    t[..nonce.len()].copy_from_slice(nonce);
    t[last] ^= index; // clears the low bit of the last nonce byte
    aes_pdf.encrypt_block(&mut t);
    let start = (index as usize) * 8;
    let mut out = [0u8; 8];
    out.copy_from_slice(&t[start..start + 8]);
    out
}

/// PDF for a 16-byte tag (RFC 4418 §3.3.1). `nonce.len()` must be 1..=16.
/// Unlike [`pdf_8`], no index masking is needed: each nonce already maps to
/// a unique AES block.
fn pdf_16(aes_pdf: &Aes128, nonce: &[u8]) -> [u8; 16] {
    debug_assert!(!nonce.is_empty() && nonce.len() <= 16);
    let mut t = [0u8; 16];
    t[..nonce.len()].copy_from_slice(nonce);
    aes_pdf.encrypt_block(&mut t);
    t
}

// ---------------------------------------------------------------------------
//  Per-iteration L2 state machine
// ---------------------------------------------------------------------------

/// One UHASH iteration's POLY state. The first L1 output is buffered (not
/// fed into POLY yet) so that, on a single-chunk message, [`finalize_l2`]
/// can shortcut to `B = zeros(8) || A` per RFC 4418 §5.1.
#[derive(Clone)]
struct UmacIter {
    poly64_acc: u64,
    poly128_acc: u128,
    /// Number of L1 outputs absorbed (including the buffered first one).
    l1_outputs: u32,
    /// The first L1 output, before POLY starts.
    pending_first: u64,
    /// `true` once we cross the 2¹⁷-byte threshold and switch to POLY-128.
    transitioned: bool,
    /// Half of a 16-byte POLY-128 input. `half_pending` tracks whether the
    /// low 8 bytes are populated.
    poly128_half: [u8; 8],
    half_pending: bool,
}

impl UmacIter {
    fn new() -> Self {
        Self {
            poly64_acc: 1, // POLY initial value
            poly128_acc: 1,
            l1_outputs: 0,
            pending_first: 0,
            transitioned: false,
            poly128_half: [0; 8],
            half_pending: false,
        }
    }

    /// Absorbs one L1 output (8-byte `nh + bitlen` value).
    fn absorb(&mut self, k64: u64, k128: u128, l1_out: u64) {
        self.l1_outputs = self.l1_outputs.saturating_add(1);
        if self.l1_outputs == 1 {
            // Hold the first output back; we may skip POLY altogether.
            self.pending_first = l1_out;
            return;
        }
        if self.l1_outputs == 2 {
            // The second output triggers L2-HASH; replay the first one.
            poly64_step(k64, &mut self.poly64_acc, self.pending_first);
        }

        if self.l1_outputs <= POLY64_OUTPUTS {
            poly64_step(k64, &mut self.poly64_acc, l1_out);
        } else {
            // POLY-128 mode. Lazily transition on the first input past the
            // boundary, then pair subsequent L1 outputs into 16-byte words.
            if !self.transitioned {
                let init = self.poly64_acc as u128; // uint2str(y, 16) is BE-padded
                poly128_step(k128, &mut self.poly128_acc, init);
                self.transitioned = true;
            }
            let bytes = l1_out.to_be_bytes();
            if !self.half_pending {
                self.poly128_half.copy_from_slice(&bytes);
                self.half_pending = true;
            } else {
                let mut paired = [0u8; 16];
                paired[..8].copy_from_slice(&self.poly128_half);
                paired[8..].copy_from_slice(&bytes);
                let m = u128::from_be_bytes(paired);
                poly128_step(k128, &mut self.poly128_acc, m);
                self.half_pending = false;
            }
        }
    }

    /// Produces the 16-byte L2-HASH output for this iteration.
    fn finalize(&mut self, k128: u128) -> [u8; 16] {
        // No L1 outputs at all: empty message edge case handled upstream by
        // forcing one zero-chunk L1 output; this branch is defensive.
        if self.l1_outputs == 0 {
            return [0u8; 16];
        }
        if self.l1_outputs == 1 {
            // B = zeros(8) || A
            let mut b = [0u8; 16];
            b[8..].copy_from_slice(&self.pending_first.to_be_bytes());
            return b;
        }
        if !self.transitioned {
            let mut b = [0u8; 16];
            b[8..].copy_from_slice(&self.poly64_acc.to_be_bytes());
            return b;
        }
        // POLY-128 close-out: append 0x80 marker, zero-pad to 16-byte
        // boundary (RFC 4418 §5.3.1). Always produces exactly one more
        // POLY-128 input.
        let mut final_block = [0u8; 16];
        if self.half_pending {
            final_block[..8].copy_from_slice(&self.poly128_half);
            final_block[8] = 0x80;
        } else {
            final_block[0] = 0x80;
        }
        let m = u128::from_be_bytes(final_block);
        poly128_step(k128, &mut self.poly128_acc, m);
        self.poly128_acc.to_be_bytes()
    }
}

// ---------------------------------------------------------------------------
//  Public types: Umac64, Umac128
// ---------------------------------------------------------------------------

/// Maximum number of UHASH iterations supported (UMAC-128 uses 4).
const MAX_ITER: usize = 4;

/// L1-key buffer size for the largest supported iteration count, in bytes.
const MAX_L1_KEY_BUF: usize = L1_KEY_LEN + (MAX_ITER - 1) * 16;

/// Shared streaming state, parameterized by the iteration count `ITER`.
/// `ITER = 2` produces an 8-byte tag; `ITER = 4` produces a 16-byte tag.
#[derive(Clone)]
struct UmacInner<const ITER: usize> {
    /// AES cipher pre-keyed with K' = KDF(K, 0).
    pdf_aes: Aes128,
    /// L1 keys, packed contiguously: iteration `i` uses
    /// `l1_key[i*16 .. i*16 + 1024]`.
    l1_key: [u8; MAX_L1_KEY_BUF],
    /// L2 keys, post-mask, per iteration.
    l2_k64: [u64; ITER],
    l2_k128: [u128; ITER],
    /// L3-K1, reduced modulo `p36` per word.
    l3_k1: [[u64; 8]; ITER],
    /// L3-K2, raw (XORed at the very end).
    l3_k2: [[u8; 4]; ITER],

    /// 1024-byte chunk buffer for streaming.
    chunk: [u8; L1_KEY_LEN],
    chunk_off: usize,
    /// Total bytes absorbed via [`update`](UmacInner::update).
    total_bytes: u64,
    /// Per-iteration POLY state.
    iter_state: [UmacIter; ITER],
}

impl<const ITER: usize> UmacInner<ITER> {
    fn new(key: &[u8; 16]) -> Self {
        let master = Aes128::new(key);

        // K' = KDF(K, 0, 16)
        let mut pdf_key = [0u8; 16];
        kdf(&master, 0, &mut pdf_key);
        let pdf_aes = Aes128::new(&pdf_key);

        // L1-Key: 1024 + (ITER-1)*16 bytes
        let l1_len = L1_KEY_LEN + (ITER - 1) * 16;
        let mut l1_key = [0u8; MAX_L1_KEY_BUF];
        kdf(&master, 1, &mut l1_key[..l1_len]);

        // L2-Key: 24·ITER bytes; split into k64 || k128 per iteration.
        let mut l2_buf = [0u8; 24 * MAX_ITER];
        kdf(&master, 2, &mut l2_buf[..24 * ITER]);
        let mut l2_k64 = [0u64; ITER];
        let mut l2_k128 = [0u128; ITER];
        for i in 0..ITER {
            let base = i * 24;
            let k64_bytes: [u8; 8] = l2_buf[base..base + 8].try_into().unwrap();
            let k128_bytes: [u8; 16] = l2_buf[base + 8..base + 24].try_into().unwrap();
            l2_k64[i] = u64::from_be_bytes(k64_bytes) & MASK64;
            l2_k128[i] = u128::from_be_bytes(k128_bytes) & MASK128;
        }

        // L3-Key1: 64·ITER bytes; 8 64-bit words per iteration, pre-reduced.
        let mut l3k1_buf = [0u8; 64 * MAX_ITER];
        kdf(&master, 3, &mut l3k1_buf[..64 * ITER]);
        let mut l3_k1 = [[0u64; 8]; ITER];
        for (i, words) in l3_k1.iter_mut().enumerate() {
            for (j, slot) in words.iter_mut().enumerate() {
                let base = i * 64 + j * 8;
                let raw: [u8; 8] = l3k1_buf[base..base + 8].try_into().unwrap();
                *slot = u64::from_be_bytes(raw) % P36;
            }
        }

        // L3-Key2: 4·ITER bytes; one 32-bit word per iteration.
        let mut l3k2_buf = [0u8; 4 * MAX_ITER];
        kdf(&master, 4, &mut l3k2_buf[..4 * ITER]);
        let mut l3_k2 = [[0u8; 4]; ITER];
        for i in 0..ITER {
            l3_k2[i].copy_from_slice(&l3k2_buf[i * 4..i * 4 + 4]);
        }

        Self {
            pdf_aes,
            l1_key,
            l2_k64,
            l2_k128,
            l3_k1,
            l3_k2,
            chunk: [0u8; L1_KEY_LEN],
            chunk_off: 0,
            total_bytes: 0,
            iter_state: core::array::from_fn(|_| UmacIter::new()),
        }
    }

    /// Absorbs the chunk currently held in `self.chunk[..1024]` as one full
    /// L1 chunk (bit length = 8192).
    fn process_full_chunk(&mut self) {
        for i in 0..ITER {
            let key_slice = &self.l1_key[i * 16..i * 16 + L1_KEY_LEN];
            let key_arr: &[u8; L1_KEY_LEN] = key_slice.try_into().unwrap();
            let l1_out = l1_chunk(key_arr, &self.chunk, (L1_KEY_LEN as u64) * 8);
            self.iter_state[i].absorb(self.l2_k64[i], self.l2_k128[i], l1_out);
        }
    }

    fn update(&mut self, mut data: &[u8]) {
        while !data.is_empty() {
            let take = (L1_KEY_LEN - self.chunk_off).min(data.len());
            self.chunk[self.chunk_off..self.chunk_off + take].copy_from_slice(&data[..take]);
            self.chunk_off += take;
            self.total_bytes = self.total_bytes.wrapping_add(take as u64);
            data = &data[take..];
            if self.chunk_off == L1_KEY_LEN {
                self.process_full_chunk();
                self.chunk_off = 0;
            }
        }
    }

    /// Finalizes all L1 chunks, runs L2-HASH and L3-HASH per iteration, and
    /// writes the concatenated 4-byte tags into `out` (length = 4·ITER).
    fn finalize_uhash(&mut self, out: &mut [u8]) {
        debug_assert_eq!(out.len(), 4 * ITER);

        if self.total_bytes == 0 {
            // RFC 4418: t = max(ceil(0/8192), 1) = 1 → process one empty
            // chunk so each iter has an L1 output to fold.
            for i in 0..ITER {
                let key_slice = &self.l1_key[i * 16..i * 16 + L1_KEY_LEN];
                let key_arr: &[u8; L1_KEY_LEN] = key_slice.try_into().unwrap();
                let l1_out = l1_chunk(key_arr, &[], 0);
                self.iter_state[i].absorb(self.l2_k64[i], self.l2_k128[i], l1_out);
            }
        } else if self.chunk_off > 0 {
            // Partial final chunk; its bit length is the real (pre-pad)
            // length, not 8192.
            let bit_len = (self.chunk_off as u64) * 8;
            for i in 0..ITER {
                let key_slice = &self.l1_key[i * 16..i * 16 + L1_KEY_LEN];
                let key_arr: &[u8; L1_KEY_LEN] = key_slice.try_into().unwrap();
                let l1_out = l1_chunk(key_arr, &self.chunk[..self.chunk_off], bit_len);
                self.iter_state[i].absorb(self.l2_k64[i], self.l2_k128[i], l1_out);
            }
        }
        // else: total_bytes is a positive multiple of 1024; every chunk was
        // a full chunk already absorbed by process_full_chunk with bit
        // length 8192, which is the correct value for both the last full
        // chunk and any non-final chunk.

        for i in 0..ITER {
            let b = self.iter_state[i].finalize(self.l2_k128[i]);
            let c = l3_hash(&self.l3_k1[i], &self.l3_k2[i], &b);
            out[i * 4..i * 4 + 4].copy_from_slice(&c);
        }
    }
}

impl<const ITER: usize> Drop for UmacInner<ITER> {
    fn drop(&mut self) {
        // Best-effort wipe of all key-derived state. `pdf_aes` zeros its
        // expanded round keys in its own `Drop`. The remaining fields hold
        // L1/L2/L3 keys derived from the master via the AES-CTR KDF, the
        // streaming chunk buffer (plaintext), and the per-iteration POLY
        // accumulators — all of which leak information about the secret keys
        // or the message. Use the same `black_box` barrier the rest of the
        // crate relies on so LLVM cannot elide the writes as dead stores.
        for b in self.l1_key.iter_mut() {
            *b = 0;
        }
        for w in self.l2_k64.iter_mut() {
            *w = 0;
        }
        for w in self.l2_k128.iter_mut() {
            *w = 0;
        }
        for row in self.l3_k1.iter_mut() {
            for w in row.iter_mut() {
                *w = 0;
            }
        }
        for row in self.l3_k2.iter_mut() {
            for b in row.iter_mut() {
                *b = 0;
            }
        }
        for b in self.chunk.iter_mut() {
            *b = 0;
        }
        self.chunk_off = 0;
        self.total_bytes = 0;
        for it in self.iter_state.iter_mut() {
            it.poly64_acc = 0;
            it.poly128_acc = 0;
            it.l1_outputs = 0;
            it.pending_first = 0;
            it.transitioned = false;
            it.poly128_half = [0; 8];
            it.half_pending = false;
            let _ = core::hint::black_box(&it.poly64_acc);
            let _ = core::hint::black_box(&it.poly128_acc);
            let _ = core::hint::black_box(&it.pending_first);
            let _ = core::hint::black_box(&it.poly128_half);
        }
        let _ = core::hint::black_box(&self.l1_key);
        let _ = core::hint::black_box(&self.l2_k64);
        let _ = core::hint::black_box(&self.l2_k128);
        let _ = core::hint::black_box(&self.l3_k1);
        let _ = core::hint::black_box(&self.l3_k2);
        let _ = core::hint::black_box(&self.chunk);
    }
}

/// UMAC-AES-128 with an 8-byte tag (RFC 4418, 2 internal iterations).
///
/// Construct with [`Umac64::new`], absorb input via [`Umac64::update`], and
/// commit with [`Umac64::finalize`] given the per-message nonce. The nonce
/// must be 1 to 16 bytes and must not repeat under the same key (otherwise
/// tag forgeries become trivial).
#[derive(Clone)]
pub struct Umac64 {
    inner: UmacInner<2>,
}

impl Umac64 {
    /// Creates a new UMAC-64 state under the given 128-bit AES key.
    pub fn new(key: &[u8; 16]) -> Self {
        Self {
            inner: UmacInner::new(key),
        }
    }

    /// Absorbs `data` into the streaming state.
    pub fn update(&mut self, data: &[u8]) {
        self.inner.update(data);
    }

    /// Finalizes the MAC and returns the 8-byte tag. `nonce.len()` must be
    /// 1..=16 and must be unique per key.
    pub fn finalize(mut self, nonce: &[u8]) -> [u8; 8] {
        assert!(
            !nonce.is_empty() && nonce.len() <= 16,
            "UMAC nonce length must be 1..=16 bytes"
        );
        let mut tag = [0u8; 8];
        self.inner.finalize_uhash(&mut tag);
        let pad = pdf_8(&self.inner.pdf_aes, nonce);
        for i in 0..8 {
            tag[i] ^= pad[i];
        }
        tag
    }

    /// One-shot: compute the 8-byte UMAC of `data` with `nonce` under `key`.
    pub fn compute(key: &[u8; 16], data: &[u8], nonce: &[u8]) -> [u8; 8] {
        let mut s = Self::new(key);
        s.update(data);
        s.finalize(nonce)
    }
}

/// UMAC-AES-128 with a 16-byte tag (RFC 4418, 4 internal iterations).
///
/// See [`Umac64`] for the construction details; the only difference is the
/// number of UHASH iterations (and therefore the tag length).
#[derive(Clone)]
pub struct Umac128 {
    inner: UmacInner<4>,
}

impl Umac128 {
    /// Creates a new UMAC-128 state under the given 128-bit AES key.
    pub fn new(key: &[u8; 16]) -> Self {
        Self {
            inner: UmacInner::new(key),
        }
    }

    /// Absorbs `data` into the streaming state.
    pub fn update(&mut self, data: &[u8]) {
        self.inner.update(data);
    }

    /// Finalizes the MAC and returns the 16-byte tag. `nonce.len()` must be
    /// 1..=16 and must be unique per key.
    pub fn finalize(mut self, nonce: &[u8]) -> [u8; 16] {
        assert!(
            !nonce.is_empty() && nonce.len() <= 16,
            "UMAC nonce length must be 1..=16 bytes"
        );
        let mut tag = [0u8; 16];
        self.inner.finalize_uhash(&mut tag);
        let pad = pdf_16(&self.inner.pdf_aes, nonce);
        for i in 0..16 {
            tag[i] ^= pad[i];
        }
        tag
    }

    /// One-shot: compute the 16-byte UMAC of `data` with `nonce` under `key`.
    pub fn compute(key: &[u8; 16], data: &[u8], nonce: &[u8]) -> [u8; 16] {
        let mut s = Self::new(key);
        s.update(data);
        s.finalize(nonce)
    }
}

// ---------------------------------------------------------------------------
//  Tests
// ---------------------------------------------------------------------------

#[cfg(all(test, feature = "alloc"))]
mod tests {
    use super::*;
    extern crate alloc;

    fn from_hex<const N: usize>(s: &str) -> [u8; N] {
        let bytes = s.as_bytes();
        assert_eq!(bytes.len(), 2 * N, "hex string has wrong length");
        let mut out = [0u8; N];
        for i in 0..N {
            let hi = (bytes[2 * i] as char).to_digit(16).expect("hex") as u8;
            let lo = (bytes[2 * i + 1] as char).to_digit(16).expect("hex") as u8;
            out[i] = (hi << 4) | lo;
        }
        out
    }

    // RFC 4418 §A.1 test vectors. The key is the ASCII string
    // "abcdefghijklmnop", the nonce is the ASCII string "bcdefghi". The
    // eight messages are an empty string, the byte string `'a'` repeated
    // 3, 2^10, 2^15, 2^20, and 2^25 times, and `'abc'` repeated 1 and 500
    // times.

    const RFC_KEY: &[u8; 16] = b"abcdefghijklmnop";
    const RFC_NONCE: &[u8; 8] = b"bcdefghi";

    /// One RFC 4418 §A.1 test message: `pattern` repeated `reps` times.
    struct RfcMsg {
        pattern: &'static [u8],
        reps: usize,
    }

    const RFC_VECTORS: [RfcMsg; 8] = [
        RfcMsg {
            pattern: b"",
            reps: 0,
        },
        RfcMsg {
            pattern: b"a",
            reps: 3,
        },
        RfcMsg {
            pattern: b"a",
            reps: 1 << 10,
        },
        RfcMsg {
            pattern: b"a",
            reps: 1 << 15,
        },
        RfcMsg {
            pattern: b"a",
            reps: 1 << 20,
        },
        RfcMsg {
            pattern: b"a",
            reps: 1 << 25,
        },
        RfcMsg {
            pattern: b"abc",
            reps: 1,
        },
        RfcMsg {
            pattern: b"abc",
            reps: 500,
        },
    ];

    /// Streams `pattern` repeated `reps` times into `f`, in 8-KiB chunks
    /// so the largest 32-MiB case doesn't need a single allocation.
    fn feed_pattern_repeated<F: FnMut(&[u8])>(pattern: &[u8], reps: usize, mut f: F) {
        if pattern.is_empty() || reps == 0 {
            return;
        }
        // Fill an 8-KiB buffer with whole repetitions of `pattern` so each
        // emitted chunk is byte-identical to the underlying message stream.
        let chunk_unit = 8192 / pattern.len().max(1) * pattern.len();
        let chunk_unit = chunk_unit.max(pattern.len());
        let mut buf = alloc::vec::Vec::with_capacity(chunk_unit);
        while buf.len() < chunk_unit {
            buf.extend_from_slice(pattern);
        }
        let total = pattern.len() * reps;
        let mut remaining = total;
        while remaining > 0 {
            let take = remaining.min(buf.len());
            f(&buf[..take]);
            remaining -= take;
        }
    }

    fn compute_umac64_msg(m: &RfcMsg) -> [u8; 8] {
        let mut s = Umac64::new(RFC_KEY);
        feed_pattern_repeated(m.pattern, m.reps, |b| s.update(b));
        s.finalize(RFC_NONCE)
    }

    fn compute_umac_iter_msg<const ITER: usize>(m: &RfcMsg) -> [u8; 16] {
        let mut inner = UmacInner::<ITER>::new(RFC_KEY);
        feed_pattern_repeated(m.pattern, m.reps, |b| inner.update(b));
        let mut tag = [0u8; 16];
        let n = 4 * ITER;
        inner.finalize_uhash(&mut tag[..n]);
        match n {
            8 => {
                let pad = pdf_8(&inner.pdf_aes, RFC_NONCE);
                for i in 0..8 {
                    tag[i] ^= pad[i];
                }
            }
            12 | 16 => {
                let pad = pdf_16(&inner.pdf_aes, RFC_NONCE);
                for i in 0..n {
                    tag[i] ^= pad[i];
                }
            }
            _ => unreachable!(),
        }
        tag
    }

    /// True if this vector's message size is small enough to run under
    /// slow harnesses like miri.
    fn vector_fits_miri(m: &RfcMsg) -> bool {
        m.pattern.len() * m.reps <= 1 << 16
    }

    #[test]
    fn rfc4418_umac64_vectors() {
        // RFC 4418 §A.1: expected 8-byte tags. The 'a' * 2^25 vector applies
        // the published RFC 4418 errata (17 March 2006); the original RFC
        // value was generated by code that mishandled messages longer than
        // 2^24 bytes (the POLY-64 → POLY-128 transition).
        let expected: [[u8; 8]; 8] = [
            from_hex("6E155FAD26900BE1"),
            from_hex("44B5CB542F220104"),
            from_hex("26BF2F5D60118BD9"),
            from_hex("27F8EF643B0D118D"),
            from_hex("A4477E87E9F55853"),
            from_hex("FACA46F856E9B45F"), // errata: original was 2E2DBC36860A0A5F
            from_hex("D4D7B9F6BD4FBFCF"),
            from_hex("D4CF26DDEFD5C01A"),
        ];
        for (i, m) in RFC_VECTORS.iter().enumerate() {
            if cfg!(miri) && !vector_fits_miri(m) {
                continue;
            }
            let tag = compute_umac64_msg(m);
            assert_eq!(
                tag,
                expected[i],
                "UMAC-64 vector {} mismatch ({} x{})",
                i,
                core::str::from_utf8(m.pattern).unwrap_or("?"),
                m.reps,
            );
        }
    }

    #[test]
    fn rfc4418_umac96_vectors_via_iter3() {
        // RFC 4418 §A.1: expected 12-byte tags (UMAC-96 = ITER=3). The
        // 'a' * 2^25 entry uses the published errata value (the original
        // RFC vector had a known generation bug at the POLY-128 boundary).
        let expected: [[u8; 12]; 8] = [
            from_hex("32FEDB100C79AD58F07FF764"),
            from_hex("185E4FE905CBA7BD85E4C2DC"),
            from_hex("7A54ABE04AF82D60FB298C3C"),
            from_hex("7B136BD911E4B734286EF2BE"),
            from_hex("F8ACFA3AC31CFEEA047F7B11"),
            from_hex("A621C2457C0012E64F3FDAE9"), // errata
            from_hex("883C3D4B97A61976FFCF2323"),
            from_hex("8824A260C53C66A36C9260A6"),
        ];
        for (i, m) in RFC_VECTORS.iter().enumerate() {
            if cfg!(miri) && !vector_fits_miri(m) {
                continue;
            }
            let tag = compute_umac_iter_msg::<3>(m);
            for j in 0..12 {
                assert_eq!(
                    tag[j],
                    expected[i][j],
                    "UMAC-96 vector {} mismatch at byte {} ({} x{})",
                    i,
                    j,
                    core::str::from_utf8(m.pattern).unwrap_or("?"),
                    m.reps,
                );
            }
        }
    }

    // UMAC-128 has no RFC 4418 published test vectors, so we don't assert
    // byte-exact tags here. The iter=2 and iter=3 tests above lock down the
    // entire UHASH pipeline (NH / POLY-64 / POLY-128 / L3-HASH) against the
    // spec; UMAC-128 reuses that same machinery with one extra iteration
    // and the no-mask PDF, both of which are independently exercised below.

    #[test]
    fn streaming_matches_one_shot() {
        // Split a 2.5-MiB message at random byte offsets and confirm that
        // the streamed tag matches the one-shot tag. This exercises the
        // POLY-64 → POLY-128 transition (the input crosses 2¹⁷ bytes of L1
        // output, which corresponds to ~16 MiB of message — too large for
        // a unit test — so we only test the < 2¹⁷ side here; the RFC §A.1
        // 2³⁰ tag above does cover the >2¹⁷ side).
        let key = b"streamtestkey567";
        let nonce = b"abcdef01";
        let mut data = alloc::vec::Vec::with_capacity(2_500_000);
        for i in 0..2_500_000usize {
            data.push((i as u8).wrapping_mul(31).wrapping_add(7));
        }

        let one_shot = Umac64::compute(key, &data, nonce);

        let mut state = Umac64::new(key);
        // Irregular chunk sizes to flush partial buffers in many configs.
        let splits: [usize; 9] = [1, 7, 511, 1024, 1023, 4097, 17, 600_000, 999_999];
        let mut offset = 0;
        let mut idx = 0;
        while offset < data.len() {
            let take = splits[idx % splits.len()].min(data.len() - offset);
            state.update(&data[offset..offset + take]);
            offset += take;
            idx += 1;
        }
        let streamed = state.finalize(nonce);

        assert_eq!(one_shot, streamed);
    }

    #[test]
    fn umac128_streaming_matches_one_shot() {
        let key = b"128streamkey7890";
        let nonce = b"01234567";
        let mut data = alloc::vec::Vec::with_capacity(50_000);
        for i in 0..50_000usize {
            data.push((i as u8).wrapping_mul(13).wrapping_add(5));
        }
        let one_shot = Umac128::compute(key, &data, nonce);
        let mut state = Umac128::new(key);
        let splits: [usize; 6] = [1, 31, 1024, 17, 10_000, 4097];
        let mut offset = 0;
        let mut idx = 0;
        while offset < data.len() {
            let take = splits[idx % splits.len()].min(data.len() - offset);
            state.update(&data[offset..offset + take]);
            offset += take;
            idx += 1;
        }
        let streamed = state.finalize(nonce);
        assert_eq!(one_shot, streamed);
    }
}