use crate::{
ciphersuite::{signable::*, *},
credentials::*,
extensions::Extensions,
treesync::node::leaf_node::{LeafNodeIn, VerifiableLeafNode},
versions::ProtocolVersion,
};
use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
use serde::{Deserialize, Serialize};
use tls_codec::{
Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
};
use super::{
errors::KeyPackageVerifyError, InitKey, KeyPackage, KeyPackageTbs, SIGNATURE_KEY_PACKAGE_LABEL,
};
#[cfg(any(feature = "test-utils", test))]
use super::KeyPackageBundle;
struct VerifiableKeyPackage {
payload: KeyPackageTbs,
signature: Signature,
}
impl VerifiableKeyPackage {
fn new(payload: KeyPackageTbs, signature: Signature) -> Self {
Self { payload, signature }
}
}
impl Verifiable for VerifiableKeyPackage {
type VerifiedStruct = KeyPackage;
fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
self.payload.tls_serialize_detached()
}
fn signature(&self) -> &Signature {
&self.signature
}
fn label(&self) -> &str {
SIGNATURE_KEY_PACKAGE_LABEL
}
fn verify(
self,
crypto: &impl OpenMlsCrypto,
pk: &OpenMlsSignaturePublicKey,
) -> Result<Self::VerifiedStruct, SignatureError> {
self.verify_no_out(crypto, pk)?;
Ok(KeyPackage {
payload: self.payload,
signature: self.signature,
})
}
}
impl VerifiedStruct for KeyPackage {}
#[derive(
Debug,
Clone,
PartialEq,
TlsSize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
Serialize,
Deserialize,
)]
struct KeyPackageTbsIn {
protocol_version: ProtocolVersion,
ciphersuite: Ciphersuite,
init_key: InitKey,
leaf_node: LeafNodeIn,
extensions: Extensions,
}
#[derive(
Debug,
PartialEq,
Clone,
Serialize,
Deserialize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSize,
)]
pub struct KeyPackageIn {
payload: KeyPackageTbsIn,
signature: Signature,
}
impl KeyPackageIn {
pub fn unverified_credential(&self) -> CredentialWithKey {
let credential = self.payload.leaf_node.credential().clone();
let signature_key = self.payload.leaf_node.signature_key().clone();
CredentialWithKey {
credential,
signature_key,
}
}
pub fn validate(
self,
crypto: &impl OpenMlsCrypto,
protocol_version: ProtocolVersion,
) -> Result<KeyPackage, KeyPackageVerifyError> {
let leaf_node = self.payload.leaf_node.clone().into_verifiable_leaf_node();
let signature_key = &OpenMlsSignaturePublicKey::from_signature_key(
self.payload.leaf_node.signature_key().clone(),
self.payload.ciphersuite.signature_algorithm(),
);
let leaf_node = match leaf_node {
VerifiableLeafNode::KeyPackage(leaf_node) => leaf_node
.verify(crypto, signature_key)
.map_err(|_| KeyPackageVerifyError::InvalidLeafNodeSignature)?,
_ => return Err(KeyPackageVerifyError::InvalidLeafNodeSourceType),
};
if !self.version_is_supported(protocol_version) {
return Err(KeyPackageVerifyError::InvalidProtocolVersion);
}
if leaf_node.encryption_key().key() == self.payload.init_key.key() {
return Err(KeyPackageVerifyError::InitKeyEqualsEncryptionKey);
}
let key_package_tbs = KeyPackageTbs {
protocol_version: self.payload.protocol_version,
ciphersuite: self.payload.ciphersuite,
init_key: self.payload.init_key,
leaf_node,
extensions: self.payload.extensions,
};
let key_package = VerifiableKeyPackage::new(key_package_tbs, self.signature)
.verify(crypto, signature_key)
.map_err(|_| KeyPackageVerifyError::InvalidSignature)?;
for extension in key_package.payload.extensions.iter() {
if !key_package
.payload
.leaf_node
.supports_extension(&extension.extension_type())
{
return Err(KeyPackageVerifyError::UnsupportedExtension);
}
}
if let Some(life_time) = key_package.payload.leaf_node.life_time() {
if !life_time.is_valid() {
return Err(KeyPackageVerifyError::InvalidLifetime);
}
} else {
return Err(KeyPackageVerifyError::MissingLifetime);
}
Ok(key_package)
}
pub(crate) fn version_is_supported(&self, protocol_version: ProtocolVersion) -> bool {
self.payload.protocol_version == protocol_version
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<KeyPackageTbsIn> for KeyPackageTbs {
fn from(value: KeyPackageTbsIn) -> Self {
KeyPackageTbs {
protocol_version: value.protocol_version,
ciphersuite: value.ciphersuite,
init_key: value.init_key,
leaf_node: value.leaf_node.into(),
extensions: value.extensions,
}
}
}
impl From<KeyPackageTbs> for KeyPackageTbsIn {
fn from(value: KeyPackageTbs) -> Self {
Self {
protocol_version: value.protocol_version,
ciphersuite: value.ciphersuite,
init_key: value.init_key,
leaf_node: value.leaf_node.into(),
extensions: value.extensions,
}
}
}
impl From<KeyPackage> for KeyPackageIn {
fn from(value: KeyPackage) -> Self {
Self {
payload: value.payload.into(),
signature: value.signature,
}
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<KeyPackageBundle> for KeyPackageIn {
fn from(value: KeyPackageBundle) -> Self {
Self {
payload: value.key_package.payload.into(),
signature: value.key_package.signature,
}
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<KeyPackageIn> for KeyPackage {
fn from(value: KeyPackageIn) -> Self {
Self {
payload: value.payload.into(),
signature: value.signature,
}
}
}