use super::*;
use crate::group::{GroupEpoch, GroupId};
use openmls_traits::{key_store::OpenMlsKeyStore, random::OpenMlsRand, OpenMlsCryptoProvider};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use tls_codec::{Serialize as TlsSerializeTrait, TlsByteVecU8, TlsVecU16};
#[derive(
Debug,
PartialEq,
Clone,
Copy,
Hash,
Serialize,
Deserialize,
TlsDeserialize,
TlsSerialize,
TlsSize,
)]
#[repr(u8)]
#[allow(missing_docs)]
pub enum PskType {
External = 1,
Reinit = 2,
Branch = 3,
}
impl TryFrom<u8> for PskType {
type Error = &'static str;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(PskType::External),
2 => Ok(PskType::Reinit),
3 => Ok(PskType::Branch),
_ => Err("Unknown PSK type."),
}
}
}
impl From<&Psk> for PskType {
fn from(psk: &Psk) -> Self {
match psk {
Psk::External(_) => PskType::External,
Psk::Reinit(_) => PskType::Reinit,
Psk::Branch(_) => PskType::Branch,
}
}
}
#[derive(
Debug, PartialEq, Clone, Hash, Serialize, Deserialize, TlsDeserialize, TlsSerialize, TlsSize,
)]
pub struct ExternalPsk {
psk_id: TlsByteVecU8,
}
impl ExternalPsk {
pub fn new(psk_id: Vec<u8>) -> Self {
Self {
psk_id: psk_id.into(),
}
}
pub fn psk_id(&self) -> &[u8] {
self.psk_id.as_slice()
}
}
#[derive(Serialize, Deserialize)]
pub(crate) struct PskBundle {
secret: Secret,
}
impl PskBundle {
#[cfg(any(feature = "test-utils", test))]
pub(crate) fn new(secret: Secret) -> Result<Self, CryptoError> {
Ok(Self { secret })
}
pub(crate) fn secret(&self) -> &Secret {
&self.secret
}
}
#[derive(
Debug, PartialEq, Clone, Hash, Serialize, Deserialize, TlsDeserialize, TlsSerialize, TlsSize,
)]
pub struct ReinitPsk {
pub(crate) psk_group_id: GroupId,
pub(crate) psk_epoch: GroupEpoch,
}
impl ReinitPsk {
pub fn psk_group_id(&self) -> &GroupId {
&self.psk_group_id
}
pub fn psk_epoch(&self) -> GroupEpoch {
self.psk_epoch
}
}
#[derive(
Debug, PartialEq, Clone, Hash, Serialize, Deserialize, TlsDeserialize, TlsSerialize, TlsSize,
)]
pub struct BranchPsk {
pub(crate) psk_group_id: GroupId,
pub(crate) psk_epoch: GroupEpoch,
}
impl BranchPsk {
pub fn psk_group_id(&self) -> &GroupId {
&self.psk_group_id
}
pub fn psk_epoch(&self) -> GroupEpoch {
self.psk_epoch
}
}
#[derive(Debug, PartialEq, Clone, Hash, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum Psk {
External(ExternalPsk),
Reinit(ReinitPsk),
Branch(BranchPsk),
}
#[derive(Debug, PartialEq, Clone, Hash, Serialize, Deserialize)]
pub struct PreSharedKeyId {
pub(crate) psk_type: PskType,
pub(crate) psk: Psk,
pub(crate) psk_nonce: TlsByteVecU8,
}
impl PreSharedKeyId {
pub fn new(
ciphersuite: Ciphersuite,
rand: &impl OpenMlsRand,
psk: Psk,
) -> Result<Self, CryptoError> {
Ok(Self {
psk_type: PskType::from(&psk),
psk,
psk_nonce: rand
.random_vec(ciphersuite.hash_length())
.map_err(|_| CryptoError::InsufficientRandomness)?
.into(),
})
}
pub fn psktype(&self) -> &PskType {
&self.psk_type
}
pub fn psk(&self) -> &Psk {
&self.psk
}
pub fn psk_nonce(&self) -> &[u8] {
self.psk_nonce.as_slice()
}
}
#[derive(TlsDeserialize, TlsSerialize, TlsSize)]
pub struct PreSharedKeys {
pub(crate) psks: TlsVecU16<PreSharedKeyId>,
}
impl PreSharedKeys {
pub fn psks(&self) -> &[PreSharedKeyId] {
self.psks.as_slice()
}
}
#[derive(TlsSerialize, TlsSize)]
pub(crate) struct PskLabel<'a> {
pub(crate) id: &'a PreSharedKeyId,
pub(crate) index: u16,
pub(crate) count: u16,
}
impl<'a> PskLabel<'a> {
fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
Self { id, index, count }
}
}
pub struct PskSecret {
secret: Secret,
}
impl PskSecret {
pub fn new(
ciphersuite: Ciphersuite,
backend: &impl OpenMlsCryptoProvider,
psk_ids: &[PreSharedKeyId],
) -> Result<Self, PskError> {
let num_psks = psk_ids.len();
if num_psks > u16::MAX as usize {
return Err(PskError::TooManyKeys);
}
let num_psks = num_psks as u16;
let mut psk_bundles: Vec<PskBundle> = Vec::new();
for psk_id in psk_ids {
if let Some(psk_bundle) = backend.key_store().read(
&psk_id
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?,
) {
psk_bundles.push(psk_bundle);
} else {
return Err(PskError::KeyNotFound);
}
}
let mls_version = ProtocolVersion::default();
let mut psk_secret = Secret::zero(ciphersuite, mls_version);
for ((index, psk_bundle), psk_id) in psk_bundles.iter().enumerate().zip(psk_ids) {
let zero_secret = Secret::zero(ciphersuite, mls_version);
let psk_extracted = zero_secret
.hkdf_extract(backend, psk_bundle.secret())
.map_err(LibraryError::unexpected_crypto_error)?;
let psk_label = PskLabel::new(psk_id, index as u16, num_psks)
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
let psk_input = psk_extracted
.kdf_expand_label(
backend,
"derived psk",
&psk_label,
ciphersuite.hash_length(),
)
.map_err(LibraryError::unexpected_crypto_error)?;
psk_secret = psk_input
.hkdf_extract(backend, &psk_secret)
.map_err(LibraryError::unexpected_crypto_error)?;
}
Ok(Self { secret: psk_secret })
}
pub(crate) fn secret(&self) -> &Secret {
&self.secret
}
#[cfg(any(feature = "test-utils", test))]
pub(crate) fn random(ciphersuite: Ciphersuite, rng: &impl OpenMlsCryptoProvider) -> Self {
Self {
secret: Secret::random(ciphersuite, rng, None )
.expect("Not enough randomness."),
}
}
#[cfg(any(feature = "test-utils", test))]
pub(crate) fn clone(&self) -> Self {
Self {
secret: self.secret.clone(),
}
}
}