use crate::{kem::Kem as KemTrait, HpkeError};
#[derive(Clone, Copy)]
pub struct PskBundle<'a> {
psk: &'a [u8],
psk_id: &'a [u8],
}
impl<'a> PskBundle<'a> {
pub fn new(psk: &'a [u8], psk_id: &'a [u8]) -> Result<Self, HpkeError> {
if (psk.is_empty() && psk_id.is_empty()) || (!psk.is_empty() && !psk_id.is_empty()) {
Ok(PskBundle { psk, psk_id })
} else {
Err(HpkeError::InvalidPskBundle)
}
}
}
pub enum OpModeR<'a, Kem: KemTrait> {
Base,
Psk(PskBundle<'a>),
Auth(Kem::PublicKey),
AuthPsk(Kem::PublicKey, PskBundle<'a>),
}
impl<Kem: KemTrait> OpModeR<'_, Kem> {
pub(crate) fn get_pk_sender_id(&self) -> Option<&Kem::PublicKey> {
match self {
OpModeR::Auth(pk) => Some(pk),
OpModeR::AuthPsk(pk, _) => Some(pk),
_ => None,
}
}
}
pub enum OpModeS<'a, Kem: KemTrait> {
Base,
Psk(PskBundle<'a>),
Auth((Kem::PrivateKey, Kem::PublicKey)),
AuthPsk((Kem::PrivateKey, Kem::PublicKey), PskBundle<'a>),
}
impl<Kem: KemTrait> OpModeS<'_, Kem> {
pub(crate) fn get_sender_id_keypair(&self) -> Option<(&Kem::PrivateKey, &Kem::PublicKey)> {
match self {
OpModeS::Auth(keypair) => Some((&keypair.0, &keypair.1)),
OpModeS::AuthPsk(keypair, _) => Some((&keypair.0, &keypair.1)),
_ => None,
}
}
}
pub(crate) trait OpMode<Kem: KemTrait> {
fn mode_id(&self) -> u8;
fn get_psk_bytes(&self) -> &[u8];
fn get_psk_id(&self) -> &[u8];
}
impl<Kem: KemTrait> OpMode<Kem> for OpModeR<'_, Kem> {
fn mode_id(&self) -> u8 {
match self {
OpModeR::Base => 0x00,
OpModeR::Psk(..) => 0x01,
OpModeR::Auth(..) => 0x02,
OpModeR::AuthPsk(..) => 0x03,
}
}
fn get_psk_bytes(&self) -> &[u8] {
match self {
OpModeR::Psk(bundle) => bundle.psk,
OpModeR::AuthPsk(_, bundle) => bundle.psk,
_ => &[],
}
}
fn get_psk_id(&self) -> &[u8] {
match self {
OpModeR::Psk(p) => p.psk_id,
OpModeR::AuthPsk(_, p) => p.psk_id,
_ => &[],
}
}
}
impl<Kem: KemTrait> OpMode<Kem> for OpModeS<'_, Kem> {
fn mode_id(&self) -> u8 {
match self {
OpModeS::Base => 0x00,
OpModeS::Psk(..) => 0x01,
OpModeS::Auth(..) => 0x02,
OpModeS::AuthPsk(..) => 0x03,
}
}
fn get_psk_bytes(&self) -> &[u8] {
match self {
OpModeS::Psk(bundle) => bundle.psk,
OpModeS::AuthPsk(_, bundle) => bundle.psk,
_ => &[],
}
}
fn get_psk_id(&self) -> &[u8] {
match self {
OpModeS::Psk(p) => p.psk_id,
OpModeS::AuthPsk(_, p) => p.psk_id,
_ => &[],
}
}
}
#[test]
fn psk_bundle_validation() {
assert!(PskBundle::new(b"hello", b"world").is_ok());
assert!(PskBundle::new(b"", b"").is_ok());
assert!(PskBundle::new(b"hello", b"").is_err());
assert!(PskBundle::new(b"", b"world").is_err());
}