use crate::{
ciphersuite::{hash_ref::KeyPackageRef, signable::*, *},
credentials::*,
error::LibraryError,
extensions::{
errors::ExtensionError, CapabilitiesExtension, Extension, ExtensionType, LifetimeExtension,
ParentHashExtension, RequiredCapabilitiesExtension,
},
versions::ProtocolVersion,
};
use log::error;
use openmls_traits::{
crypto::OpenMlsCrypto,
types::{Ciphersuite, CryptoError, HpkeKeyPair, SignatureScheme},
OpenMlsCryptoProvider,
};
use serde::{Deserialize, Serialize};
use tls_codec::{
Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait, TlsSize, TlsVecU32,
};
mod codec;
use errors::*;
pub mod errors;
#[cfg(test)]
mod test_key_packages;
#[derive(Debug, Clone, PartialEq, TlsSize, Serialize, Deserialize)]
struct KeyPackagePayload {
protocol_version: ProtocolVersion,
ciphersuite: Ciphersuite,
hpke_init_key: HpkePublicKey,
credential: Credential,
extensions: TlsVecU32<Extension>,
}
impl tls_codec::Serialize for KeyPackagePayload {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
let mut written = self.protocol_version.tls_serialize(writer)?;
written += self.ciphersuite.tls_serialize(writer)?;
written += self.hpke_init_key.tls_serialize(writer)?;
written += self.credential.tls_serialize(writer)?;
self.extensions.tls_serialize(writer).map(|l| l + written)
}
}
impl Signable for KeyPackagePayload {
type SignedOutput = KeyPackage;
fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
self.tls_serialize_detached()
}
}
impl From<KeyPackage> for KeyPackagePayload {
fn from(kp: KeyPackage) -> Self {
kp.payload
}
}
impl KeyPackagePayload {
fn from_key_package(kp: &KeyPackage, hpke_init_key: HpkePublicKey) -> Self {
Self {
protocol_version: kp.payload.protocol_version,
ciphersuite: kp.payload.ciphersuite,
hpke_init_key,
credential: kp.payload.credential.clone(),
extensions: kp.payload.extensions.clone(),
}
}
pub(crate) fn remove_extension(&mut self, extension_type: ExtensionType) {
self.extensions
.retain(|e| e.extension_type() != extension_type);
}
#[cfg(any(feature = "test-utils", test))]
fn add_extension(&mut self, extension: Extension) {
self.remove_extension(extension.extension_type());
self.extensions.push(extension);
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_credential(&mut self, credential: Credential) {
self.credential = credential
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_public_key(&mut self, public_key: HpkePublicKey) {
self.hpke_init_key = public_key
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_version(&mut self, version: ProtocolVersion) {
self.protocol_version = version
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_ciphersuite(&mut self, ciphersuite: Ciphersuite) {
self.ciphersuite = ciphersuite
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyPackage {
payload: KeyPackagePayload,
signature: Signature,
}
impl TryFrom<&[u8]> for KeyPackage {
type Error = tls_codec::Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
Self::tls_deserialize(&mut &*bytes)
}
}
impl PartialEq for KeyPackage {
fn eq(&self, other: &Self) -> bool {
self.payload == other.payload
}
}
impl SignedStruct<KeyPackagePayload> for KeyPackage {
fn from_payload(payload: KeyPackagePayload, signature: Signature) -> Self {
Self { payload, signature }
}
}
impl Verifiable for KeyPackage {
fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
self.payload.tls_serialize_detached()
}
fn signature(&self) -> &Signature {
&self.signature
}
}
const MANDATORY_EXTENSIONS: [ExtensionType; 2] =
[ExtensionType::Capabilities, ExtensionType::Lifetime];
impl KeyPackage {
pub fn verify(
&self,
backend: &impl OpenMlsCryptoProvider,
) -> Result<(), KeyPackageVerifyError> {
let mut mandatory_extensions_found = MANDATORY_EXTENSIONS.to_vec();
for extension in self.payload.extensions.iter() {
if let Some(p) = mandatory_extensions_found
.iter()
.position(|&e| e == extension.extension_type())
{
let _ = mandatory_extensions_found.remove(p);
}
if extension.extension_type() == ExtensionType::Lifetime {
match extension.as_lifetime_extension() {
Ok(e) => {
if !e.is_valid() {
log::error!("Invalid lifetime extension in key package.");
return Err(KeyPackageVerifyError::InvalidLifetimeExtension);
}
}
Err(_) => {
log::error!("as_lifetime_extension failed while verifying a key package.");
return Err(LibraryError::custom("Expected a lifetime extension").into());
}
}
}
}
if !mandatory_extensions_found.is_empty() {
log::error!("This key package is missing mandatory extensions.");
return Err(KeyPackageVerifyError::MandatoryExtensionsMissing);
}
<Self as Verifiable>::verify_no_out(self, backend, &self.payload.credential).map_err(|_| {
log::error!("Key package signature is invalid.");
KeyPackageVerifyError::InvalidSignature
})
}
pub fn external_key_id(&self) -> Result<&[u8], ExtensionError> {
if let Some(key_id_ext) = self.extension_with_type(ExtensionType::ExternalKeyId) {
return Ok(key_id_ext.as_external_key_id_extension()?.as_slice());
} else {
Err(ExtensionError::InvalidExtensionType(
"Tried to get a key ID extension".into(),
))
}
}
pub fn extensions(&self) -> &[Extension] {
self.payload.extensions.as_slice()
}
pub fn check_extension_support(
&self,
required_extensions: &[ExtensionType],
) -> Result<(), KeyPackageExtensionSupportError> {
let my_extension_types = self.extensions().iter().map(|ext| ext.extension_type());
for required in required_extensions.iter() {
if !my_extension_types.clone().any(|e| &e == required) {
return Err(KeyPackageExtensionSupportError::UnsupportedExtension);
}
}
Ok(())
}
pub fn credential(&self) -> &Credential {
&self.payload.credential
}
pub(crate) fn validate_required_capabilities<'a>(
&self,
required_capabilities: impl Into<Option<&'a RequiredCapabilitiesExtension>>,
) -> Result<(), KeyPackageExtensionSupportError> {
if let Some(required_capabilities) = required_capabilities.into() {
let my_extension_types = self.extensions().iter().map(|e| e.extension_type());
for required_extension in required_capabilities.extensions() {
if !my_extension_types.clone().any(|e| &e == required_extension) {
return Err(KeyPackageExtensionSupportError::UnsupportedExtension);
}
}
}
Ok(())
}
pub fn hash_ref(&self, backend: &impl OpenMlsCrypto) -> Result<KeyPackageRef, LibraryError> {
KeyPackageRef::new(
&self
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?,
self.payload.ciphersuite,
backend,
)
.map_err(LibraryError::unexpected_crypto_error)
}
pub fn ciphersuite(&self) -> Ciphersuite {
self.payload.ciphersuite
}
}
impl KeyPackage {
fn new(
ciphersuite: Ciphersuite,
backend: &impl OpenMlsCryptoProvider,
hpke_init_key: HpkePublicKey,
credential_bundle: &CredentialBundle,
extensions: Vec<Extension>,
) -> Result<Self, KeyPackageNewError> {
if SignatureScheme::from(ciphersuite) != credential_bundle.credential().signature_scheme() {
return Err(KeyPackageNewError::CiphersuiteSignatureSchemeMismatch);
}
let key_package = KeyPackagePayload {
protocol_version: ProtocolVersion::default(),
ciphersuite,
hpke_init_key,
credential: credential_bundle.credential().clone(),
extensions: extensions.into(),
};
Ok(key_package.sign(backend, credential_bundle)?)
}
}
impl KeyPackage {
pub(crate) fn extension_with_type(&self, extension_type: ExtensionType) -> Option<&Extension> {
for e in self.payload.extensions.as_slice() {
if e.extension_type() == extension_type {
return Some(e);
}
}
None
}
pub(crate) fn hpke_init_key(&self) -> &HpkePublicKey {
&self.payload.hpke_init_key
}
pub(crate) fn protocol_version(&self) -> ProtocolVersion {
self.payload.protocol_version
}
}
#[cfg(any(feature = "test-utils", test))]
pub struct KeyPackageBundlePayload {
key_package_payload: KeyPackagePayload,
private_key: HpkePrivateKey,
leaf_secret: Secret,
}
#[cfg(not(any(feature = "test-utils", test)))]
pub(crate) struct KeyPackageBundlePayload {
key_package_payload: KeyPackagePayload,
private_key: HpkePrivateKey,
leaf_secret: Secret,
}
impl KeyPackageBundlePayload {
pub(crate) fn from_rekeyed_key_package(
key_package: &KeyPackage,
backend: &impl OpenMlsCryptoProvider,
) -> Result<Self, CryptoError> {
let leaf_secret = Secret::random(
key_package.ciphersuite(),
backend,
key_package.protocol_version(),
)?;
Self::from_key_package_and_leaf_secret(leaf_secret, key_package, backend)
}
pub(crate) fn from_key_package_and_leaf_secret(
leaf_secret: Secret,
key_package: &KeyPackage,
backend: &impl OpenMlsCryptoProvider,
) -> Result<Self, CryptoError> {
let leaf_node_secret = derive_leaf_node_secret(&leaf_secret, backend);
let key_pair = backend.crypto().derive_hpke_keypair(
key_package.ciphersuite().hpke_config(),
leaf_node_secret?.as_slice(),
);
let key_package_payload =
KeyPackagePayload::from_key_package(key_package, key_pair.public.into());
Ok(Self {
key_package_payload,
private_key: key_pair.private.into(),
leaf_secret,
})
}
pub(crate) fn update_parent_hash(&mut self, parent_hash: &[u8]) {
self.key_package_payload
.remove_extension(ExtensionType::ParentHash);
let extension = Extension::ParentHash(ParentHashExtension::new(parent_hash));
self.key_package_payload.extensions.push(extension);
}
#[cfg(any(feature = "test-utils", test))]
pub fn add_extension(&mut self, extension: Extension) {
self.key_package_payload.add_extension(extension)
}
pub(crate) fn leaf_secret(&self) -> &Secret {
&self.leaf_secret
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_credential(&mut self, credential: Credential) {
self.key_package_payload.set_credential(credential)
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_public_key(&mut self, public_key: HpkePublicKey) {
self.key_package_payload.set_public_key(public_key)
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_version(&mut self, version: ProtocolVersion) {
self.key_package_payload.set_version(version)
}
#[cfg(any(feature = "test-utils", test))]
pub fn set_ciphersuite(&mut self, ciphersuite: Ciphersuite) {
self.key_package_payload.set_ciphersuite(ciphersuite)
}
}
impl Signable for KeyPackageBundlePayload {
type SignedOutput = KeyPackageBundle;
fn unsigned_payload(&self) -> Result<Vec<u8>, tls_codec::Error> {
self.key_package_payload.unsigned_payload()
}
}
impl SignedStruct<KeyPackageBundlePayload> for KeyPackageBundle {
fn from_payload(payload: KeyPackageBundlePayload, signature: Signature) -> Self {
let key_package = KeyPackage::from_payload(payload.key_package_payload, signature);
Self {
key_package,
private_key: payload.private_key,
leaf_secret: payload.leaf_secret,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq))]
pub struct KeyPackageBundle {
pub(crate) key_package: KeyPackage,
pub(crate) private_key: HpkePrivateKey,
pub(crate) leaf_secret: Secret,
}
impl From<KeyPackageBundle> for KeyPackageBundlePayload {
fn from(kpb: KeyPackageBundle) -> Self {
Self {
key_package_payload: kpb.key_package.into(),
private_key: kpb.private_key,
leaf_secret: kpb.leaf_secret,
}
}
}
impl KeyPackageBundle {
pub fn new(
ciphersuites: &[Ciphersuite],
credential_bundle: &CredentialBundle,
backend: &impl OpenMlsCryptoProvider,
extensions: Vec<Extension>,
) -> Result<Self, KeyPackageBundleNewError> {
Self::new_with_version(
ProtocolVersion::default(),
ciphersuites,
backend,
credential_bundle,
extensions,
)
}
pub fn new_with_version(
version: ProtocolVersion,
ciphersuites: &[Ciphersuite],
backend: &impl OpenMlsCryptoProvider,
credential_bundle: &CredentialBundle,
extensions: Vec<Extension>,
) -> Result<Self, KeyPackageBundleNewError> {
if ciphersuites.is_empty() {
let error = KeyPackageBundleNewError::NoCiphersuitesSupplied;
error!(
"Error creating new KeyPackageBundle: No Ciphersuites specified {:?}",
error
);
return Err(error);
}
if SignatureScheme::from(ciphersuites[0])
!= credential_bundle.credential().signature_scheme()
{
return Err(KeyPackageBundleNewError::CiphersuiteSignatureSchemeMismatch);
}
let ciphersuite = ciphersuites[0];
let leaf_secret = Secret::random(ciphersuite, backend, version)
.map_err(LibraryError::unexpected_crypto_error)?;
Self::new_from_leaf_secret(
ciphersuites,
backend,
credential_bundle,
extensions,
leaf_secret,
)
}
pub(crate) fn new_with_keypair(
ciphersuites: &[Ciphersuite],
backend: &impl OpenMlsCryptoProvider,
credential_bundle: &CredentialBundle,
mut extensions: Vec<Extension>,
key_pair: HpkeKeyPair,
leaf_secret: Secret,
) -> Result<Self, KeyPackageBundleNewError> {
if ciphersuites.is_empty() {
let error = KeyPackageBundleNewError::NoCiphersuitesSupplied;
error!(
"Error creating new KeyPackageBundle: No Ciphersuites specified {:?}",
error
);
return Err(error);
}
let extensions_length = extensions.len();
extensions.sort();
extensions.dedup();
if extensions_length != extensions.len() {
let error = KeyPackageBundleNewError::DuplicateExtension;
error!(
"Error creating new KeyPackageBundle: Duplicate Extension {:?}",
error
);
return Err(error);
}
match extensions
.iter()
.find(|e| e.extension_type() == ExtensionType::Capabilities)
{
Some(extension) => {
let capabilities_extension = extension.as_capabilities_extension()?;
if capabilities_extension.ciphersuites() != ciphersuites {
let error = KeyPackageBundleNewError::CiphersuiteMismatch;
error!(
"Error creating new KeyPackageBundle: Invalid Capabilities Extensions {:?}",
error
);
return Err(error);
}
}
None => extensions.push(Extension::Capabilities(CapabilitiesExtension::new(
None,
Some(ciphersuites),
None,
None,
))),
};
if !extensions
.iter()
.any(|e| e.extension_type() == ExtensionType::Lifetime)
{
extensions.push(Extension::LifeTime(LifetimeExtension::default()));
}
let key_package = KeyPackage::new(
ciphersuites[0],
backend,
key_pair.public.into(),
credential_bundle,
extensions,
)
.map_err(|e| match e {
KeyPackageNewError::LibraryError(e) => e.into(),
KeyPackageNewError::CiphersuiteSignatureSchemeMismatch => {
KeyPackageBundleNewError::CiphersuiteSignatureSchemeMismatch
}
})?;
Ok(KeyPackageBundle {
key_package,
private_key: key_pair.private.into(),
leaf_secret,
})
}
pub fn key_package(&self) -> &KeyPackage {
&self.key_package
}
pub fn into_parts(self) -> (KeyPackage, (Vec<u8>, Vec<u8>)) {
(
self.key_package,
(
self.private_key.as_slice().to_vec(),
self.leaf_secret.as_slice().to_vec(),
),
)
}
#[cfg(feature = "test-utils")]
pub fn unsigned(self) -> KeyPackageBundlePayload {
self.into()
}
}
impl KeyPackageBundle {
pub(crate) fn new_from_leaf_secret(
ciphersuites: &[Ciphersuite],
backend: &impl OpenMlsCryptoProvider,
credential_bundle: &CredentialBundle,
extensions: Vec<Extension>,
leaf_secret: Secret,
) -> Result<Self, KeyPackageBundleNewError> {
if ciphersuites.is_empty() {
let error = KeyPackageBundleNewError::NoCiphersuitesSupplied;
error!(
"Error creating new KeyPackageBundle: No Ciphersuites specified {:?}",
error
);
return Err(error);
}
let ciphersuite = ciphersuites[0];
let leaf_node_secret = derive_leaf_node_secret(&leaf_secret, backend)
.map_err(LibraryError::unexpected_crypto_error)?;
let keypair = backend
.crypto()
.derive_hpke_keypair(ciphersuite.hpke_config(), leaf_node_secret.as_slice());
Self::new_with_keypair(
ciphersuites,
backend,
credential_bundle,
extensions,
keypair,
leaf_secret,
)
}
pub(crate) fn _set_private_key(&mut self, private_key: HpkePrivateKey) {
self.private_key = private_key;
}
pub(crate) fn private_key(&self) -> &HpkePrivateKey {
&self.private_key
}
pub(crate) fn leaf_secret(&self) -> &Secret {
&self.leaf_secret
}
}
pub(crate) fn derive_leaf_node_secret(
leaf_secret: &Secret,
backend: &impl OpenMlsCryptoProvider,
) -> Result<Secret, CryptoError> {
leaf_secret.derive_secret(backend, "node")
}