use crate::{Error, Result};
use alloc::vec::Vec;
use rand::RngCore;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct Key([u8; 32]);
impl Key {
pub const SIZE: usize = 32;
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != Self::SIZE {
return Err(Error::InvalidKeyLength {
expected: Self::SIZE,
actual: bytes.len(),
});
}
let mut key = [0u8; 32];
key.copy_from_slice(bytes);
Ok(Key(key))
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub fn to_base64(&self) -> String {
use base64::{Engine, engine::general_purpose::STANDARD};
STANDARD.encode(&self.0)
}
pub fn from_base64(encoded: &str) -> Result<Self> {
use base64::{Engine, engine::general_purpose::STANDARD};
let bytes = STANDARD.decode(encoded)?;
Self::from_bytes(&bytes)
}
}
impl AsRef<[u8]> for Key {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl core::fmt::Debug for Key {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Key")
.field("length", &Self::SIZE)
.finish_non_exhaustive()
}
}
pub fn generate_key() -> Key {
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key);
Key(key)
}
pub fn derive_key_hkdf(
input_key_material: &[u8],
salt: Option<&[u8]>,
info: &[u8],
) -> Result<Key> {
use hkdf::Hkdf;
use sha2::Sha256;
let hk = Hkdf::<Sha256>::new(salt, input_key_material);
let mut okm = [0u8; 32];
hk.expand(info, &mut okm)
.map_err(|e| Error::KeyDerivationFailed(e.to_string()))?;
Ok(Key(okm))
}
pub fn derive_key_pbkdf2(
password: &[u8],
salt: &[u8],
iterations: u32,
) -> Result<Key> {
use pbkdf2::pbkdf2_hmac;
use sha2::Sha256;
let mut key = [0u8; 32];
pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut key);
Ok(Key(key))
}
#[allow(dead_code)]
pub fn generate_salt(length: usize) -> Vec<u8> {
let mut salt = vec![0u8; length];
rand::thread_rng().fill_bytes(&mut salt);
salt
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_generation() {
let key1 = generate_key();
let key2 = generate_key();
assert_ne!(key1.as_bytes(), key2.as_bytes());
assert_eq!(key1.as_bytes().len(), 32);
}
#[test]
fn test_key_base64_roundtrip() {
let key = generate_key();
let encoded = key.to_base64();
let decoded = Key::from_base64(&encoded).unwrap();
assert_eq!(key.as_bytes(), decoded.as_bytes());
}
#[test]
fn test_pbkdf2_derivation() {
let password = b"test password";
let salt = b"random salt here";
let iterations = 1000;
let key1 = derive_key_pbkdf2(password, salt, iterations).unwrap();
let key2 = derive_key_pbkdf2(password, salt, iterations).unwrap();
assert_eq!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn test_hkdf_derivation() {
let ikm = b"input key material";
let salt = b"optional salt";
let info = b"context info";
let key1 = derive_key_hkdf(ikm, Some(salt), info).unwrap();
let key2 = derive_key_hkdf(ikm, Some(salt), info).unwrap();
assert_eq!(key1.as_bytes(), key2.as_bytes());
}
}