candle_examples/
token_output_stream.rs1use candle::Result;
2
3pub struct TokenOutputStream {
6 tokenizer: tokenizers::Tokenizer,
7 tokens: Vec<u32>,
8 prev_index: usize,
9 current_index: usize,
10}
11
12impl TokenOutputStream {
13 pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
14 Self {
15 tokenizer,
16 tokens: Vec::new(),
17 prev_index: 0,
18 current_index: 0,
19 }
20 }
21
22 pub fn into_inner(self) -> tokenizers::Tokenizer {
23 self.tokenizer
24 }
25
26 fn decode(&self, tokens: &[u32]) -> Result<String> {
27 match self.tokenizer.decode(tokens, true) {
28 Ok(str) => Ok(str),
29 Err(err) => candle::bail!("cannot decode: {err}"),
30 }
31 }
32
33 pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
35 let prev_text = if self.tokens.is_empty() {
36 String::new()
37 } else {
38 let tokens = &self.tokens[self.prev_index..self.current_index];
39 self.decode(tokens)?
40 };
41 self.tokens.push(token);
42 let text = self.decode(&self.tokens[self.prev_index..])?;
43 if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
44 let text = text.split_at(prev_text.len());
45 self.prev_index = self.current_index;
46 self.current_index = self.tokens.len();
47 Ok(Some(text.1.to_string()))
48 } else {
49 Ok(None)
50 }
51 }
52
53 pub fn decode_rest(&self) -> Result<Option<String>> {
54 let prev_text = if self.tokens.is_empty() {
55 String::new()
56 } else {
57 let tokens = &self.tokens[self.prev_index..self.current_index];
58 self.decode(tokens)?
59 };
60 let text = self.decode(&self.tokens[self.prev_index..])?;
61 if text.len() > prev_text.len() {
62 let text = text.split_at(prev_text.len());
63 Ok(Some(text.1.to_string()))
64 } else {
65 Ok(None)
66 }
67 }
68
69 pub fn decode_all(&self) -> Result<String> {
70 self.decode(&self.tokens)
71 }
72
73 pub fn get_token(&self, token_s: &str) -> Option<u32> {
74 self.tokenizer.get_vocab(true).get(token_s).copied()
75 }
76
77 pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
78 &self.tokenizer
79 }
80
81 pub fn clear(&mut self) {
82 self.tokens.clear();
83 self.prev_index = 0;
84 self.current_index = 0;
85 }
86}