miden_crypto/merkle/
index.rs1use core::fmt::Display;
2
3use super::{Felt, MerkleError, RpoDigest};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27 depth: u8,
28 value: u64,
29}
30
31impl NodeIndex {
32 pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
40 if (64 - value.leading_zeros()) > depth as u32 {
41 Err(MerkleError::InvalidNodeIndex { depth, value })
42 } else {
43 Ok(Self { depth, value })
44 }
45 }
46
47 pub const fn new_unchecked(depth: u8, value: u64) -> Self {
49 debug_assert!((64 - value.leading_zeros()) <= depth as u32);
50 Self { depth, value }
51 }
52
53 #[cfg(test)]
58 pub fn make(depth: u8, value: u64) -> Self {
59 Self::new(depth, value).unwrap()
60 }
61
62 pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
69 let depth = depth.as_int();
70 let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
71 let value = value.as_int();
72 Self::new(depth, value)
73 }
74
75 pub const fn root() -> Self {
77 Self { depth: 0, value: 0 }
78 }
79
80 pub const fn sibling(mut self) -> Self {
82 self.value ^= 1;
83 self
84 }
85
86 pub const fn left_child(mut self) -> Self {
88 self.depth += 1;
89 self.value <<= 1;
90 self
91 }
92
93 pub const fn right_child(mut self) -> Self {
95 self.depth += 1;
96 self.value = (self.value << 1) + 1;
97 self
98 }
99
100 pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
107 if self.is_value_odd() {
108 [sibling, slf]
109 } else {
110 [slf, sibling]
111 }
112 }
113
114 pub const fn to_scalar_index(&self) -> u64 {
118 (1 << self.depth as u64) + self.value
119 }
120
121 pub const fn depth(&self) -> u8 {
123 self.depth
124 }
125
126 pub const fn value(&self) -> u64 {
128 self.value
129 }
130
131 pub const fn is_value_odd(&self) -> bool {
133 (self.value & 1) == 1
134 }
135
136 pub const fn is_root(&self) -> bool {
138 self.depth == 0
139 }
140
141 pub fn move_up(&mut self) {
146 self.depth = self.depth.saturating_sub(1);
147 self.value >>= 1;
148 }
149
150 pub fn move_up_to(&mut self, depth: u8) {
154 debug_assert!(depth < self.depth);
155 let delta = self.depth.saturating_sub(depth);
156 self.depth = self.depth.saturating_sub(delta);
157 self.value >>= delta as u32;
158 }
159}
160
161impl Display for NodeIndex {
162 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
163 write!(f, "depth={}, value={}", self.depth, self.value)
164 }
165}
166
167impl Serializable for NodeIndex {
168 fn write_into<W: ByteWriter>(&self, target: &mut W) {
169 target.write_u8(self.depth);
170 target.write_u64(self.value);
171 }
172}
173
174impl Deserializable for NodeIndex {
175 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
176 let depth = source.read_u8()?;
177 let value = source.read_u64()?;
178 NodeIndex::new(depth, value)
179 .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use assert_matches::assert_matches;
186 use proptest::prelude::*;
187
188 use super::*;
189
190 #[test]
191 fn test_node_index_value_too_high() {
192 assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
193 let err = NodeIndex::new(0, 1).unwrap_err();
194 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
195
196 assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
197 let err = NodeIndex::new(1, 2).unwrap_err();
198 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
199
200 assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
201 let err = NodeIndex::new(2, 4).unwrap_err();
202 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
203
204 assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
205 let err = NodeIndex::new(3, 8).unwrap_err();
206 assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
207 }
208
209 #[test]
210 fn test_node_index_can_represent_depth_64() {
211 assert!(NodeIndex::new(64, u64::MAX).is_ok());
212 }
213
214 prop_compose! {
215 fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
216 let mut depth = value.ilog2() as u8;
218 if value > (1 << depth) { depth += 1;
220 }
221 NodeIndex::new(depth, value).unwrap()
222 }
223 }
224
225 proptest! {
226 #[test]
227 fn arbitrary_index_wont_panic_on_move_up(
228 mut index in node_index(),
229 count in prop::num::u8::ANY,
230 ) {
231 for _ in 0..count {
232 index.move_up();
233 }
234 }
235 }
236}