nam_sparse_merkle_tree/
internal_key.rs

1#[cfg(feature = "borsh")]
2use borsh::{BorshDeserialize, BorshSerialize};
3#[cfg(feature = "borsh")]
4use core::convert::TryInto;
5use std::fmt::Debug;
6use std::io::Read;
7#[cfg(feature = "borsh")]
8use std::io::Write;
9
10/// The actual key value used in the tree
11#[derive(Eq, PartialEq, Debug, Hash, Clone, Copy, PartialOrd, Ord)]
12pub struct InternalKey<const N: usize>([u8; N]);
13
14#[cfg(feature = "borsh")]
15impl<const N: usize> BorshSerialize for InternalKey<N> {
16    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
17        let bytes = self.0.to_vec();
18        BorshSerialize::serialize(&bytes, writer)
19    }
20}
21
22#[cfg(feature = "borsh")]
23impl<const N: usize> BorshDeserialize for InternalKey<N> {
24    fn deserialize_reader<R: Read>(reader: &mut R) -> std::io::Result<Self> {
25        use std::io::ErrorKind;
26        let bytes: Vec<u8> = BorshDeserialize::deserialize_reader(reader)?;
27        let bytes: [u8; N] = bytes.try_into().map_err(|_| {
28            std::io::Error::new(ErrorKind::InvalidData, "Input byte vector is too large")
29        })?;
30        Ok(InternalKey(bytes))
31    }
32}
33
34const BYTE_SIZE: usize = 8;
35
36impl<const N: usize> InternalKey<N> {
37    pub fn new(array: [u8; N]) -> Self {
38        Self(array)
39    }
40
41    pub fn as_slice(&self) -> &[u8] {
42        &self.0[..]
43    }
44
45    pub const fn zero() -> Self {
46        InternalKey([0u8; N])
47    }
48
49    pub const fn max_index() -> usize {
50        N - 1
51    }
52
53    #[inline]
54    pub fn get_bit(&self, i: usize) -> bool {
55        if i / BYTE_SIZE > Self::max_index() {
56            println!("Hey");
57        }
58        let byte_pos = Self::max_index() - i / BYTE_SIZE;
59        let bit_pos = i % BYTE_SIZE;
60        let bit = self.0[byte_pos] >> bit_pos & 1;
61        bit != 0
62    }
63
64    #[inline]
65    pub fn set_bit(&mut self, i: usize) {
66        let byte_pos = Self::max_index() - i / BYTE_SIZE;
67        let bit_pos = i % BYTE_SIZE;
68        self.0[byte_pos as usize] |= 1 << bit_pos as u8;
69    }
70
71    #[inline]
72    pub fn clear_bit(&mut self, i: usize) {
73        let byte_pos = Self::max_index() - i / BYTE_SIZE;
74        let bit_pos = i % BYTE_SIZE;
75        self.0[byte_pos as usize] &= !((1 << bit_pos) as u8);
76    }
77
78    /// Treat InternalKey as a path in a tree
79    /// fork height is the number of common bits(from higher to lower)
80    /// of two InternalKeys
81    pub fn fork_height(&self, key: &InternalKey<N>) -> usize {
82        let max = (BYTE_SIZE * N) as usize;
83        for h in (0..max).rev() {
84            if self.get_bit(h) != key.get_bit(h) {
85                return h;
86            }
87        }
88        0
89    }
90
91    /// Treat InternalKey as a path in a tree
92    /// return parent_path of self
93    pub fn parent_path(&self, height: usize) -> Self {
94        height
95            .checked_add(1)
96            .map(|i| self.copy_bits(i..))
97            .unwrap_or_else(InternalKey::zero)
98    }
99
100    /// Copy bits and return a new InternalKey
101    pub fn copy_bits(&self, range: impl core::ops::RangeBounds<usize>) -> Self {
102        let array_size = N;
103        let max = 8 * N;
104        use core::ops::Bound;
105
106        let mut target = InternalKey::zero();
107        let start = match range.start_bound() {
108            Bound::Included(&i) => i as usize,
109            Bound::Excluded(&i) => panic!("do not allows excluded start: {}", i),
110            Bound::Unbounded => 0,
111        };
112
113        let mut end = match range.end_bound() {
114            Bound::Included(&i) => i.saturating_add(1) as usize,
115            Bound::Excluded(&i) => i as usize,
116            Bound::Unbounded => max,
117        };
118
119        if start >= max {
120            return target;
121        } else if end > max {
122            end = max;
123        }
124
125        if end < start {
126            panic!("end can't less than start: start {} end {}", start, end);
127        }
128
129        let end_byte = {
130            let remain = if start % BYTE_SIZE != 0 { 1 } else { 0 };
131            array_size - start / BYTE_SIZE - remain
132        };
133        let start_byte = array_size - end / BYTE_SIZE;
134        // copy bytes
135        if start_byte < self.0.len() && start_byte <= end_byte {
136            target.0[start_byte..end_byte].copy_from_slice(&self.0[start_byte..end_byte]);
137        }
138
139        // copy remain bits
140        for i in (start..core::cmp::min((array_size - end_byte) * BYTE_SIZE, end))
141            .chain(core::cmp::max((array_size - start_byte) * BYTE_SIZE, start)..end)
142        {
143            if self.get_bit(i) {
144                target.set_bit(i)
145            }
146        }
147        target
148    }
149}
150
151impl<const N: usize> From<[u8; N]> for InternalKey<N> {
152    fn from(v: [u8; N]) -> Self {
153        Self::new(v)
154    }
155}
156
157impl<const N: usize> From<InternalKey<N>> for [u8; N] {
158    fn from(v: InternalKey<N>) -> Self {
159        v.0
160    }
161}