use core::ops::Add;
use derive_where::derive_where;
use digest::{FixedOutput, Output, Update};
use generic_array::sequence::Concat;
use generic_array::typenum::Sum;
use generic_array::{ArrayLength, GenericArray};
use zeroize::Zeroize;
use crate::ciphersuite::{CipherSuite, KeGroup, KeHash, OprfGroup};
use crate::errors::ProtocolError;
use crate::hash::OutputSize;
use crate::key_exchange::group::Group;
use crate::key_exchange::shared::{Ke1MessageIter, Ke1MessageIterLen, NonceLen};
use crate::key_exchange::{
Deserialize, Serialize, SerializedContext, SerializedCredentialRequest,
SerializedCredentialRequestLen, SerializedCredentialResponse, SerializedCredentialResponseLen,
SerializedIdentifier, SerializedIdentifiers,
};
use crate::opaque::MaskedResponseLen;
use crate::serialization::{SliceExt, UpdateExt};
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound(deserialize = "'de: 'a", serialize = ""))
)]
#[derive_where(Clone, Debug, Eq, Hash, PartialEq, ZeroizeOnDrop)]
pub struct Message<'a, CS: CipherSuite, KE: Group> {
pub(super) role: Role,
pub(super) context: SerializedContext<'a>,
pub(super) identifiers: SerializedIdentifiers<'a, KeGroup<CS>>,
pub(super) cache: CachedMessage<CS, KE>,
}
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound(deserialize = "'de: 'a", serialize = ""))
)]
#[derive_where(Clone, Debug, Eq, Hash, PartialEq, ZeroizeOnDrop)]
pub struct VerifyMessage<'a, CS: CipherSuite, KE: Group> {
role: Role,
context: SerializedContext<'a>,
identifier: SerializedIdentifier<'a, KeGroup<CS>>,
pub(super) cache: CachedMessage<CS, KE>,
}
#[derive_where(Debug, Eq, Hash, PartialEq, ZeroizeOnDrop)]
pub struct MessageBuilder<'a, CS: CipherSuite> {
pub(super) role: Role,
pub(super) context: SerializedContext<'a>,
pub(super) identifier: SerializedIdentifier<'a, KeGroup<CS>>,
}
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound = "")
)]
#[derive_where(Clone, Debug, Eq, Hash, PartialEq, Zeroize, ZeroizeOnDrop)]
pub struct CachedMessage<CS: CipherSuite, KE: Group> {
pub(super) credential_request: SerializedCredentialRequest<CS>,
pub(super) ke1_message: Ke1MessageIter<KE>,
pub(super) credential_response: SerializedCredentialResponse<CS>,
pub(super) server_nonce: GenericArray<u8, NonceLen>,
pub(super) server_e_pk: GenericArray<u8, KE::PkLen>,
pub(super) server_mac: Output<KeHash<CS>>,
}
impl<CS: CipherSuite, KE: Group> Message<'_, CS, KE> {
pub fn sign_message(&self) -> impl Clone + Iterator<Item = &[u8]> {
self.context.iter().chain(self.post_message(Stage::Sign))
}
pub fn hash<KEH: Default + Clone + FixedOutput + Update>(&self) -> HashOutput<KEH> {
let mut context = KEH::default();
context.update_iter(self.context.iter());
let sign = context.clone().chain_iter(self.post_message(Stage::Sign));
let verify = context.chain_iter(self.post_message(Stage::Verify));
HashOutput { sign, verify }
}
fn post_message(&self, stage: Stage) -> impl Clone + Iterator<Item = &[u8]> {
let transcript = match (self.role, stage) {
(Role::Server, Stage::Sign) => Role::Server,
(Role::Server, Stage::Verify) => Role::Client,
(Role::Client, Stage::Sign) => Role::Client,
(Role::Client, Stage::Verify) => Role::Server,
};
let identifier = match transcript {
Role::Server => &self.identifiers.server,
Role::Client => &self.identifiers.client,
};
self.cache.post_message(transcript, identifier)
}
pub fn to_cached(&self) -> CachedMessage<CS, KE> {
self.cache.clone()
}
}
impl<CS: CipherSuite, KE: Group> VerifyMessage<'_, CS, KE> {
pub fn verify_message(&self) -> impl Clone + Iterator<Item = &[u8]> {
let transcript = match self.role {
Role::Server => Role::Client,
Role::Client => Role::Server,
};
self.context
.iter()
.chain(self.cache.post_message(transcript, &self.identifier))
}
}
impl<CS: CipherSuite, KE: Group> CachedMessage<CS, KE> {
fn post_message<'a>(
&'a self,
transcript: Role,
identifier: &'a SerializedIdentifier<'_, KeGroup<CS>>,
) -> impl Clone + Iterator<Item = &'a [u8]> {
Some(identifier.iter())
.filter(|_| matches!(transcript, Role::Client))
.into_iter()
.flatten()
.chain(self.credential_request.iter())
.chain(self.ke1_message.iter())
.chain(
Some(identifier.iter())
.filter(|_| matches!(transcript, Role::Server))
.into_iter()
.flatten(),
)
.chain(self.credential_response.iter())
.chain([self.server_nonce.as_slice(), &self.server_e_pk])
.chain(Some(self.server_mac.as_slice()).filter(|_| matches!(transcript, Role::Client)))
}
}
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound = "")
)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub(super) enum Role {
Server,
Client,
}
impl Zeroize for Role {
fn zeroize(&mut self) {
*self = Self::Server;
}
}
enum Stage {
Sign,
Verify,
}
pub struct HashOutput<H> {
pub sign: H,
pub verify: H,
}
impl<'a, CS: CipherSuite> MessageBuilder<'a, CS> {
pub fn build<KE: Group>(self, cache: CachedMessage<CS, KE>) -> VerifyMessage<'a, CS, KE> {
VerifyMessage {
role: self.role,
context: self.context.clone(),
identifier: self.identifier.clone(),
cache,
}
}
}
impl<CS: CipherSuite, KE: Group> Deserialize for CachedMessage<CS, KE> {
fn deserialize_take(input: &mut &[u8]) -> Result<Self, ProtocolError> {
Ok(Self {
credential_request: SerializedCredentialRequest::deserialize_take(input)?,
ke1_message: Ke1MessageIter::deserialize_take(input)?,
credential_response: SerializedCredentialResponse::deserialize_take(input)?,
server_nonce: input.take_array("server nonce")?,
server_e_pk: input.take_array("serialized server ephemeral key")?,
server_mac: input.take_array("server mac")?,
})
}
}
type CachedMessageLen<CS: CipherSuite, KE: Group> = Sum<
Sum<
Sum<
Sum<
Sum<SerializedCredentialRequestLen<CS>, Ke1MessageIterLen<KE>>,
SerializedCredentialResponseLen<CS>,
>,
NonceLen,
>,
KE::PkLen,
>,
OutputSize<KeHash<CS>>,
>;
impl<CS: CipherSuite, KE: Group> Serialize for CachedMessage<CS, KE>
where
SerializedCredentialRequestLen<CS>: ArrayLength<u8> + Add<Ke1MessageIterLen<KE>>,
Sum<SerializedCredentialRequestLen<CS>, Ke1MessageIterLen<KE>>:
ArrayLength<u8> + Add<SerializedCredentialResponseLen<CS>>,
Sum<
Sum<SerializedCredentialRequestLen<CS>, Ke1MessageIterLen<KE>>,
SerializedCredentialResponseLen<CS>,
>: ArrayLength<u8> + Add<NonceLen>,
Sum<
Sum<
Sum<SerializedCredentialRequestLen<CS>, Ke1MessageIterLen<KE>>,
SerializedCredentialResponseLen<CS>,
>,
NonceLen,
>: ArrayLength<u8> + Add<KE::PkLen>,
Sum<
Sum<
Sum<
Sum<SerializedCredentialRequestLen<CS>, Ke1MessageIterLen<KE>>,
SerializedCredentialResponseLen<CS>,
>,
NonceLen,
>,
KE::PkLen,
>: ArrayLength<u8> + Add<OutputSize<KeHash<CS>>>,
CachedMessageLen<CS, KE>: ArrayLength<u8>,
NonceLen: Add<KE::PkLen>,
Ke1MessageIterLen<KE>: ArrayLength<u8>,
<OprfGroup<CS> as voprf::Group>::ElemLen: Add<NonceLen>,
Sum<<OprfGroup<CS> as voprf::Group>::ElemLen, NonceLen>:
ArrayLength<u8> + Add<MaskedResponseLen<CS>>,
SerializedCredentialResponseLen<CS>: ArrayLength<u8>,
{
type Len = CachedMessageLen<CS, KE>;
fn serialize(&self) -> GenericArray<u8, Self::Len> {
self.credential_request
.serialize()
.concat(self.ke1_message.serialize())
.concat(self.credential_response.serialize())
.concat(self.server_nonce)
.concat(self.server_e_pk.clone())
.concat(self.server_mac.clone())
}
}