mls_spec/tree/
leaf_node.rs

1use crate::{
2    SensitiveBytes,
3    credential::Credential,
4    crypto::{HpkePublicKey, HpkePublicKeyRef, SignaturePublicKey, SignaturePublicKeyRef},
5    defs::{Capabilities, LeafIndex},
6    group::{KeyPackageLifetime, extensions::Extension},
7};
8
9#[derive(
10    Debug,
11    Clone,
12    Copy,
13    PartialEq,
14    Eq,
15    tls_codec::TlsSize,
16    tls_codec::TlsDeserialize,
17    tls_codec::TlsSerialize,
18    strum::Display,
19)]
20#[repr(u8)]
21pub enum LeafNodeSourceType {
22    Reserved = 0x00,
23    KeyPackage = 0x01,
24    Update = 0x02,
25    Commit = 0x03,
26}
27
28#[derive(
29    Debug,
30    Clone,
31    PartialEq,
32    Eq,
33    Hash,
34    tls_codec::TlsSerialize,
35    tls_codec::TlsDeserialize,
36    tls_codec::TlsSize,
37)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[repr(u8)]
40pub enum LeafNodeSource {
41    #[tls_codec(discriminant = "LeafNodeSourceType::KeyPackage")]
42    KeyPackage { lifetime: KeyPackageLifetime },
43    #[tls_codec(discriminant = "LeafNodeSourceType::Update")]
44    Update,
45    #[tls_codec(discriminant = "LeafNodeSourceType::Commit")]
46    Commit { parent_hash: SensitiveBytes },
47}
48
49impl From<&LeafNodeSource> for LeafNodeSourceType {
50    fn from(value: &LeafNodeSource) -> Self {
51        match value {
52            LeafNodeSource::KeyPackage { .. } => Self::KeyPackage,
53            LeafNodeSource::Update => Self::Update,
54            LeafNodeSource::Commit { .. } => Self::Commit,
55        }
56    }
57}
58
59#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize))]
61pub struct LeafNodeMemberInfo<'a> {
62    #[tls_codec(with = "crate::tlspl::bytes")]
63    pub group_id: &'a [u8],
64    pub leaf_index: LeafIndex,
65}
66
67#[derive(
68    Debug,
69    Clone,
70    PartialEq,
71    Eq,
72    Hash,
73    tls_codec::TlsSerialize,
74    tls_codec::TlsDeserialize,
75    tls_codec::TlsSize,
76)]
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78pub struct LeafNode {
79    pub encryption_key: HpkePublicKey,
80    pub signature_key: SignaturePublicKey,
81    pub credential: Credential,
82    pub capabilities: Capabilities,
83    pub source: LeafNodeSource,
84    pub extensions: Vec<Extension>,
85    pub signature: SensitiveBytes,
86}
87
88impl LeafNode {
89    #[inline]
90    pub fn requires_member_info(&self) -> bool {
91        matches!(
92            self.source,
93            LeafNodeSource::Update | LeafNodeSource::Commit { .. }
94        )
95    }
96
97    pub fn parent_hash(&self) -> Option<&[u8]> {
98        match &self.source {
99            LeafNodeSource::Commit { parent_hash } => Some(parent_hash),
100            _ => None,
101        }
102    }
103
104    pub fn to_tbs<'a>(
105        &'a self,
106        member_info: Option<LeafNodeMemberInfo<'a>>,
107    ) -> Option<LeafNodeTBS<'a>> {
108        Some(LeafNodeTBS {
109            encryption_key: &self.encryption_key,
110            signature_key: &self.signature_key,
111            credential: &self.credential,
112            capabilities: &self.capabilities,
113            source: &self.source,
114            extensions: &self.extensions,
115            member_info: if self.requires_member_info() {
116                // Invalid because in those context we should have a valid member_info
117                Some(member_info?)
118            } else {
119                None
120            },
121        })
122    }
123
124    pub fn application_id(&self) -> Option<&[u8]> {
125        self.extensions.iter().find_map(|ext| {
126            if let Extension::ApplicationId(app_id) = ext {
127                Some(app_id.as_slice())
128            } else {
129                None
130            }
131        })
132    }
133}
134
135#[derive(Debug, PartialEq, Eq)]
136pub struct LeafNodeTBS<'a> {
137    pub encryption_key: HpkePublicKeyRef<'a>,
138    pub signature_key: SignaturePublicKeyRef<'a>,
139    pub credential: &'a Credential,
140    pub capabilities: &'a Capabilities,
141    pub source: &'a LeafNodeSource,
142    pub extensions: &'a Vec<Extension>,
143    pub member_info: Option<LeafNodeMemberInfo<'a>>,
144}
145
146impl tls_codec::Size for LeafNodeTBS<'_> {
147    fn tls_serialized_len(&self) -> usize {
148        self.encryption_key.tls_serialized_len()
149            + self.signature_key.tls_serialized_len()
150            + self.credential.tls_serialized_len()
151            + self.capabilities.tls_serialized_len()
152            + self.source.tls_serialized_len()
153            + self.extensions.tls_serialized_len()
154            + self.member_info.map_or(0, |mi| mi.tls_serialized_len())
155    }
156}
157
158impl tls_codec::Serialize for LeafNodeTBS<'_> {
159    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
160        let mut written = 0;
161        written += crate::tlspl::bytes::tls_serialize(self.encryption_key, writer)?;
162        written += crate::tlspl::bytes::tls_serialize(self.signature_key, writer)?;
163        written += self.credential.tls_serialize(writer)?;
164        written += self.capabilities.tls_serialize(writer)?;
165        written += self.source.tls_serialize(writer)?;
166        written += self.extensions.tls_serialize(writer)?;
167        if let Some(member_info) = self.member_info {
168            written += member_info.tls_serialize(writer)?;
169        }
170
171        Ok(written)
172    }
173}