Skip to main content

llm_tokenizer/
stream.rs

1// src/tokenizer/stream.rs
2
3use std::sync::Arc;
4
5use anyhow::Result;
6
7use crate::traits::{self, TokenIdType};
8
9const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
10
11/// DecodeStream will keep the state necessary to produce individual chunks of
12/// strings given an input stream of token_ids
13pub struct DecodeStream {
14    /// The tokenizer used to decode token_ids
15    tokenizer: Arc<dyn traits::Tokenizer>,
16
17    skip_special_tokens: bool,
18
19    /// A temporary buffer of the necessary token_ids needed
20    /// to produce valid string chunks
21    all_token_ids: Vec<TokenIdType>,
22
23    prefix_offset: usize,
24    read_offset: usize,
25}
26
27impl DecodeStream {
28    pub fn new(
29        tokenizer: Arc<dyn traits::Tokenizer>,
30        prompt_token_ids: &[TokenIdType],
31        skip_special_tokens: bool,
32    ) -> Self {
33        let num_input_tokens = prompt_token_ids.len();
34        let prompt_token_ids = prompt_token_ids.to_vec();
35        Self {
36            tokenizer,
37            skip_special_tokens,
38            all_token_ids: prompt_token_ids,
39            prefix_offset: num_input_tokens
40                .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
41            read_offset: num_input_tokens,
42        }
43    }
44
45    /// Step appends a token_id to the internal state and tries to produce a text chunk.
46    /// Returning `None` means the given id is not enough to produce a chunk.
47    #[inline]
48    pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
49        self.all_token_ids.push(id);
50
51        let prefix_text = self.tokenizer.decode(
52            &self.all_token_ids[self.prefix_offset..self.read_offset],
53            self.skip_special_tokens,
54        )?;
55
56        let new_text = self.tokenizer.decode(
57            &self.all_token_ids[self.prefix_offset..],
58            self.skip_special_tokens,
59        )?;
60
61        if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
62            // Find the nearest char boundary at or before prefix_text.len()
63            // to avoid panicking on multi-byte UTF-8 sequences
64            let mut split_at = prefix_text.len();
65            while !new_text.is_char_boundary(split_at) && split_at > 0 {
66                split_at -= 1;
67            }
68
69            let new_text = new_text[split_at..].to_string();
70
71            self.prefix_offset = self.read_offset;
72            self.read_offset = self.all_token_ids.len();
73
74            Ok(Some(new_text))
75        } else {
76            Ok(None)
77        }
78    }
79
80    /// Process multiple tokens at once
81    pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
82        // Pre-allocate with capacity - most tokens produce output
83        let mut chunks = Vec::with_capacity(token_ids.len());
84
85        for &token_id in token_ids {
86            if let Some(text) = self.step(token_id)? {
87                chunks.push(text);
88            }
89        }
90
91        Ok(chunks)
92    }
93
94    /// Force flush any remaining text
95    pub fn flush(&mut self) -> Result<Option<String>> {
96        if self.read_offset < self.all_token_ids.len() {
97            let remaining = self.tokenizer.decode(
98                &self.all_token_ids[self.read_offset..],
99                self.skip_special_tokens,
100            )?;
101
102            self.read_offset = self.all_token_ids.len();
103
104            if !remaining.is_empty() {
105                return Ok(Some(remaining));
106            }
107        }
108
109        Ok(None)
110    }
111
112    /// Get all tokens processed so far
113    pub fn tokens(&self) -> &[u32] {
114        &self.all_token_ids
115    }
116}