use core::marker::PhantomData;
use digest::Output;
use rand_core::CryptoRng;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use crate::ciphersuite::{CipherSuite, Kem};
use crate::commitment;
use crate::error::Error;
use crate::responder::MessageTwo;
use crate::sas::{compute_sas, derive_session_key};
use crate::verification::ProtocolOutput;
use crate::Nonce;
#[derive(Clone)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(bound(
serialize = "<CS::Kem as Kem>::EncapsulationKey: serde::Serialize",
deserialize = "<CS::Kem as Kem>::EncapsulationKey: serde::Deserialize<'de>",
))
)]
pub struct MessageOne<CS: CipherSuite> {
pub(crate) ek: <CS::Kem as Kem>::EncapsulationKey,
pub(crate) commitment: Output<CS::Hash>,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MessageThree {
pub(crate) initiator_nonce: Nonce,
}
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(bound(
serialize = "<CS::Kem as Kem>::DecapsulationKey: serde::Serialize, <CS::Kem as Kem>::EncapsulationKey: serde::Serialize",
deserialize = "<CS::Kem as Kem>::DecapsulationKey: serde::Deserialize<'de>, <CS::Kem as Kem>::EncapsulationKey: serde::Deserialize<'de>",
))
)]
pub struct Initiator<CS: CipherSuite> {
dk: <CS::Kem as Kem>::DecapsulationKey,
ek: <CS::Kem as Kem>::EncapsulationKey,
initiator_nonce: Nonce,
_marker: PhantomData<CS>,
}
impl<CS: CipherSuite> Drop for Initiator<CS> {
fn drop(&mut self) {
self.initiator_nonce.zeroize();
self.dk.zeroize();
self.ek.zeroize();
}
}
impl<CS: CipherSuite> Initiator<CS> {
pub fn start(rng: &mut impl CryptoRng) -> (Self, MessageOne<CS>) {
let (dk, ek) = CS::Kem::generate(rng);
let mut initiator_nonce = [0u8; 32];
rng.fill_bytes(&mut initiator_nonce);
let commitment = commitment::commit::<CS::Hash>(ek.as_ref(), &initiator_nonce);
let state = Self {
dk,
ek: ek.clone(),
initiator_nonce,
_marker: PhantomData,
};
let message = MessageOne { ek, commitment };
(state, message)
}
pub fn finish(self, msg2: MessageTwo<CS>) -> Result<(ProtocolOutput<CS>, MessageThree), Error> {
if self.ek.as_ref().ct_eq(msg2.ct.as_ref()).into() {
return Err(Error::ReflectionDetected);
}
let mut kem_ss =
CS::Kem::decaps(&self.dk, &msg2.ct).map_err(|_| Error::DecapsulationFailed)?;
let session_key = derive_session_key::<CS::Hash>(
self.ek.as_ref(),
msg2.ct.as_ref(),
&msg2.responder_nonce,
&self.initiator_nonce,
kem_ss.as_ref(),
);
kem_ss.zeroize();
let sas = compute_sas::<CS::Hash>(
&msg2.responder_nonce,
&self.initiator_nonce,
msg2.ct.as_ref(),
);
let output = ProtocolOutput {
sas,
session_key,
_marker: PhantomData,
};
let message = MessageThree {
initiator_nonce: self.initiator_nonce,
};
Ok((output, message))
}
}