1#![no_std]
9#![deny(missing_docs)]
10#![deny(missing_debug_implementations)]
11#![deny(rust_2018_idioms)]
12#![deny(unreachable_pub)]
13
14extern crate alloc;
15#[cfg(any(test, feature = "std"))]
16extern crate std;
17
18use alloc::vec::Vec;
19use core::{convert::TryInto, fmt, iter::FusedIterator, mem, num::NonZeroUsize};
20
21#[derive(Default, Copy, Clone, Debug)]
23pub struct InvalidTrieError;
24
25impl fmt::Display for InvalidTrieError {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.write_str("invalid trie")
28 }
29}
30
31#[cfg(feature = "std")]
32impl std::error::Error for InvalidTrieError {}
33
34type Result<T> = core::result::Result<T, InvalidTrieError>;
35
36#[derive(Clone, Debug)]
37struct NodeIterState {
38 offset: usize,
40 children_left: usize,
42 sym_delta: usize,
44}
45
46impl Default for NodeIterState {
47 #[inline(always)]
48 fn default() -> Self {
49 NodeIterState {
50 offset: 0,
51 children_left: 1,
52 sym_delta: 0,
53 }
54 }
55}
56
57fn read_uleb128_size(slice: &mut &[u8]) -> Result<usize> {
58 let value = leb128::read::unsigned(slice).map_err(|_| InvalidTrieError)?;
59 value.try_into().map_err(|_| InvalidTrieError)
60}
61
62fn read_slice<'a>(slice: &mut &'a [u8], len: usize) -> Result<&'a [u8]> {
63 if slice.len() >= len {
64 let (head, tail) = slice.split_at(len);
65 *slice = tail;
66 Ok(head)
67 } else {
68 Err(InvalidTrieError)
69 }
70}
71
72fn read_size_and_slice<'a>(slice: &mut &'a [u8]) -> Result<&'a [u8]> {
73 let size = read_uleb128_size(slice)?;
74 read_slice(slice, size)
75}
76
77fn read_string<'a>(slice: &mut &'a [u8]) -> Result<&'a [u8]> {
78 let pos = slice.iter().position(|&b| b == 0).ok_or(InvalidTrieError)?;
79 let head = &slice[..pos];
80 *slice = &slice[pos + 1..];
81 Ok(head)
82}
83
84pub fn walk<'a>(trie: &'a [u8], sym: &[u8]) -> Result<Option<&'a [u8]>> {
90 let mut visited_offsets = Vec::new();
91 visited_offsets.push(0);
92
93 let mut data = trie;
94 let mut sym = sym;
95 loop {
96 let terminal = read_size_and_slice(&mut data)?;
97
98 if sym.is_empty() && !terminal.is_empty() {
99 break Ok(Some(terminal));
100 }
101
102 let children_count = read_uleb128_size(&mut data)?;
103
104 let mut node_offset: Option<NonZeroUsize> = None;
105 for _ in 0..children_count {
106 let mut cur_sym = sym;
107 let mut wrong_edge = false;
108
109 loop {
113 let (&c, tail) = data.split_first().ok_or(InvalidTrieError)?;
114 data = tail;
115 if c == 0 {
116 break;
117 }
118
119 if !wrong_edge {
120 let (&cs, tail) = cur_sym.split_first().ok_or(InvalidTrieError)?;
121 wrong_edge = c != cs;
122 cur_sym = tail;
123 }
124 }
125
126 if wrong_edge {
127 read_uleb128_size(&mut data)?;
129 } else {
130 let offset = read_uleb128_size(&mut data)?;
132
133 if offset > trie.len() {
134 return Err(InvalidTrieError);
135 }
136 node_offset = Some(NonZeroUsize::new(offset).ok_or(InvalidTrieError)?);
137 sym = cur_sym;
138 break;
139 }
140 }
141
142 if let Some(offset) = node_offset {
143 let offset = offset.get();
144
145 if visited_offsets.contains(&offset) || visited_offsets.len() > 128 {
147 return Err(InvalidTrieError);
148 }
149 visited_offsets.push(offset);
150
151 data = &trie[offset..];
152 } else {
153 return Ok(None);
154 }
155 }
156}
157
158pub fn iter(trie: &[u8]) -> TrieIter<'_> {
165 TrieIter {
166 trie,
167 sym_buf: Vec::new(),
168 stack: Vec::new(),
169 node_state: Default::default(),
170 visited_offsets: Vec::new(),
171 }
172}
173
174#[derive(Debug)]
185pub struct TrieIter<'trie> {
186 trie: &'trie [u8],
187 sym_buf: Vec<u8>,
188 stack: Vec<NodeIterState>,
189 node_state: NodeIterState,
190 visited_offsets: Vec<usize>,
191}
192
193impl<'trie> TrieIter<'trie> {
194 pub fn next_no_copy<'iter>(&'iter mut self) -> Result<Option<(&'iter [u8], &'trie [u8])>> {
200 loop {
201 let mut data = if self.node_state.offset == 0 {
203 let mut data = self.trie;
204
205 read_size_and_slice(&mut data)?;
207 let children_count = read_uleb128_size(&mut data)?;
208
209 self.node_state.offset = self.trie.len() - data.len();
210 self.node_state.children_left = children_count;
211
212 if self.node_state.children_left == 0 {
213 return Ok(None);
214 }
215
216 data
217 } else {
218 while self.node_state.children_left == 0 {
219 let Some(next_state) = self.stack.pop() else {
220 return Ok(None);
221 };
222
223 let node_state = mem::replace(&mut self.node_state, next_state);
224 self.sym_buf
225 .truncate(self.sym_buf.len() - node_state.sym_delta);
226 }
227
228 &self.trie[self.node_state.offset..]
229 };
230
231 let len_before = data.len();
232 let part = read_string(&mut data)?;
233 let offset = read_uleb128_size(&mut data)?;
234
235 self.node_state.offset += len_before - data.len();
237 self.node_state.children_left -= 1;
238
239 if self.visited_offsets.contains(&offset) {
240 break Err(InvalidTrieError);
242 }
243 self.visited_offsets.push(offset);
244
245 let mut data = &self.trie[offset..];
247 let terminal = read_size_and_slice(&mut data)?;
248 let children_count = read_uleb128_size(&mut data)?;
249
250 self.sym_buf.extend_from_slice(part);
251
252 self.stack.push(mem::replace(
253 &mut self.node_state,
254 NodeIterState {
255 offset: self.trie.len() - data.len(),
256 children_left: children_count,
257 sym_delta: part.len(),
258 },
259 ));
260
261 if !terminal.is_empty() {
262 break Ok(Some((self.sym_buf.as_slice(), terminal)));
263 }
264 }
265 }
266}
267
268impl<'trie> Iterator for TrieIter<'trie> {
269 type Item = (Vec<u8>, &'trie [u8]);
270
271 fn next(&mut self) -> Option<Self::Item> {
272 self.next_no_copy()
273 .unwrap()
274 .map(|(sym, terminal)| (sym.to_vec(), terminal))
275 }
276}
277
278impl FusedIterator for TrieIter<'_> {}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use hex_literal::hex;
284 use sha2::Digest;
285
286 const TEST_TRIE: &'static [u8] = include_bytes!("../tests/test_trie.bin");
287 const TEST_TRIE_HASH: [u8; 32] =
288 hex!("9829e0f5330988ef653dc534cde5998dd4fbd5e107dc6c92545253155d5f04ef");
289
290 #[test]
291 fn test_walk() {
292 walk(
293 TEST_TRIE,
294 b"__ZN3JSC12RegExpObjectC1ERNS_2VMEPNS_9StructureEPNS_6RegExpE",
295 )
296 .unwrap();
297 assert!(walk(
298 TEST_TRIE,
299 b"__ZN3JSC12RegExpObjectC1ERNS_2VMEPNS_9StructureEPNS_6RegEx"
300 )
301 .is_err());
302 }
303
304 #[test]
305 fn test_iter() {
306 let mut iter = TrieIter {
307 trie: TEST_TRIE,
308 sym_buf: Vec::new(),
309 stack: Vec::new(),
310 node_state: Default::default(),
311 visited_offsets: Vec::new(),
312 };
313
314 let mut digest = sha2::Sha256::new();
315 while let Some((sym, terminal)) = iter.next_no_copy().unwrap() {
316 digest.update(sym);
317 digest.update(&[0]);
318 digest.update(terminal);
319 digest.update(&[0]);
320 }
321
322 assert_eq!(digest.finalize().as_slice(), &TEST_TRIE_HASH);
323 }
324}