1use core::fmt::Display;
2
3use super::{Felt, MerkleError, Word};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27 depth: u8,
28 position: u64,
29}
30
31impl NodeIndex {
32 pub const fn new(depth: u8, position: u64) -> Result<Self, MerkleError> {
42 if depth > 64 {
43 Err(MerkleError::DepthTooBig(depth as u64))
44 } else if (64 - position.leading_zeros()) > depth as u32 {
45 Err(MerkleError::InvalidNodeIndex { depth, position })
46 } else {
47 Ok(Self { depth, position })
48 }
49 }
50
51 pub const fn new_unchecked(depth: u8, position: u64) -> Self {
53 debug_assert!(depth <= 64);
54 debug_assert!((64 - position.leading_zeros()) <= depth as u32);
55 Self { depth, position }
56 }
57
58 #[cfg(test)]
63 pub fn make(depth: u8, position: u64) -> Self {
64 Self::new(depth, position).unwrap()
65 }
66
67 pub fn from_elements(depth: &Felt, position: &Felt) -> Result<Self, MerkleError> {
74 let depth = depth.as_canonical_u64();
75 let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
76 let position = position.as_canonical_u64();
77 Self::new(depth, position)
78 }
79
80 pub const fn root() -> Self {
82 Self { depth: 0, position: 0 }
83 }
84
85 pub const fn sibling(mut self) -> Self {
87 self.position ^= 1;
88 self
89 }
90
91 pub const fn left_child(mut self) -> Self {
93 self.depth += 1;
94 self.position <<= 1;
95 self
96 }
97
98 pub const fn right_child(mut self) -> Self {
100 self.depth += 1;
101 self.position = (self.position << 1) + 1;
102 self
103 }
104
105 pub const fn parent(mut self) -> Self {
108 self.depth = self.depth.saturating_sub(1);
109 self.position >>= 1;
110 self
111 }
112
113 pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
120 if self.is_position_odd() {
121 [sibling, slf]
122 } else {
123 [slf, sibling]
124 }
125 }
126
127 pub const fn to_scalar_index(&self) -> Result<u64, MerkleError> {
136 if self.depth >= 64 {
137 return Err(MerkleError::DepthTooBig(self.depth as u64));
138 }
139 Ok((1u64 << self.depth as u64) + self.position)
140 }
141
142 pub const fn depth(&self) -> u8 {
144 self.depth
145 }
146
147 pub const fn position(&self) -> u64 {
149 self.position
150 }
151
152 pub const fn is_position_odd(&self) -> bool {
154 (self.position & 1) == 1
155 }
156
157 pub const fn is_nth_bit_odd(&self, n: u8) -> bool {
159 (self.position >> n) & 1 == 1
160 }
161
162 pub const fn is_root(&self) -> bool {
164 self.depth == 0
165 }
166
167 pub fn move_up(&mut self) {
172 self.depth = self.depth.saturating_sub(1);
173 self.position >>= 1;
174 }
175
176 pub fn move_up_to(&mut self, depth: u8) {
180 debug_assert!(depth < self.depth);
181 let delta = self.depth.saturating_sub(depth);
182 self.depth = self.depth.saturating_sub(delta);
183 self.position >>= delta as u32;
184 }
185
186 pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
195 ProofIter { next_index: self.sibling() }
196 }
197}
198
199impl Display for NodeIndex {
200 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
201 write!(f, "depth={}, position={}", self.depth, self.position)
202 }
203}
204
205impl Serializable for NodeIndex {
206 fn write_into<W: ByteWriter>(&self, target: &mut W) {
207 target.write_u8(self.depth);
208 target.write_u64(self.position);
209 }
210}
211
212impl Deserializable for NodeIndex {
213 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
214 let depth = source.read_u8()?;
215 let position = source.read_u64()?;
216 NodeIndex::new(depth, position)
217 .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
218 }
219
220 fn min_serialized_size() -> usize {
221 9
223 }
224}
225
226#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
228struct ProofIter {
229 next_index: NodeIndex,
230}
231
232impl Iterator for ProofIter {
233 type Item = NodeIndex;
234
235 fn next(&mut self) -> Option<NodeIndex> {
236 if self.next_index.is_root() {
237 return None;
238 }
239
240 let index = self.next_index;
241 self.next_index = index.parent().sibling();
242
243 Some(index)
244 }
245
246 fn size_hint(&self) -> (usize, Option<usize>) {
247 let remaining = ExactSizeIterator::len(self);
248
249 (remaining, Some(remaining))
250 }
251}
252
253impl ExactSizeIterator for ProofIter {
254 fn len(&self) -> usize {
255 self.next_index.depth() as usize
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use assert_matches::assert_matches;
262 use proptest::prelude::*;
263
264 use super::*;
265
266 #[test]
267 fn test_node_index_position_too_high() {
268 assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, position: 0 });
269 let err = NodeIndex::new(0, 1).unwrap_err();
270 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, position: 1 });
271
272 assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, position: 1 });
273 let err = NodeIndex::new(1, 2).unwrap_err();
274 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, position: 2 });
275
276 assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, position: 3 });
277 let err = NodeIndex::new(2, 4).unwrap_err();
278 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, position: 4 });
279
280 assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, position: 7 });
281 let err = NodeIndex::new(3, 8).unwrap_err();
282 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, position: 8 });
283 }
284
285 #[test]
286 fn test_node_index_can_represent_depth_64() {
287 assert!(NodeIndex::new(64, u64::MAX).is_ok());
288 }
289
290 prop_compose! {
291 fn node_index()(position in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
292 let mut depth = position.ilog2() as u8;
294 if position > (1 << depth) { depth += 1;
296 }
297 NodeIndex::new(depth, position).unwrap()
298 }
299 }
300
301 proptest! {
302 #[test]
303 fn arbitrary_index_wont_panic_on_move_up(
304 mut index in node_index(),
305 count in prop::num::u8::ANY,
306 ) {
307 for _ in 0..count {
308 index.move_up();
309 }
310 }
311
312 #[test]
313 fn to_scalar_index_succeeds_for_depth_lt_64(depth in 0u8..64, position_bits in 0u64..u64::MAX) {
314 let position = if depth == 0 { 0 } else { position_bits % (1u64 << depth) };
315 let index = NodeIndex::new(depth, position).unwrap();
316 assert!(index.to_scalar_index().is_ok());
317 }
318 }
319
320 #[test]
321 fn test_to_scalar_index_depth_64_returns_error() {
322 let index = NodeIndex::new(64, 0).unwrap();
323 assert_matches!(index.to_scalar_index(), Err(MerkleError::DepthTooBig(64)));
324
325 let index = NodeIndex::new(64, u64::MAX).unwrap();
326 assert_matches!(index.to_scalar_index(), Err(MerkleError::DepthTooBig(64)));
327 }
328
329 #[test]
330 fn test_to_scalar_index_known_values() {
331 assert_eq!(NodeIndex::make(1, 0).to_scalar_index().unwrap(), 2);
333 assert_eq!(NodeIndex::make(1, 1).to_scalar_index().unwrap(), 3);
334
335 assert_eq!(NodeIndex::make(2, 0).to_scalar_index().unwrap(), 4);
337 assert_eq!(NodeIndex::make(2, 3).to_scalar_index().unwrap(), 7);
338
339 assert_eq!(NodeIndex::make(3, 0).to_scalar_index().unwrap(), 8);
341 assert_eq!(NodeIndex::make(3, 7).to_scalar_index().unwrap(), 15);
342 }
343
344 #[test]
345 fn test_to_scalar_index_depth_63_max_position() {
346 let index = NodeIndex::new(63, (1u64 << 63) - 1).unwrap();
348 assert_eq!(index.to_scalar_index().unwrap(), u64::MAX);
349 }
350
351 #[test]
352 fn test_to_scalar_index_boundary_depths() {
353 assert_eq!(NodeIndex::make(0, 0).to_scalar_index().unwrap(), 1);
355
356 assert_eq!(NodeIndex::make(62, 0).to_scalar_index().unwrap(), 1u64 << 62);
358
359 assert_eq!(NodeIndex::make(63, 0).to_scalar_index().unwrap(), 1u64 << 63);
361 }
362}