bnf_sampler 0.3.8

A crate that uses recursive descent algorithm to ensure tokens produced by a large language model follow a Backus Naur Form schema.
Documentation
use itertools::Itertools;
use nohash_hasher::BuildNoHashHasher;
use std::{collections::HashMap, hash::Hash};

use crate::utils::NonterminalID;
#[derive(Clone, Debug)]
pub(crate) struct TerminalsTrie {
    pub roots: HashMap<NonterminalID, TrieNodeID, BuildNoHashHasher<NonterminalID>>,
    arena: Vec<TrieNode>,
}
#[derive(Clone, Debug)]
pub(crate) struct TerminalsTrieIter<'a> {
    initial_index: usize,
    pub stack: Vec<std::collections::hash_map::Iter<'a, u8, TrieNodeID>>,
    trie: &'a TerminalsTrie,
}

impl<'a> Iterator for TerminalsTrieIter<'a> {
    type Item = &'a [u8];

    fn next(&mut self) -> Option<Self::Item> {
        loop {
            match self.stack.last_mut() {
                None => {
                    return None;
                }
                Some(x) => match x.next() {
                    None => {
                        self.stack.pop();
                    }
                    Some((_, v)) => {
                        self.stack.push(self.trie.get(*v).children.iter());
                        if let Some(value) = &self.trie.get(*v).value {
                            return Some(&value[self.initial_index..]);
                        }
                    }
                },
            }
        }
    }
}

impl TerminalsTrie {
    pub fn new() -> Self {
        let arena = Vec::new();
        TerminalsTrie {
            roots: HashMap::default(),
            arena,
        }
    }

    fn new_node(arena: &mut Vec<TrieNode>, node: TrieNode) -> TrieNodeID {
        arena.push(node);
        TrieNodeID {
            id: arena.len() - 1,
        }
    }

    pub fn get(&self, node_id: TrieNodeID) -> &TrieNode {
        &self.arena[node_id.id]
    }

    fn get_mut(&mut self, node_id: TrieNodeID) -> &mut TrieNode {
        &mut self.arena[node_id.id]
    }

    pub fn add(&mut self, terminal: &[u8], nonterminal_id: NonterminalID, can_stop: bool) {
        let mut current_node_id = *self.roots.entry(nonterminal_id).or_insert(Self::new_node(
            &mut self.arena,
            TrieNode {
                index: 0,
                negative_bytes_index: None,
                value: None,
                children: HashMap::default(),
                can_stop,
            },
        ));
        for i in terminal {
            let matched_child_node = self.get(current_node_id).children.get(i);
            match matched_child_node {
                None => {
                    let index = self.get(current_node_id).index + 1;
                    let new_node_id = Self::new_node(
                        &mut self.arena,
                        TrieNode {
                            index,
                            negative_bytes_index: None,
                            value: None,
                            children: HashMap::default(),
                            can_stop,
                        },
                    );
                    self.get_mut(current_node_id).append(*i, new_node_id);
                    current_node_id = new_node_id;
                }
                Some(id) => {
                    current_node_id = *id;
                }
            }
        }
        let mut temp = Vec::with_capacity(terminal.len());
        temp.extend_from_slice(terminal);
        self.get_mut(current_node_id).value = Some(temp.into_boxed_slice());
    }

    pub fn except_literal(&mut self, terminal: &[u8], nonterminal_id: NonterminalID) {
        fn _except_literal(
            this: &mut TerminalsTrie,
            current_node_id: TrieNodeID,
            terminal: &[u8],
            mut index: usize,
        ) {
            if index == terminal.len() {
                this.get_mut(current_node_id).negative_bytes_index = Some(index as u16);
                index = 0;
            }
            let current_node = this.get(current_node_id);
            for (k, v) in current_node
                .children
                .iter()
                .map(|(k, v)| (*k, *v))
                .collect_vec()
            {
                if terminal[index] == k {
                    _except_literal(this, v, terminal, index + 1);
                } else if terminal[0] == k {
                    _except_literal(this, v, terminal, 1);
                } else {
                    _except_literal(this, v, terminal, 0);
                }
            }
        }
        _except_literal(self, self.roots[&nonterminal_id], terminal, 0);
    }

    pub fn iter(&self, start_node_id: TrieNodeID) -> TerminalsTrieIter {
        let stack = vec![self.get(start_node_id).children.iter()];
        return TerminalsTrieIter {
            trie: self,
            initial_index: self.get(start_node_id).index as usize,
            stack,
        };
    }
}
#[derive(PartialEq, Clone, Debug, Copy, Eq, Hash)]
pub struct TrieNodeID {
    pub id: usize,
}
#[derive(Clone, Debug)]
pub(crate) struct TrieNode {
    pub index: u16,
    pub can_stop: bool,
    pub negative_bytes_index: Option<u16>,
    pub value: Option<Box<[u8]>>,
    pub children: HashMap<u8, TrieNodeID, BuildNoHashHasher<u8>>,
}

impl TrieNode {
    pub fn append(&mut self, byte: u8, node_id: TrieNodeID) {
        self.children.insert(byte, node_id);
    }
}