1#![allow(deprecated)]
2#[allow(deprecated)]
29use aes::cipher::{Array, BlockCipherDecrypt, KeyInit};
30use aes::{Aes128, Aes256};
31use polyval::{Polyval, universal_hash::UniversalHash};
32use sha2::{Digest, Sha256};
33
34use crate::common::{
35 BLOCK_LENGTH, Direction, Error, absorb, lfsr_next_128, xor_block, xor_blocks_3,
36};
37use crate::hctr2::AesCipher;
38use crate::hctr2fp::{
39 bits_per_digit, decode_base_radix, encode_base_radix, first_block_length, is_power_of_two,
40};
41
42pub struct Hctr3Fp<Aes: AesCipher, const RADIX: u16> {
44 ks_enc: Aes,
45 ks_dec: Aes::Dec,
46 ke_enc: Aes,
47 h: [u8; BLOCK_LENGTH],
48 l: [u8; BLOCK_LENGTH],
49}
50
51#[allow(non_camel_case_types)]
52pub type Hctr3Fp_128_Decimal = Hctr3Fp<Aes128, 10>;
54
55#[allow(non_camel_case_types)]
56pub type Hctr3Fp_256_Decimal = Hctr3Fp<Aes256, 10>;
58
59#[allow(non_camel_case_types)]
60pub type Hctr3Fp_128_Hex = Hctr3Fp<Aes128, 16>;
62
63#[allow(non_camel_case_types)]
64pub type Hctr3Fp_256_Hex = Hctr3Fp<Aes256, 16>;
66
67#[allow(non_camel_case_types)]
68pub type Hctr3Fp_128_Base64 = Hctr3Fp<Aes128, 64>;
70
71#[allow(non_camel_case_types)]
72pub type Hctr3Fp_256_Base64 = Hctr3Fp<Aes256, 64>;
74
75impl<Aes: AesCipher, const RADIX: u16> Hctr3Fp<Aes, RADIX> {
76 pub const FIRST_BLOCK_LENGTH: usize = first_block_length(RADIX);
78
79 pub const MIN_MESSAGE_LENGTH: usize = Self::FIRST_BLOCK_LENGTH;
81
82 pub const BLOCK_LENGTH: usize = BLOCK_LENGTH;
84
85 pub fn new(key: &[u8]) -> Self {
89 debug_assert_eq!(key.len(), Aes::KEY_LEN);
90
91 let ks_enc = Aes::new(Array::from_slice(key));
92 let ks_dec = Aes::new_dec(key);
93
94 let mut ke_block0 = Array::clone_from_slice(&[0u8; 16]);
96 ks_enc.encrypt_block(&mut ke_block0);
97
98 let ke_key: Vec<u8> = if Aes::KEY_LEN <= 16 {
99 ke_block0[..Aes::KEY_LEN].to_vec()
100 } else {
101 let mut ke_block1 = Array::clone_from_slice(&[0x01u8; 16]);
103 ks_enc.encrypt_block(&mut ke_block1);
104 let mut ke = vec![0u8; Aes::KEY_LEN];
105 ke[..16].copy_from_slice(ke_block0.as_slice());
106 ke[16..].copy_from_slice(&ke_block1[..(Aes::KEY_LEN - 16)]);
107 ke
108 };
109
110 let ke_enc = Aes::new(Array::from_slice(&ke_key));
111
112 let mut h_block = Array::clone_from_slice(&[0u8; 16]);
113 let mut l_block = Array::clone_from_slice(&{
114 let mut b = [0u8; 16];
115 b[15] = 1;
116 b
117 });
118 ke_enc.encrypt_block(&mut h_block);
119 ke_enc.encrypt_block(&mut l_block);
120
121 let h: [u8; 16] = h_block.as_slice().try_into().unwrap();
122 let l: [u8; 16] = l_block.as_slice().try_into().unwrap();
123 Self {
124 ks_enc,
125 ks_dec,
126 ke_enc,
127 h,
128 l,
129 }
130 }
131
132 pub fn encrypt(
136 &self,
137 plaintext: &[u8],
138 tweak: &[u8],
139 ciphertext: &mut [u8],
140 ) -> Result<(), Error> {
141 self.hctr3fp(plaintext, tweak, ciphertext, Direction::Encrypt)
142 }
143
144 pub fn decrypt(
146 &self,
147 ciphertext: &[u8],
148 tweak: &[u8],
149 plaintext: &mut [u8],
150 ) -> Result<(), Error> {
151 self.hctr3fp(ciphertext, tweak, plaintext, Direction::Decrypt)
152 }
153
154 fn hctr3fp(
155 &self,
156 src: &[u8],
157 tweak: &[u8],
158 dst: &mut [u8],
159 direction: Direction,
160 ) -> Result<(), Error> {
161 debug_assert_eq!(dst.len(), src.len());
162
163 let first_block_len = Self::FIRST_BLOCK_LENGTH;
164 if src.len() < first_block_len {
165 return Err(Error::InputTooShort);
166 }
167
168 for &digit in src {
169 if digit >= RADIX as u8 {
170 return Err(Error::InvalidDigit);
171 }
172 }
173
174 let first_part = &src[..first_block_len];
175 let tail = &src[first_block_len..];
176
177 let mut hasher = Sha256::new();
178 hasher.update(tweak);
179 let hash_out = hasher.finalize();
180 let mut t = [0u8; BLOCK_LENGTH];
181 t.copy_from_slice(&hash_out[..BLOCK_LENGTH]);
182
183 let mut block_bytes = [0u8; BLOCK_LENGTH];
184 let tweak_len_bits = tweak.len() * 8;
185 let tweak_len_bytes: u128 = if tail.len() % BLOCK_LENGTH == 0 {
186 (2 * tweak_len_bits + 2) as u128
187 } else {
188 (2 * tweak_len_bits + 3) as u128
189 };
190 block_bytes.copy_from_slice(&tweak_len_bytes.to_le_bytes());
191
192 let mut poly = Polyval::new(Array::from_slice(&self.h));
193 poly.update(&[Array::clone_from_slice(&block_bytes)]);
194
195 poly.update(&[Array::clone_from_slice(&t)]);
196
197 let poly_after_tweak = poly.clone();
198
199 match direction {
200 Direction::Encrypt => {
201 let hh = absorb(&mut poly, tail);
202 let m_bits = decode_base_radix(first_part, RADIX)?;
203 let mut mm: [u8; BLOCK_LENGTH] = m_bits.to_le_bytes();
204 xor_block(&mut mm, &hh);
205
206 let mut uu_block = Array::clone_from_slice(&mm);
207 self.ks_enc.encrypt_block(&mut uu_block);
208 let uu: [u8; BLOCK_LENGTH] = uu_block.as_slice().try_into().unwrap();
209
210 let s = xor_blocks_3(&mm, &uu, &self.l);
211 self.fp_elk(&mut dst[first_block_len..], tail, &s, Direction::Encrypt);
212
213 let mut poly = poly_after_tweak;
214 let hh2 = absorb(&mut poly, &dst[first_block_len..]);
215 let mut u_bytes = uu;
216 xor_block(&mut u_bytes, &hh2);
217 encode_base_radix(
218 u128::from_le_bytes(u_bytes),
219 RADIX,
220 &mut dst[..first_block_len],
221 );
222 }
223 Direction::Decrypt => {
224 let hh2 = absorb(&mut poly, tail);
225 let u_bits = decode_base_radix(first_part, RADIX)?;
226 let mut uu: [u8; BLOCK_LENGTH] = u_bits.to_le_bytes();
227 xor_block(&mut uu, &hh2);
228
229 let mut mm_block = Array::clone_from_slice(&uu);
230 self.ks_dec.decrypt_block(&mut mm_block);
231 let mm: [u8; BLOCK_LENGTH] = mm_block.as_slice().try_into().unwrap();
232
233 let s = xor_blocks_3(&mm, &uu, &self.l);
234 self.fp_elk(&mut dst[first_block_len..], tail, &s, Direction::Decrypt);
235
236 let mut poly = poly_after_tweak;
237 let hh = absorb(&mut poly, &dst[first_block_len..]);
238 let mut m_bytes = mm;
239 xor_block(&mut m_bytes, &hh);
240 encode_base_radix(
241 u128::from_le_bytes(m_bytes),
242 RADIX,
243 &mut dst[..first_block_len],
244 );
245 }
246 }
247
248 Ok(())
249 }
250
251 fn fp_elk(&self, dst: &mut [u8], src: &[u8], seed: &[u8; BLOCK_LENGTH], dir: Direction) {
252 debug_assert_eq!(dst.len(), src.len());
253
254 let mut lfsr = *seed;
255 let mut i = 0;
256
257 if is_power_of_two(RADIX) {
258 let bpd = bits_per_digit(RADIX);
259 let digits_per_block = 128 / bpd as usize;
260 let mask: u128 = (RADIX as u128) - 1;
261
262 let mut block = [0u8; BLOCK_LENGTH];
263
264 while i + digits_per_block <= src.len() {
265 block.copy_from_slice(&lfsr);
266 lfsr = lfsr_next_128(&lfsr);
267 let mut ga_block = Array::clone_from_slice(&block);
268 self.ke_enc.encrypt_block(&mut ga_block);
269 let mut ks_bytes = [0u8; 16];
270 ks_bytes.copy_from_slice(ga_block.as_slice());
271 let keystream = u128::from_le_bytes(ks_bytes);
272
273 let mut ks = keystream;
274 for j in 0..digits_per_block {
275 let ks_digit = (ks & mask) as u8;
276 let adjustment = match dir {
277 Direction::Encrypt => ks_digit,
278 Direction::Decrypt => {
279 (RADIX as u8).wrapping_sub(ks_digit) & ((RADIX as u8) - 1)
280 }
281 };
282 dst[i + j] = ((src[i + j] as u16 + adjustment as u16) & (RADIX - 1)) as u8;
283 ks >>= bpd;
284 }
285
286 i += digits_per_block;
287 }
288
289 if i < src.len() {
290 block.copy_from_slice(&lfsr);
291 let mut ga_block = Array::clone_from_slice(&block);
292 self.ke_enc.encrypt_block(&mut ga_block);
293 let mut ks_bytes = [0u8; 16];
294 ks_bytes.copy_from_slice(ga_block.as_slice());
295 let keystream = u128::from_le_bytes(ks_bytes);
296
297 let mut ks = keystream;
298 while i < src.len() {
299 let ks_digit = (ks & mask) as u8;
300 let adjustment = match dir {
301 Direction::Encrypt => ks_digit,
302 Direction::Decrypt => {
303 (RADIX as u8).wrapping_sub(ks_digit) & ((RADIX as u8) - 1)
304 }
305 };
306 dst[i] = ((src[i] as u16 + adjustment as u16) & (RADIX - 1)) as u8;
307 ks >>= bpd;
308 i += 1;
309 }
310 }
311
312 return;
313 }
314
315 let mut block = [0u8; BLOCK_LENGTH];
316
317 while i < src.len() {
318 block.copy_from_slice(&lfsr);
319 lfsr = lfsr_next_128(&lfsr);
320
321 let mut ga_block = Array::clone_from_slice(&block);
322 self.ke_enc.encrypt_block(&mut ga_block);
323 let mut ks_bytes = [0u8; 16];
324 ks_bytes.copy_from_slice(ga_block.as_slice());
325 let keystream = u128::from_le_bytes(ks_bytes);
326
327 let ks_digit = (keystream % (RADIX as u128)) as u8;
328 match dir {
329 Direction::Encrypt => {
330 dst[i] = ((src[i] as u16 + ks_digit as u16) % RADIX) as u8;
331 }
332 Direction::Decrypt => {
333 dst[i] = ((src[i] as u16 + RADIX - ks_digit as u16) % RADIX) as u8;
334 }
335 }
336
337 i += 1;
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_hctr3fp_decimal_roundtrip() {
348 let key = [0u8; 16];
349 let cipher = Hctr3Fp_128_Decimal::new(&key);
350
351 let mut plaintext = vec![0u8; 40];
352 for i in 0..38 {
353 plaintext[i] = (i % 10) as u8;
354 }
355 plaintext[38] = 2;
356 plaintext[39] = 5;
357
358 let mut ciphertext = vec![0u8; plaintext.len()];
359 let mut decrypted = vec![0u8; plaintext.len()];
360
361 cipher
362 .encrypt(&plaintext, b"tweak", &mut ciphertext)
363 .unwrap();
364
365 for &d in &ciphertext {
366 assert!(d < 10);
367 }
368
369 cipher
370 .decrypt(&ciphertext, b"tweak", &mut decrypted)
371 .unwrap();
372 assert_eq!(plaintext, decrypted);
373 }
374
375 #[test]
376 fn test_hctr3fp_hex_roundtrip() {
377 let key = [0u8; 16];
378 let cipher = Hctr3Fp_128_Hex::new(&key);
379
380 let plaintext: Vec<u8> = (0..33).map(|i| (i % 16) as u8).collect();
381 let mut ciphertext = vec![0u8; plaintext.len()];
382 let mut decrypted = vec![0u8; plaintext.len()];
383
384 cipher
385 .encrypt(&plaintext, b"tweak", &mut ciphertext)
386 .unwrap();
387
388 for &d in &ciphertext {
389 assert!(d < 16);
390 }
391
392 cipher
393 .decrypt(&ciphertext, b"tweak", &mut decrypted)
394 .unwrap();
395 assert_eq!(plaintext, decrypted);
396 }
397
398 #[test]
399 fn test_hctr3fp_decimal_nonzero_key() {
400 let key: [u8; 16] = core::array::from_fn(|i| (i + 1) as u8);
401 let cipher = Hctr3Fp_128_Decimal::new(&key);
402
403 let mut plaintext = vec![0u8; 40];
404 for i in 0..38 {
405 plaintext[i] = (i % 10) as u8;
406 }
407 plaintext[38] = 1;
408 plaintext[39] = 7;
409
410 let mut ciphertext = vec![0u8; plaintext.len()];
411 let mut decrypted = vec![0u8; plaintext.len()];
412
413 cipher
414 .encrypt(&plaintext, b"tweak", &mut ciphertext)
415 .unwrap();
416 assert_ne!(plaintext, ciphertext);
417
418 cipher
419 .decrypt(&ciphertext, b"tweak", &mut decrypted)
420 .unwrap();
421 assert_eq!(plaintext, decrypted);
422 }
423
424 #[test]
425 fn test_hctr3fp_decimal_minimum_length() {
426 let key = [0u8; 16];
427 let cipher = Hctr3Fp_128_Decimal::new(&key);
428
429 let mut plaintext = [5u8; 39];
431 plaintext[38] = 2;
432
433 let mut ciphertext = [0u8; 39];
434 let mut decrypted = [0u8; 39];
435
436 cipher.encrypt(&plaintext, b"", &mut ciphertext).unwrap();
437 cipher.decrypt(&ciphertext, b"", &mut decrypted).unwrap();
438
439 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
440 }
441
442 #[test]
443 fn test_hctr3fp_decimal_too_short() {
444 let key = [0u8; 16];
445 let cipher = Hctr3Fp_128_Decimal::new(&key);
446
447 let plaintext = [5u8; 38]; let mut ciphertext = [0u8; 38];
449
450 assert_eq!(
451 cipher.encrypt(&plaintext, b"", &mut ciphertext),
452 Err(Error::InputTooShort)
453 );
454 }
455
456 #[test]
457 fn test_hctr3fp_decimal_invalid_digit() {
458 let key = [0u8; 16];
459 let cipher = Hctr3Fp_128_Decimal::new(&key);
460
461 let mut plaintext = [5u8; 40];
462 plaintext[0] = 10; let mut ciphertext = [0u8; 40];
464
465 assert_eq!(
466 cipher.encrypt(&plaintext, b"", &mut ciphertext),
467 Err(Error::InvalidDigit)
468 );
469 }
470
471 #[test]
472 fn test_hctr3fp_different_tweaks() {
473 let key = [0u8; 16];
474 let cipher = Hctr3Fp_128_Decimal::new(&key);
475
476 let plaintext = [5u8; 40];
477 let mut ciphertext1 = [0u8; 40];
478 let mut ciphertext2 = [0u8; 40];
479
480 cipher
481 .encrypt(&plaintext, b"tweak1", &mut ciphertext1)
482 .unwrap();
483 cipher
484 .encrypt(&plaintext, b"tweak2", &mut ciphertext2)
485 .unwrap();
486
487 assert_ne!(ciphertext1, ciphertext2);
488 }
489
490 #[test]
491 fn test_hctr3fp_256_decimal_roundtrip() {
492 let key = [0u8; 32];
493 let cipher = Hctr3Fp_256_Decimal::new(&key);
494
495 let mut plaintext = vec![0u8; 50];
496 plaintext[0] = 5;
497 plaintext[1] = 7;
498 plaintext[2] = 9;
499 plaintext[38] = 3;
500 for i in 39..50 {
501 plaintext[i] = ((i - 39) % 10) as u8;
502 }
503
504 let mut ciphertext = vec![0u8; plaintext.len()];
505 let mut decrypted = vec![0u8; plaintext.len()];
506
507 cipher
508 .encrypt(&plaintext, b"tweak", &mut ciphertext)
509 .unwrap();
510 cipher
511 .decrypt(&ciphertext, b"tweak", &mut decrypted)
512 .unwrap();
513
514 assert_eq!(plaintext, decrypted);
515 }
516
517 #[test]
518 fn test_hctr3fp_decimal_zeros() {
519 let key = [0u8; 16];
520 let cipher = Hctr3Fp_128_Decimal::new(&key);
521
522 let plaintext = [0u8; 39];
523 let mut ciphertext = [0u8; 39];
524 let mut decrypted = [0u8; 39];
525
526 cipher.encrypt(&plaintext, b"", &mut ciphertext).unwrap();
527 cipher.decrypt(&ciphertext, b"", &mut decrypted).unwrap();
528
529 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
530 }
531
532 #[test]
533 fn test_hctr3fp_base64_roundtrip() {
534 let key = [0u8; 16];
535 let cipher = Hctr3Fp_128_Base64::new(&key);
536
537 let mut plaintext = vec![0u8; 23];
538 for i in 0..21 {
539 plaintext[i] = (i % 64) as u8;
540 }
541 plaintext[21] = 3;
542 plaintext[22] = 42;
543
544 let mut ciphertext = vec![0u8; plaintext.len()];
545 let mut decrypted = vec![0u8; plaintext.len()];
546
547 cipher
548 .encrypt(&plaintext, b"tweak", &mut ciphertext)
549 .unwrap();
550
551 for &d in &ciphertext {
552 assert!(d < 64);
553 }
554
555 cipher
556 .decrypt(&ciphertext, b"tweak", &mut decrypted)
557 .unwrap();
558 assert_eq!(plaintext, decrypted);
559 }
560
561 #[test]
562 fn test_hctr3fp_vs_hctr2fp_different_output() {
563 use crate::hctr2fp::Hctr2Fp_128_Decimal;
564
565 let key = [0u8; 16];
566 let hctr2fp = Hctr2Fp_128_Decimal::new(&key);
567 let hctr3fp = Hctr3Fp_128_Decimal::new(&key);
568
569 let plaintext = [0u8; 40];
570 let mut ciphertext2 = [0u8; 40];
571 let mut ciphertext3 = [0u8; 40];
572
573 hctr2fp
574 .encrypt(&plaintext, b"tweak", &mut ciphertext2)
575 .unwrap();
576 hctr3fp
577 .encrypt(&plaintext, b"tweak", &mut ciphertext3)
578 .unwrap();
579 assert_ne!(ciphertext2, ciphertext3);
580 }
581
582 #[test]
583 fn test_lfsr_next_produces_unique_states() {
584 let initial = [0x01u8; 16];
585 let mut state = initial;
586 let mut seen = std::collections::HashSet::new();
587 seen.insert(state);
588
589 for _ in 0..1000 {
590 state = lfsr_next_128(&state);
591 assert!(seen.insert(state), "LFSR produced duplicate state");
592 }
593 }
594
595 #[test]
596 fn test_hctr3fp_large_message() {
597 let key = [0u8; 16];
598 let cipher = Hctr3Fp_128_Hex::new(&key);
599
600 let plaintext: Vec<u8> = (0..256).map(|i| (i % 16) as u8).collect();
601 let mut ciphertext = vec![0u8; plaintext.len()];
602 let mut decrypted = vec![0u8; plaintext.len()];
603
604 cipher
605 .encrypt(&plaintext, b"large tweak", &mut ciphertext)
606 .unwrap();
607 cipher
608 .decrypt(&ciphertext, b"large tweak", &mut decrypted)
609 .unwrap();
610
611 assert_eq!(plaintext, decrypted);
612 }
613}