miden_crypto/merkle/
index.rs1use core::fmt::Display;
2
3use super::{Felt, MerkleError, RpoDigest};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27 depth: u8,
28 value: u64,
29}
30
31impl NodeIndex {
32 pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
40 if (64 - value.leading_zeros()) > depth as u32 {
41 Err(MerkleError::InvalidNodeIndex { depth, value })
42 } else {
43 Ok(Self { depth, value })
44 }
45 }
46
47 pub const fn new_unchecked(depth: u8, value: u64) -> Self {
49 debug_assert!((64 - value.leading_zeros()) <= depth as u32);
50 Self { depth, value }
51 }
52
53 #[cfg(test)]
58 pub fn make(depth: u8, value: u64) -> Self {
59 Self::new(depth, value).unwrap()
60 }
61
62 pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
69 let depth = depth.as_int();
70 let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
71 let value = value.as_int();
72 Self::new(depth, value)
73 }
74
75 pub const fn root() -> Self {
77 Self { depth: 0, value: 0 }
78 }
79
80 pub const fn sibling(mut self) -> Self {
82 self.value ^= 1;
83 self
84 }
85
86 pub const fn left_child(mut self) -> Self {
88 self.depth += 1;
89 self.value <<= 1;
90 self
91 }
92
93 pub const fn right_child(mut self) -> Self {
95 self.depth += 1;
96 self.value = (self.value << 1) + 1;
97 self
98 }
99
100 pub const fn parent(mut self) -> Self {
103 self.depth = self.depth.saturating_sub(1);
104 self.value >>= 1;
105 self
106 }
107
108 pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
115 if self.is_value_odd() {
116 [sibling, slf]
117 } else {
118 [slf, sibling]
119 }
120 }
121
122 pub const fn to_scalar_index(&self) -> u64 {
126 (1 << self.depth as u64) + self.value
127 }
128
129 pub const fn depth(&self) -> u8 {
131 self.depth
132 }
133
134 pub const fn value(&self) -> u64 {
136 self.value
137 }
138
139 pub const fn is_value_odd(&self) -> bool {
141 (self.value & 1) == 1
142 }
143
144 pub const fn is_root(&self) -> bool {
146 self.depth == 0
147 }
148
149 pub fn move_up(&mut self) {
154 self.depth = self.depth.saturating_sub(1);
155 self.value >>= 1;
156 }
157
158 pub fn move_up_to(&mut self, depth: u8) {
162 debug_assert!(depth < self.depth);
163 let delta = self.depth.saturating_sub(depth);
164 self.depth = self.depth.saturating_sub(delta);
165 self.value >>= delta as u32;
166 }
167}
168
169impl Display for NodeIndex {
170 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
171 write!(f, "depth={}, value={}", self.depth, self.value)
172 }
173}
174
175impl Serializable for NodeIndex {
176 fn write_into<W: ByteWriter>(&self, target: &mut W) {
177 target.write_u8(self.depth);
178 target.write_u64(self.value);
179 }
180}
181
182impl Deserializable for NodeIndex {
183 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
184 let depth = source.read_u8()?;
185 let value = source.read_u64()?;
186 NodeIndex::new(depth, value)
187 .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use assert_matches::assert_matches;
194 use proptest::prelude::*;
195
196 use super::*;
197
198 #[test]
199 fn test_node_index_value_too_high() {
200 assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
201 let err = NodeIndex::new(0, 1).unwrap_err();
202 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
203
204 assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
205 let err = NodeIndex::new(1, 2).unwrap_err();
206 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
207
208 assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
209 let err = NodeIndex::new(2, 4).unwrap_err();
210 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
211
212 assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
213 let err = NodeIndex::new(3, 8).unwrap_err();
214 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
215 }
216
217 #[test]
218 fn test_node_index_can_represent_depth_64() {
219 assert!(NodeIndex::new(64, u64::MAX).is_ok());
220 }
221
222 prop_compose! {
223 fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
224 let mut depth = value.ilog2() as u8;
226 if value > (1 << depth) { depth += 1;
228 }
229 NodeIndex::new(depth, value).unwrap()
230 }
231 }
232
233 proptest! {
234 #[test]
235 fn arbitrary_index_wont_panic_on_move_up(
236 mut index in node_index(),
237 count in prop::num::u8::ANY,
238 ) {
239 for _ in 0..count {
240 index.move_up();
241 }
242 }
243 }
244}