mls_spec/tree/
leaf_node.rs

1use crate::{
2    SensitiveBytes,
3    credential::Credential,
4    crypto::{HpkePublicKey, HpkePublicKeyRef, SignaturePublicKey, SignaturePublicKeyRef},
5    defs::{Capabilities, LeafIndex},
6    group::{KeyPackageLifetime, extensions::Extension},
7};
8
9#[derive(
10    Debug,
11    Clone,
12    Copy,
13    PartialEq,
14    Eq,
15    tls_codec::TlsSize,
16    tls_codec::TlsDeserialize,
17    tls_codec::TlsSerialize,
18    strum::Display,
19)]
20#[repr(u8)]
21pub enum LeafNodeSourceType {
22    Reserved = 0x00,
23    KeyPackage = 0x01,
24    Update = 0x02,
25    Commit = 0x03,
26}
27
28#[derive(
29    Debug,
30    Clone,
31    PartialEq,
32    Eq,
33    tls_codec::TlsSerialize,
34    tls_codec::TlsDeserialize,
35    tls_codec::TlsSize,
36)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38#[repr(u8)]
39pub enum LeafNodeSource {
40    #[tls_codec(discriminant = "LeafNodeSourceType::KeyPackage")]
41    KeyPackage { lifetime: KeyPackageLifetime },
42    #[tls_codec(discriminant = "LeafNodeSourceType::Update")]
43    Update,
44    #[tls_codec(discriminant = "LeafNodeSourceType::Commit")]
45    Commit { parent_hash: SensitiveBytes },
46}
47
48impl From<&LeafNodeSource> for LeafNodeSourceType {
49    fn from(value: &LeafNodeSource) -> Self {
50        match value {
51            LeafNodeSource::KeyPackage { .. } => Self::KeyPackage,
52            LeafNodeSource::Update => Self::Update,
53            LeafNodeSource::Commit { .. } => Self::Commit,
54        }
55    }
56}
57
58#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize))]
60pub struct LeafNodeMemberInfo<'a> {
61    #[tls_codec(with = "crate::tlspl::bytes")]
62    pub group_id: &'a [u8],
63    pub leaf_index: LeafIndex,
64}
65
66#[derive(
67    Debug,
68    Clone,
69    PartialEq,
70    Eq,
71    tls_codec::TlsSerialize,
72    tls_codec::TlsDeserialize,
73    tls_codec::TlsSize,
74)]
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76pub struct LeafNode {
77    pub encryption_key: HpkePublicKey,
78    pub signature_key: SignaturePublicKey,
79    pub credential: Credential,
80    pub capabilities: Capabilities,
81    pub source: LeafNodeSource,
82    pub extensions: Vec<Extension>,
83    pub signature: SensitiveBytes,
84}
85
86impl LeafNode {
87    #[inline]
88    pub fn requires_member_info(&self) -> bool {
89        matches!(
90            self.source,
91            LeafNodeSource::Update | LeafNodeSource::Commit { .. }
92        )
93    }
94
95    pub fn parent_hash(&self) -> Option<&[u8]> {
96        match &self.source {
97            LeafNodeSource::Commit { parent_hash } => Some(parent_hash),
98            _ => None,
99        }
100    }
101
102    pub fn to_tbs<'a>(
103        &'a self,
104        member_info: Option<LeafNodeMemberInfo<'a>>,
105    ) -> Option<LeafNodeTBS<'a>> {
106        Some(LeafNodeTBS {
107            encryption_key: &self.encryption_key,
108            signature_key: &self.signature_key,
109            credential: &self.credential,
110            capabilities: &self.capabilities,
111            source: &self.source,
112            extensions: &self.extensions,
113            member_info: if self.requires_member_info() {
114                // Invalid because in those context we should have a valid member_info
115                Some(member_info?)
116            } else {
117                None
118            },
119        })
120    }
121
122    pub fn application_id(&self) -> Option<&[u8]> {
123        self.extensions.iter().find_map(|ext| {
124            if let Extension::ApplicationId(app_id) = ext {
125                Some(app_id.as_slice())
126            } else {
127                None
128            }
129        })
130    }
131}
132
133#[derive(Debug, PartialEq, Eq)]
134pub struct LeafNodeTBS<'a> {
135    pub encryption_key: HpkePublicKeyRef<'a>,
136    pub signature_key: SignaturePublicKeyRef<'a>,
137    pub credential: &'a Credential,
138    pub capabilities: &'a Capabilities,
139    pub source: &'a LeafNodeSource,
140    pub extensions: &'a Vec<Extension>,
141    pub member_info: Option<LeafNodeMemberInfo<'a>>,
142}
143
144impl tls_codec::Size for LeafNodeTBS<'_> {
145    fn tls_serialized_len(&self) -> usize {
146        self.encryption_key.tls_serialized_len()
147            + self.signature_key.tls_serialized_len()
148            + self.credential.tls_serialized_len()
149            + self.capabilities.tls_serialized_len()
150            + self.source.tls_serialized_len()
151            + self.extensions.tls_serialized_len()
152            + self.member_info.map_or(0, |mi| mi.tls_serialized_len())
153    }
154}
155
156impl tls_codec::Serialize for LeafNodeTBS<'_> {
157    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
158        let mut written = 0;
159        written += crate::tlspl::bytes::tls_serialize(self.encryption_key, writer)?;
160        written += crate::tlspl::bytes::tls_serialize(self.signature_key, writer)?;
161        written += self.credential.tls_serialize(writer)?;
162        written += self.capabilities.tls_serialize(writer)?;
163        written += self.source.tls_serialize(writer)?;
164        written += self.extensions.tls_serialize(writer)?;
165        if let Some(member_info) = self.member_info {
166            written += member_info.tls_serialize(writer)?;
167        }
168
169        Ok(written)
170    }
171}