dyld-trie 0.1.0

A little library for parsing Dyld trie.
Documentation
//! A little `no_std` library for parsing Dyld trie.
//!
//! This library provides two functions:
//!
//! 1. [`walk`] used for "walking" the trie (i.e. finding a terminal for a symbol);
//! 2. [`iter`] used for iterating the trie.

#![no_std]
#![deny(missing_docs)]
#![deny(missing_debug_implementations)]
#![deny(rust_2018_idioms)]
#![deny(unreachable_pub)]

extern crate alloc;
#[cfg(any(test, feature = "std"))]
extern crate std;

use alloc::vec::Vec;
use core::{convert::TryInto, fmt, iter::FusedIterator, mem, num::NonZeroUsize};

/// An error returned when a parsed trie has invalid format.
#[derive(Default, Copy, Clone, Debug)]
pub struct InvalidTrieError;

impl fmt::Display for InvalidTrieError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("invalid trie")
    }
}

#[cfg(feature = "std")]
impl std::error::Error for InvalidTrieError {}

type Result<T> = core::result::Result<T, InvalidTrieError>;

#[derive(Clone, Debug)]
struct NodeIterState {
    /// Current offset in trie.
    offset: usize,
    /// Number of children nodes left to visit.
    children_left: usize,
    /// How many characters current node added to the symbol.
    sym_delta: usize,
}

impl Default for NodeIterState {
    #[inline(always)]
    fn default() -> Self {
        NodeIterState {
            offset: 0,
            children_left: 1,
            sym_delta: 0,
        }
    }
}

fn read_uleb128_size(slice: &mut &[u8]) -> Result<usize> {
    let value = leb128::read::unsigned(slice).map_err(|_| InvalidTrieError)?;
    value.try_into().map_err(|_| InvalidTrieError)
}

fn read_slice<'a>(slice: &mut &'a [u8], len: usize) -> Result<&'a [u8]> {
    if slice.len() >= len {
        let (head, tail) = slice.split_at(len);
        *slice = tail;
        Ok(head)
    } else {
        Err(InvalidTrieError)
    }
}

fn read_size_and_slice<'a>(slice: &mut &'a [u8]) -> Result<&'a [u8]> {
    let size = read_uleb128_size(slice)?;
    read_slice(slice, size)
}

fn read_string<'a>(slice: &mut &'a [u8]) -> Result<&'a [u8]> {
    let pos = slice.iter().position(|&b| b == 0).ok_or(InvalidTrieError)?;
    let head = &slice[..pos];
    *slice = &slice[pos + 1..];
    Ok(head)
}

/// Walks trie searching for a symbol. When found returns the associated terminal.
///
/// # Errors
///
/// Returns [`InvalidTrieError`] when either trie is corrupted or trie depth exceeds 128.
pub fn walk<'a>(trie: &'a [u8], sym: &[u8]) -> Result<Option<&'a [u8]>> {
    let mut visited_offsets = Vec::new();
    visited_offsets.push(0);

    let mut data = trie;
    let mut sym = sym;
    loop {
        let terminal = read_size_and_slice(&mut data)?;

        if sym.is_empty() && !terminal.is_empty() {
            break Ok(Some(terminal));
        }

        let children_count = read_uleb128_size(&mut data)?;

        let mut node_offset: Option<NonZeroUsize> = None;
        for _ in 0..children_count {
            let mut cur_sym = sym;
            let mut wrong_edge = false;

            // Read edge's weight value character by character until NUL terminator is found and
            // compare with the current symbol's part. Set wrong_edge to true when the first
            // non-matching symbol is encountered.
            loop {
                let (&c, tail) = data.split_first().ok_or(InvalidTrieError)?;
                data = tail;
                if c == 0 {
                    break;
                }

                if !wrong_edge {
                    let (&cs, tail) = cur_sym.split_first().ok_or(InvalidTrieError)?;
                    wrong_edge = c != cs;
                    cur_sym = tail;
                }
            }

            if wrong_edge {
                // Advance to next child: skip zero terminator and uleb128 value.
                read_uleb128_size(&mut data)?;
            } else {
                // The symbol so far matches this edge (child), advance to the child's node.
                let offset = read_uleb128_size(&mut data)?;

                if offset > trie.len() {
                    return Err(InvalidTrieError);
                }
                node_offset = Some(NonZeroUsize::new(offset).ok_or(InvalidTrieError)?);
                sym = cur_sym;
                break;
            }
        }

        if let Some(offset) = node_offset {
            let offset = offset.get();

            // Check that no cycles occur and that trie path size is reasonable.
            if visited_offsets.contains(&offset) || visited_offsets.len() > 128 {
                return Err(InvalidTrieError);
            }
            visited_offsets.push(offset);

            data = &trie[offset..];
        } else {
            return Ok(None);
        }
    }
}

/// Returns an iterator over the entries of a Dyld trie.
///
/// # Handling invalid trie
///
/// Invalid trie may be handled by using the [`TrieIter::next_no_copy`] method for iteration instead
/// of the [`Iterator`] trait. See docs for [`TrieIter`] for more info.
pub fn iter(trie: &[u8]) -> TrieIter<'_> {
    TrieIter {
        trie,
        sym_buf: Vec::new(),
        stack: Vec::new(),
        node_state: Default::default(),
        visited_offsets: Vec::new(),
    }
}

