1use std::sync::Arc;
4
5use anyhow::Result;
6
7use crate::traits::{self, TokenIdType};
8
9const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
10
11pub struct DecodeStream {
14 tokenizer: Arc<dyn traits::Tokenizer>,
16
17 skip_special_tokens: bool,
18
19 all_token_ids: Vec<TokenIdType>,
22
23 prefix_offset: usize,
24 read_offset: usize,
25}
26
27impl DecodeStream {
28 pub fn new(
29 tokenizer: Arc<dyn traits::Tokenizer>,
30 prompt_token_ids: &[TokenIdType],
31 skip_special_tokens: bool,
32 ) -> Self {
33 let num_input_tokens = prompt_token_ids.len();
34 let prompt_token_ids = prompt_token_ids.to_vec();
35 Self {
36 tokenizer,
37 skip_special_tokens,
38 all_token_ids: prompt_token_ids,
39 prefix_offset: num_input_tokens
40 .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
41 read_offset: num_input_tokens,
42 }
43 }
44
45 #[inline]
48 pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
49 self.all_token_ids.push(id);
50
51 let prefix_text = self.tokenizer.decode(
52 &self.all_token_ids[self.prefix_offset..self.read_offset],
53 self.skip_special_tokens,
54 )?;
55
56 let new_text = self.tokenizer.decode(
57 &self.all_token_ids[self.prefix_offset..],
58 self.skip_special_tokens,
59 )?;
60
61 if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
62 let mut split_at = prefix_text.len();
65 while !new_text.is_char_boundary(split_at) && split_at > 0 {
66 split_at -= 1;
67 }
68
69 let new_text = new_text[split_at..].to_string();
70
71 self.prefix_offset = self.read_offset;
72 self.read_offset = self.all_token_ids.len();
73
74 Ok(Some(new_text))
75 } else {
76 Ok(None)
77 }
78 }
79
80 pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
82 let mut chunks = Vec::with_capacity(token_ids.len());
84
85 for &token_id in token_ids {
86 if let Some(text) = self.step(token_id)? {
87 chunks.push(text);
88 }
89 }
90
91 Ok(chunks)
92 }
93
94 pub fn flush(&mut self) -> Result<Option<String>> {
96 if self.read_offset < self.all_token_ids.len() {
97 let remaining = self.tokenizer.decode(
98 &self.all_token_ids[self.read_offset..],
99 self.skip_special_tokens,
100 )?;
101
102 self.read_offset = self.all_token_ids.len();
103
104 if !remaining.is_empty() {
105 return Ok(Some(remaining));
106 }
107 }
108
109 Ok(None)
110 }
111
112 pub fn tokens(&self) -> &[u32] {
114 &self.all_token_ids
115 }
116}