miden_crypto/merkle/
index.rs1use core::fmt::Display;
2
3use p3_field::PrimeField64;
4
5use super::{Felt, MerkleError, Word};
6use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
7
8#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct NodeIndex {
29 depth: u8,
30 value: u64,
31}
32
33impl NodeIndex {
34 pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
44 if depth > 64 {
45 Err(MerkleError::DepthTooBig(depth as u64))
46 } else if (64 - value.leading_zeros()) > depth as u32 {
47 Err(MerkleError::InvalidNodeIndex { depth, value })
48 } else {
49 Ok(Self { depth, value })
50 }
51 }
52
53 pub const fn new_unchecked(depth: u8, value: u64) -> Self {
55 debug_assert!(depth <= 64);
56 debug_assert!((64 - value.leading_zeros()) <= depth as u32);
57 Self { depth, value }
58 }
59
60 #[cfg(test)]
65 pub fn make(depth: u8, value: u64) -> Self {
66 Self::new(depth, value).unwrap()
67 }
68
69 pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
76 let depth = depth.as_canonical_u64();
77 let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
78 let value = value.as_canonical_u64();
79 Self::new(depth, value)
80 }
81
82 pub const fn root() -> Self {
84 Self { depth: 0, value: 0 }
85 }
86
87 pub const fn sibling(mut self) -> Self {
89 self.value ^= 1;
90 self
91 }
92
93 pub const fn left_child(mut self) -> Self {
95 self.depth += 1;
96 self.value <<= 1;
97 self
98 }
99
100 pub const fn right_child(mut self) -> Self {
102 self.depth += 1;
103 self.value = (self.value << 1) + 1;
104 self
105 }
106
107 pub const fn parent(mut self) -> Self {
110 self.depth = self.depth.saturating_sub(1);
111 self.value >>= 1;
112 self
113 }
114
115 pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
122 if self.is_value_odd() {
123 [sibling, slf]
124 } else {
125 [slf, sibling]
126 }
127 }
128
129 pub const fn to_scalar_index(&self) -> u64 {
133 (1 << self.depth as u64) + self.value
134 }
135
136 pub const fn depth(&self) -> u8 {
138 self.depth
139 }
140
141 pub const fn value(&self) -> u64 {
143 self.value
144 }
145
146 pub const fn is_value_odd(&self) -> bool {
148 (self.value & 1) == 1
149 }
150
151 pub const fn is_nth_bit_odd(&self, n: u8) -> bool {
153 (self.value >> n) & 1 == 1
154 }
155
156 pub const fn is_root(&self) -> bool {
158 self.depth == 0
159 }
160
161 pub fn move_up(&mut self) {
166 self.depth = self.depth.saturating_sub(1);
167 self.value >>= 1;
168 }
169
170 pub fn move_up_to(&mut self, depth: u8) {
174 debug_assert!(depth < self.depth);
175 let delta = self.depth.saturating_sub(depth);
176 self.depth = self.depth.saturating_sub(delta);
177 self.value >>= delta as u32;
178 }
179
180 pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
189 ProofIter { next_index: self.sibling() }
190 }
191}
192
193impl Display for NodeIndex {
194 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
195 write!(f, "depth={}, value={}", self.depth, self.value)
196 }
197}
198
199impl Serializable for NodeIndex {
200 fn write_into<W: ByteWriter>(&self, target: &mut W) {
201 target.write_u8(self.depth);
202 target.write_u64(self.value);
203 }
204}
205
206impl Deserializable for NodeIndex {
207 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
208 let depth = source.read_u8()?;
209 let value = source.read_u64()?;
210 NodeIndex::new(depth, value)
211 .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
212 }
213}
214
215#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
217struct ProofIter {
218 next_index: NodeIndex,
219}
220
221impl Iterator for ProofIter {
222 type Item = NodeIndex;
223
224 fn next(&mut self) -> Option<NodeIndex> {
225 if self.next_index.is_root() {
226 return None;
227 }
228
229 let index = self.next_index;
230 self.next_index = index.parent().sibling();
231
232 Some(index)
233 }
234
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 let remaining = ExactSizeIterator::len(self);
237
238 (remaining, Some(remaining))
239 }
240}
241
242impl ExactSizeIterator for ProofIter {
243 fn len(&self) -> usize {
244 self.next_index.depth() as usize
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use assert_matches::assert_matches;
251 use proptest::prelude::*;
252
253 use super::*;
254
255 #[test]
256 fn test_node_index_value_too_high() {
257 assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
258 let err = NodeIndex::new(0, 1).unwrap_err();
259 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
260
261 assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
262 let err = NodeIndex::new(1, 2).unwrap_err();
263 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
264
265 assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
266 let err = NodeIndex::new(2, 4).unwrap_err();
267 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
268
269 assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
270 let err = NodeIndex::new(3, 8).unwrap_err();
271 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
272 }
273
274 #[test]
275 fn test_node_index_can_represent_depth_64() {
276 assert!(NodeIndex::new(64, u64::MAX).is_ok());
277 }
278
279 prop_compose! {
280 fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
281 let mut depth = value.ilog2() as u8;
283 if value > (1 << depth) { depth += 1;
285 }
286 NodeIndex::new(depth, value).unwrap()
287 }
288 }
289
290 proptest! {
291 #[test]
292 fn arbitrary_index_wont_panic_on_move_up(
293 mut index in node_index(),
294 count in prop::num::u8::ANY,
295 ) {
296 for _ in 0..count {
297 index.move_up();
298 }
299 }
300 }
301}