mls_spec/
tree.rs

1pub mod hashes;
2pub mod leaf_node;
3
4use crate::{
5    SensitiveBytes,
6    crypto::{HpkeCiphertext, HpkePublicKey},
7    defs::LeafIndex,
8    tree::{hashes::ParentNodeHash, leaf_node::LeafNode},
9};
10
11#[derive(
12    Debug,
13    Clone,
14    PartialEq,
15    Eq,
16    Default,
17    tls_codec::TlsDeserialize,
18    tls_codec::TlsSerialize,
19    tls_codec::TlsSize,
20)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct RatchetTree(Vec<Option<TreeNode>>);
23
24impl RatchetTree {
25    pub fn into_inner(self) -> Vec<Option<TreeNode>> {
26        self.0
27    }
28}
29
30impl From<Vec<Option<TreeNode>>> for RatchetTree {
31    fn from(value: Vec<Option<TreeNode>>) -> Self {
32        Self(value)
33    }
34}
35
36impl std::ops::Deref for RatchetTree {
37    type Target = [Option<TreeNode>];
38
39    fn deref(&self) -> &Self::Target {
40        self.0.as_slice()
41    }
42}
43
44pub type TreeHash = SensitiveBytes;
45
46#[derive(
47    Debug,
48    Clone,
49    PartialEq,
50    Eq,
51    tls_codec::TlsSerialize,
52    tls_codec::TlsDeserialize,
53    tls_codec::TlsSize,
54)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct ParentNode {
57    pub encryption_key: HpkePublicKey,
58    pub parent_hash: ParentNodeHash,
59    pub unmerged_leaves: Vec<LeafIndex>,
60}
61
62#[derive(
63    Debug,
64    Clone,
65    PartialEq,
66    Eq,
67    tls_codec::TlsSerialize,
68    tls_codec::TlsDeserialize,
69    tls_codec::TlsSize,
70)]
71#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72#[repr(u8)]
73pub enum NodeType {
74    Reserved = 0x00,
75    Leaf = 0x01,
76    Parent = 0x02,
77}
78
79#[derive(
80    Debug,
81    Clone,
82    PartialEq,
83    Eq,
84    tls_codec::TlsSerialize,
85    tls_codec::TlsDeserialize,
86    tls_codec::TlsSize,
87)]
88#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
89#[repr(u8)]
90#[allow(clippy::large_enum_variant)]
91pub enum TreeNode {
92    #[tls_codec(discriminant = "NodeType::Leaf")]
93    LeafNode(LeafNode),
94    #[tls_codec(discriminant = "NodeType::Parent")]
95    ParentNode(ParentNode),
96}
97
98impl From<LeafNode> for TreeNode {
99    fn from(value: LeafNode) -> Self {
100        Self::LeafNode(value)
101    }
102}
103
104impl From<ParentNode> for TreeNode {
105    fn from(value: ParentNode) -> Self {
106        Self::ParentNode(value)
107    }
108}
109
110impl TreeNode {
111    pub fn as_leaf_node(&self) -> Option<&LeafNode> {
112        if let Self::LeafNode(leaf_node) = &self {
113            Some(leaf_node)
114        } else {
115            None
116        }
117    }
118
119    pub fn as_leaf_node_mut(&mut self) -> Option<&mut LeafNode> {
120        if let Self::LeafNode(leaf_node) = self {
121            Some(leaf_node)
122        } else {
123            None
124        }
125    }
126
127    pub fn as_parent_node(&self) -> Option<&ParentNode> {
128        if let Self::ParentNode(parent_node) = &self {
129            Some(parent_node)
130        } else {
131            None
132        }
133    }
134
135    pub fn as_parent_node_mut(&mut self) -> Option<&mut ParentNode> {
136        if let Self::ParentNode(parent_node) = self {
137            Some(parent_node)
138        } else {
139            None
140        }
141    }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
145#[repr(u8)]
146pub enum TreeNodeRef<'a> {
147    #[tls_codec(discriminant = "NodeType::Leaf")]
148    LeafNode(&'a LeafNode),
149    #[tls_codec(discriminant = "NodeType::Parent")]
150    ParentNode(&'a ParentNode),
151}
152
153#[derive(
154    Debug,
155    Clone,
156    PartialEq,
157    Eq,
158    tls_codec::TlsSerialize,
159    tls_codec::TlsDeserialize,
160    tls_codec::TlsSize,
161)]
162#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
163pub struct UpdatePathNode {
164    pub encryption_key: HpkePublicKey,
165    pub encrypted_path_secret: Vec<HpkeCiphertext>,
166}
167
168#[derive(
169    Debug,
170    Clone,
171    PartialEq,
172    Eq,
173    tls_codec::TlsSerialize,
174    tls_codec::TlsDeserialize,
175    tls_codec::TlsSize,
176)]
177#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
178pub struct UpdatePath {
179    pub leaf_node: LeafNode,
180    pub nodes: Vec<UpdatePathNode>,
181}