mls_spec/
tree.rs

1pub mod hashes;
2pub mod leaf_node;
3
4use crate::{
5    SensitiveBytes,
6    crypto::{HpkeCiphertext, HpkePublicKey},
7    defs::LeafIndex,
8    tree::{hashes::ParentNodeHash, leaf_node::LeafNode},
9};
10
11#[derive(
12    Debug,
13    Clone,
14    PartialEq,
15    Eq,
16    Hash,
17    Default,
18    tls_codec::TlsDeserialize,
19    tls_codec::TlsSerialize,
20    tls_codec::TlsSize,
21)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct RatchetTree(Vec<Option<TreeNode>>);
24
25impl RatchetTree {
26    pub fn into_inner(self) -> Vec<Option<TreeNode>> {
27        self.0
28    }
29}
30
31impl From<Vec<Option<TreeNode>>> for RatchetTree {
32    fn from(value: Vec<Option<TreeNode>>) -> Self {
33        Self(value)
34    }
35}
36
37impl std::ops::Deref for RatchetTree {
38    type Target = [Option<TreeNode>];
39
40    fn deref(&self) -> &Self::Target {
41        self.0.as_slice()
42    }
43}
44
45pub type TreeHash = SensitiveBytes;
46
47#[derive(
48    Debug,
49    Clone,
50    PartialEq,
51    Eq,
52    Hash,
53    tls_codec::TlsSerialize,
54    tls_codec::TlsDeserialize,
55    tls_codec::TlsSize,
56)]
57#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58pub struct ParentNode {
59    pub encryption_key: HpkePublicKey,
60    pub parent_hash: ParentNodeHash,
61    pub unmerged_leaves: Vec<LeafIndex>,
62}
63
64#[derive(
65    Debug,
66    Clone,
67    PartialEq,
68    Eq,
69    Hash,
70    tls_codec::TlsSerialize,
71    tls_codec::TlsDeserialize,
72    tls_codec::TlsSize,
73)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75#[repr(u8)]
76pub enum NodeType {
77    Reserved = 0x00,
78    Leaf = 0x01,
79    Parent = 0x02,
80}
81
82#[derive(
83    Debug,
84    Clone,
85    PartialEq,
86    Eq,
87    Hash,
88    tls_codec::TlsSerialize,
89    tls_codec::TlsDeserialize,
90    tls_codec::TlsSize,
91)]
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
93#[repr(u8)]
94#[allow(clippy::large_enum_variant)]
95pub enum TreeNode {
96    #[tls_codec(discriminant = "NodeType::Leaf")]
97    LeafNode(LeafNode),
98    #[tls_codec(discriminant = "NodeType::Parent")]
99    ParentNode(ParentNode),
100}
101
102impl From<LeafNode> for TreeNode {
103    fn from(value: LeafNode) -> Self {
104        Self::LeafNode(value)
105    }
106}
107
108impl From<ParentNode> for TreeNode {
109    fn from(value: ParentNode) -> Self {
110        Self::ParentNode(value)
111    }
112}
113
114impl TreeNode {
115    pub fn as_leaf_node(&self) -> Option<&LeafNode> {
116        if let Self::LeafNode(leaf_node) = &self {
117            Some(leaf_node)
118        } else {
119            None
120        }
121    }
122
123    pub fn as_leaf_node_mut(&mut self) -> Option<&mut LeafNode> {
124        if let Self::LeafNode(leaf_node) = self {
125            Some(leaf_node)
126        } else {
127            None
128        }
129    }
130
131    pub fn as_parent_node(&self) -> Option<&ParentNode> {
132        if let Self::ParentNode(parent_node) = &self {
133            Some(parent_node)
134        } else {
135            None
136        }
137    }
138
139    pub fn as_parent_node_mut(&mut self) -> Option<&mut ParentNode> {
140        if let Self::ParentNode(parent_node) = self {
141            Some(parent_node)
142        } else {
143            None
144        }
145    }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
149#[repr(u8)]
150pub enum TreeNodeRef<'a> {
151    #[tls_codec(discriminant = "NodeType::Leaf")]
152    LeafNode(&'a LeafNode),
153    #[tls_codec(discriminant = "NodeType::Parent")]
154    ParentNode(&'a ParentNode),
155}
156
157#[derive(
158    Debug,
159    Clone,
160    PartialEq,
161    Eq,
162    tls_codec::TlsSerialize,
163    tls_codec::TlsDeserialize,
164    tls_codec::TlsSize,
165)]
166#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
167pub struct UpdatePathNode {
168    pub encryption_key: HpkePublicKey,
169    pub encrypted_path_secret: Vec<HpkeCiphertext>,
170}
171
172#[derive(
173    Debug,
174    Clone,
175    PartialEq,
176    Eq,
177    tls_codec::TlsSerialize,
178    tls_codec::TlsDeserialize,
179    tls_codec::TlsSize,
180)]
181#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
182pub struct UpdatePath {
183    pub leaf_node: LeafNode,
184    pub nodes: Vec<UpdatePathNode>,
185}