use openmls_traits::crypto::OpenMlsCrypto;
use serde::{Deserialize, Serialize};
use std::{
convert::TryFrom,
fmt::Debug,
io::{Read, Write},
};
use tls_codec::*;
mod capabilities_extension;
mod external_key_id_extension;
mod life_time_extension;
mod parent_hash_extension;
mod ratchet_tree_extension;
mod required_capabilities;
use errors::*;
pub mod errors;
pub use capabilities_extension::CapabilitiesExtension;
pub use external_key_id_extension::ExternalKeyIdExtension;
pub use life_time_extension::LifetimeExtension;
pub use parent_hash_extension::ParentHashExtension;
pub use ratchet_tree_extension::RatchetTreeExtension;
pub use required_capabilities::RequiredCapabilitiesExtension;
use crate::treesync::node::Node;
#[cfg(test)]
mod test_extensions;
#[derive(
Debug,
Copy,
Clone,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
Ord,
PartialOrd,
TlsSerialize,
TlsDeserialize,
TlsSize,
)]
#[repr(u16)]
#[allow(missing_docs)]
pub enum ExtensionType {
Reserved = 0,
Capabilities = 1,
Lifetime = 2,
ExternalKeyId = 3,
ParentHash = 4,
RatchetTree = 5,
RequiredCapabilities = 6,
}
impl TryFrom<u16> for ExtensionType {
type Error = tls_codec::Error;
fn try_from(a: u16) -> Result<Self, Self::Error> {
match a {
0 => Ok(ExtensionType::Reserved),
1 => Ok(ExtensionType::Capabilities),
2 => Ok(ExtensionType::Lifetime),
3 => Ok(ExtensionType::ExternalKeyId),
4 => Ok(ExtensionType::ParentHash),
5 => Ok(ExtensionType::RatchetTree),
_ => Err(tls_codec::Error::DecodingError(format!(
"{} is an unkown extension type",
a
))),
}
}
}
impl ExtensionType {
pub fn is_supported(&self) -> bool {
match self {
ExtensionType::Reserved
| ExtensionType::Capabilities
| ExtensionType::Lifetime
| ExtensionType::ExternalKeyId
| ExtensionType::ParentHash
| ExtensionType::RatchetTree
| ExtensionType::RequiredCapabilities => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Extension {
Capabilities(CapabilitiesExtension),
ExternalKeyId(ExternalKeyIdExtension),
LifeTime(LifetimeExtension),
ParentHash(ParentHashExtension),
RatchetTree(RatchetTreeExtension),
RequiredCapabilities(RequiredCapabilitiesExtension),
}
impl tls_codec::Size for Extension {
#[inline]
fn tls_serialized_len(&self) -> usize {
2
+ 4 +
match self {
Extension::Capabilities(e) => e.tls_serialized_len(),
Extension::ExternalKeyId(e) => e.tls_serialized_len(),
Extension::LifeTime(e) => e.tls_serialized_len(),
Extension::ParentHash(e) => e.tls_serialized_len(),
Extension::RatchetTree(e) => e.tls_serialized_len(),
Extension::RequiredCapabilities(e) => e.tls_serialized_len(),
}
}
}
impl tls_codec::Serialize for Extension {
fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
let written = self.extension_type().tls_serialize(writer)?;
let extension_data_len = self.tls_serialized_len() - 6 ;
let mut extension_data = Vec::with_capacity(extension_data_len);
let extension_data_written = match self {
Extension::Capabilities(e) => e.tls_serialize(&mut extension_data),
Extension::ExternalKeyId(e) => e.tls_serialize(&mut extension_data),
Extension::LifeTime(e) => e.tls_serialize(&mut extension_data),
Extension::ParentHash(e) => e.tls_serialize(&mut extension_data),
Extension::RatchetTree(e) => e.tls_serialize(&mut extension_data),
Extension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
}?;
debug_assert_eq!(extension_data_written, extension_data_len);
debug_assert_eq!(extension_data_written, extension_data.len());
TlsSliceU32(&extension_data)
.tls_serialize(writer)
.map(|l| l + written)
}
}
impl tls_codec::Deserialize for Extension {
fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
let extension_type = ExtensionType::tls_deserialize(bytes)?;
let extension_data = TlsByteVecU32::tls_deserialize(bytes)?;
let mut extension_data = extension_data.as_slice();
Ok(match extension_type {
ExtensionType::Capabilities => Extension::Capabilities(
CapabilitiesExtension::tls_deserialize(&mut extension_data)?,
),
ExtensionType::ExternalKeyId => Extension::ExternalKeyId(
ExternalKeyIdExtension::tls_deserialize(&mut extension_data)?,
),
ExtensionType::Lifetime => {
Extension::LifeTime(LifetimeExtension::tls_deserialize(&mut extension_data)?)
}
ExtensionType::ParentHash => {
Extension::ParentHash(ParentHashExtension::tls_deserialize(&mut extension_data)?)
}
ExtensionType::RatchetTree => {
Extension::RatchetTree(RatchetTreeExtension::tls_deserialize(&mut extension_data)?)
}
ExtensionType::RequiredCapabilities => Extension::RequiredCapabilities(
RequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?,
),
ExtensionType::Reserved => {
return Err(tls_codec::Error::DecodingError(format!(
"{:?} is not a valid extension type",
extension_type
)))
}
})
}
}
impl Extension {
pub fn as_ratchet_tree_extension(&self) -> Result<&RatchetTreeExtension, ExtensionError> {
match self {
Self::RatchetTree(rte) => Ok(rte),
_ => Err(ExtensionError::InvalidExtensionType(
"This is not a RatchetTreeExtension".into(),
)),
}
}
pub fn as_lifetime_extension(&self) -> Result<&LifetimeExtension, ExtensionError> {
match self {
Self::LifeTime(e) => Ok(e),
_ => Err(ExtensionError::InvalidExtensionType(
"This is not a LifetimeExtension".into(),
)),
}
}
pub fn as_external_key_id_extension(&self) -> Result<&ExternalKeyIdExtension, ExtensionError> {
match self {
Self::ExternalKeyId(e) => Ok(e),
_ => Err(ExtensionError::InvalidExtensionType(
"This is not an ExternalKeyIdExtension".into(),
)),
}
}
pub fn as_capabilities_extension(&self) -> Result<&CapabilitiesExtension, ExtensionError> {
match self {
Self::Capabilities(e) => Ok(e),
_ => Err(ExtensionError::InvalidExtensionType(
"This is not a CapabilitiesExtension".into(),
)),
}
}
pub fn as_parent_hash_extension(&self) -> Result<&ParentHashExtension, ExtensionError> {
match self {
Self::ParentHash(e) => Ok(e),
_ => Err(ExtensionError::InvalidExtensionType(
"This is not a ParentHashExtension".into(),
)),
}
}
pub fn as_required_capabilities_extension(
&self,
) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
match self {
Self::RequiredCapabilities(e) => Ok(e),
_ => Err(ExtensionError::InvalidExtensionType(
"This is not a RequiredCapabilitiesExtension".into(),
)),
}
}
#[inline]
pub const fn extension_type(&self) -> ExtensionType {
match self {
Extension::Capabilities(_) => ExtensionType::Capabilities,
Extension::ExternalKeyId(_) => ExtensionType::ExternalKeyId,
Extension::LifeTime(_) => ExtensionType::Lifetime,
Extension::ParentHash(_) => ExtensionType::ParentHash,
Extension::RatchetTree(_) => ExtensionType::RatchetTree,
Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
}
}
}
impl Eq for Extension {}
impl PartialOrd for Extension {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.extension_type().partial_cmp(&other.extension_type())
}
}
impl Ord for Extension {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.extension_type().cmp(&other.extension_type())
}
}
pub(crate) fn try_nodes_from_extensions(
other_extensions: &[Extension],
crypto_backend: &impl OpenMlsCrypto,
) -> Result<Option<Vec<Option<Node>>>, ExtensionError> {
let mut ratchet_tree_extensions = other_extensions
.iter()
.filter(|e| e.extension_type() == ExtensionType::RatchetTree);
let nodes = match ratchet_tree_extensions.next() {
Some(e) => {
let mut nodes: Vec<Option<Node>> = e.as_ratchet_tree_extension()?.as_slice().into();
for node in nodes.iter_mut().flatten() {
if let Node::LeafNode(leaf) = node {
leaf.set_key_package_ref(crypto_backend)?;
}
}
Some(nodes)
}
None => None,
};
if ratchet_tree_extensions.next().is_some() {
return Err(ExtensionError::DuplicateRatchetTreeExtension);
};
Ok(nodes)
}