use crate::hash::{Digest, Sha1, Sha256};
use alloc::vec;
use alloc::vec::Vec;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum PkcsHash {
Sha1,
Sha256,
}
impl PkcsHash {
fn u(self) -> usize {
match self {
PkcsHash::Sha1 => 20,
PkcsHash::Sha256 => 32,
}
}
fn v(self) -> usize {
64
}
}
pub(crate) const ID_MAC: u8 = 3;
pub(crate) const ID_KEY: u8 = 1;
pub(crate) const ID_IV: u8 = 2;
pub(crate) fn password_to_bmp(password: &str) -> Vec<u8> {
let mut out = Vec::with_capacity(password.len() * 2 + 2);
for unit in password.encode_utf16() {
out.extend_from_slice(&unit.to_be_bytes());
}
out.extend_from_slice(&[0x00, 0x00]); out
}
pub(crate) fn derive(
hash: PkcsHash,
password_bmp: &[u8],
salt: &[u8],
iterations: u32,
id: u8,
out: &mut [u8],
) {
match hash {
PkcsHash::Sha1 => derive_with::<Sha1>(hash, password_bmp, salt, iterations, id, out),
PkcsHash::Sha256 => derive_with::<Sha256>(hash, password_bmp, salt, iterations, id, out),
}
}
fn derive_with<D: Digest>(
hash: PkcsHash,
password_bmp: &[u8],
salt: &[u8],
iterations: u32,
id: u8,
out: &mut [u8],
) {
let u = hash.u();
let v = hash.v();
debug_assert_eq!(u, D::OUTPUT_LEN);
let d = vec![id; v];
let s = fill_blocks(salt, v);
let p = fill_blocks(password_bmp, v);
let mut i_buf = Vec::with_capacity(s.len() + p.len());
i_buf.extend_from_slice(&s);
i_buf.extend_from_slice(&p);
let n = out.len();
let mut produced = 0;
while produced < n {
let mut hasher = D::new();
hasher.update(&d);
hasher.update(&i_buf);
let mut a = hasher.finalize();
let iters = iterations.max(1); for _ in 1..iters {
let mut h = D::new();
h.update(a.as_ref());
a = h.finalize();
}
let a_bytes = a.as_ref();
let take = core::cmp::min(u, n - produced);
out[produced..produced + take].copy_from_slice(&a_bytes[..take]);
produced += take;
if produced >= n {
break;
}
let mut b = vec![0u8; v];
for (j, slot) in b.iter_mut().enumerate() {
*slot = a_bytes[j % u];
}
let k = i_buf.len() / v;
for j in 0..k {
let block = &mut i_buf[j * v..(j + 1) * v];
add_block(block, &b);
}
}
for byte in i_buf.iter_mut() {
*byte = 0;
}
let _ = core::hint::black_box(&i_buf);
}
fn fill_blocks(data: &[u8], v: usize) -> Vec<u8> {
if data.is_empty() {
return Vec::new();
}
let blocks = data.len().div_ceil(v);
let total = blocks * v;
let mut out = Vec::with_capacity(total);
for i in 0..total {
out.push(data[i % data.len()]);
}
out
}
fn add_block(block: &mut [u8], addend: &[u8]) {
debug_assert_eq!(block.len(), addend.len());
let mut carry: u16 = 1; for idx in (0..block.len()).rev() {
let sum = block[idx] as u16 + addend[idx] as u16 + carry;
block[idx] = (sum & 0xff) as u8;
carry = sum >> 8;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn password_bmp_has_trailing_nul() {
assert_eq!(password_to_bmp(""), vec![0x00, 0x00]);
assert_eq!(
password_to_bmp("ab"),
vec![0x00, 0x61, 0x00, 0x62, 0x00, 0x00]
);
}
#[test]
fn fill_blocks_repeats_and_pads() {
assert_eq!(fill_blocks(&[], 4), Vec::<u8>::new());
assert_eq!(fill_blocks(&[1, 2, 3], 4), vec![1, 2, 3, 1]);
assert_eq!(fill_blocks(&[1, 2, 3, 4], 4), vec![1, 2, 3, 4]);
}
#[test]
fn add_block_carries() {
let mut b = [0xffu8];
add_block(&mut b, &[0x00]);
assert_eq!(b, [0x00]);
let mut b = [0x00u8, 0xff];
add_block(&mut b, &[0x00, 0x00]);
assert_eq!(b, [0x01, 0x00]);
}
}