1use anyhow::Result;
2use kbnf::{
3 engine_like::AcceptTokenError, AcceptTokenResult, Engine, EngineLike, Token, Vocabulary,
4};
5use web_rwkv::tokenizer::Tokenizer;
6
7use super::Formatter;
8
9#[derive(Debug)]
10pub struct BnfSampler(Engine);
11
12impl BnfSampler {
13 pub fn new(tokenizer: &Tokenizer, schema: &str) -> Result<Self> {
14 let tokens = tokenizer
15 .token_index_to_bytes()
16 .iter()
17 .enumerate()
18 .filter(|(_, v)| !v.is_empty())
19 .map(|(k, v)| (k as u32, Token(v.clone().into_boxed_slice())))
20 .collect();
21 let strings = tokenizer
22 .token_index_to_bytes()
23 .iter()
24 .enumerate()
25 .filter(|(_, v)| !v.is_empty())
26 .map(|(k, v)| (k as u32, String::from_utf8_lossy(v).to_string()))
27 .collect();
28 let vocab = Vocabulary::new(tokens, strings)?;
29 let engine = Engine::new(schema, vocab)?;
30 Ok(Self(engine))
31 }
32}
33
34impl Formatter for BnfSampler {
35 fn transform(&self, output: &mut [f32]) {
36 let output = &mut output[..self.0.vocab().vocab_size()];
37 self.0.mask_logits(output).expect("bnf transform error")
38 }
39
40 fn update(&mut self, token: u32) -> bool {
41 let halt = match self.0.try_accept_new_token(token) {
42 Ok(AcceptTokenResult::Finished) | Err(AcceptTokenError::Finished) => true,
43 Ok(AcceptTokenResult::Ongoing) => false,
44 Err(_) => self.0.is_finished(),
45 };
46 self.0.compute_allowed_token_ids();
47 halt
48 }
49}