use alloc::{vec, vec::Vec};
use der::asn1::BmpString;
use digest::{core_api::BlockSizeUser, Digest, FixedOutputReset, OutputSizeUser, Update};
use zeroize::{Zeroize, Zeroizing};
pub enum Pkcs12KeyType {
EncryptionKey = 1,
Iv = 2,
Mac = 3,
}
pub fn derive_key_utf8<D>(
password: &str,
salt: &[u8],
id: Pkcs12KeyType,
rounds: i32,
key_len: usize,
) -> der::Result<Vec<u8>>
where
D: Digest + FixedOutputReset + BlockSizeUser,
{
let password_bmp = BmpString::from_utf8(password)?;
Ok(derive_key_bmp::<D>(password_bmp, salt, id, rounds, key_len))
}
pub fn derive_key_bmp<D>(
password: BmpString,
salt: &[u8],
id: Pkcs12KeyType,
rounds: i32,
key_len: usize,
) -> Vec<u8>
where
D: Digest + FixedOutputReset + BlockSizeUser,
{
let mut password = Zeroizing::new(Vec::from(password.into_bytes()));
password.extend([0u8; 2]);
derive_key::<D>(&password, salt, id, rounds, key_len)
}
pub fn derive_key<D>(
pass: &[u8],
salt: &[u8],
id: Pkcs12KeyType,
rounds: i32,
key_len: usize,
) -> Vec<u8>
where
D: Digest + FixedOutputReset + BlockSizeUser,
{
let mut digest = D::new();
let output_size = <D as OutputSizeUser>::output_size();
let block_size = D::block_size();
let id_block = match id {
Pkcs12KeyType::EncryptionKey => vec![1u8; block_size],
Pkcs12KeyType::Iv => vec![2u8; block_size],
Pkcs12KeyType::Mac => vec![3u8; block_size],
};
let slen = block_size * ((salt.len() + block_size - 1) / block_size);
let plen = block_size * ((pass.len() + block_size - 1) / block_size);
let ilen = slen + plen;
let mut init_key = vec![0u8; ilen];
for i in 0..slen {
init_key[i] = salt[i % salt.len()];
}
for i in 0..plen {
init_key[slen + i] = pass[i % pass.len()];
}
let mut m = key_len;
let mut n = 0;
let mut out = vec![0u8; key_len];
loop {
<D as Update>::update(&mut digest, &id_block);
<D as Update>::update(&mut digest, &init_key);
let mut result = digest.finalize_fixed_reset();
for _ in 1..rounds {
<D as Update>::update(&mut digest, &result[0..output_size]);
result = digest.finalize_fixed_reset();
}
let new_bytes_num = m.min(output_size);
out[n..n + new_bytes_num].copy_from_slice(&result[0..new_bytes_num]);
n += new_bytes_num;
if m <= new_bytes_num {
break;
}
m -= new_bytes_num;
let mut j = 0;
while j < ilen {
let mut c = 1_u16;
let mut k = block_size - 1;
loop {
c += init_key[k + j] as u16 + result[k % output_size] as u16;
init_key[j + k] = (c & 0x00ff) as u8;
c >>= 8;
if k == 0 {
break;
}
k -= 1;
}
j += block_size;
}
}
init_key.zeroize();
out
}