Skip to main content

llm_tokenizer/
sequence.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4
5use crate::traits::{TokenIdType, Tokenizer as TokenizerTrait};
6
7/// Maintains state for an ongoing sequence of tokens and their decoded text
8/// This provides a cleaner abstraction for managing token sequences
9pub struct Sequence {
10    /// The tokenizer used for encoding/decoding
11    tokenizer: Arc<dyn TokenizerTrait>,
12
13    /// The current sequence of token ids
14    token_ids: Vec<TokenIdType>,
15
16    /// The position in the current sequence the last decoded token completed
17    prefix_offset: usize,
18
19    /// Current position in the sequence
20    read_offset: usize,
21
22    /// Whether to skip special tokens when decoding
23    skip_special_tokens: bool,
24}
25
26impl std::fmt::Debug for Sequence {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("Sequence")
29            .field("tokenizer", &"Arc<dyn Tokenizer>")
30            .field(
31                "token_ids",
32                &format_args!("{}", {
33                    let token_ids = self.token_ids();
34                    if token_ids.len() <= 20 {
35                        format!("{token_ids:?}")
36                    } else {
37                        let first_ten = &token_ids[..10];
38                        let last_ten = &token_ids[token_ids.len() - 10..];
39                        format!("{first_ten:?} ... {last_ten:?}")
40                    }
41                }),
42            )
43            .field("prefix_offset", &self.prefix_offset)
44            .field("read_offset", &self.read_offset)
45            .field("token count", &self.token_ids.len())
46            .finish()
47    }
48}
49
50impl Sequence {
51    /// Create a new empty sequence
52    pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
53        Self::new_with_options(tokenizer, false)
54    }
55
56    /// Create a new empty sequence with skip_special_tokens option
57    pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
58        Self {
59            tokenizer,
60            token_ids: Vec::new(),
61            prefix_offset: 0,
62            read_offset: 0,
63            skip_special_tokens,
64        }
65    }
66
67    /// Create a sequence with initial tokens
68    pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
69        Self::with_tokens_and_options(tokenizer, token_ids, false)
70    }
71
72    /// Create a sequence with initial tokens and skip_special_tokens option
73    pub fn with_tokens_and_options(
74        tokenizer: Arc<dyn TokenizerTrait>,
75        token_ids: Vec<TokenIdType>,
76        skip_special_tokens: bool,
77    ) -> Self {
78        let len = token_ids.len();
79        Self {
80            tokenizer,
81            token_ids,
82            prefix_offset: 0,
83            read_offset: len,
84            skip_special_tokens,
85        }
86    }
87
88    /// Check if the sequence is empty
89    #[inline]
90    pub fn is_empty(&self) -> bool {
91        self.token_ids.is_empty()
92    }
93
94    /// Get the length of the sequence
95    #[inline]
96    pub fn len(&self) -> usize {
97        self.token_ids.len()
98    }
99
100    /// Clear the sequence
101    pub fn clear(&mut self) {
102        self.token_ids.clear();
103        self.prefix_offset = 0;
104        self.read_offset = 0;
105    }
106
107    /// Append text to the sequence by encoding it
108    ///
109    /// Set `add_special_tokens` to `true` for embeddings, or `false` for chat completion
110    /// where the chat template already handles special tokens.
111    pub fn append_text(&mut self, input: &str, add_special_tokens: bool) -> Result<()> {
112        let encoding = self.tokenizer.encode(input, add_special_tokens)?;
113        self.token_ids.extend(encoding.token_ids());
114        Ok(())
115    }
116
117    /// Append a single token to the sequence and return newly decoded text
118    /// Based on HuggingFace TGI incremental decoding
119    #[inline]
120    pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
121        // Store the old read offset before adding the new token
122        let old_read_offset = self.read_offset;
123
124        self.token_ids.push(token_id);
125        self.read_offset = self.token_ids.len();
126
127        // If this is the first token or we're at the beginning, decode everything
128        if self.prefix_offset == 0 && old_read_offset == 0 {
129            let text = self
130                .tokenizer
131                .decode(&self.token_ids, self.skip_special_tokens)?;
132            if text.ends_with("�") {
133                // Incomplete UTF-8 sequence, wait for more tokens
134                return Ok(String::new());
135            }
136            self.prefix_offset = 0;
137            return Ok(text);
138        }
139
140        // Decode the text up to the previous position
141        let prefix_text = self.tokenizer.decode(
142            &self.token_ids[self.prefix_offset..old_read_offset],
143            self.skip_special_tokens,
144        )?;
145
146        // Decode the text including the new token
147        let new_text = self.tokenizer.decode(
148            &self.token_ids[self.prefix_offset..],
149            self.skip_special_tokens,
150        )?;
151
152        // Handle multi-byte character boundaries
153        let mut prefix_text_len = prefix_text.len();
154        while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
155            prefix_text_len -= 1;
156        }
157
158        if new_text.len() > prefix_text.len() {
159            if new_text.ends_with("�") {
160                // Incomplete UTF-8 sequence, wait for more tokens
161                return Ok(String::new());
162            } else {
163                // Return the new text portion
164                let incremental_text = new_text[prefix_text_len..].to_string().replace("�", "");
165                self.prefix_offset = old_read_offset;
166                return Ok(incremental_text);
167            }
168        }
169
170        Ok(String::new())
171    }
172
173    /// Get a reference to the tokenizer
174    #[inline]
175    pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
176        &self.tokenizer
177    }
178
179    /// Get the current token ids
180    #[inline]
181    pub fn token_ids(&self) -> &[TokenIdType] {
182        &self.token_ids
183    }
184
185    /// Decode the entire sequence to text
186    pub fn text(&self) -> Result<String> {
187        self.tokenizer
188            .decode(&self.token_ids, self.skip_special_tokens)
189    }
190
191    /// Get the prefix offset
192    #[inline]
193    pub fn prefix_offset(&self) -> usize {
194        self.prefix_offset
195    }
196
197    /// Get the read offset
198    #[inline]
199    pub fn read_offset(&self) -> usize {
200        self.read_offset
201    }
202
203    /// Get whether special tokens are skipped during decoding
204    #[inline]
205    pub fn skip_special_tokens(&self) -> bool {
206        self.skip_special_tokens
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use crate::{mock::MockTokenizer, *};
213
214    #[test]
215    fn test_sequence_new() {
216        let tokenizer = Arc::new(MockTokenizer::new());
217        let seq = Sequence::new(tokenizer);
218        assert!(seq.is_empty());
219        assert_eq!(seq.len(), 0);
220    }
221
222    #[test]
223    fn test_sequence_append_text() {
224        let tokenizer = Arc::new(MockTokenizer::new());
225        let mut seq = Sequence::new(tokenizer);
226
227        seq.append_text("Hello", false).unwrap();
228        assert!(!seq.is_empty());
229        assert!(!seq.is_empty());
230
231        let text = seq.text().unwrap();
232        assert_eq!(text, "Hello");
233    }
234
235    #[test]
236    fn test_sequence_append_token() {
237        let tokenizer = Arc::new(MockTokenizer::new());
238        let mut seq = Sequence::new(tokenizer.clone());
239
240        // Start with an empty sequence and append token 1 ("Hello")
241        let text1 = seq.append_token(1).unwrap();
242        assert_eq!(text1, "Hello");
243
244        // Now append token 2 ("world")
245        // The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
246        let text2 = seq.append_token(2).unwrap();
247        // The incremental text should be " world" (with the space that the mock tokenizer adds)
248        assert_eq!(text2, " world");
249
250        assert_eq!(seq.text().unwrap(), "Hello world");
251    }
252
253    #[test]
254    fn test_sequence_clear() {
255        let tokenizer = Arc::new(MockTokenizer::new());
256        let mut seq = Sequence::new(tokenizer);
257
258        seq.append_text("Hello world", false).unwrap();
259        assert!(!seq.is_empty());
260
261        seq.clear();
262        assert!(seq.is_empty());
263        assert_eq!(seq.len(), 0);
264        assert_eq!(seq.prefix_offset(), 0);
265        assert_eq!(seq.read_offset(), 0);
266    }
267
268    #[test]
269    fn test_sequence_debug() {
270        let tokenizer = Arc::new(MockTokenizer::new());
271        let mut seq = Sequence::new(tokenizer);
272
273        seq.append_text("Test", false).unwrap();
274        let debug_str = format!("{seq:?}");
275        assert!(debug_str.contains("Sequence"));
276        assert!(debug_str.contains("token count"));
277    }
278}