llm_tokenizer/
stream.rs

1// src/tokenizer/stream.rs
2
3use std::sync::Arc;
4
5use anyhow::Result;
6
7use crate::traits::{self, TokenIdType};
8
9const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
10
11/// DecodeStream will keep the state necessary to produce individual chunks of
12/// strings given an input stream of token_ids
13pub struct DecodeStream {
14    /// The tokenizer used to decode token_ids
15    tokenizer: Arc<dyn traits::Tokenizer>,
16
17    skip_special_tokens: bool,
18
19    /// A temporary buffer of the necessary token_ids needed
20    /// to produce valid string chunks
21    all_token_ids: Vec<TokenIdType>,
22
23    prefix_offset: usize,
24    read_offset: usize,
25}
26
27impl DecodeStream {
28    pub fn new(
29        tokenizer: Arc<dyn traits::Tokenizer>,
30        prompt_token_ids: &[TokenIdType],
31        skip_special_tokens: bool,
32    ) -> Self {
33        let num_input_tokens = prompt_token_ids.len();
34        let prompt_token_ids = prompt_token_ids.to_vec();
35        Self {
36            tokenizer,
37            skip_special_tokens,
38            all_token_ids: prompt_token_ids,
39            prefix_offset: num_input_tokens
40                .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
41            read_offset: num_input_tokens,
42        }
43    }
44
45    /// Step appends a token_id to the internal state and tries to produce a text chunk.
46    /// Returning `None` means the given id is not enough to produce a chunk.
47    #[inline]
48    pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
49        self.all_token_ids.push(id);
50
51        let prefix_text = self.tokenizer.decode(
52            &self.all_token_ids[self.prefix_offset..self.read_offset],
53            self.skip_special_tokens,
54        )?;
55
56        let new_text = self.tokenizer.decode(
57            &self.all_token_ids[self.prefix_offset..],
58            self.skip_special_tokens,
59        )?;
60
61        if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
62            let new_text = new_text[prefix_text.len()..].to_string();
63
64            self.prefix_offset = self.read_offset;
65            self.read_offset = self.all_token_ids.len();
66
67            Ok(Some(new_text))
68        } else {
69            Ok(None)
70        }
71    }
72
73    /// Process multiple tokens at once
74    pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
75        // Pre-allocate with capacity - most tokens produce output
76        let mut chunks = Vec::with_capacity(token_ids.len());
77
78        for &token_id in token_ids {
79            if let Some(text) = self.step(token_id)? {
80                chunks.push(text);
81            }
82        }
83
84        Ok(chunks)
85    }
86
87    /// Force flush any remaining text
88    pub fn flush(&mut self) -> Result<Option<String>> {
89        if self.read_offset < self.all_token_ids.len() {
90            let remaining = self.tokenizer.decode(
91                &self.all_token_ids[self.read_offset..],
92                self.skip_special_tokens,
93            )?;
94
95            self.read_offset = self.all_token_ids.len();
96
97            if !remaining.is_empty() {
98                return Ok(Some(remaining));
99            }
100        }
101
102        Ok(None)
103    }
104
105    /// Get all tokens processed so far
106    pub fn tokens(&self) -> &[u32] {
107        &self.all_token_ids
108    }
109}