ai00_core/sampler/
bnf.rs

1use anyhow::Result;
2use kbnf::{
3    engine_like::AcceptTokenError, AcceptTokenResult, Engine, EngineLike, Token, Vocabulary,
4};
5use web_rwkv::tokenizer::Tokenizer;
6
7use super::Formatter;
8
9#[derive(Debug)]
10pub struct BnfSampler(Engine);
11
12impl BnfSampler {
13    pub fn new(tokenizer: &Tokenizer, schema: &str) -> Result<Self> {
14        let tokens = tokenizer
15            .token_index_to_bytes()
16            .iter()
17            .enumerate()
18            .filter(|(_, v)| !v.is_empty())
19            .map(|(k, v)| (k as u32, Token(v.clone().into_boxed_slice())))
20            .collect();
21        let strings = tokenizer
22            .token_index_to_bytes()
23            .iter()
24            .enumerate()
25            .filter(|(_, v)| !v.is_empty())
26            .map(|(k, v)| (k as u32, String::from_utf8_lossy(v).to_string()))
27            .collect();
28        let vocab = Vocabulary::new(tokens, strings)?;
29        let engine = Engine::new(schema, vocab)?;
30        Ok(Self(engine))
31    }
32}
33
34impl Formatter for BnfSampler {
35    fn transform(&self, output: &mut [f32]) {
36        let output = &mut output[..self.0.vocab().vocab_size()];
37        self.0.mask_logits(output).expect("bnf transform error")
38    }
39
40    fn update(&mut self, token: u32) -> bool {
41        let halt = match self.0.try_accept_new_token(token) {
42            Ok(AcceptTokenResult::Finished) | Err(AcceptTokenError::Finished) => true,
43            Ok(AcceptTokenResult::Ongoing) => false,
44            Err(_) => self.0.is_finished(),
45        };
46        self.0.compute_allowed_token_ids();
47        halt
48    }
49}