Skip to main content

llm_tokenizer/
sequence.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4
5use crate::traits::{TokenIdType, Tokenizer as TokenizerTrait};
6
7/// Maintains state for an ongoing sequence of tokens and their decoded text.
8///
9/// Mirrors the design of the native `DecodeStream` in the HuggingFace `tokenizers`
10/// crate but works through the `dyn Tokenizer` trait so it supports all backends
11/// (HuggingFace, Tiktoken, Mock).
12///
13/// Key design decisions (matching native `DecodeStream`):
14/// - **Token draining**: consumed tokens are drained from the buffer after each
15///   successful step, keeping memory bounded regardless of generation length.
16/// - **Prefix caching**: the decoded prefix string is cached between calls,
17///   avoiding a redundant `decode()` on the next step.
18pub struct Sequence {
19    /// The tokenizer used for encoding/decoding
20    tokenizer: Arc<dyn TokenizerTrait>,
21
22    /// Sliding window of token ids needed for correct incremental decoding.
23    /// Consumed tokens are drained after each successful step.
24    token_ids: Vec<TokenIdType>,
25
26    /// Total number of tokens ever appended (does NOT reset on drain).
27    /// Used by callers that need the logical sequence length.
28    total_tokens: usize,
29
30    /// Index within `token_ids` that marks the start of the "prefix" window.
31    /// Everything before this has already been decoded and can be drained.
32    prefix_index: usize,
33
34    /// Cached decoded prefix string from the previous successful step.
35    /// On the next call we skip one `decode()` by reusing this.
36    cached_prefix: String,
37
38    /// Whether to skip special tokens when decoding
39    skip_special_tokens: bool,
40}
41
42impl std::fmt::Debug for Sequence {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("Sequence")
45            .field("tokenizer", &"Arc<dyn Tokenizer>")
46            .field(
47                "token_ids",
48                &format_args!("{}", {
49                    let token_ids = &self.token_ids;
50                    if token_ids.len() <= 20 {
51                        format!("{token_ids:?}")
52                    } else {
53                        let first_ten = &token_ids[..10];
54                        let last_ten = &token_ids[token_ids.len() - 10..];
55                        format!("{first_ten:?} ... {last_ten:?}")
56                    }
57                }),
58            )
59            .field("prefix_index", &self.prefix_index)
60            .field("buffer_len", &self.token_ids.len())
61            .field("total_tokens", &self.total_tokens)
62            .finish()
63    }
64}
65
66impl Sequence {
67    /// Create a new empty sequence
68    pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
69        Self::new_with_options(tokenizer, false)
70    }
71
72    /// Create a new empty sequence with skip_special_tokens option
73    pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
74        Self {
75            tokenizer,
76            token_ids: Vec::new(),
77            total_tokens: 0,
78            prefix_index: 0,
79            cached_prefix: String::new(),
80            skip_special_tokens,
81        }
82    }
83
84    /// Create a sequence with initial tokens
85    pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
86        Self::with_tokens_and_options(tokenizer, token_ids, false)
87    }
88
89    /// Create a sequence with initial tokens and skip_special_tokens option
90    pub fn with_tokens_and_options(
91        tokenizer: Arc<dyn TokenizerTrait>,
92        token_ids: Vec<TokenIdType>,
93        skip_special_tokens: bool,
94    ) -> Self {
95        let len = token_ids.len();
96        Self {
97            tokenizer,
98            token_ids,
99            total_tokens: len,
100            prefix_index: 0,
101            cached_prefix: String::new(),
102            skip_special_tokens,
103        }
104    }
105
106    /// Check if the sequence is empty
107    #[inline]
108    pub fn is_empty(&self) -> bool {
109        self.total_tokens == 0
110    }
111
112    /// Get the total number of tokens appended (logical length, not buffer size)
113    #[inline]
114    pub fn len(&self) -> usize {
115        self.total_tokens
116    }
117
118    /// Clear the sequence
119    pub fn clear(&mut self) {
120        self.token_ids.clear();
121        self.total_tokens = 0;
122        self.prefix_index = 0;
123        self.cached_prefix.clear();
124    }
125
126    /// Append text to the sequence by encoding it.
127    ///
128    /// WARNING: Do not mix `append_text()` and `append_token()` on the same
129    /// instance. `append_text()` does not invalidate the incremental decode
130    /// cache (`cached_prefix`/`prefix_index`), so subsequent `append_token()`
131    /// calls would diff against stale state.
132    ///
133    /// Set `add_special_tokens` to `true` for embeddings, or `false` for chat completion
134    /// where the chat template already handles special tokens.
135    pub fn append_text(&mut self, input: &str, add_special_tokens: bool) -> Result<()> {
136        let encoding = self.tokenizer.encode(input, add_special_tokens)?;
137        let ids = encoding.token_ids();
138        self.token_ids.extend(ids);
139        self.total_tokens += ids.len();
140        Ok(())
141    }
142
143    /// Append a single token to the sequence and return newly decoded text.
144    ///
145    /// Delegates to `Decoder::decode_step` on the tokenizer trait. For HuggingFace
146    /// tokenizers this uses the native `step_decode_stream`; other backends use the
147    /// default double-decode fallback. Both paths handle token draining and prefix
148    /// caching internally.
149    #[inline]
150    pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
151        let result = self.tokenizer.decode_step(
152            token_id,
153            &mut self.token_ids,
154            &mut self.cached_prefix,
155            &mut self.prefix_index,
156            self.skip_special_tokens,
157        )?;
158        self.total_tokens += 1;
159        match result {
160            Some(text) => Ok(text),
161            None => Ok(String::new()),
162        }
163    }
164
165    /// Get a reference to the tokenizer
166    #[inline]
167    pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
168        &self.tokenizer
169    }
170
171    /// Get the current token ids in the buffer (sliding window, not full history)
172    #[inline]
173    pub fn token_ids(&self) -> &[TokenIdType] {
174        &self.token_ids
175    }
176
177    /// Decode the current buffer to text.
178    ///
179    /// WARNING: after `append_token()` calls, this only decodes the sliding
180    /// window (retained tokens), not the full sequence history. Use the
181    /// incremental return values from `append_token()` to build the full text.
182    pub fn text(&self) -> Result<String> {
183        self.tokenizer
184            .decode(&self.token_ids, self.skip_special_tokens)
185    }
186
187    /// Get whether special tokens are skipped during decoding
188    #[inline]
189    pub fn skip_special_tokens(&self) -> bool {
190        self.skip_special_tokens
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::{mock::MockTokenizer, *};
197
198    #[test]
199    fn test_sequence_new() {
200        let tokenizer = Arc::new(MockTokenizer::new());
201        let seq = Sequence::new(tokenizer);
202        assert!(seq.is_empty());
203        assert_eq!(seq.len(), 0);
204    }
205
206    #[test]
207    fn test_sequence_append_text() {
208        let tokenizer = Arc::new(MockTokenizer::new());
209        let mut seq = Sequence::new(tokenizer);
210
211        seq.append_text("Hello", false).unwrap();
212        assert!(!seq.is_empty());
213
214        let text = seq.text().unwrap();
215        assert_eq!(text, "Hello");
216    }
217
218    #[test]
219    fn test_sequence_append_token() {
220        let tokenizer = Arc::new(MockTokenizer::new());
221        let mut seq = Sequence::new(tokenizer.clone());
222
223        // Start with an empty sequence and append token 1 ("Hello")
224        let text1 = seq.append_token(1).unwrap();
225        assert_eq!(text1, "Hello");
226
227        // Now append token 2 ("world")
228        // The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
229        let text2 = seq.append_token(2).unwrap();
230        // The incremental text should be " world" (with the space that the mock tokenizer adds)
231        assert_eq!(text2, " world");
232    }
233
234    #[test]
235    fn test_sequence_clear() {
236        let tokenizer = Arc::new(MockTokenizer::new());
237        let mut seq = Sequence::new(tokenizer);
238
239        seq.append_text("Hello world", false).unwrap();
240        assert!(!seq.is_empty());
241
242        seq.clear();
243        assert!(seq.is_empty());
244        assert_eq!(seq.len(), 0);
245    }
246
247    #[test]
248    fn test_sequence_debug() {
249        let tokenizer = Arc::new(MockTokenizer::new());
250        let mut seq = Sequence::new(tokenizer);
251
252        seq.append_text("Test", false).unwrap();
253        let debug_str = format!("{seq:?}");
254        assert!(debug_str.contains("Sequence"));
255        assert!(debug_str.contains("total_tokens"));
256    }
257
258    #[test]
259    fn test_sequence_token_drain() {
260        // Verify that the token buffer stays bounded after many appends
261        let tokenizer = Arc::new(MockTokenizer::new());
262        let mut seq = Sequence::new(tokenizer);
263
264        // Append many tokens and accumulate decoded output
265        let mut output = String::new();
266        let mut all_token_ids = Vec::new();
267        for i in 0..100 {
268            let token_id = (i % 5) + 1; // cycle through mock tokens
269            all_token_ids.push(token_id);
270            let text = seq.append_token(token_id).unwrap();
271            output.push_str(&text);
272        }
273
274        // Logical length should reflect all tokens
275        assert_eq!(seq.len(), 100);
276
277        // Buffer should be much smaller than 100 due to draining
278        assert!(
279            seq.token_ids().len() < 100,
280            "Token buffer should be drained, but has {} entries",
281            seq.token_ids().len()
282        );
283
284        // Accumulated output must match a full decode of all tokens
285        let expected = seq.tokenizer().decode(&all_token_ids, false).unwrap();
286        assert_eq!(
287            output, expected,
288            "Drained incremental output must match full decode"
289        );
290    }
291}