#![deny(
bad_style,
dead_code,
improper_ctypes,
rustdoc::broken_intra_doc_links,
non_shorthand_field_patterns,
no_mangle_generic_items,
overflowing_literals,
path_statements,
patterns_in_fns_without_body,
private_bounds,
private_interfaces,
unconditional_recursion,
unused,
unused_allocation,
unused_comparisons,
unused_parens,
while_true,
missing_debug_implementations,
missing_copy_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unnameable_types,
unused_extern_crates,
unused_import_braces,
unused_qualifications,
unused_results
)]
#![no_std]
use generic_array::{ArrayLength, GenericArray};
use typenum::ToInt;
use zeroize::Zeroize;
pub trait PseudoRandomFunctionKey {
type KeyHandle;
fn key_handle(&self) -> &Self::KeyHandle;
}
pub trait PseudoRandomFunction<'a> {
type KeyHandle;
type PrfOutputSize: ArrayLength<u8> + ToInt<usize>;
type Error;
fn init(
&mut self,
key: &'a dyn PseudoRandomFunctionKey<KeyHandle = Self::KeyHandle>,
) -> Result<(), Self::Error>;
fn update(&mut self, msg: &[u8]) -> Result<(), Self::Error>;
fn finish(&mut self, out: &mut [u8]) -> Result<usize, Self::Error>;
}
#[derive(Copy, Clone, Debug)]
pub struct CounterMode {
pub counter_length: usize,
}
#[derive(Copy, Clone, Debug)]
pub struct FeedbackMode<'a> {
pub iv: Option<&'a [u8]>,
pub counter_length: Option<usize>,
}
#[derive(Copy, Clone, Debug)]
pub struct DoublePipelineIterationMode {
pub counter_length: Option<usize>,
}
#[derive(Copy, Clone, Debug)]
pub enum KDFMode<'a> {
CounterMode(CounterMode),
FeedbackMode(FeedbackMode<'a>),
DoublePipelineIterationMode(DoublePipelineIterationMode),
}
#[derive(Copy, Clone, Debug)]
pub enum CounterLocation {
NoCounter,
BeforeFixedInput,
BeforeIter,
MiddleOfFixedInput(usize),
AfterFixedInput,
AfterIter,
}
#[derive(Debug)]
pub struct FixedInput<'a> {
pub fixed_input: &'a [u8],
pub counter_location: CounterLocation,
}
#[derive(Debug)]
pub struct SpecifiedInput<'a> {
pub label: &'a [u8],
pub context: &'a [u8],
}
#[derive(Debug)]
pub enum InputType<'a> {
FixedInput(FixedInput<'a>),
SpecifiedInput(SpecifiedInput<'a>),
}
pub fn kbkdf<'a, T: PseudoRandomFunction<'a>>(
kdf_mode: &KDFMode,
input_type: &InputType,
key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
prf: &mut T,
derived_key: &mut [u8],
) -> Result<(), T::Error> {
match kdf_mode {
KDFMode::CounterMode(counter_mode) => {
kbkdf_counter::<T>(counter_mode, input_type, key, prf, derived_key)
}
KDFMode::FeedbackMode(feedback_mode) => {
kbkdf_feedback::<T>(feedback_mode, input_type, key, prf, derived_key)
}
KDFMode::DoublePipelineIterationMode(double_pipeline) => {
kbkdf_double_pipeline::<T>(double_pipeline, input_type, key, prf, derived_key)
}
}
}
fn kbkdf_counter<'a, T: PseudoRandomFunction<'a>>(
counter_mode: &CounterMode,
input_type: &InputType,
key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
prf: &mut T,
derived_key: &mut [u8],
) -> Result<(), T::Error> {
let l = derived_key.len() * 8;
let h = T::PrfOutputSize::to_int() * 8;
let n = calculate_counter(l, h);
let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
assert!(
n < 2_usize.pow(counter_mode.counter_length as u32),
"Invalid derived key length"
);
for i in 1..=n {
prf.init(key)?;
let counter = i.to_be_bytes();
let counter = &counter[(counter.len() - counter_mode.counter_length / 8)..];
match input_type {
InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
CounterLocation::NoCounter => prf.update(fixed_input.fixed_input)?,
CounterLocation::BeforeFixedInput => {
prf.update(counter)?;
prf.update(fixed_input.fixed_input)?;
}
CounterLocation::MiddleOfFixedInput(position) => {
prf.update(&fixed_input.fixed_input[..position])?;
prf.update(counter)?;
prf.update(&fixed_input.fixed_input[position..])?;
}
CounterLocation::AfterFixedInput => {
prf.update(fixed_input.fixed_input)?;
prf.update(counter)?;
}
_ => panic!(
"Invalid counter location for KBKDF In Counter Mode: {:?}",
fixed_input.counter_location
),
},
InputType::SpecifiedInput(specified_input) => {
prf.update(counter)?;
prf.update(specified_input.label)?;
prf.update(b"\0")?;
prf.update(specified_input.context)?;
let length = (l as u32).to_be_bytes();
prf.update(&length)?;
}
}
let _ = prf.finish(intermediate_key.as_mut_slice())?;
insert_result(i, intermediate_key.as_slice(), derived_key);
intermediate_key.zeroize();
}
Ok(())
}
fn kbkdf_double_pipeline<'a, T: PseudoRandomFunction<'a>>(
double_feedback: &DoublePipelineIterationMode,
input_type: &InputType,
key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
prf: &mut T,
derived_key: &mut [u8],
) -> Result<(), T::Error> {
let l = derived_key.len() * 8;
let h = T::PrfOutputSize::to_int() * 8;
let n = calculate_counter(l, h);
let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
let mut feedback = GenericArray::<u8, T::PrfOutputSize>::default();
let length = (l as u32).to_be_bytes();
assert!(
n < 2_usize.pow(32),
"Invalid length provided for derived key"
);
for i in 1..=n {
let counter = i.to_be_bytes();
let counter = feedback_counter(double_feedback.counter_length, counter.as_slice());
prf.init(key)?;
if i == 1 {
match input_type {
InputType::FixedInput(fixed_input) => {
prf.update(fixed_input.fixed_input)?;
}
InputType::SpecifiedInput(specified_input) => {
prf.update(specified_input.label)?;
prf.update(b"\0")?;
prf.update(specified_input.context)?;
prf.update(length.as_slice())?;
}
}
} else {
prf.update(feedback.as_slice())?;
}
let _ = prf.finish(feedback.as_mut_slice())?;
prf.init(key)?;
match input_type {
InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
CounterLocation::NoCounter => {
prf.update(feedback.as_slice())?;
prf.update(fixed_input.fixed_input)?;
}
CounterLocation::BeforeIter => {
prf.update(
counter
.expect("Counter length not provided for BeforeIter counter location"),
)?;
prf.update(feedback.as_slice())?;
prf.update(fixed_input.fixed_input)?;
}
CounterLocation::AfterFixedInput => {
prf.update(feedback.as_slice())?;
prf.update(fixed_input.fixed_input)?;
prf.update(counter.expect(
"Counter length not provided for AfterFixedInput counter location",
))?;
}
CounterLocation::AfterIter => {
prf.update(feedback.as_slice())?;
prf.update(
counter
.expect("Counter length not provided for AfterIter counter location"),
)?;
prf.update(fixed_input.fixed_input)?;
}
_ => panic!(
"Invalid counter location for double feedback: {:?}",
fixed_input.counter_location
),
},
InputType::SpecifiedInput(specified_input) => {
prf.update(feedback.as_slice())?;
if let Some(counter) = counter {
prf.update(counter)?;
}
prf.update(specified_input.label)?;
prf.update(b"\0")?;
prf.update(specified_input.context)?;
prf.update(&length)?;
}
}
let _ = prf.finish(intermediate_key.as_mut_slice())?;
insert_result(i, intermediate_key.as_slice(), derived_key);
intermediate_key.zeroize();
}
Ok(())
}
fn kbkdf_feedback<'a, T: PseudoRandomFunction<'a>>(
feedback_mode: &FeedbackMode,
input_type: &InputType,
key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
prf: &mut T,
derived_key: &mut [u8],
) -> Result<(), T::Error> {
let l = derived_key.len() * 8;
let h = T::PrfOutputSize::to_int() * 8;
let n = calculate_counter(l, h);
let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
let mut has_intermediate = feedback_mode.iv.is_some();
if let Some(iv) = feedback_mode.iv {
assert_eq!(iv.len(), T::PrfOutputSize::to_int());
intermediate_key.copy_from_slice(iv);
}
assert!(n < 2_usize.pow(32), "Invalid derived_key length provided");
for i in 1..=n {
prf.init(key)?;
let counter = i.to_be_bytes();
let counter = feedback_counter(feedback_mode.counter_length, counter.as_slice());
match input_type {
InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
CounterLocation::NoCounter => {
if has_intermediate {
prf.update(intermediate_key.as_slice())?;
}
prf.update(fixed_input.fixed_input)?;
}
CounterLocation::BeforeIter => {
prf.update(
counter
.expect("Counter length not provided for BeforeIter counter location"),
)?;
if has_intermediate {
prf.update(intermediate_key.as_slice())?;
}
prf.update(fixed_input.fixed_input)?;
}
CounterLocation::AfterIter => {
if has_intermediate {
prf.update(intermediate_key.as_slice())?;
}
prf.update(
counter
.expect("Counter length not provided for AfterIter counter location"),
)?;
prf.update(fixed_input.fixed_input)?;
}
CounterLocation::AfterFixedInput => {
if has_intermediate {
prf.update(intermediate_key.as_slice())?;
}
prf.update(fixed_input.fixed_input)?;
prf.update(counter.expect(
"Counter length not provided for AfterFixedInput counter location",
))?;
}
_ => panic!(
"Invalid counter location provided for KDF feedback mode: {:?}",
fixed_input.counter_location
),
},
InputType::SpecifiedInput(specified_input) => {
if has_intermediate {
prf.update(intermediate_key.as_slice())?;
}
if let Some(counter) = counter {
prf.update(counter)?;
}
prf.update(specified_input.label)?;
prf.update(b"\0")?;
prf.update(specified_input.context)?;
let length = (l as u32).to_be_bytes();
prf.update(&length)?;
}
}
let _ = prf.finish(intermediate_key.as_mut_slice())?;
insert_result(i, intermediate_key.as_slice(), derived_key);
has_intermediate = true;
}
Ok(())
}
fn calculate_counter(derived_key_len_bits: usize, prf_output_size_in_bits: usize) -> usize {
derived_key_len_bits / prf_output_size_in_bits
+ if derived_key_len_bits % prf_output_size_in_bits != 0 {
1
} else {
0
}
}
fn feedback_counter(counter_length: Option<usize>, counter: &[u8]) -> Option<&[u8]> {
match counter_length {
None => None,
Some(length) => Some(&counter[(counter.len() - length / 8)..]),
}
}
fn insert_result(counter: usize, intermediate: &[u8], result: &mut [u8]) {
let low_index = (counter - 1) * intermediate.len();
assert!(
low_index < result.len(),
"The starting insert index should not exceed bounds of result slice"
);
let high_index = core::cmp::min(low_index + intermediate.len(), result.len());
assert!(
high_index <= result.len(),
"Ending insert index should not exceed bounds of result slice"
);
result[low_index..high_index].clone_from_slice(&intermediate[..(high_index - low_index)]);
}