use std::{
collections::HashMap,
io::{self, BufReader, Read},
};
use aes::{
Aes128, Aes192, Aes256,
cipher::{
BlockDecryptMut, BlockEncryptMut, KeyIvInit, block_padding::Pkcs7,
generic_array::GenericArray,
},
};
use aes_kw::Kek;
use pbkdf2::pbkdf2_hmac;
use sha1::Sha1;
use sha2::Sha256;
use crate::{
backup::models::{
file::WrappedKey,
keyring::{ClassKeyData, EncryptionKey, ProtectionClassKey},
manifest::manifest_plist::ManifestData,
},
error::{BackupError, Result},
};
type Aes256CbcDec = cbc::Decryptor<Aes256>;
type Aes256CbcEnc = cbc::Encryptor<Aes256>;
pub const STREAM_BUFFER_SIZE: usize = 8 * 1024;
pub fn derive_key_from_password(
password: &[u8],
dpsl: &[u8],
dpic: u32,
salt: &[u8],
iter: u32,
) -> Result<EncryptionKey> {
let mut derived_pw = vec![0u8; 32]; let mut key = vec![0u8; 32];
pbkdf2_hmac::<Sha256>(password, dpsl, dpic, &mut derived_pw);
pbkdf2_hmac::<Sha1>(&derived_pw, salt, iter, &mut key);
Ok(key.into())
}
pub(crate) fn unlock_keys_from_manifest(
master_key: &EncryptionKey,
manifest: &ManifestData,
) -> Result<HashMap<u32, ProtectionClassKey>> {
if master_key.len() != 32 {
return Err(BackupError::Crypto(format!(
"Main key for unlocking class keys must be 32 bytes for AES-256, got {}",
master_key.len()
)));
}
let mut unlocked_keys = HashMap::new();
let key_ring = manifest
.key_ring
.as_ref()
.ok_or_else(|| BackupError::Crypto("BackupKeyBag not found in PlistInfo".to_string()))?;
for (&class_id, class_key_data) in &key_ring.class_keys {
match class_key_data {
ClassKeyData {
wpky: Some(wpky),
wrap: Some(wrap_bytes),
..
} => {
let wrap_val = u32::from_be_bytes(
wrap_bytes
.as_slice()
.try_into()
.map_err(|_| BackupError::KeyUnwrapFailed(class_id))?,
);
if wrap_val & 0x02 == 0 {
continue;
}
let unwrapped = aes_kw_unwrap(master_key, &WrappedKey::from(wpky.clone()))
.map_err(|_| BackupError::KeyUnwrapFailed(class_id))?;
unlocked_keys.insert(
class_id,
ProtectionClassKey {
class_id,
key: unwrapped,
},
);
}
_ => continue,
}
}
Ok(unlocked_keys)
}
pub fn aes_decrypt_cbc_with_padding(data: &[u8], key: &EncryptionKey) -> Result<Vec<u8>> {
if key.len() != 32 {
return Err(BackupError::InvalidCryptoDataLength {
expected: 32,
actual: key.len(),
});
}
let data_len = if data.len().is_multiple_of(16) {
data.len()
} else {
data.len() - (data.len() % 16)
};
let iv_bytes = [0u8; 16];
let iv = GenericArray::from_slice(&iv_bytes);
let mut buf = if data.len() == data_len {
data.to_vec()
} else {
data[..data_len].to_vec()
};
let key_ga = GenericArray::from_slice(key);
let cipher = Aes256CbcDec::new(key_ga, iv);
let pt_len = cipher
.decrypt_padded_mut::<Pkcs7>(&mut buf)
.map_err(|e| BackupError::Crypto(format!("AES CBC decryption error (padding): {e:?}")))?
.len();
buf.truncate(pt_len);
Ok(buf)
}
pub fn aes_encrypt_cbc_with_padding(data: &[u8], key: &EncryptionKey) -> Result<Vec<u8>> {
if key.len() != 32 {
return Err(BackupError::InvalidCryptoDataLength {
expected: 32,
actual: key.len(),
});
}
let iv_bytes = [0u8; 16];
let iv = GenericArray::from_slice(&iv_bytes);
let mut buffer = vec![0u8; data.len() + 16]; buffer[..data.len()].copy_from_slice(data);
let key_ga = GenericArray::from_slice(key);
let cipher = Aes256CbcEnc::new(key_ga, iv);
let ct_len = cipher
.encrypt_padded_mut::<Pkcs7>(&mut buffer, data.len())
.map_err(|e| BackupError::Crypto(format!("AES CBC encryption error (padding): {e:?}")))?
.len();
buffer.truncate(ct_len);
Ok(buffer)
}
pub(crate) fn aes_kw_unwrap(
kek_bytes: &EncryptionKey,
wrapped_data: &WrappedKey,
) -> Result<EncryptionKey> {
if wrapped_data.len() <= 8 {
return Err(BackupError::Crypto(format!(
"Wrapped data is too short ({} bytes)",
wrapped_data.len()
)));
}
let mut unwrapped = vec![0u8; wrapped_data.len() - 8]; match kek_bytes.len() {
16 => {
let kek = Kek::<Aes128>::new(GenericArray::from_slice(kek_bytes));
kek.unwrap(wrapped_data, &mut unwrapped)
.map_err(|_| BackupError::Crypto("AES 128 Key Unwrap failed".to_string()))?;
}
24 => {
let kek = Kek::<Aes192>::new(GenericArray::from_slice(kek_bytes));
kek.unwrap(wrapped_data, &mut unwrapped)
.map_err(|_| BackupError::Crypto("AES 192 Key Unwrap failed".to_string()))?;
}
32 => {
let kek = Kek::<Aes256>::new(GenericArray::from_slice(kek_bytes));
kek.unwrap(wrapped_data, &mut unwrapped)
.map_err(|_| BackupError::Crypto("AES 256 Key Unwrap failed".to_string()))?;
}
_ => {
return Err(BackupError::Crypto(format!(
"Invalid KEK length: {} bytes (must be 16, 24, or 32)",
kek_bytes.len()
)));
}
}
Ok(unwrapped.into())
}
pub struct AesCbcDecryptReader<R: Read> {
inner: R,
cipher: Aes256CbcDec,
lookahead: [u8; 16],
buf: Vec<u8>,
buf_pos: usize,
eof: bool,
}
impl<R: Read> AesCbcDecryptReader<R> {
pub fn from(reader: R, key: &EncryptionKey) -> Result<AesCbcDecryptReader<BufReader<R>>> {
if key.len() != 32 {
return Err(BackupError::InvalidCryptoDataLength {
expected: 32,
actual: key.len(),
});
}
let mut buf_reader = BufReader::with_capacity(STREAM_BUFFER_SIZE, reader);
let mut lookahead = [0u8; 16];
let n = buf_reader
.read(&mut lookahead)
.map_err(|e| BackupError::Crypto(format!("I/O error: {e}")))?;
if n == 0 {
return Err(BackupError::Crypto("Ciphertext empty".into()));
}
if n != 16 {
return Err(BackupError::Crypto(format!(
"Unexpected ciphertext length: {n}"
)));
}
let iv = GenericArray::from_slice(&[0u8; 16]);
let key_ga = GenericArray::from_slice(key);
let cipher = Aes256CbcDec::new(key_ga, iv);
Ok(AesCbcDecryptReader {
inner: buf_reader,
cipher,
lookahead,
buf: Vec::new(),
buf_pos: 0,
eof: false,
})
}
}
impl<R: Read> Read for AesCbcDecryptReader<R> {
fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
let mut written = 0;
while written < out.len() {
if self.buf_pos < self.buf.len() {
let to_copy = (self.buf.len() - self.buf_pos).min(out.len() - written);
out[written..written + to_copy]
.copy_from_slice(&self.buf[self.buf_pos..self.buf_pos + to_copy]);
self.buf_pos += to_copy;
written += to_copy;
continue;
}
if self.eof {
break;
}
let mut chunk = vec![0u8; STREAM_BUFFER_SIZE];
let n = self.inner.read(&mut chunk)?;
if n == 0 {
let mut tail = self.lookahead.to_vec();
let pt = self
.cipher
.clone()
.decrypt_padded_mut::<Pkcs7>(&mut tail)
.map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("Padding error: {e:?}"))
})?;
self.buf.clear();
self.buf.extend_from_slice(pt);
self.buf_pos = 0;
self.eof = true;
} else {
let mut data = self.lookahead.to_vec();
data.extend_from_slice(&chunk[..n]);
let total = data.len();
let num_blocks = total / 16;
if num_blocks == 0 {
} else {
self.buf.clear();
for i in 0..(num_blocks - 1) {
let start = i * 16;
let mut arr = GenericArray::clone_from_slice(&data[start..start + 16]);
self.cipher.decrypt_block_mut(&mut arr);
self.buf.extend_from_slice(&arr);
}
let last_start = (num_blocks - 1) * 16;
self.lookahead
.copy_from_slice(&data[last_start..last_start + 16]);
self.buf_pos = 0;
}
}
}
Ok(written)
}
}
#[cfg(test)]
mod tests {
use super::*;
use aes::cipher::generic_array::GenericArray;
use aes::{Aes128, Aes192, Aes256};
use aes_kw::Kek;
#[test]
fn test_derive_key_consistency() {
let salt = b"saltsalt";
let key1 = derive_key_from_password(b"password", &[], 0, salt, 1000).unwrap();
let key2 = derive_key_from_password(b"password", &[], 0, salt, 1000).unwrap();
assert_eq!(key1, key2);
assert_eq!(key1.len(), 32);
}
#[test]
fn test_aes_cbc_roundtrip() {
let key = vec![0x42; 32].into();
let data = b"The quick brown fox jumps over the lazy dog";
let ciphertext = aes_encrypt_cbc_with_padding(data, &key).unwrap();
assert_ne!(ciphertext, data);
let plaintext = aes_decrypt_cbc_with_padding(&ciphertext, &key).unwrap();
assert_eq!(plaintext, data);
}
fn wrap_and_unwrap(kek_bytes: &EncryptionKey, plain: &[u8]) {
let mut wrapped = vec![0u8; plain.len() + 8];
match kek_bytes.len() {
16 => {
let kek = Kek::<Aes128>::new(GenericArray::from_slice(kek_bytes));
kek.wrap(plain, &mut wrapped).unwrap();
}
24 => {
let kek = Kek::<Aes192>::new(GenericArray::from_slice(kek_bytes));
kek.wrap(plain, &mut wrapped).unwrap();
}
32 => {
let kek = Kek::<Aes256>::new(GenericArray::from_slice(kek_bytes));
kek.wrap(plain, &mut wrapped).unwrap();
}
_ => panic!("Invalid KEK length"),
}
let unwrapped = aes_kw_unwrap(kek_bytes, &wrapped.into()).unwrap();
assert_eq!(unwrapped, plain.to_vec().into());
}
#[test]
fn test_key_wrap_unwrap_128() {
let kek = vec![0x0b; 16].into();
let data = b"12345678ABCDEFGH";
wrap_and_unwrap(&kek, data);
}
#[test]
fn test_key_wrap_unwrap_192() {
let kek = vec![0x0c; 24].into();
let data = b"12345678ABCDEFGH";
wrap_and_unwrap(&kek, data);
}
#[test]
fn test_key_wrap_unwrap_256() {
let kek = vec![0x0d; 32].into();
let data = b"12345678ABCDEFGH";
wrap_and_unwrap(&kek, data);
}
#[test]
fn test_aes_kw_unwrap_errors() {
let kek = vec![0u8; 16].into();
let short_data = vec![0u8; 8];
let err = aes_kw_unwrap(&kek, &WrappedKey::from(short_data)).unwrap_err();
match err {
BackupError::Crypto(msg) => assert!(msg.contains("too short")),
_ => panic!("Expected Crypto error for short data"),
}
let invalid_kek = vec![0u8; 10].into();
let wrapped = vec![0u8; 16];
let err2 = aes_kw_unwrap(&invalid_kek, &WrappedKey::from(wrapped)).unwrap_err();
match err2 {
BackupError::Crypto(msg) => assert!(msg.contains("Invalid KEK length")),
_ => panic!("Expected Crypto error for invalid KEK length"),
}
}
#[test]
fn test_aes_encrypt_invalid_key_length() {
let data = b"hello";
let short_key = vec![0u8; 16].into();
let err = aes_encrypt_cbc_with_padding(data, &short_key).unwrap_err();
match err {
BackupError::InvalidCryptoDataLength {
actual,
expected: _,
} => assert_eq!(actual, 16),
_ => panic!("Expected InvalidCryptoDataLength for short key"),
}
let long_key = vec![0u8; 64].into();
let err2 = aes_encrypt_cbc_with_padding(data, &long_key).unwrap_err();
match err2 {
BackupError::InvalidCryptoDataLength {
actual,
expected: _,
} => assert_eq!(actual, 64),
_ => panic!("Expected InvalidCryptoDataLength for long key"),
}
}
#[test]
fn test_aes_decrypt_invalid_key_length() {
let cipher = vec![0u8; 16];
let short_key = vec![0u8; 24].into();
let err = aes_decrypt_cbc_with_padding(&cipher, &short_key).unwrap_err();
match err {
BackupError::InvalidCryptoDataLength { actual, expected } => {
assert_eq!(actual, 24);
assert_eq!(expected, 32);
}
_ => panic!("Expected InvalidCryptoDataLength with actual=24, expected=32"),
}
}
#[test]
fn test_derive_key_length_and_determinism() {
let password = b"password";
let dpsl = b"salt1";
let dpic = 2;
let salt = b"salt2";
let iter = 3;
let key1 = derive_key_from_password(password, dpsl, dpic, salt, iter).unwrap();
let key2 = derive_key_from_password(password, dpsl, dpic, salt, iter).unwrap();
assert_eq!(key1.len(), 32);
assert_eq!(key1, key2);
}
#[test]
fn test_aes_encrypt_decrypt_empty_data() {
let key = vec![0u8; 32].into();
let ciphertext = aes_encrypt_cbc_with_padding(&[], &key).unwrap();
assert_eq!(ciphertext.len(), 16);
let plaintext = aes_decrypt_cbc_with_padding(&ciphertext, &key).unwrap();
assert_eq!(plaintext.len(), 0);
}
#[test]
fn test_aes_decrypt_trims_non_multiple_of_block_size() {
let key = vec![0u8; 32].into();
let original = b"hello";
let mut ciphertext = aes_encrypt_cbc_with_padding(original, &key).unwrap();
ciphertext.extend(&[0u8; 5]);
let plaintext = aes_decrypt_cbc_with_padding(&ciphertext, &key).unwrap();
assert_eq!(plaintext, original);
}
}