candle_examples/
token_output_stream.rs

1use candle::Result;
2
3/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
4/// streaming way rather than having to wait for the full decoding.
5pub struct TokenOutputStream {
6    tokenizer: tokenizers::Tokenizer,
7    tokens: Vec<u32>,
8    prev_index: usize,
9    current_index: usize,
10}
11
12impl TokenOutputStream {
13    pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
14        Self {
15            tokenizer,
16            tokens: Vec::new(),
17            prev_index: 0,
18            current_index: 0,
19        }
20    }
21
22    pub fn into_inner(self) -> tokenizers::Tokenizer {
23        self.tokenizer
24    }
25
26    fn decode(&self, tokens: &[u32]) -> Result<String> {
27        match self.tokenizer.decode(tokens, true) {
28            Ok(str) => Ok(str),
29            Err(err) => candle::bail!("cannot decode: {err}"),
30        }
31    }
32
33    // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
34    pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
35        let prev_text = if self.tokens.is_empty() {
36            String::new()
37        } else {
38            let tokens = &self.tokens[self.prev_index..self.current_index];
39            self.decode(tokens)?
40        };
41        self.tokens.push(token);
42        let text = self.decode(&self.tokens[self.prev_index..])?;
43        if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
44            let text = text.split_at(prev_text.len());
45            self.prev_index = self.current_index;
46            self.current_index = self.tokens.len();
47            Ok(Some(text.1.to_string()))
48        } else {
49            Ok(None)
50        }
51    }
52
53    pub fn decode_rest(&self) -> Result<Option<String>> {
54        let prev_text = if self.tokens.is_empty() {
55            String::new()
56        } else {
57            let tokens = &self.tokens[self.prev_index..self.current_index];
58            self.decode(tokens)?
59        };
60        let text = self.decode(&self.tokens[self.prev_index..])?;
61        if text.len() > prev_text.len() {
62            let text = text.split_at(prev_text.len());
63            Ok(Some(text.1.to_string()))
64        } else {
65            Ok(None)
66        }
67    }
68
69    pub fn decode_all(&self) -> Result<String> {
70        self.decode(&self.tokens)
71    }
72
73    pub fn get_token(&self, token_s: &str) -> Option<u32> {
74        self.tokenizer.get_vocab(true).get(token_s).copied()
75    }
76
77    pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
78        &self.tokenizer
79    }
80
81    pub fn clear(&mut self) {
82        self.tokens.clear();
83        self.prev_index = 0;
84        self.current_index = 0;
85    }
86}