/// An iterator over encoded Mach-O trie.
///
/// # Iteration and copying
///
/// This structure implements the [`Iterator`] trait, however due to the requirement to have an
/// instance of an associated type as a return value a reference to the internal buffer may not
/// be returned from `next()`. The `next()` implementation returns a clone of the internal symbol
/// buffer on each invocation (a `Vec<u8>`). In case you don't need this it's better to use the
/// [`TrieIter::next_no_copy`] method which returns a reference to the internal buffer instead of
/// cloning it.
#[derive(Debug)]
pub struct TrieIter<'trie> {
    trie: &'trie [u8],
    sym_buf: Vec<u8>,
    stack: Vec<NodeIterState>,
    node_state: NodeIterState,
    visited_offsets: Vec<usize>,
}

impl<'trie> TrieIter<'trie> {
    /// Returns a tuple where the first element is the symbol name and the second element is the
    /// terminal corresponding to it or `None` if there are no more symbols.
    ///
    /// This is the recommended way of trie iteration. In case you want to create a `Vec<u8>` for
    /// each iteration and don't care about error handling, you may use [`Iterator::next`] instead.
    pub fn next_no_copy<'iter>(&'iter mut self) -> Result<Option<(&'iter [u8], &'trie [u8])>> {
        loop {
            // If offset is 0 we're at the beginning of the trie,
            let mut data = if self.node_state.offset == 0 {
                let mut data = self.trie;

                // skip terminal header and advance to children count
                read_size_and_slice(&mut data)?;
                let children_count = read_uleb128_size(&mut data)?;

                self.node_state.offset = self.trie.len() - data.len();
                self.node_state.children_left = children_count;

                if self.node_state.children_left == 0 {
                    return Ok(None);
                }

                data
            } else {
                while self.node_state.children_left == 0 {
                    let Some(next_state) = self.stack.pop() else {
                        return Ok(None);
                    };

                    let node_state = mem::replace(&mut self.node_state, next_state);
                    self.sym_buf
                        .truncate(self.sym_buf.len() - node_state.sym_delta);
                }

                &self.trie[self.node_state.offset..]
            };

            let len_before = data.len();
            let part = read_string(&mut data)?;
            let offset = read_uleb128_size(&mut data)?;

            // Update state for current node.
            self.node_state.offset += len_before - data.len();
            self.node_state.children_left -= 1;

            if self.visited_offsets.contains(&offset) {
                // Detected a cycle, this is bad.
                break Err(InvalidTrieError);
            }
            self.visited_offsets.push(offset);

            // Parse the node header.
            let mut data = &self.trie[offset..];
            let terminal = read_size_and_slice(&mut data)?;
            let children_count = read_uleb128_size(&mut data)?;

            self.sym_buf.extend_from_slice(part);

            self.stack.push(mem::replace(
                &mut self.node_state,
                NodeIterState {
                    offset: self.trie.len() - data.len(),
                    children_left: children_count,
                    sym_delta: part.len(),
                },
            ));

            if !terminal.is_empty() {
                break Ok(Some((self.sym_buf.as_slice(), terminal)));
            }
        }
    }
}

impl<'trie> Iterator for TrieIter<'trie> {
    type Item = (Vec<u8>, &'trie [u8]);

    fn next(&mut self) -> Option<Self::Item> {
        self.next_no_copy()
            .unwrap()
            .map(|(sym, terminal)| (sym.to_vec(), terminal))
    }
}

impl FusedIterator for TrieIter<'_> {}

#[cfg(test)]
mod tests {
    use super::*;
    use hex_literal::hex;
    use sha2::Digest;

    const TEST_TRIE: &'static [u8] = include_bytes!("../tests/test_trie.bin");
    const TEST_TRIE_HASH: [u8; 32] =
        hex!("9829e0f5330988ef653dc534cde5998dd4fbd5e107dc6c92545253155d5f04ef");

    #[test]
    fn test_walk() {
        walk(
            TEST_TRIE,
            b"__ZN3JSC12RegExpObjectC1ERNS_2VMEPNS_9StructureEPNS_6RegExpE",
        )
        .unwrap();
        assert!(walk(
            TEST_TRIE,
            b"__ZN3JSC12RegExpObjectC1ERNS_2VMEPNS_9StructureEPNS_6RegEx"
        )
        .is_err());
    }

    #[test]
    fn test_iter() {
        let mut iter = TrieIter {
            trie: TEST_TRIE,
            sym_buf: Vec::new(),
            stack: Vec::new(),
            node_state: Default::default(),
            visited_offsets: Vec::new(),
        };

        let mut digest = sha2::Sha256::new();
        while let Some((sym, terminal)) = iter.next_no_copy().unwrap() {
            digest.update(sym);
            digest.update(&[0]);
            digest.update(terminal);
            digest.update(&[0]);
        }

        assert_eq!(digest.finalize().as_slice(), &TEST_TRIE_HASH);
    }
}