use super::algorithm::{Hash, KeyDerivation, RawKeyAgreement};
use super::key::Id;
#[cfg(feature = "operations")]
use super::status::{Error, Result, Status};
#[cfg(feature = "operations")]
use core::convert::{From, TryFrom};
#[derive(Debug, Clone, Copy)]
pub struct Operation<'a> {
pub inputs: Inputs<'a>,
pub capacity: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
pub enum Inputs<'a> {
Hkdf {
hash_alg: Hash,
salt: Option<Input<'a>>,
secret: InputSecret<'a>,
info: Input<'a>,
},
Tls12Prf {
hash_alg: Hash,
seed: Input<'a>,
secret: InputSecret<'a>,
label: Input<'a>,
},
Tls12PskToMs {
hash_alg: Hash,
seed: Input<'a>,
secret: InputSecret<'a>,
label: Input<'a>,
},
}
#[cfg(feature = "operations")]
#[derive(Debug, Clone, Copy)]
enum DerivationStep {
Secret,
Label,
Salt,
Info,
Seed,
}
#[cfg(feature = "operations")]
impl From<DerivationStep> for psa_crypto_sys::psa_key_derivation_step_t {
fn from(derivation_step: DerivationStep) -> Self {
match derivation_step {
DerivationStep::Secret => psa_crypto_sys::PSA_KEY_DERIVATION_INPUT_SECRET,
DerivationStep::Label => psa_crypto_sys::PSA_KEY_DERIVATION_INPUT_LABEL,
DerivationStep::Salt => psa_crypto_sys::PSA_KEY_DERIVATION_INPUT_SALT,
DerivationStep::Info => psa_crypto_sys::PSA_KEY_DERIVATION_INPUT_INFO,
DerivationStep::Seed => psa_crypto_sys::PSA_KEY_DERIVATION_INPUT_SEED,
}
}
}
#[cfg(feature = "interface")]
impl From<Inputs<'_>> for psa_crypto_sys::psa_algorithm_t {
fn from(key_derivation_with_inputs: Inputs) -> Self {
key_derivation_with_inputs.key_derivation().into()
}
}
impl Inputs<'_> {
pub fn key_derivation(&self) -> KeyDerivation {
match self {
Inputs::Hkdf { hash_alg, .. } => KeyDerivation::Hkdf {
hash_alg: *hash_alg,
},
Inputs::Tls12Prf { hash_alg, .. } => KeyDerivation::Tls12Prf {
hash_alg: *hash_alg,
},
Inputs::Tls12PskToMs { hash_alg, .. } => KeyDerivation::Tls12PskToMs {
hash_alg: *hash_alg,
},
}
}
#[cfg(feature = "operations")]
pub(crate) fn apply_inputs_to_op(
&self,
op: &mut psa_crypto_sys::psa_key_derivation_operation_t,
) -> Result<()> {
match self {
Inputs::Hkdf {
salt, secret, info, ..
} => {
if let Some(salt) = salt {
Inputs::apply_input_step_to_op(op, DerivationStep::Salt, salt)?;
}
Inputs::apply_input_secret_step_to_op(op, secret)?;
Inputs::apply_input_step_to_op(op, DerivationStep::Info, info)
}
Inputs::Tls12Prf {
seed,
secret,
label,
..
}
| Inputs::Tls12PskToMs {
seed,
secret,
label,
..
} => {
Inputs::apply_input_step_to_op(op, DerivationStep::Seed, seed)?;
Inputs::apply_input_secret_step_to_op(op, secret)?;
Inputs::apply_input_step_to_op(op, DerivationStep::Label, label)
}
}
}
#[cfg(feature = "operations")]
fn apply_input_step_to_op(
op: &mut psa_crypto_sys::psa_key_derivation_operation_t,
step: DerivationStep,
input: &Input,
) -> Result<()> {
match input {
Input::Bytes(bytes) => Status::from(unsafe {
psa_crypto_sys::psa_key_derivation_input_bytes(
op,
step.into(),
bytes.as_ptr(),
bytes.len(),
)
})
.to_result(),
Input::Key(key_id) => {
let handle = key_id.handle()?;
Status::from(unsafe {
psa_crypto_sys::psa_key_derivation_input_key(op, step.into(), handle)
})
.to_result()?;
key_id.close_handle(handle)
}
}
}
#[cfg(feature = "operations")]
fn apply_input_secret_step_to_op(
op: &mut psa_crypto_sys::psa_key_derivation_operation_t,
secret: &InputSecret,
) -> Result<()> {
match secret {
InputSecret::Input(input) => {
Inputs::apply_input_step_to_op(op, DerivationStep::Secret, &input)
}
InputSecret::KeyAgreement {
private_key,
peer_key,
..
} => {
let handle = private_key.handle()?;
let key_agreement_res = Status::from(unsafe {
psa_crypto_sys::psa_key_derivation_key_agreement(
op,
DerivationStep::Secret.into(),
handle,
(**peer_key).as_ptr(),
peer_key.len(),
)
})
.to_result();
let close_handle_res = private_key.close_handle(handle);
key_agreement_res?;
close_handle_res
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum Input<'a> {
Bytes(&'a [u8]),
Key(Id),
}
#[derive(Debug, Clone, Copy)]
pub enum InputSecret<'a> {
Input(Input<'a>),
KeyAgreement {
alg: RawKeyAgreement,
private_key: Id,
peer_key: &'a [u8],
},
}
impl<'a> From<Input<'a>> for InputSecret<'a> {
fn from(input: Input<'a>) -> Self {
InputSecret::<'a>::Input(input)
}
}
#[cfg(feature = "operations")]
impl TryFrom<Operation<'_>> for psa_crypto_sys::psa_key_derivation_operation_t {
type Error = Error;
fn try_from(operation: Operation) -> Result<Self> {
let mut op = psa_crypto_sys::psa_key_derivation_operation_init();
let mut setup_deriv_op = || -> Result<()> {
let mut key_derivation_alg: psa_crypto_sys::psa_algorithm_t =
operation.inputs.key_derivation().into();
let secret = match operation.inputs {
Inputs::Hkdf { secret, .. }
| Inputs::Tls12Prf { secret, .. }
| Inputs::Tls12PskToMs { secret, .. } => secret,
};
if let InputSecret::KeyAgreement { alg, .. } = secret {
key_derivation_alg = unsafe {
psa_crypto_sys::PSA_ALG_KEY_AGREEMENT(alg.into(), key_derivation_alg)
};
}
Status::from(unsafe {
psa_crypto_sys::psa_key_derivation_setup(&mut op, key_derivation_alg)
})
.to_result()?;
operation.inputs.apply_inputs_to_op(&mut op)
};
if let Err(error) = setup_deriv_op() {
Operation::abort(op)?;
return Err(error);
}
if let Some(capacity) = operation.capacity {
Status::from(unsafe {
psa_crypto_sys::psa_key_derivation_set_capacity(&mut op, capacity)
})
.to_result()?;
}
Ok(op)
}
}
impl Operation<'_> {
#[cfg(feature = "operations")]
pub(crate) fn abort(mut op: psa_crypto_sys::psa_key_derivation_operation_t) -> Result<()> {
Status::from(unsafe { psa_crypto_sys::psa_key_derivation_abort(&mut op) }).to_result()
}
}