Skip to main content

revm_bytecode/legacy/
jump_map.rs

1use bitvec::vec::BitVec;
2use core::{cmp::Ordering, fmt, hash};
3use primitives::hex;
4use std::{borrow::Cow, vec::Vec};
5
6/// A table of valid `jump` destinations.
7///
8/// It is immutable and memory efficient, with one bit per byte in the bytecode.
9pub struct JumpTable {
10    table: Cow<'static, [u8]>,
11    bit_len: usize,
12}
13
14impl Clone for JumpTable {
15    #[inline]
16    fn clone(&self) -> Self {
17        Self {
18            table: match &self.table {
19                Cow::Borrowed(b) => Cow::Borrowed(b),
20                Cow::Owned(o) => Cow::Owned(o.clone()),
21            },
22            bit_len: self.bit_len,
23        }
24    }
25}
26
27impl fmt::Debug for JumpTable {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("JumpTable")
30            .field("map", &hex::encode(self.as_slice()))
31            .finish()
32    }
33}
34
35impl PartialEq for JumpTable {
36    #[inline]
37    fn eq(&self, other: &Self) -> bool {
38        self.as_slice().eq(other.as_slice())
39    }
40}
41
42impl Eq for JumpTable {}
43
44impl PartialOrd for JumpTable {
45    #[inline]
46    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
47        Some(self.cmp(other))
48    }
49}
50
51impl Ord for JumpTable {
52    #[inline]
53    fn cmp(&self, other: &Self) -> Ordering {
54        self.as_slice().cmp(other.as_slice())
55    }
56}
57
58impl hash::Hash for JumpTable {
59    #[inline]
60    fn hash<H: hash::Hasher>(&self, state: &mut H) {
61        self.as_slice().hash(state);
62    }
63}
64
65impl Default for JumpTable {
66    #[inline]
67    fn default() -> Self {
68        Self::new(Default::default())
69    }
70}
71
72#[cfg(feature = "serde")]
73impl serde::Serialize for JumpTable {
74    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
75    where
76        S: serde::Serializer,
77    {
78        let mut bitvec = BitVec::<u8>::from_vec(self.as_slice().to_vec());
79        bitvec.resize(self.bit_len, false);
80        bitvec.serialize(serializer)
81    }
82}
83
84#[cfg(feature = "serde")]
85impl<'de> serde::Deserialize<'de> for JumpTable {
86    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87    where
88        D: serde::Deserializer<'de>,
89    {
90        BitVec::deserialize(deserializer).map(Self::new)
91    }
92}
93
94impl JumpTable {
95    /// Create new JumpTable directly from an existing BitVec.
96    #[inline]
97    pub fn new(jumps: BitVec<u8>) -> Self {
98        let bit_len = jumps.len();
99        Self::from_vec(jumps.into_vec(), bit_len)
100    }
101
102    /// Constructs a jump map from raw bytes and length.
103    ///
104    /// Bit length represents number of used bits inside slice.
105    ///
106    /// # Panics
107    ///
108    /// Panics if number of bits in slice is less than bit_len.
109    #[inline]
110    pub fn from_slice(slice: &[u8], bit_len: usize) -> Self {
111        Self::size_assert(slice.len(), bit_len);
112        Self::from_vec(slice.to_vec(), bit_len)
113    }
114
115    #[inline]
116    fn from_vec(slice: Vec<u8>, bit_len: usize) -> Self {
117        #[cfg(debug_assertions)]
118        Self::size_assert(slice.len(), bit_len);
119        Self {
120            table: slice.into(),
121            bit_len,
122        }
123    }
124
125    /// Constructs a jump map from raw bytes and length.
126    ///
127    /// Bit length represents number of used bits inside slice.
128    ///
129    /// # Panics
130    ///
131    /// Panics if number of bits in slice is less than bit_len.
132    #[inline]
133    pub fn from_static_slice(slice: &'static [u8], bit_len: usize) -> Self {
134        Self::size_assert(slice.len(), bit_len);
135        Self {
136            table: Cow::Borrowed(slice),
137            bit_len,
138        }
139    }
140
141    #[inline]
142    fn size_assert(len: usize, bit_len: usize) {
143        const BYTE_LEN: usize = 8;
144        assert!(
145            len * BYTE_LEN >= bit_len,
146            "slice bit length {} is less than bit_len {}",
147            len * BYTE_LEN,
148            bit_len
149        );
150    }
151
152    /// Gets the raw bytes of the jump map.
153    #[inline]
154    pub fn as_slice(&self) -> &[u8] {
155        // SAFETY: always valid.
156        &self.table
157    }
158
159    /// Gets the bit length of the jump map.
160    #[inline]
161    pub fn len(&self) -> usize {
162        self.bit_len
163    }
164
165    /// Returns true if the jump map is empty.
166    #[inline]
167    pub fn is_empty(&self) -> bool {
168        self.len() == 0
169    }
170
171    /// Checks if `pc` is a valid jump destination.
172    #[inline]
173    pub fn is_valid(&self, pc: usize) -> bool {
174        pc < self.bit_len
175            && unsafe { *self.as_slice().as_ptr().add(pc >> 3) & (1 << (pc & 7)) != 0 }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    #[should_panic(expected = "slice bit length 8 is less than bit_len 10")]
185    fn test_jump_table_from_slice_panic() {
186        let slice = &[0x00];
187        let _ = JumpTable::from_slice(slice, 10);
188    }
189
190    #[test]
191    fn test_jump_table_from_slice() {
192        let slice = &[0x00];
193        let jump_table = JumpTable::from_slice(slice, 3);
194        assert_eq!(jump_table.len(), 3);
195    }
196
197    #[test]
198    fn test_is_valid() {
199        let jump_table = JumpTable::from_slice(&[0x0D, 0x06], 13);
200
201        assert_eq!(jump_table.len(), 13);
202
203        assert!(jump_table.is_valid(0)); // valid
204        assert!(!jump_table.is_valid(1));
205        assert!(jump_table.is_valid(2)); // valid
206        assert!(jump_table.is_valid(3)); // valid
207        assert!(!jump_table.is_valid(4));
208        assert!(!jump_table.is_valid(5));
209        assert!(!jump_table.is_valid(6));
210        assert!(!jump_table.is_valid(7));
211        assert!(!jump_table.is_valid(8));
212        assert!(jump_table.is_valid(9)); // valid
213        assert!(jump_table.is_valid(10)); // valid
214        assert!(!jump_table.is_valid(11));
215        assert!(!jump_table.is_valid(12));
216    }
217
218    #[test]
219    #[cfg(feature = "serde")]
220    fn test_serde_legacy_format() {
221        let legacy_format = r#"
222        {
223            "order": "bitvec::order::Lsb0",
224            "head": {
225                "width": 8,
226                "index": 0
227            },
228            "bits": 4,
229            "data": [5]
230        }"#;
231
232        let table: JumpTable = serde_json::from_str(legacy_format).expect("Failed to deserialize");
233        assert_eq!(table.len(), 4);
234        assert!(table.is_valid(0));
235        assert!(!table.is_valid(1));
236        assert!(table.is_valid(2));
237        assert!(!table.is_valid(3));
238    }
239
240    #[test]
241    #[cfg(feature = "serde")]
242    fn test_serde_roundtrip() {
243        let original = JumpTable::from_slice(&[0x0D, 0x06], 13);
244
245        // Serialize to JSON
246        let serialized = serde_json::to_string(&original).expect("Failed to serialize");
247
248        // Deserialize from JSON
249        let deserialized: JumpTable =
250            serde_json::from_str(&serialized).expect("Failed to deserialize");
251
252        // Check that the deserialized table matches the original
253        assert_eq!(original.len(), deserialized.len());
254        assert_eq!(original.table, deserialized.table);
255        assert_eq!(original, deserialized);
256
257        // Verify functionality is preserved
258        for i in 0..13 {
259            assert_eq!(
260                original.is_valid(i),
261                deserialized.is_valid(i),
262                "Mismatch at index {i}"
263            );
264        }
265    }
266}