bnf_sampler/
vocabulary.rs1use bit_set::BitSet;
2use qp_trie::Trie;
3use rustc_hash::FxHashMap;
4
5use crate::utils::U8ArrayWrapper;
6#[derive(Debug, Clone)]
7pub struct Vocabulary {
9 pub token_to_id: Trie<U8ArrayWrapper, u32>,
10 pub id_to_token: FxHashMap<u32, Vec<u8>>,
12 pub id_to_token_string: FxHashMap<u32, String>,
14}
15
16impl Vocabulary {
17 pub fn get_token_strings_from_token_ids<'a>(
18 &'a self,
19 token_ids: &'a BitSet,
20 ) -> impl Iterator<Item = &'a str> {
21 token_ids
22 .iter()
23 .map(|x| self.id_to_token_string[&(x as u32)].as_str())
24 }
25
26 pub fn get_token_from_token_ids<'a>(
27 &'a self,
28 token_ids: &'a BitSet,
29 ) -> impl Iterator<Item = &'a [u8]> {
30 token_ids
31 .iter()
32 .map(|x| self.id_to_token[&(x as u32)].as_slice())
33 }
34}