use openmls::group::GroupId;
use openmls_traits::OpenMlsCryptoProvider;
use tls_codec::{TlsByteVecU8, TlsDeserialize, TlsSerialize, TlsSize};
use crate::hash::Hash;
use crate::secret_group::lts::{
aead, LongTermSecretCiphersuite, LongTermSecretCiphertext, LongTermSecretEpoch,
LongTermSecretError,
};
#[derive(Debug, Clone, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)]
pub struct LongTermSecret {
group_id: GroupId,
ciphersuite: LongTermSecretCiphersuite,
long_term_epoch: LongTermSecretEpoch,
value: TlsByteVecU8,
}
impl LongTermSecret {
pub fn new(
group_instance_id: Hash,
ciphersuite: LongTermSecretCiphersuite,
long_term_epoch: LongTermSecretEpoch,
value: TlsByteVecU8,
) -> Self {
Self {
group_id: GroupId::from_slice(&group_instance_id.to_bytes()),
ciphersuite,
long_term_epoch,
value,
}
}
pub fn group_instance_id(&self) -> Result<Hash, LongTermSecretError> {
let hex_str = hex::encode(&self.group_id.as_slice());
Ok(Hash::new(&hex_str)?)
}
pub fn long_term_epoch(&self) -> LongTermSecretEpoch {
self.long_term_epoch
}
#[cfg(test)]
pub(crate) fn value(&self) -> Vec<u8> {
self.value.as_slice().to_vec()
}
pub fn encrypt(
&self,
provider: &impl OpenMlsCryptoProvider,
nonce: &[u8],
data: &[u8],
) -> Result<LongTermSecretCiphertext, LongTermSecretError> {
let ciphertext_tag = aead::encrypt(
provider,
&self.ciphersuite,
self.value.as_slice(),
data,
nonce,
self.group_id.as_slice(),
)?;
Ok(LongTermSecretCiphertext::new(
self.group_instance_id()?,
self.long_term_epoch(),
ciphertext_tag,
nonce.to_vec(),
))
}
pub fn decrypt(
&self,
provider: &impl OpenMlsCryptoProvider,
ciphertext: &LongTermSecretCiphertext,
) -> Result<Vec<u8>, LongTermSecretError> {
if ciphertext.long_term_epoch() != self.long_term_epoch {
return Err(LongTermSecretError::EpochNotMatching(
self.long_term_epoch.0,
ciphertext.long_term_epoch().0,
));
}
if ciphertext.group_instance_id()? != self.group_instance_id()? {
return Err(LongTermSecretError::GroupNotMatching(
self.group_instance_id()?.as_str().into(),
ciphertext.group_instance_id()?.as_str().into(),
));
}
let payload = aead::decrypt(
provider,
&self.ciphersuite,
self.value.as_slice(),
&ciphertext.ciphertext_with_tag(),
&ciphertext.nonce(),
self.group_id.as_slice(),
)?;
Ok(payload)
}
}
#[cfg(test)]
mod tests {
use openmls_traits::random::OpenMlsRand;
use openmls_traits::OpenMlsCryptoProvider;
use crate::hash::Hash;
use crate::secret_group::lts::{
LongTermSecret, LongTermSecretCiphersuite, LongTermSecretEpoch, LongTermSecretError,
};
use crate::secret_group::MlsProvider;
#[test]
fn group_id_hash_encoding() {
let group_instance_id = Hash::new_from_bytes(vec![1, 2, 3]).unwrap();
let secret = LongTermSecret::new(
group_instance_id.clone(),
LongTermSecretCiphersuite::PANDA10_AES256GCM,
LongTermSecretEpoch(0),
vec![1, 2, 3].into(),
);
assert_eq!(
group_instance_id.as_str(),
secret.group_instance_id().unwrap().as_str()
);
}
#[test]
fn invalid_ciphertext() {
let provider = MlsProvider::new();
for ciphersuite in LongTermSecretCiphersuite::supported_ciphersuites() {
let aead_key = provider
.rand()
.random_vec(ciphersuite.aead_key_length())
.unwrap();
let group_instance_id = Hash::new_from_bytes(vec![1, 2, 3]).unwrap();
let group_instance_id_2 = Hash::new_from_bytes(vec![4, 5, 6]).unwrap();
let secret = LongTermSecret::new(
group_instance_id.clone(),
ciphersuite,
LongTermSecretEpoch(0),
aead_key.clone().into(),
);
let secret_different_group = LongTermSecret::new(
group_instance_id_2,
ciphersuite,
LongTermSecretEpoch(0),
aead_key.clone().into(),
);
let secret_different_epoch = LongTermSecret::new(
group_instance_id,
ciphersuite,
LongTermSecretEpoch(2),
aead_key.into(),
);
let aead_nonce = provider
.rand()
.random_vec(ciphersuite.aead_nonce_length())
.unwrap();
let ciphertext = secret
.encrypt(&provider, &aead_nonce, b"Secret Message")
.unwrap();
assert!(secret.decrypt(&provider, &ciphertext).is_ok());
assert!(matches!(
secret_different_epoch.decrypt(&provider, &ciphertext),
Err(LongTermSecretError::EpochNotMatching(_, _))
));
assert!(matches!(
secret_different_group.decrypt(&provider, &ciphertext),
Err(LongTermSecretError::GroupNotMatching(_, _))
));
}
}
}