llm_tokenizer/
sequence.rs1use std::sync::Arc;
2
3use anyhow::Result;
4
5use crate::traits::{TokenIdType, Tokenizer as TokenizerTrait};
6
7pub struct Sequence {
19 tokenizer: Arc<dyn TokenizerTrait>,
21
22 token_ids: Vec<TokenIdType>,
25
26 total_tokens: usize,
29
30 prefix_index: usize,
33
34 cached_prefix: String,
37
38 skip_special_tokens: bool,
40}
41
42impl std::fmt::Debug for Sequence {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("Sequence")
45 .field("tokenizer", &"Arc<dyn Tokenizer>")
46 .field(
47 "token_ids",
48 &format_args!("{}", {
49 let token_ids = &self.token_ids;
50 if token_ids.len() <= 20 {
51 format!("{token_ids:?}")
52 } else {
53 let first_ten = &token_ids[..10];
54 let last_ten = &token_ids[token_ids.len() - 10..];
55 format!("{first_ten:?} ... {last_ten:?}")
56 }
57 }),
58 )
59 .field("prefix_index", &self.prefix_index)
60 .field("buffer_len", &self.token_ids.len())
61 .field("total_tokens", &self.total_tokens)
62 .finish()
63 }
64}
65
66impl Sequence {
67 pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
69 Self::new_with_options(tokenizer, false)
70 }
71
72 pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
74 Self {
75 tokenizer,
76 token_ids: Vec::new(),
77 total_tokens: 0,
78 prefix_index: 0,
79 cached_prefix: String::new(),
80 skip_special_tokens,
81 }
82 }
83
84 pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
86 Self::with_tokens_and_options(tokenizer, token_ids, false)
87 }
88
89 pub fn with_tokens_and_options(
91 tokenizer: Arc<dyn TokenizerTrait>,
92 token_ids: Vec<TokenIdType>,
93 skip_special_tokens: bool,
94 ) -> Self {
95 let len = token_ids.len();
96 Self {
97 tokenizer,
98 token_ids,
99 total_tokens: len,
100 prefix_index: 0,
101 cached_prefix: String::new(),
102 skip_special_tokens,
103 }
104 }
105
106 #[inline]
108 pub fn is_empty(&self) -> bool {
109 self.total_tokens == 0
110 }
111
112 #[inline]
114 pub fn len(&self) -> usize {
115 self.total_tokens
116 }
117
118 pub fn clear(&mut self) {
120 self.token_ids.clear();
121 self.total_tokens = 0;
122 self.prefix_index = 0;
123 self.cached_prefix.clear();
124 }
125
126 pub fn append_text(&mut self, input: &str, add_special_tokens: bool) -> Result<()> {
136 let encoding = self.tokenizer.encode(input, add_special_tokens)?;
137 let ids = encoding.token_ids();
138 self.token_ids.extend(ids);
139 self.total_tokens += ids.len();
140 Ok(())
141 }
142
143 #[inline]
150 pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
151 let result = self.tokenizer.decode_step(
152 token_id,
153 &mut self.token_ids,
154 &mut self.cached_prefix,
155 &mut self.prefix_index,
156 self.skip_special_tokens,
157 )?;
158 self.total_tokens += 1;
159 match result {
160 Some(text) => Ok(text),
161 None => Ok(String::new()),
162 }
163 }
164
165 #[inline]
167 pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
168 &self.tokenizer
169 }
170
171 #[inline]
173 pub fn token_ids(&self) -> &[TokenIdType] {
174 &self.token_ids
175 }
176
177 pub fn text(&self) -> Result<String> {
183 self.tokenizer
184 .decode(&self.token_ids, self.skip_special_tokens)
185 }
186
187 #[inline]
189 pub fn skip_special_tokens(&self) -> bool {
190 self.skip_special_tokens
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use crate::{mock::MockTokenizer, *};
197
198 #[test]
199 fn test_sequence_new() {
200 let tokenizer = Arc::new(MockTokenizer::new());
201 let seq = Sequence::new(tokenizer);
202 assert!(seq.is_empty());
203 assert_eq!(seq.len(), 0);
204 }
205
206 #[test]
207 fn test_sequence_append_text() {
208 let tokenizer = Arc::new(MockTokenizer::new());
209 let mut seq = Sequence::new(tokenizer);
210
211 seq.append_text("Hello", false).unwrap();
212 assert!(!seq.is_empty());
213
214 let text = seq.text().unwrap();
215 assert_eq!(text, "Hello");
216 }
217
218 #[test]
219 fn test_sequence_append_token() {
220 let tokenizer = Arc::new(MockTokenizer::new());
221 let mut seq = Sequence::new(tokenizer.clone());
222
223 let text1 = seq.append_token(1).unwrap();
225 assert_eq!(text1, "Hello");
226
227 let text2 = seq.append_token(2).unwrap();
230 assert_eq!(text2, " world");
232 }
233
234 #[test]
235 fn test_sequence_clear() {
236 let tokenizer = Arc::new(MockTokenizer::new());
237 let mut seq = Sequence::new(tokenizer);
238
239 seq.append_text("Hello world", false).unwrap();
240 assert!(!seq.is_empty());
241
242 seq.clear();
243 assert!(seq.is_empty());
244 assert_eq!(seq.len(), 0);
245 }
246
247 #[test]
248 fn test_sequence_debug() {
249 let tokenizer = Arc::new(MockTokenizer::new());
250 let mut seq = Sequence::new(tokenizer);
251
252 seq.append_text("Test", false).unwrap();
253 let debug_str = format!("{seq:?}");
254 assert!(debug_str.contains("Sequence"));
255 assert!(debug_str.contains("total_tokens"));
256 }
257
258 #[test]
259 fn test_sequence_token_drain() {
260 let tokenizer = Arc::new(MockTokenizer::new());
262 let mut seq = Sequence::new(tokenizer);
263
264 let mut output = String::new();
266 let mut all_token_ids = Vec::new();
267 for i in 0..100 {
268 let token_id = (i % 5) + 1; all_token_ids.push(token_id);
270 let text = seq.append_token(token_id).unwrap();
271 output.push_str(&text);
272 }
273
274 assert_eq!(seq.len(), 100);
276
277 assert!(
279 seq.token_ids().len() < 100,
280 "Token buffer should be drained, but has {} entries",
281 seq.token_ids().len()
282 );
283
284 let expected = seq.tokenizer().decode(&all_token_ids, false).unwrap();
286 assert_eq!(
287 output, expected,
288 "Drained incremental output must match full decode"
289 );
290 }
291}