miden_crypto/merkle/
index.rs1use core::fmt::Display;
2
3use super::{Felt, MerkleError, Word};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27 depth: u8,
28 value: u64,
29}
30
31impl NodeIndex {
32 pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
42 if depth > 64 {
43 Err(MerkleError::DepthTooBig(depth as u64))
44 } else if (64 - value.leading_zeros()) > depth as u32 {
45 Err(MerkleError::InvalidNodeIndex { depth, value })
46 } else {
47 Ok(Self { depth, value })
48 }
49 }
50
51 pub const fn new_unchecked(depth: u8, value: u64) -> Self {
53 debug_assert!(depth <= 64);
54 debug_assert!((64 - value.leading_zeros()) <= depth as u32);
55 Self { depth, value }
56 }
57
58 #[cfg(test)]
63 pub fn make(depth: u8, value: u64) -> Self {
64 Self::new(depth, value).unwrap()
65 }
66
67 pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
74 let depth = depth.as_int();
75 let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
76 let value = value.as_int();
77 Self::new(depth, value)
78 }
79
80 pub const fn root() -> Self {
82 Self { depth: 0, value: 0 }
83 }
84
85 pub const fn sibling(mut self) -> Self {
87 self.value ^= 1;
88 self
89 }
90
91 pub const fn left_child(mut self) -> Self {
93 self.depth += 1;
94 self.value <<= 1;
95 self
96 }
97
98 pub const fn right_child(mut self) -> Self {
100 self.depth += 1;
101 self.value = (self.value << 1) + 1;
102 self
103 }
104
105 pub const fn parent(mut self) -> Self {
108 self.depth = self.depth.saturating_sub(1);
109 self.value >>= 1;
110 self
111 }
112
113 pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
120 if self.is_value_odd() {
121 [sibling, slf]
122 } else {
123 [slf, sibling]
124 }
125 }
126
127 pub const fn to_scalar_index(&self) -> u64 {
131 (1 << self.depth as u64) + self.value
132 }
133
134 pub const fn depth(&self) -> u8 {
136 self.depth
137 }
138
139 pub const fn value(&self) -> u64 {
141 self.value
142 }
143
144 pub const fn is_value_odd(&self) -> bool {
146 (self.value & 1) == 1
147 }
148
149 pub const fn is_nth_bit_odd(&self, n: u8) -> bool {
151 (self.value >> n) & 1 == 1
152 }
153
154 pub const fn is_root(&self) -> bool {
156 self.depth == 0
157 }
158
159 pub fn move_up(&mut self) {
164 self.depth = self.depth.saturating_sub(1);
165 self.value >>= 1;
166 }
167
168 pub fn move_up_to(&mut self, depth: u8) {
172 debug_assert!(depth < self.depth);
173 let delta = self.depth.saturating_sub(depth);
174 self.depth = self.depth.saturating_sub(delta);
175 self.value >>= delta as u32;
176 }
177
178 pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
187 ProofIter { next_index: self.sibling() }
188 }
189}
190
191impl Display for NodeIndex {
192 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
193 write!(f, "depth={}, value={}", self.depth, self.value)
194 }
195}
196
197impl Serializable for NodeIndex {
198 fn write_into<W: ByteWriter>(&self, target: &mut W) {
199 target.write_u8(self.depth);
200 target.write_u64(self.value);
201 }
202}
203
204impl Deserializable for NodeIndex {
205 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
206 let depth = source.read_u8()?;
207 let value = source.read_u64()?;
208 NodeIndex::new(depth, value)
209 .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
210 }
211}
212
213#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
215struct ProofIter {
216 next_index: NodeIndex,
217}
218
219impl Iterator for ProofIter {
220 type Item = NodeIndex;
221
222 fn next(&mut self) -> Option<NodeIndex> {
223 if self.next_index.is_root() {
224 return None;
225 }
226
227 let index = self.next_index;
228 self.next_index = index.parent().sibling();
229
230 Some(index)
231 }
232
233 fn size_hint(&self) -> (usize, Option<usize>) {
234 let remaining = ExactSizeIterator::len(self);
235
236 (remaining, Some(remaining))
237 }
238}
239
240impl ExactSizeIterator for ProofIter {
241 fn len(&self) -> usize {
242 self.next_index.depth() as usize
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use assert_matches::assert_matches;
249 use proptest::prelude::*;
250
251 use super::*;
252
253 #[test]
254 fn test_node_index_value_too_high() {
255 assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
256 let err = NodeIndex::new(0, 1).unwrap_err();
257 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
258
259 assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
260 let err = NodeIndex::new(1, 2).unwrap_err();
261 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
262
263 assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
264 let err = NodeIndex::new(2, 4).unwrap_err();
265 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
266
267 assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
268 let err = NodeIndex::new(3, 8).unwrap_err();
269 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
270 }
271
272 #[test]
273 fn test_node_index_can_represent_depth_64() {
274 assert!(NodeIndex::new(64, u64::MAX).is_ok());
275 }
276
277 prop_compose! {
278 fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
279 let mut depth = value.ilog2() as u8;
281 if value > (1 << depth) { depth += 1;
283 }
284 NodeIndex::new(depth, value).unwrap()
285 }
286 }
287
288 proptest! {
289 #[test]
290 fn arbitrary_index_wont_panic_on_move_up(
291 mut index in node_index(),
292 count in prop::num::u8::ANY,
293 ) {
294 for _ in 0..count {
295 index.move_up();
296 }
297 }
298 }
299}