use crate::{Algorithm, Error, KeyMetadata, Result};
use std::fmt;
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct SecretKey {
bytes: Vec<u8>,
algorithm: Algorithm,
}
impl SecretKey {
pub fn from_bytes(bytes: Vec<u8>, algorithm: Algorithm) -> Result<Self> {
if bytes.len() != algorithm.key_size() {
return Err(Error::crypto(
"key_validation",
&format!(
"invalid key size: expected {}, got {}",
algorithm.key_size(),
bytes.len()
),
));
}
Ok(Self { bytes, algorithm })
}
pub fn generate(algorithm: Algorithm) -> Result<Self> {
use crate::crypto::{KeyGenerator, SimpleSymmetricKeyGenerator};
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
let mut rng = ChaCha20Rng::from_entropy();
let generator = SimpleSymmetricKeyGenerator;
let params = crate::crypto::KeyGenParams {
algorithm,
seed: None,
key_size: None,
};
generator.generate_with_params(&mut rng, params)
}
pub fn algorithm(&self) -> Algorithm {
self.algorithm
}
pub fn expose_secret(&self) -> &[u8] {
&self.bytes
}
pub fn ct_eq(&self, other: &Self) -> bool {
if self.algorithm != other.algorithm {
return false;
}
self.bytes.ct_eq(&other.bytes).into()
}
}
impl fmt::Debug for SecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SecretKey")
.field("algorithm", &self.algorithm)
.field("bytes", &"[REDACTED]")
.finish()
}
}
#[derive(Clone, Debug)]
pub struct VersionedKey {
pub key: SecretKey,
pub metadata: KeyMetadata,
}
impl VersionedKey {
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.metadata.expires_at {
std::time::SystemTime::now() > expires_at
} else {
false
}
}
pub fn can_encrypt(&self) -> bool {
matches!(
self.metadata.state,
crate::KeyState::Active | crate::KeyState::Rotating
) && !self.is_expired()
}
pub fn can_decrypt(&self) -> bool {
!matches!(self.metadata.state, crate::KeyState::Revoked) && !self.is_expired()
}
}
pub trait KeyDerivation {
fn derive(&self, input: &[u8], salt: &[u8], info: &[u8]) -> Result<SecretKey>;
}
pub trait KeyWrap {
fn wrap(&self, key: &SecretKey, kek: &SecretKey) -> Result<Vec<u8>>;
fn unwrap(&self, wrapped: &[u8], kek: &SecretKey, algorithm: Algorithm) -> Result<SecretKey>;
}
pub struct HkdfSha256;
impl KeyDerivation for HkdfSha256 {
fn derive(&self, input: &[u8], salt: &[u8], info: &[u8]) -> Result<SecretKey> {
use hkdf::Hkdf;
use sha2::Sha256;
let hkdf = Hkdf::<Sha256>::new(Some(salt), input);
let mut okm = vec![0u8; 32]; hkdf.expand(info, &mut okm)
.map_err(|e| Error::crypto("hkdf_expand", &format!("HKDF expansion failed: {}", e)))?;
SecretKey::from_bytes(okm, Algorithm::ChaCha20Poly1305)
}
}
pub struct HkdfSha512;
impl KeyDerivation for HkdfSha512 {
fn derive(&self, input: &[u8], salt: &[u8], info: &[u8]) -> Result<SecretKey> {
use hkdf::Hkdf;
use sha2::Sha512;
let hkdf = Hkdf::<Sha512>::new(Some(salt), input);
let mut okm = vec![0u8; 32];
hkdf.expand(info, &mut okm)
.map_err(|e| Error::crypto("hkdf_expand", &format!("HKDF expansion failed: {}", e)))?;
SecretKey::from_bytes(okm, Algorithm::ChaCha20Poly1305)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secret_key_zeroize() {
{
let original_bytes = vec![0x42; 32];
let key = SecretKey::from_bytes(original_bytes, Algorithm::ChaCha20Poly1305).unwrap();
let _key_ptr = key.expose_secret().as_ptr();
assert_eq!(key.expose_secret()[0], 0x42);
drop(key);
}
{
let key = SecretKey::from_bytes(vec![0x33; 32], Algorithm::Aes256Gcm).unwrap();
assert_eq!(key.expose_secret().len(), 32);
assert_eq!(key.algorithm(), Algorithm::Aes256Gcm);
}
{
let result = SecretKey::from_bytes(vec![0x11; 1], Algorithm::ChaCha20Poly1305);
assert!(result.is_err());
let result = SecretKey::from_bytes(vec![0x11; 16], Algorithm::ChaCha20Poly1305);
assert!(result.is_err());
let result = SecretKey::from_bytes(vec![0x11; 64], Algorithm::ChaCha20Poly1305);
assert!(result.is_err());
}
}
#[test]
fn test_hkdf_sha256_derivation() {
let kdf = HkdfSha256;
let input = b"input key material for testing";
let salt = b"unique random salt";
let info = b"application context info";
let key1 = kdf.derive(input, salt, info).unwrap();
assert_eq!(key1.expose_secret().len(), 32);
assert_eq!(key1.algorithm(), Algorithm::ChaCha20Poly1305);
let key2 = kdf.derive(input, salt, info).unwrap();
assert!(key1.ct_eq(&key2));
let key3 = kdf.derive(input, salt, b"different context").unwrap();
assert!(!key1.ct_eq(&key3));
let key4 = kdf.derive(input, b"different salt", info).unwrap();
assert!(!key1.ct_eq(&key4));
}
#[test]
fn test_hkdf_sha512_derivation() {
let kdf = HkdfSha512;
let input = b"test input material";
let salt = b"test salt";
let info = b"test info";
let key = kdf.derive(input, salt, info).unwrap();
assert_eq!(key.expose_secret().len(), 32);
let kdf256 = HkdfSha256;
let key256 = kdf256.derive(input, salt, info).unwrap();
assert!(!key.ct_eq(&key256));
}
#[test]
fn test_hkdf_use_case_session_key() {
let kdf = HkdfSha256;
let master_secret = b"shared master secret from ECDH";
let salt = b"session-2024-01-01";
let encryption_key = kdf.derive(master_secret, salt, b"encryption-key").unwrap();
let mac_key = kdf.derive(master_secret, salt, b"mac-key").unwrap();
let iv_key = kdf.derive(master_secret, salt, b"iv-key").unwrap();
assert!(!encryption_key.ct_eq(&mac_key));
assert!(!encryption_key.ct_eq(&iv_key));
assert!(!mac_key.ct_eq(&iv_key));
}
}