use aes::Aes256;
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit};
use crate::error::CryptoError;
const IV: u64 = 0xA6A6A6A6A6A6A6A6;
pub fn wrap(kek: &[u8; 32], plaintext_key: &[u8]) -> Result<Vec<u8>, CryptoError> {
let n = plaintext_key.len();
if !n.is_multiple_of(8) || n < 16 {
return Err(CryptoError::KeyWrap(
"key to wrap must be >= 16 bytes and multiple of 8".into(),
));
}
let cipher = Aes256::new(kek.into());
let n_blocks = n / 8;
let mut a = IV;
let mut r: Vec<u64> = plaintext_key
.chunks_exact(8)
.map(|chunk| u64::from_be_bytes(chunk.try_into().unwrap()))
.collect();
for j in 0..6u64 {
for (i, ri) in r.iter_mut().enumerate().take(n_blocks) {
let mut block = [0u8; 16];
block[..8].copy_from_slice(&a.to_be_bytes());
block[8..].copy_from_slice(&ri.to_be_bytes());
let b = aes::Block::from_mut_slice(&mut block);
cipher.encrypt_block(b);
let t = (n_blocks as u64) * j + (i as u64) + 1;
a = u64::from_be_bytes(block[..8].try_into().unwrap()) ^ t;
*ri = u64::from_be_bytes(block[8..].try_into().unwrap());
}
}
let mut output = Vec::with_capacity(8 + n);
output.extend_from_slice(&a.to_be_bytes());
for block in &r {
output.extend_from_slice(&block.to_be_bytes());
}
Ok(output)
}
pub fn unwrap(kek: &[u8; 32], ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let total = ciphertext.len();
if !total.is_multiple_of(8) || total < 24 {
return Err(CryptoError::KeyWrap(
"wrapped key must be >= 24 bytes and multiple of 8".into(),
));
}
let cipher = Aes256::new(kek.into());
let n_blocks = (total / 8) - 1;
let mut a = u64::from_be_bytes(ciphertext[..8].try_into().unwrap());
let mut r: Vec<u64> = ciphertext[8..]
.chunks_exact(8)
.map(|chunk| u64::from_be_bytes(chunk.try_into().unwrap()))
.collect();
for j in (0..6u64).rev() {
for i in (0..n_blocks).rev() {
let t = (n_blocks as u64) * j + (i as u64) + 1;
let mut block = [0u8; 16];
block[..8].copy_from_slice(&(a ^ t).to_be_bytes());
block[8..].copy_from_slice(&r[i].to_be_bytes());
let b = aes::Block::from_mut_slice(&mut block);
cipher.decrypt_block(b);
a = u64::from_be_bytes(block[..8].try_into().unwrap());
r[i] = u64::from_be_bytes(block[8..].try_into().unwrap());
}
}
if a != IV {
return Err(CryptoError::KeyWrap(
"key unwrap integrity check failed".into(),
));
}
let mut output = Vec::with_capacity(n_blocks * 8);
for block in &r {
output.extend_from_slice(&block.to_be_bytes());
}
Ok(output)
}