use std::borrow::Borrow;
use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait};
use serde::{Deserialize, Serialize};
use tls_codec::{Serialize as TlsSerializeTrait, VLBytes};
use super::*;
use crate::{
group::{GroupEpoch, GroupId},
schedule::psk::store::ResumptionPskStore,
storage::{OpenMlsProvider, StorageProvider},
};
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Deserialize,
Serialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
#[repr(u8)]
pub enum ResumptionPskUsage {
Application = 1,
Reinit = 2,
Branch = 3,
}
#[derive(
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Hash,
Deserialize,
Serialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct ExternalPsk {
psk_id: VLBytes,
}
impl ExternalPsk {
pub fn new(psk_id: Vec<u8>) -> Self {
Self {
psk_id: psk_id.into(),
}
}
pub fn psk_id(&self) -> &[u8] {
self.psk_id.as_slice()
}
}
#[derive(Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)]
pub(crate) struct PskBundle {
secret: Secret,
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Deserialize,
Serialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
Hash,
)]
pub struct ResumptionPsk {
pub(crate) usage: ResumptionPskUsage,
pub(crate) psk_group_id: GroupId,
pub(crate) psk_epoch: GroupEpoch,
}
impl ResumptionPsk {
pub fn new(usage: ResumptionPskUsage, psk_group_id: GroupId, psk_epoch: GroupEpoch) -> Self {
Self {
usage,
psk_group_id,
psk_epoch,
}
}
pub fn usage(&self) -> ResumptionPskUsage {
self.usage
}
pub fn psk_group_id(&self) -> &GroupId {
&self.psk_group_id
}
pub fn psk_epoch(&self) -> GroupEpoch {
self.psk_epoch
}
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Deserialize,
Serialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
Hash,
)]
#[repr(u8)]
pub enum Psk {
#[tls_codec(discriminant = 1)]
External(ExternalPsk),
#[tls_codec(discriminant = 2)]
Resumption(ResumptionPsk),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum PskType {
External = 1,
Resumption = 2,
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Deserialize,
Serialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
Hash,
)]
pub struct PreSharedKeyId {
pub(crate) psk: Psk,
pub(crate) psk_nonce: VLBytes,
}
impl PreSharedKeyId {
pub fn new(
ciphersuite: Ciphersuite,
rand: &impl OpenMlsRand,
psk: Psk,
) -> Result<Self, CryptoError> {
let psk_nonce = rand
.random_vec(ciphersuite.hash_length())
.map_err(|_| CryptoError::InsufficientRandomness)?
.into();
Ok(Self { psk, psk_nonce })
}
pub fn external(psk_id: Vec<u8>, psk_nonce: Vec<u8>) -> Self {
let psk = Psk::External(ExternalPsk::new(psk_id));
Self {
psk,
psk_nonce: psk_nonce.into(),
}
}
pub fn resumption(
usage: ResumptionPskUsage,
psk_group_id: GroupId,
psk_epoch: GroupEpoch,
psk_nonce: Vec<u8>,
) -> Self {
let psk = Psk::Resumption(ResumptionPsk::new(usage, psk_group_id, psk_epoch));
Self {
psk,
psk_nonce: psk_nonce.into(),
}
}
pub fn psk(&self) -> &Psk {
&self.psk
}
pub fn psk_nonce(&self) -> &[u8] {
self.psk_nonce.as_slice()
}
pub fn store<Provider: OpenMlsProvider>(
&self,
provider: &Provider,
psk: &[u8],
) -> Result<(), PskError> {
let psk_bundle = {
let secret = Secret::from_slice(psk);
PskBundle { secret }
};
provider
.storage()
.write_psk(&self.psk, &psk_bundle)
.map_err(|_| PskError::Storage)
}
pub(crate) fn validate_in_proposal(self, ciphersuite: Ciphersuite) -> Result<(), PskError> {
match self.psk() {
Psk::Resumption(resumption_psk) => {
if resumption_psk.usage != ResumptionPskUsage::Application {
return Err(PskError::UsageMismatch {
allowed: vec![ResumptionPskUsage::Application],
got: resumption_psk.usage,
});
}
}
Psk::External(_) => {}
};
{
let expected_nonce_length = ciphersuite.hash_length();
let got_nonce_length = self.psk_nonce().len();
if expected_nonce_length != got_nonce_length {
return Err(PskError::NonceLengthMismatch {
expected: expected_nonce_length,
got: got_nonce_length,
});
}
}
Ok(())
}
pub(crate) fn validate_in_welcome(
psk_ids: &[PreSharedKeyId],
ciphersuite: Ciphersuite,
) -> Result<(), PskError> {
let mut contains_branch_psk = false;
let mut contains_reinit_psk = false;
for id in psk_ids {
match id.psk() {
Psk::Resumption(resumption_psk) => match resumption_psk.usage {
ResumptionPskUsage::Application => {
return Err(PskError::UsageMismatch {
allowed: vec![ResumptionPskUsage::Reinit, ResumptionPskUsage::Branch],
got: resumption_psk.usage,
});
}
ResumptionPskUsage::Reinit => {
if contains_reinit_psk {
return Err(PskError::UsageDuplicate {
usage: ResumptionPskUsage::Reinit,
});
}
if contains_branch_psk {
return Err(PskError::UsageConflict {
first: ResumptionPskUsage::Reinit,
second: ResumptionPskUsage::Branch,
});
}
contains_reinit_psk = true;
}
ResumptionPskUsage::Branch => {
if contains_branch_psk {
return Err(PskError::UsageDuplicate {
usage: ResumptionPskUsage::Branch,
});
}
if contains_reinit_psk {
return Err(PskError::UsageConflict {
first: ResumptionPskUsage::Branch,
second: ResumptionPskUsage::Reinit,
});
}
contains_branch_psk = true;
}
},
Psk::External(_) => {}
};
{
let expected_nonce_length = ciphersuite.hash_length();
let got_nonce_length = id.psk_nonce().len();
if expected_nonce_length != got_nonce_length {
return Err(PskError::NonceLengthMismatch {
expected: expected_nonce_length,
got: got_nonce_length,
});
}
}
}
Ok(())
}
}
#[cfg(test)]
impl PreSharedKeyId {
pub(crate) fn new_with_nonce(psk: Psk, psk_nonce: Vec<u8>) -> Self {
Self {
psk,
psk_nonce: psk_nonce.into(),
}
}
}
#[derive(TlsSerialize, TlsSize)]
pub(crate) struct PskLabel<'a> {
pub(crate) id: &'a PreSharedKeyId,
pub(crate) index: u16,
pub(crate) count: u16,
}
impl<'a> PskLabel<'a> {
fn new(id: &'a PreSharedKeyId, index: u16, count: u16) -> Self {
Self { id, index, count }
}
}
#[derive(Clone)]
pub struct PskSecret {
secret: Secret,
}
impl PskSecret {
pub(crate) fn new(
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
psks: Vec<(impl Borrow<PreSharedKeyId>, Secret)>,
) -> Result<Self, PskError> {
let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?;
let mut psk_secret = Secret::zero(ciphersuite);
for (index, (psk_id, psk)) in psks.into_iter().enumerate() {
let psk_extracted = {
let zero_secret = Secret::zero(ciphersuite);
zero_secret
.hkdf_extract(crypto, ciphersuite, &psk)
.map_err(LibraryError::unexpected_crypto_error)?
};
let psk_input = {
let psk_label = PskLabel::new(psk_id.borrow(), index as u16, num_psks)
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
psk_extracted
.kdf_expand_label(
crypto,
ciphersuite,
"derived psk",
&psk_label,
ciphersuite.hash_length(),
)
.map_err(LibraryError::unexpected_crypto_error)?
};
psk_secret = psk_input
.hkdf_extract(crypto, ciphersuite, &psk_secret)
.map_err(LibraryError::unexpected_crypto_error)?;
}
Ok(Self { secret: psk_secret })
}
pub(crate) fn secret(&self) -> &Secret {
&self.secret
}
#[cfg(any(feature = "test-utils", feature = "crypto-debug", test))]
pub(crate) fn as_slice(&self) -> &[u8] {
self.secret.as_slice()
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<Secret> for PskSecret {
fn from(secret: Secret) -> Self {
Self { secret }
}
}
pub(crate) fn load_psks<'p, Storage: StorageProvider>(
storage: &Storage,
resumption_psk_store: &ResumptionPskStore,
psk_ids: &'p [PreSharedKeyId],
) -> Result<Vec<(&'p PreSharedKeyId, Secret)>, PskError> {
let mut psk_bundles = Vec::new();
for psk_id in psk_ids.iter() {
log_crypto!(trace, "PSK store {:?}", resumption_psk_store);
match &psk_id.psk {
Psk::Resumption(resumption) => {
if let Some(psk_bundle) = resumption_psk_store.get(resumption.psk_epoch()) {
psk_bundles.push((psk_id, psk_bundle.secret.clone()));
} else {
return Err(PskError::KeyNotFound);
}
}
Psk::External(_) => {
let psk_bundle: Option<PskBundle> = storage
.psk(psk_id.psk())
.map_err(|_| PskError::KeyNotFound)?;
if let Some(psk_bundle) = psk_bundle {
psk_bundles.push((psk_id, psk_bundle.secret));
} else {
return Err(PskError::KeyNotFound);
}
}
}
}
Ok(psk_bundles)
}
pub mod store {
use serde::{Deserialize, Serialize};
use crate::{group::GroupEpoch, schedule::ResumptionPskSecret};
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
pub(crate) struct ResumptionPskStore {
max_number_of_secrets: usize,
resumption_psk: Vec<(GroupEpoch, ResumptionPskSecret)>,
cursor: usize,
}
impl ResumptionPskStore {
pub(crate) fn new(max_number_of_secrets: usize) -> Self {
Self {
max_number_of_secrets,
resumption_psk: vec![],
cursor: 0,
}
}
pub(crate) fn add(&mut self, epoch: GroupEpoch, resumption_psk: ResumptionPskSecret) {
if self.max_number_of_secrets == 0 {
return;
}
let item = (epoch, resumption_psk);
if self.resumption_psk.len() < self.max_number_of_secrets {
self.resumption_psk.push(item);
self.cursor += 1;
} else {
self.cursor += 1;
self.cursor %= self.resumption_psk.len();
self.resumption_psk[self.cursor] = item;
}
}
pub(crate) fn get(&self, epoch: GroupEpoch) -> Option<&ResumptionPskSecret> {
self.resumption_psk
.iter()
.find(|&(e, _s)| e == &epoch)
.map(|(_e, s)| s)
}
}
#[cfg(test)]
impl ResumptionPskStore {
pub(crate) fn cursor(&self) -> usize {
self.cursor
}
}
}