miden_crypto/merkle/
index.rs1use core::fmt::Display;
2
3use super::{Felt, MerkleError, Word};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27 depth: u8,
28 value: u64,
29}
30
31impl NodeIndex {
32 pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
40 if (64 - value.leading_zeros()) > depth as u32 {
41 Err(MerkleError::InvalidNodeIndex { depth, value })
42 } else {
43 Ok(Self { depth, value })
44 }
45 }
46
47 pub const fn new_unchecked(depth: u8, value: u64) -> Self {
49 debug_assert!((64 - value.leading_zeros()) <= depth as u32);
50 Self { depth, value }
51 }
52
53 #[cfg(test)]
58 pub fn make(depth: u8, value: u64) -> Self {
59 Self::new(depth, value).unwrap()
60 }
61
62 pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
69 let depth = depth.as_int();
70 let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
71 let value = value.as_int();
72 Self::new(depth, value)
73 }
74
75 pub const fn root() -> Self {
77 Self { depth: 0, value: 0 }
78 }
79
80 pub const fn sibling(mut self) -> Self {
82 self.value ^= 1;
83 self
84 }
85
86 pub const fn left_child(mut self) -> Self {
88 self.depth += 1;
89 self.value <<= 1;
90 self
91 }
92
93 pub const fn right_child(mut self) -> Self {
95 self.depth += 1;
96 self.value = (self.value << 1) + 1;
97 self
98 }
99
100 pub const fn parent(mut self) -> Self {
103 self.depth = self.depth.saturating_sub(1);
104 self.value >>= 1;
105 self
106 }
107
108 pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
115 if self.is_value_odd() {
116 [sibling, slf]
117 } else {
118 [slf, sibling]
119 }
120 }
121
122 pub const fn to_scalar_index(&self) -> u64 {
126 (1 << self.depth as u64) + self.value
127 }
128
129 pub const fn depth(&self) -> u8 {
131 self.depth
132 }
133
134 pub const fn value(&self) -> u64 {
136 self.value
137 }
138
139 pub const fn is_value_odd(&self) -> bool {
141 (self.value & 1) == 1
142 }
143
144 pub const fn is_root(&self) -> bool {
146 self.depth == 0
147 }
148
149 pub fn move_up(&mut self) {
154 self.depth = self.depth.saturating_sub(1);
155 self.value >>= 1;
156 }
157
158 pub fn move_up_to(&mut self, depth: u8) {
162 debug_assert!(depth < self.depth);
163 let delta = self.depth.saturating_sub(depth);
164 self.depth = self.depth.saturating_sub(delta);
165 self.value >>= delta as u32;
166 }
167
168 pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
177 ProofIter { next_index: self.sibling() }
178 }
179}
180
181impl Display for NodeIndex {
182 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
183 write!(f, "depth={}, value={}", self.depth, self.value)
184 }
185}
186
187impl Serializable for NodeIndex {
188 fn write_into<W: ByteWriter>(&self, target: &mut W) {
189 target.write_u8(self.depth);
190 target.write_u64(self.value);
191 }
192}
193
194impl Deserializable for NodeIndex {
195 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
196 let depth = source.read_u8()?;
197 let value = source.read_u64()?;
198 NodeIndex::new(depth, value)
199 .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
200 }
201}
202
203#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
205struct ProofIter {
206 next_index: NodeIndex,
207}
208
209impl Iterator for ProofIter {
210 type Item = NodeIndex;
211
212 fn next(&mut self) -> Option<NodeIndex> {
213 if self.next_index.is_root() {
214 return None;
215 }
216
217 let index = self.next_index;
218 self.next_index = index.parent().sibling();
219
220 Some(index)
221 }
222
223 fn size_hint(&self) -> (usize, Option<usize>) {
224 let remaining = ExactSizeIterator::len(self);
225
226 (remaining, Some(remaining))
227 }
228}
229
230impl ExactSizeIterator for ProofIter {
231 fn len(&self) -> usize {
232 self.next_index.depth() as usize
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use assert_matches::assert_matches;
239 use proptest::prelude::*;
240
241 use super::*;
242
243 #[test]
244 fn test_node_index_value_too_high() {
245 assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
246 let err = NodeIndex::new(0, 1).unwrap_err();
247 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
248
249 assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
250 let err = NodeIndex::new(1, 2).unwrap_err();
251 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
252
253 assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
254 let err = NodeIndex::new(2, 4).unwrap_err();
255 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
256
257 assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
258 let err = NodeIndex::new(3, 8).unwrap_err();
259 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
260 }
261
262 #[test]
263 fn test_node_index_can_represent_depth_64() {
264 assert!(NodeIndex::new(64, u64::MAX).is_ok());
265 }
266
267 prop_compose! {
268 fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
269 let mut depth = value.ilog2() as u8;
271 if value > (1 << depth) { depth += 1;
273 }
274 NodeIndex::new(depth, value).unwrap()
275 }
276 }
277
278 proptest! {
279 #[test]
280 fn arbitrary_index_wont_panic_on_move_up(
281 mut index in node_index(),
282 count in prop::num::u8::ANY,
283 ) {
284 for _ in 0..count {
285 index.move_up();
286 }
287 }
288 }
289}