1use std::sync::Arc;
4
5use anyhow::Result;
6
7use crate::traits::{self, TokenIdType};
8
9const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
10
11pub struct DecodeStream {
14 tokenizer: Arc<dyn traits::Tokenizer>,
16
17 skip_special_tokens: bool,
18
19 all_token_ids: Vec<TokenIdType>,
22
23 prefix_offset: usize,
24 read_offset: usize,
25}
26
27impl DecodeStream {
28 pub fn new(
29 tokenizer: Arc<dyn traits::Tokenizer>,
30 prompt_token_ids: &[TokenIdType],
31 skip_special_tokens: bool,
32 ) -> Self {
33 let num_input_tokens = prompt_token_ids.len();
34 let prompt_token_ids = prompt_token_ids.to_vec();
35 Self {
36 tokenizer,
37 skip_special_tokens,
38 all_token_ids: prompt_token_ids,
39 prefix_offset: num_input_tokens
40 .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
41 read_offset: num_input_tokens,
42 }
43 }
44
45 #[inline]
48 pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
49 self.all_token_ids.push(id);
50
51 let prefix_text = self.tokenizer.decode(
52 &self.all_token_ids[self.prefix_offset..self.read_offset],
53 self.skip_special_tokens,
54 )?;
55
56 let new_text = self.tokenizer.decode(
57 &self.all_token_ids[self.prefix_offset..],
58 self.skip_special_tokens,
59 )?;
60
61 if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
62 let new_text = new_text[prefix_text.len()..].to_string();
63
64 self.prefix_offset = self.read_offset;
65 self.read_offset = self.all_token_ids.len();
66
67 Ok(Some(new_text))
68 } else {
69 Ok(None)
70 }
71 }
72
73 pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
75 let mut chunks = Vec::with_capacity(token_ids.len());
77
78 for &token_id in token_ids {
79 if let Some(text) = self.step(token_id)? {
80 chunks.push(text);
81 }
82 }
83
84 Ok(chunks)
85 }
86
87 pub fn flush(&mut self) -> Result<Option<String>> {
89 if self.read_offset < self.all_token_ids.len() {
90 let remaining = self.tokenizer.decode(
91 &self.all_token_ids[self.read_offset..],
92 self.skip_special_tokens,
93 )?;
94
95 self.read_offset = self.all_token_ids.len();
96
97 if !remaining.is_empty() {
98 return Ok(Some(remaining));
99 }
100 }
101
102 Ok(None)
103 }
104
105 pub fn tokens(&self) -> &[u32] {
107 &self.all_token_ids
108 }
109}