use crate::error::{Error, Result};
use crate::operation::{Operation, WithOutputLength, WithData};
use std::marker::PhantomData;
pub trait KdfOperation {
type Salt: AsRef<[u8]>;
type Info: AsRef<[u8]>;
const DEFAULT_OUTPUT_SIZE: usize;
const MIN_SALT_SIZE: usize;
fn algorithm_name() -> &'static str;
}
pub struct KdfBuilder<'a, T: KdfOperation> {
ikm: Option<&'a [u8]>,
salt: Option<&'a T::Salt>,
info: Option<&'a T::Info>,
output_length: usize,
_phantom: PhantomData<T>,
}
impl<'a, T: KdfOperation> KdfBuilder<'a, T> {
pub fn new() -> Self {
Self {
ikm: None,
salt: None,
info: None,
output_length: T::DEFAULT_OUTPUT_SIZE,
_phantom: PhantomData,
}
}
pub fn with_salt(mut self, salt: &'a T::Salt) -> Self {
self.salt = Some(salt);
self
}
pub fn with_info(mut self, info: &'a T::Info) -> Self {
self.info = Some(info);
self
}
pub fn derive(self) -> Result<Vec<u8>> {
let ikm = self.ikm.ok_or_else(|| Error::InvalidParameter(
"Input keying material is required for key derivation"
))?;
if self.output_length == 0 {
return Err(Error::InvalidParameter(
"Output length must be greater than zero"
));
}
Err(Error::NotImplemented(
"KDF implementation"
))
}
pub fn derive_array<const N: usize>(self) -> Result<[u8; N]> {
if self.output_length != N {
return Err(Error::InvalidLength {
context: "KDF output",
needed: N,
got: self.output_length,
});
}
let key_vec = self.derive()?;
let mut result = [0u8; N];
result.copy_from_slice(&key_vec);
Ok(result)
}
}
impl<'a, T: KdfOperation> Default for KdfBuilder<'a, T> {
fn default() -> Self {
Self::new()
}
}
impl<'a, T: KdfOperation> Operation<Vec<u8>> for KdfBuilder<'a, T> {
fn execute(self) -> Result<Vec<u8>> {
self.derive()
}
fn reset(&mut self) {
self.ikm = None;
self.salt = None;
self.info = None;
self.output_length = T::DEFAULT_OUTPUT_SIZE;
}
}
impl<'a, T: KdfOperation> WithOutputLength<Self> for KdfBuilder<'a, T> {
fn with_output_length(mut self, length: usize) -> Self {
self.output_length = length;
self
}
}
impl<'a, T: KdfOperation> WithData<'a, Self> for KdfBuilder<'a, T> {
fn with_data(mut self, data: &'a [u8]) -> Self {
self.ikm = Some(data);
self
}
}