bnf_sampler/
vocabulary.rs

1use bit_set::BitSet;
2use qp_trie::Trie;
3use rustc_hash::FxHashMap;
4
5use crate::utils::U8ArrayWrapper;
6#[derive(Debug, Clone)]
7/// The struct represents a language model's vocabulary.
8pub struct Vocabulary {
9    pub token_to_id: Trie<U8ArrayWrapper, u32>,
10    /// This field represents a map from token id to the token in bytes.
11    pub id_to_token: FxHashMap<u32, Vec<u8>>,
12    /// This field represents a map from token id to the token in UTF-8 String representation.
13    pub id_to_token_string: FxHashMap<u32, String>,
14}
15
16impl Vocabulary {
17    pub fn get_token_strings_from_token_ids<'a>(
18        &'a self,
19        token_ids: &'a BitSet,
20    ) -> impl Iterator<Item = &'a str> {
21        token_ids
22            .iter()
23            .map(|x| self.id_to_token_string[&(x as u32)].as_str())
24    }
25
26    pub fn get_token_from_token_ids<'a>(
27        &'a self,
28        token_ids: &'a BitSet,
29    ) -> impl Iterator<Item = &'a [u8]> {
30        token_ids
31            .iter()
32            .map(|x| self.id_to_token[&(x as u32)].as_slice())
33    }
34}