use crate::error::Unspecified;
use native_ossl::digest::DigestAlg;
use native_ossl::kdf::{HkdfBuilder, HkdfMode};
use native_ossl::util::SecretBuf;
#[derive(Clone, Copy, Debug)]
pub struct Algorithm(pub(crate) &'static crate::digest::Algorithm);
pub static HKDF_SHA256: Algorithm = Algorithm(&crate::digest::SHA256);
pub static HKDF_SHA384: Algorithm = Algorithm(&crate::digest::SHA384);
pub static HKDF_SHA512: Algorithm = Algorithm(&crate::digest::SHA512);
impl Algorithm {
#[must_use]
pub fn hmac_algorithm(&self) -> crate::hmac::Algorithm {
crate::hmac::Algorithm(self.0)
}
}
pub trait KeyType {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug)]
pub struct Salt {
alg: Algorithm,
salt: Vec<u8>,
}
impl Salt {
#[must_use]
pub fn new(algorithm: Algorithm, value: &[u8]) -> Self {
Self {
alg: algorithm,
salt: value.to_vec(),
}
}
#[must_use]
pub fn extract(self, secret: &[u8]) -> Prk {
let digest_alg = DigestAlg::fetch(self.alg.0.name, None)
.unwrap_or_else(|e| panic!("OpenSSL digest unavailable: {e}"));
let prk = HkdfBuilder::new(&digest_alg)
.mode(HkdfMode::ExtractOnly)
.key(secret)
.salt(&self.salt)
.derive_to_vec(self.alg.0.output_len)
.unwrap_or_else(|e| panic!("HKDF extract failed: {e}"));
Prk {
alg: self.alg,
prk: SecretBuf::new(prk),
}
}
#[must_use]
pub fn algorithm(&self) -> Algorithm {
self.alg
}
}
pub struct Prk {
alg: Algorithm,
prk: SecretBuf,
}
impl std::fmt::Debug for Prk {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Prk")
.field("alg", &self.alg)
.finish_non_exhaustive()
}
}
impl Prk {
#[must_use]
pub fn new_less_safe(algorithm: Algorithm, value: &[u8]) -> Self {
Self {
alg: algorithm,
prk: SecretBuf::from_slice(value),
}
}
pub fn expand<'a, L: KeyType>(
&'a self,
info: &'a [&'a [u8]],
len: L,
) -> Result<Okm<'a, L>, Unspecified> {
Ok(Okm {
prk: self,
info,
len,
})
}
}
pub struct Okm<'a, L: KeyType> {
prk: &'a Prk,
info: &'a [&'a [u8]],
len: L,
}
impl<L: KeyType> Okm<'_, L> {
#[must_use]
pub fn len(&self) -> &L {
&self.len
}
pub fn fill(self, out: &mut [u8]) -> Result<(), Unspecified> {
if out.len() != self.len.len() {
return Err(Unspecified);
}
let digest_alg = DigestAlg::fetch(self.prk.alg.0.name, None).map_err(|_| Unspecified)?;
let info_concat: Vec<u8> = self.info.iter().flat_map(|s| s.iter().copied()).collect();
HkdfBuilder::new(&digest_alg)
.mode(HkdfMode::ExpandOnly)
.key(self.prk.prk.as_ref())
.info(&info_concat)
.derive(out)
.map_err(|_| Unspecified)
}
}