aes_keywrap/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4use std::error::Error;
5use std::fmt;
6
7use aes::cipher::{Array, BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
8use aes::{Aes128, Aes256};
9use byteorder::{BigEndian, ByteOrder};
10
11const FEISTEL_ROUNDS: usize = 6;
12const KW_IV: [u8; 8] = [0xa6u8; 8];
13
14#[derive(Debug, Eq, PartialEq)]
15pub enum KeywrapError {
16    /// Input is too big.
17    TooBig,
18    /// Input is too small.
19    TooSmall,
20    /// Ciphertext has invalid padding.
21    Unpadded,
22    /// Input length is not a multiple of 8 bytes (required for AES-KW).
23    NotAligned,
24    /// The ciphertext is not valid for the expected length.
25    InvalidExpectedLen,
26    /// The ciphertext couldn't be authenticated.
27    AuthenticationFailed,
28}
29
30impl Error for KeywrapError {}
31
32impl fmt::Display for KeywrapError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
34        match self {
35            KeywrapError::TooBig => f.write_str("Input too big"),
36            KeywrapError::TooSmall => f.write_str("Input too small"),
37            KeywrapError::Unpadded => f.write_str("Padding error"),
38            KeywrapError::NotAligned => f.write_str("Input length not a multiple of 8 bytes"),
39            KeywrapError::InvalidExpectedLen => f.write_str("Invalid expected length"),
40            KeywrapError::AuthenticationFailed => f.write_str("Authentication failed"),
41        }
42    }
43}
44
45#[derive(Debug)]
46pub struct Aes256KeyWrap {
47    aes: Aes256,
48}
49
50impl Aes256KeyWrap {
51    pub const KEY_BYTES: usize = 32;
52    pub const MAC_BYTES: usize = 8;
53
54    pub fn new(key: &[u8; Self::KEY_BYTES]) -> Self {
55        Aes256KeyWrap {
56            aes: Aes256::new(key.into()),
57        }
58    }
59
60    pub fn encapsulate(&self, input: &[u8]) -> Result<Vec<u8>, KeywrapError> {
61        if input.len() > u32::MAX as usize || input.len() as u64 >= u64::MAX / FEISTEL_ROUNDS as u64
62        {
63            return Err(KeywrapError::TooBig);
64        }
65        let mut aiv: [u8; 8] = [0xa6u8, 0x59, 0x59, 0xa6, 0, 0, 0, 0];
66        BigEndian::write_u32(&mut aiv[4..8], input.len() as u32);
67        let mut block = Array([0u8; 16]);
68        block[0..8].copy_from_slice(&aiv);
69
70        if input.len() == 8 {
71            block[8..16].copy_from_slice(input);
72            self.aes.encrypt_block(&mut block);
73            return Ok(block.to_vec());
74        }
75
76        let mut counter = 0u64;
77        let mut counter_bin = [0u8; 8];
78        let mut output = vec![0u8; ((input.len() + 7) & !7) + Self::MAC_BYTES];
79        output[8..][..input.len()].copy_from_slice(input);
80        for _ in 0..FEISTEL_ROUNDS {
81            let mut i = 8;
82            while i <= (input.len() + 7) & !7 {
83                block[8..16].copy_from_slice(&output[i..][0..8]);
84                self.aes.encrypt_block(&mut block);
85                counter += 1;
86                BigEndian::write_u64(&mut counter_bin, counter);
87                block[0..8]
88                    .iter_mut()
89                    .zip(counter_bin.iter())
90                    .for_each(|(a, b)| *a ^= b);
91                output[i..i + 8].copy_from_slice(&block[8..16]);
92                i += 8;
93            }
94        }
95        output[0..8].copy_from_slice(&block[0..8]);
96        Ok(output)
97    }
98
99    pub fn decapsulate(&self, input: &[u8], expected_len: usize) -> Result<Vec<u8>, KeywrapError> {
100        if !input.len().is_multiple_of(8) {
101            return Err(KeywrapError::Unpadded);
102        }
103        let output_len = input
104            .len()
105            .checked_sub(Self::MAC_BYTES)
106            .ok_or(KeywrapError::TooSmall)?;
107        if output_len > u32::MAX as usize || output_len as u64 >= u64::MAX / FEISTEL_ROUNDS as u64 {
108            return Err(KeywrapError::TooBig);
109        }
110        if expected_len > output_len || (expected_len & !7) > output_len {
111            return Err(KeywrapError::InvalidExpectedLen);
112        }
113        let mut output = vec![0u8; output_len];
114        let mut aiv: [u8; 8] = [0xa6u8, 0x59, 0x59, 0xa6, 0, 0, 0, 0];
115        BigEndian::write_u32(&mut aiv[4..8], expected_len as u32);
116
117        let mut block = Array([0u8; 16]);
118
119        if output.len() == 8 {
120            block.copy_from_slice(input);
121            self.aes.decrypt_block(&mut block);
122            let c = block[0..8]
123                .iter()
124                .zip(aiv.iter())
125                .fold(0, |acc, (a, b)| acc | (a ^ b));
126            if c != 0 {
127                return Err(KeywrapError::AuthenticationFailed);
128            }
129            output[0..8].copy_from_slice(&block[8..16]);
130            return Ok(output);
131        }
132
133        output.copy_from_slice(&input[8..]);
134        block[0..8].copy_from_slice(&input[0..8]);
135        let mut counter = (FEISTEL_ROUNDS * output.len() / 8) as u64;
136        let mut counter_bin = [0u8; 8];
137        for _ in 0..FEISTEL_ROUNDS {
138            let mut i = output.len();
139            while i >= 8 {
140                i -= 8;
141                block[8..16].copy_from_slice(&output[i..][0..8]);
142                BigEndian::write_u64(&mut counter_bin, counter);
143                counter -= 1;
144                block[0..8]
145                    .iter_mut()
146                    .zip(counter_bin.iter())
147                    .for_each(|(a, b)| *a ^= b);
148                self.aes.decrypt_block(&mut block);
149                output[i..][0..8].copy_from_slice(&block[8..16]);
150            }
151        }
152        let c = block[0..8]
153            .iter()
154            .zip(aiv.iter())
155            .fold(0, |acc, (a, b)| acc | (a ^ b));
156        if c != 0 {
157            return Err(KeywrapError::AuthenticationFailed);
158        }
159        Ok(output)
160    }
161}
162
163// --
164
165#[derive(Debug)]
166pub struct Aes128KeyWrap {
167    aes: Aes128,
168}
169
170impl Aes128KeyWrap {
171    pub const KEY_BYTES: usize = 16;
172    pub const MAC_BYTES: usize = 8;
173
174    pub fn new(key: &[u8; Self::KEY_BYTES]) -> Self {
175        Aes128KeyWrap {
176            aes: Aes128::new(key.into()),
177        }
178    }
179
180    pub fn encapsulate(&self, input: &[u8]) -> Result<Vec<u8>, KeywrapError> {
181        if input.len() > u32::MAX as usize || input.len() as u64 >= u64::MAX / FEISTEL_ROUNDS as u64
182        {
183            return Err(KeywrapError::TooBig);
184        }
185        let mut aiv: [u8; 8] = [0xa6u8, 0x59, 0x59, 0xa6, 0, 0, 0, 0];
186        BigEndian::write_u32(&mut aiv[4..8], input.len() as u32);
187        let mut block = Array([0u8; 16]);
188        block[0..8].copy_from_slice(&aiv);
189
190        if input.len() == 8 {
191            block[8..16].copy_from_slice(input);
192            self.aes.encrypt_block(&mut block);
193            return Ok(block.to_vec());
194        }
195
196        let mut counter = 0u64;
197        let mut counter_bin = [0u8; 8];
198        let mut output = vec![0u8; ((input.len() + 7) & !7) + Self::MAC_BYTES];
199        output[8..][..input.len()].copy_from_slice(input);
200        for _ in 0..FEISTEL_ROUNDS {
201            let mut i = 8;
202            while i <= (input.len() + 7) & !7 {
203                block[8..16].copy_from_slice(&output[i..][0..8]);
204                self.aes.encrypt_block(&mut block);
205                counter += 1;
206                BigEndian::write_u64(&mut counter_bin, counter);
207                block[0..8]
208                    .iter_mut()
209                    .zip(counter_bin.iter())
210                    .for_each(|(a, b)| *a ^= b);
211                output[i..i + 8].copy_from_slice(&block[8..16]);
212                i += 8;
213            }
214        }
215        output[0..8].copy_from_slice(&block[0..8]);
216        Ok(output)
217    }
218
219    pub fn decapsulate(&self, input: &[u8], expected_len: usize) -> Result<Vec<u8>, KeywrapError> {
220        if !input.len().is_multiple_of(8) {
221            return Err(KeywrapError::Unpadded);
222        }
223        let output_len = input
224            .len()
225            .checked_sub(Self::MAC_BYTES)
226            .ok_or(KeywrapError::TooSmall)?;
227        if output_len > u32::MAX as usize || output_len as u64 >= u64::MAX / FEISTEL_ROUNDS as u64 {
228            return Err(KeywrapError::TooBig);
229        }
230        if expected_len > output_len || (expected_len & !7) > output_len {
231            return Err(KeywrapError::InvalidExpectedLen);
232        }
233        let mut output = vec![0u8; output_len];
234        let mut aiv: [u8; 8] = [0xa6u8, 0x59, 0x59, 0xa6, 0, 0, 0, 0];
235        BigEndian::write_u32(&mut aiv[4..8], expected_len as u32);
236
237        let mut block = Array([0u8; 16]);
238
239        if output.len() == 8 {
240            block.copy_from_slice(input);
241            self.aes.decrypt_block(&mut block);
242            let c = block[0..8]
243                .iter()
244                .zip(aiv.iter())
245                .fold(0, |acc, (a, b)| acc | (a ^ b));
246            if c != 0 {
247                return Err(KeywrapError::AuthenticationFailed);
248            }
249            output[0..8].copy_from_slice(&block[8..16]);
250            return Ok(output);
251        }
252
253        output.copy_from_slice(&input[8..]);
254        block[0..8].copy_from_slice(&input[0..8]);
255        let mut counter = (FEISTEL_ROUNDS * output.len() / 8) as u64;
256        let mut counter_bin = [0u8; 8];
257        for _ in 0..FEISTEL_ROUNDS {
258            let mut i = output.len();
259            while i >= 8 {
260                i -= 8;
261                block[8..16].copy_from_slice(&output[i..][0..8]);
262                BigEndian::write_u64(&mut counter_bin, counter);
263                counter -= 1;
264                block[0..8]
265                    .iter_mut()
266                    .zip(counter_bin.iter())
267                    .for_each(|(a, b)| *a ^= b);
268                self.aes.decrypt_block(&mut block);
269                output[i..][0..8].copy_from_slice(&block[8..16]);
270            }
271        }
272        let c = block[0..8]
273            .iter()
274            .zip(aiv.iter())
275            .fold(0, |acc, (a, b)| acc | (a ^ b));
276        if c != 0 {
277            return Err(KeywrapError::AuthenticationFailed);
278        }
279        Ok(output)
280    }
281}
282
283// -- AES-KW (RFC 3394) - requires 8-byte aligned input
284
285#[derive(Debug)]
286pub struct Aes256KeyWrapAligned {
287    aes: Aes256,
288}
289
290impl Aes256KeyWrapAligned {
291    pub const KEY_BYTES: usize = 32;
292    pub const MAC_BYTES: usize = 8;
293
294    pub fn new(key: &[u8; Self::KEY_BYTES]) -> Self {
295        Aes256KeyWrapAligned {
296            aes: Aes256::new(key.into()),
297        }
298    }
299
300    pub fn encapsulate(&self, input: &[u8]) -> Result<Vec<u8>, KeywrapError> {
301        if !input.len().is_multiple_of(8) {
302            return Err(KeywrapError::NotAligned);
303        }
304        if input.len() < 16 {
305            return Err(KeywrapError::TooSmall);
306        }
307        if input.len() as u64 >= u64::MAX / FEISTEL_ROUNDS as u64 {
308            return Err(KeywrapError::TooBig);
309        }
310
311        let mut block = Array([0u8; 16]);
312        block[0..8].copy_from_slice(&KW_IV);
313
314        let mut counter = 0u64;
315        let mut counter_bin = [0u8; 8];
316        let mut output = vec![0u8; input.len() + Self::MAC_BYTES];
317        output[8..].copy_from_slice(input);
318        for _ in 0..FEISTEL_ROUNDS {
319            let mut i = 8;
320            while i < output.len() {
321                block[8..16].copy_from_slice(&output[i..][0..8]);
322                self.aes.encrypt_block(&mut block);
323                counter += 1;
324                BigEndian::write_u64(&mut counter_bin, counter);
325                block[0..8]
326                    .iter_mut()
327                    .zip(counter_bin.iter())
328                    .for_each(|(a, b)| *a ^= b);
329                output[i..i + 8].copy_from_slice(&block[8..16]);
330                i += 8;
331            }
332        }
333        output[0..8].copy_from_slice(&block[0..8]);
334        Ok(output)
335    }
336
337    pub fn decapsulate(&self, input: &[u8]) -> Result<Vec<u8>, KeywrapError> {
338        if !input.len().is_multiple_of(8) {
339            return Err(KeywrapError::NotAligned);
340        }
341        let output_len = input
342            .len()
343            .checked_sub(Self::MAC_BYTES)
344            .ok_or(KeywrapError::TooSmall)?;
345        if output_len < 16 {
346            return Err(KeywrapError::TooSmall);
347        }
348        if output_len as u64 >= u64::MAX / FEISTEL_ROUNDS as u64 {
349            return Err(KeywrapError::TooBig);
350        }
351
352        let mut output = vec![0u8; output_len];
353        let mut block = Array([0u8; 16]);
354
355        output.copy_from_slice(&input[8..]);
356        block[0..8].copy_from_slice(&input[0..8]);
357        let mut counter = (FEISTEL_ROUNDS * output.len() / 8) as u64;
358        let mut counter_bin = [0u8; 8];
359        for _ in 0..FEISTEL_ROUNDS {
360            let mut i = output.len();
361            while i >= 8 {
362                i -= 8;
363                block[8..16].copy_from_slice(&output[i..][0..8]);
364                BigEndian::write_u64(&mut counter_bin, counter);
365                counter -= 1;
366                block[0..8]
367                    .iter_mut()
368                    .zip(counter_bin.iter())
369                    .for_each(|(a, b)| *a ^= b);
370                self.aes.decrypt_block(&mut block);
371                output[i..][0..8].copy_from_slice(&block[8..16]);
372            }
373        }
374        let c = block[0..8]
375            .iter()
376            .zip(KW_IV.iter())
377            .fold(0, |acc, (a, b)| acc | (a ^ b));
378        if c != 0 {
379            return Err(KeywrapError::AuthenticationFailed);
380        }
381        Ok(output)
382    }
383}
384
385#[derive(Debug)]
386pub struct Aes128KeyWrapAligned {
387    aes: Aes128,
388}
389
390impl Aes128KeyWrapAligned {
391    pub const KEY_BYTES: usize = 16;
392    pub const MAC_BYTES: usize = 8;
393
394    pub fn new(key: &[u8; Self::KEY_BYTES]) -> Self {
395        Aes128KeyWrapAligned {
396            aes: Aes128::new(key.into()),
397        }
398    }
399
400    pub fn encapsulate(&self, input: &[u8]) -> Result<Vec<u8>, KeywrapError> {
401        if !input.len().is_multiple_of(8) {
402            return Err(KeywrapError::NotAligned);
403        }
404        if input.len() < 16 {
405            return Err(KeywrapError::TooSmall);
406        }
407        if input.len() as u64 >= u64::MAX / FEISTEL_ROUNDS as u64 {
408            return Err(KeywrapError::TooBig);
409        }
410
411        let mut block = Array([0u8; 16]);
412        block[0..8].copy_from_slice(&KW_IV);
413
414        let mut counter = 0u64;
415        let mut counter_bin = [0u8; 8];
416        let mut output = vec![0u8; input.len() + Self::MAC_BYTES];
417        output[8..].copy_from_slice(input);
418        for _ in 0..FEISTEL_ROUNDS {
419            let mut i = 8;
420            while i < output.len() {
421                block[8..16].copy_from_slice(&output[i..][0..8]);
422                self.aes.encrypt_block(&mut block);
423                counter += 1;
424                BigEndian::write_u64(&mut counter_bin, counter);
425                block[0..8]
426                    .iter_mut()
427                    .zip(counter_bin.iter())
428                    .for_each(|(a, b)| *a ^= b);
429                output[i..i + 8].copy_from_slice(&block[8..16]);
430                i += 8;
431            }
432        }
433        output[0..8].copy_from_slice(&block[0..8]);
434        Ok(output)
435    }
436
437    pub fn decapsulate(&self, input: &[u8]) -> Result<Vec<u8>, KeywrapError> {
438        if !input.len().is_multiple_of(8) {
439            return Err(KeywrapError::NotAligned);
440        }
441        let output_len = input
442            .len()
443            .checked_sub(Self::MAC_BYTES)
444            .ok_or(KeywrapError::TooSmall)?;
445        if output_len < 16 {
446            return Err(KeywrapError::TooSmall);
447        }
448        if output_len as u64 >= u64::MAX / FEISTEL_ROUNDS as u64 {
449            return Err(KeywrapError::TooBig);
450        }
451
452        let mut output = vec![0u8; output_len];
453        let mut block = Array([0u8; 16]);
454
455        output.copy_from_slice(&input[8..]);
456        block[0..8].copy_from_slice(&input[0..8]);
457        let mut counter = (FEISTEL_ROUNDS * output.len() / 8) as u64;
458        let mut counter_bin = [0u8; 8];
459        for _ in 0..FEISTEL_ROUNDS {
460            let mut i = output.len();
461            while i >= 8 {
462                i -= 8;
463                block[8..16].copy_from_slice(&output[i..][0..8]);
464                BigEndian::write_u64(&mut counter_bin, counter);
465                counter -= 1;
466                block[0..8]
467                    .iter_mut()
468                    .zip(counter_bin.iter())
469                    .for_each(|(a, b)| *a ^= b);
470                self.aes.decrypt_block(&mut block);
471                output[i..][0..8].copy_from_slice(&block[8..16]);
472            }
473        }
474        let c = block[0..8]
475            .iter()
476            .zip(KW_IV.iter())
477            .fold(0, |acc, (a, b)| acc | (a ^ b));
478        if c != 0 {
479            return Err(KeywrapError::AuthenticationFailed);
480        }
481        Ok(output)
482    }
483}
484
485// --
486
487#[test]
488fn kw_aligned_roundtrip() {
489    let secret = b"1234567812345678";
490    let key = [42u8; 32];
491    let kw = Aes256KeyWrapAligned::new(&key);
492    let wrapped = kw.encapsulate(secret).unwrap();
493    let unwrapped = kw.decapsulate(&wrapped).unwrap();
494    assert_eq!(secret, unwrapped.as_slice());
495}
496
497#[test]
498fn kw_aligned_rejects_unaligned() {
499    let secret = b"12345678901234567"; // 17 bytes, not aligned
500    let key = [42u8; 32];
501    let kw = Aes256KeyWrapAligned::new(&key);
502    assert_eq!(kw.encapsulate(secret), Err(KeywrapError::NotAligned));
503}
504
505#[test]
506fn kw_aligned_rejects_small() {
507    let secret = b"12345678"; // 8 bytes, too small (need 16)
508    let key = [42u8; 32];
509    let kw = Aes256KeyWrapAligned::new(&key);
510    assert_eq!(kw.encapsulate(secret), Err(KeywrapError::TooSmall));
511}
512
513#[test]
514fn kw_rfc3394_test_vector() {
515    // RFC 3394 Section 4.1 - 128-bit KEK, 128-bit Key Data
516    let kek = [
517        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
518        0x0F,
519    ];
520    let key_data = [
521        0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE,
522        0xFF,
523    ];
524    let expected = [
525        0x1F, 0xA6, 0x8B, 0x0A, 0x81, 0x12, 0xB4, 0x47, 0xAE, 0xF3, 0x4B, 0xD8, 0xFB, 0x5A, 0x7B,
526        0x82, 0x9D, 0x3E, 0x86, 0x23, 0x71, 0xD2, 0xCF, 0xE5,
527    ];
528
529    let kw = Aes128KeyWrapAligned::new(&kek);
530    let wrapped = kw.encapsulate(&key_data).unwrap();
531    assert_eq!(wrapped, expected);
532
533    let unwrapped = kw.decapsulate(&wrapped).unwrap();
534    assert_eq!(unwrapped, key_data);
535}
536
537#[test]
538fn aligned() {
539    let secret = b"1234567812345678";
540    let key = [42u8; 32];
541    let kw = Aes256KeyWrap::new(&key);
542    let wrapped = kw.encapsulate(secret).unwrap();
543    let unwrapped = kw.decapsulate(&wrapped, secret.len()).unwrap();
544    assert_eq!(secret, unwrapped.as_slice());
545}
546
547#[test]
548fn not_aligned() {
549    let secret = b"1234567812345";
550    let key = [42u8; 32];
551    let kw = Aes256KeyWrap::new(&key);
552    let wrapped = kw.encapsulate(secret).unwrap();
553    let unwrapped = kw.decapsulate(&wrapped, secret.len()).unwrap();
554    assert_eq!(secret, &unwrapped.as_slice()[..secret.len()]);
555}
556
557#[test]
558fn singleblock() {
559    let secret = b"12345678";
560    let key = [42u8; 32];
561    let kw = Aes256KeyWrap::new(&key);
562    let wrapped = kw.encapsulate(secret).unwrap();
563    let unwrapped = kw.decapsulate(&wrapped, secret.len()).unwrap();
564    assert_eq!(secret, unwrapped.as_slice());
565}