use std::io::{Read, Write};
use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tls_codec::{
Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
};
use crate::{
binary_tree::array_representation::LeafNodeIndex,
ciphersuite::hash_ref::{make_proposal_ref, KeyPackageRef, ProposalRef},
error::LibraryError,
extensions::Extensions,
framing::{
mls_auth_content::AuthenticatedContent, mls_content::FramedContentBody, ContentType,
},
group::{GroupContext, GroupId},
key_packages::*,
schedule::psk::*,
treesync::LeafNode,
versions::ProtocolVersion,
};
#[cfg(feature = "extensions-draft-08")]
use crate::component::ComponentId;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub enum ProposalType {
Add,
Update,
Remove,
PreSharedKey,
Reinit,
ExternalInit,
GroupContextExtensions,
SelfRemove,
#[cfg(feature = "extensions-draft-08")]
AppEphemeral,
#[cfg(feature = "extensions-draft-08")]
AppDataUpdate,
Grease(u16),
Custom(u16),
}
impl ProposalType {
pub(crate) fn is_default(self) -> bool {
match self {
ProposalType::Add
| ProposalType::Update
| ProposalType::Remove
| ProposalType::PreSharedKey
| ProposalType::Reinit
| ProposalType::ExternalInit
| ProposalType::GroupContextExtensions => true,
ProposalType::SelfRemove | ProposalType::Grease(_) | ProposalType::Custom(_) => false,
#[cfg(feature = "extensions-draft-08")]
ProposalType::AppEphemeral | ProposalType::AppDataUpdate => false,
}
}
pub fn is_grease(&self) -> bool {
matches!(self, ProposalType::Grease(_))
}
}
impl Size for ProposalType {
fn tls_serialized_len(&self) -> usize {
2
}
}
impl TlsDeserializeTrait for ProposalType {
fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
where
Self: Sized,
{
let mut proposal_type = [0u8; 2];
bytes.read_exact(&mut proposal_type)?;
Ok(ProposalType::from(u16::from_be_bytes(proposal_type)))
}
}
impl TlsSerializeTrait for ProposalType {
fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
writer.write_all(&u16::from(*self).to_be_bytes())?;
Ok(2)
}
}
impl DeserializeBytes for ProposalType {
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
where
Self: Sized,
{
let mut bytes_ref = bytes;
let proposal_type = ProposalType::tls_deserialize(&mut bytes_ref)?;
let remainder = &bytes[proposal_type.tls_serialized_len()..];
Ok((proposal_type, remainder))
}
}
impl ProposalType {
pub fn is_path_required(&self) -> bool {
matches!(
self,
Self::Update
| Self::Remove
| Self::ExternalInit
| Self::GroupContextExtensions
| Self::SelfRemove
)
}
}
impl From<u16> for ProposalType {
fn from(value: u16) -> Self {
match value {
1 => ProposalType::Add,
2 => ProposalType::Update,
3 => ProposalType::Remove,
4 => ProposalType::PreSharedKey,
5 => ProposalType::Reinit,
6 => ProposalType::ExternalInit,
7 => ProposalType::GroupContextExtensions,
#[cfg(feature = "extensions-draft-08")]
8 => ProposalType::AppDataUpdate,
#[cfg(feature = "extensions-draft-08")]
0x0009 => ProposalType::AppEphemeral,
0x000a => ProposalType::SelfRemove,
other if crate::grease::is_grease_value(other) => ProposalType::Grease(other),
other => ProposalType::Custom(other),
}
}
}
impl From<ProposalType> for u16 {
fn from(value: ProposalType) -> Self {
match value {
ProposalType::Add => 1,
ProposalType::Update => 2,
ProposalType::Remove => 3,
ProposalType::PreSharedKey => 4,
ProposalType::Reinit => 5,
ProposalType::ExternalInit => 6,
ProposalType::GroupContextExtensions => 7,
#[cfg(feature = "extensions-draft-08")]
ProposalType::AppDataUpdate => 8,
#[cfg(feature = "extensions-draft-08")]
ProposalType::AppEphemeral => 0x0009,
ProposalType::SelfRemove => 0x000a,
ProposalType::Grease(id) => id,
ProposalType::Custom(id) => id,
}
}
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
#[repr(u16)]
pub enum Proposal {
Add(Box<AddProposal>),
Update(Box<UpdateProposal>),
Remove(Box<RemoveProposal>),
PreSharedKey(Box<PreSharedKeyProposal>),
ReInit(Box<ReInitProposal>),
ExternalInit(Box<ExternalInitProposal>),
GroupContextExtensions(Box<GroupContextExtensionProposal>),
#[cfg(feature = "extensions-draft-08")]
AppDataUpdate(Box<AppDataUpdateProposal>),
SelfRemove,
#[cfg(feature = "extensions-draft-08")]
AppEphemeral(Box<AppEphemeralProposal>),
Custom(Box<CustomProposal>),
}
impl Proposal {
pub(crate) fn remove(r: RemoveProposal) -> Self {
Self::Remove(Box::new(r))
}
pub(crate) fn add(a: AddProposal) -> Self {
Self::Add(Box::new(a))
}
pub(crate) fn custom(c: CustomProposal) -> Self {
Self::Custom(Box::new(c))
}
pub(crate) fn psk(p: PreSharedKeyProposal) -> Self {
Self::PreSharedKey(Box::new(p))
}
pub(crate) fn update(p: UpdateProposal) -> Self {
Self::Update(Box::new(p))
}
pub(crate) fn group_context_extensions(p: GroupContextExtensionProposal) -> Self {
Self::GroupContextExtensions(Box::new(p))
}
pub(crate) fn external_init(p: ExternalInitProposal) -> Self {
Self::ExternalInit(Box::new(p))
}
#[cfg(test)]
pub(crate) fn re_init(p: ReInitProposal) -> Self {
Self::ReInit(Box::new(p))
}
pub fn proposal_type(&self) -> ProposalType {
match self {
Proposal::Add(_) => ProposalType::Add,
Proposal::Update(_) => ProposalType::Update,
Proposal::Remove(_) => ProposalType::Remove,
Proposal::PreSharedKey(_) => ProposalType::PreSharedKey,
Proposal::ReInit(_) => ProposalType::Reinit,
Proposal::ExternalInit(_) => ProposalType::ExternalInit,
Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions,
#[cfg(feature = "extensions-draft-08")]
Proposal::AppDataUpdate(_) => ProposalType::AppDataUpdate,
Proposal::SelfRemove => ProposalType::SelfRemove,
#[cfg(feature = "extensions-draft-08")]
Proposal::AppEphemeral(_) => ProposalType::AppEphemeral,
Proposal::Custom(custom) => ProposalType::Custom(custom.proposal_type.to_owned()),
}
}
pub(crate) fn is_type(&self, proposal_type: ProposalType) -> bool {
self.proposal_type() == proposal_type
}
pub fn is_path_required(&self) -> bool {
self.proposal_type().is_path_required()
}
pub(crate) fn has_lower_priority_than(&self, new_proposal: &Proposal) -> bool {
match (self, new_proposal) {
(Proposal::Update(_), _) => true,
(Proposal::Remove(_), Proposal::Update(_)) => false,
(Proposal::Remove(_), Proposal::Remove(_)) => true,
(_, Proposal::SelfRemove) => true,
_ => {
debug_assert!(false);
false
}
}
}
pub(crate) fn as_remove(&self) -> Option<&RemoveProposal> {
if let Self::Remove(v) = self {
Some(v)
} else {
None
}
}
#[must_use]
pub fn is_remove(&self) -> bool {
matches!(self, Self::Remove(..))
}
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
pub struct AddProposal {
pub(crate) key_package: KeyPackage,
}
impl AddProposal {
pub fn key_package(&self) -> &KeyPackage {
&self.key_package
}
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
pub struct UpdateProposal {
pub(crate) leaf_node: LeafNode,
}
impl UpdateProposal {
pub fn leaf_node(&self) -> &LeafNode {
&self.leaf_node
}
}
#[derive(
Debug,
PartialEq,
Eq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct RemoveProposal {
pub(crate) removed: LeafNodeIndex,
}
impl RemoveProposal {
pub fn removed(&self) -> LeafNodeIndex {
self.removed
}
}
#[derive(
Debug,
PartialEq,
Eq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct PreSharedKeyProposal {
psk: PreSharedKeyId,
}
impl PreSharedKeyProposal {
pub(crate) fn into_psk_id(self) -> PreSharedKeyId {
self.psk
}
}
impl PreSharedKeyProposal {
pub fn new(psk: PreSharedKeyId) -> Self {
Self { psk }
}
}
#[derive(
Debug,
PartialEq,
Eq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct ReInitProposal {
pub(crate) group_id: GroupId,
pub(crate) version: ProtocolVersion,
pub(crate) ciphersuite: Ciphersuite,
pub(crate) extensions: Extensions<GroupContext>,
}
#[derive(
Debug,
PartialEq,
Eq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct ExternalInitProposal {
kem_output: VLBytes,
}
impl ExternalInitProposal {
pub(crate) fn kem_output(&self) -> &[u8] {
self.kem_output.as_slice()
}
}
impl From<Vec<u8>> for ExternalInitProposal {
fn from(kem_output: Vec<u8>) -> Self {
ExternalInitProposal {
kem_output: kem_output.into(),
}
}
}
#[cfg(feature = "extensions-draft-08")]
#[derive(
Debug,
PartialEq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct AppAck {
received_ranges: Vec<MessageRange>,
}
#[cfg(feature = "extensions-draft-08")]
#[derive(
Debug,
PartialEq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub struct AppEphemeralProposal {
component_id: ComponentId,
data: VLBytes,
}
#[cfg(feature = "extensions-draft-08")]
impl AppEphemeralProposal {
pub fn new(component_id: ComponentId, data: Vec<u8>) -> Self {
Self {
component_id,
data: data.into(),
}
}
pub fn component_id(&self) -> ComponentId {
self.component_id
}
pub fn data(&self) -> &[u8] {
self.data.as_slice()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct GroupContextExtensionProposal {
extensions: Extensions<GroupContext>,
}
impl Size for GroupContextExtensionProposal {
fn tls_serialized_len(&self) -> usize {
self.extensions.tls_serialized_len()
}
}
impl TlsSerializeTrait for GroupContextExtensionProposal {
fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
self.extensions.tls_serialize(writer)
}
}
impl GroupContextExtensionProposal {
pub(crate) fn new(extensions: Extensions<GroupContext>) -> Self {
Self { extensions }
}
pub fn extensions(&self) -> &Extensions<GroupContext> {
&self.extensions
}
pub fn into_extensions(self) -> Extensions<GroupContext> {
self.extensions
}
}
#[derive(
PartialEq,
Clone,
Copy,
Debug,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSize,
Serialize,
Deserialize,
)]
#[repr(u8)]
pub enum ProposalOrRefType {
Proposal = 1,
Reference = 2,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSerialize, TlsSize)]
#[repr(u8)]
#[allow(missing_docs)]
pub enum ProposalOrRef {
#[tls_codec(discriminant = 1)]
Proposal(Box<Proposal>),
Reference(Box<ProposalRef>),
}
impl ProposalOrRef {
pub(crate) fn proposal(p: Proposal) -> Self {
Self::Proposal(Box::new(p))
}
pub(crate) fn reference(p: ProposalRef) -> Self {
Self::Reference(Box::new(p))
}
pub(crate) fn as_proposal(&self) -> Option<&Proposal> {
if let Self::Proposal(v) = self {
Some(v)
} else {
None
}
}
pub(crate) fn as_reference(&self) -> Option<&ProposalRef> {
if let Self::Reference(v) = self {
Some(v)
} else {
None
}
}
}
impl From<Proposal> for ProposalOrRef {
fn from(value: Proposal) -> Self {
Self::proposal(value)
}
}
impl From<ProposalRef> for ProposalOrRef {
fn from(value: ProposalRef) -> Self {
Self::reference(value)
}
}
#[derive(Error, Debug)]
pub(crate) enum ProposalRefError {
#[error("Expected `Proposal`, got `{wrong:?}`.")]
AuthenticatedContentHasWrongType { wrong: ContentType },
#[error(transparent)]
Other(#[from] LibraryError),
}
impl ProposalRef {
pub(crate) fn from_authenticated_content_by_ref(
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
authenticated_content: &AuthenticatedContent,
) -> Result<Self, ProposalRefError> {
if !matches!(
authenticated_content.content(),
FramedContentBody::Proposal(_)
) {
return Err(ProposalRefError::AuthenticatedContentHasWrongType {
wrong: authenticated_content.content().content_type(),
});
};
let encoded = authenticated_content
.tls_serialize_detached()
.map_err(|error| ProposalRefError::Other(LibraryError::missing_bound_check(error)))?;
make_proposal_ref(&encoded, ciphersuite, crypto)
.map_err(|error| ProposalRefError::Other(LibraryError::unexpected_crypto_error(error)))
}
pub(crate) fn from_raw_proposal(
ciphersuite: Ciphersuite,
crypto: &impl OpenMlsCrypto,
proposal: &Proposal,
) -> Result<Self, LibraryError> {
let mut data = b"Internal OpenMLS ProposalRef Label".to_vec();
let mut encoded = proposal
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?;
data.append(&mut encoded);
make_proposal_ref(&data, ciphersuite, crypto).map_err(LibraryError::unexpected_crypto_error)
}
}
#[derive(
Debug,
PartialEq,
Clone,
Serialize,
Deserialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSerialize,
TlsSize,
)]
pub(crate) struct MessageRange {
sender: KeyPackageRef,
first_generation: u32,
last_generation: u32,
}
#[cfg(feature = "extensions-draft-08")]
mod app_data_update;
#[cfg(feature = "extensions-draft-08")]
pub use app_data_update::*;
#[derive(
Debug,
PartialEq,
Clone,
Serialize,
Deserialize,
TlsSize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
)]
pub struct CustomProposal {
proposal_type: u16,
payload: Vec<u8>,
}
impl CustomProposal {
pub fn new(proposal_type: u16, payload: Vec<u8>) -> Self {
Self {
proposal_type,
payload,
}
}
pub fn proposal_type(&self) -> u16 {
self.proposal_type
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
}
#[cfg(test)]
mod tests {
use tls_codec::{Deserialize, Serialize};
use super::ProposalType;
#[test]
fn that_unknown_proposal_types_are_de_serialized_correctly() {
let proposal_types = [0x0000u16, 0x0B0B, 0x7C7C, 0xF000, 0xFFFF];
for proposal_type in proposal_types.into_iter() {
let test = proposal_type.to_be_bytes().to_vec();
let got = ProposalType::tls_deserialize_exact(&test).unwrap();
match got {
ProposalType::Custom(got_proposal_type) => {
assert_eq!(proposal_type, got_proposal_type);
}
other => panic!("Expected `ProposalType::Unknown`, got `{other:?}`."),
}
let got_serialized = got.tls_serialize_detached().unwrap();
assert_eq!(test, got_serialized);
}
}
}