llm_tokenizer/
sequence.rs1use std::sync::Arc;
2
3use anyhow::Result;
4
5use crate::traits::{TokenIdType, Tokenizer as TokenizerTrait};
6
7pub struct Sequence {
10 tokenizer: Arc<dyn TokenizerTrait>,
12
13 token_ids: Vec<TokenIdType>,
15
16 prefix_offset: usize,
18
19 read_offset: usize,
21
22 skip_special_tokens: bool,
24}
25
26impl std::fmt::Debug for Sequence {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("Sequence")
29 .field("tokenizer", &"Arc<dyn Tokenizer>")
30 .field(
31 "token_ids",
32 &format_args!("{}", {
33 let token_ids = self.token_ids();
34 if token_ids.len() <= 20 {
35 format!("{token_ids:?}")
36 } else {
37 let first_ten = &token_ids[..10];
38 let last_ten = &token_ids[token_ids.len() - 10..];
39 format!("{first_ten:?} ... {last_ten:?}")
40 }
41 }),
42 )
43 .field("prefix_offset", &self.prefix_offset)
44 .field("read_offset", &self.read_offset)
45 .field("token count", &self.token_ids.len())
46 .finish()
47 }
48}
49
50impl Sequence {
51 pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
53 Self::new_with_options(tokenizer, false)
54 }
55
56 pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
58 Self {
59 tokenizer,
60 token_ids: Vec::new(),
61 prefix_offset: 0,
62 read_offset: 0,
63 skip_special_tokens,
64 }
65 }
66
67 pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
69 Self::with_tokens_and_options(tokenizer, token_ids, false)
70 }
71
72 pub fn with_tokens_and_options(
74 tokenizer: Arc<dyn TokenizerTrait>,
75 token_ids: Vec<TokenIdType>,
76 skip_special_tokens: bool,
77 ) -> Self {
78 let len = token_ids.len();
79 Self {
80 tokenizer,
81 token_ids,
82 prefix_offset: 0,
83 read_offset: len,
84 skip_special_tokens,
85 }
86 }
87
88 #[inline]
90 pub fn is_empty(&self) -> bool {
91 self.token_ids.is_empty()
92 }
93
94 #[inline]
96 pub fn len(&self) -> usize {
97 self.token_ids.len()
98 }
99
100 pub fn clear(&mut self) {
102 self.token_ids.clear();
103 self.prefix_offset = 0;
104 self.read_offset = 0;
105 }
106
107 pub fn append_text(&mut self, input: &str, add_special_tokens: bool) -> Result<()> {
112 let encoding = self.tokenizer.encode(input, add_special_tokens)?;
113 self.token_ids.extend(encoding.token_ids());
114 Ok(())
115 }
116
117 #[inline]
120 pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
121 let old_read_offset = self.read_offset;
123
124 self.token_ids.push(token_id);
125 self.read_offset = self.token_ids.len();
126
127 if self.prefix_offset == 0 && old_read_offset == 0 {
129 let text = self
130 .tokenizer
131 .decode(&self.token_ids, self.skip_special_tokens)?;
132 if text.ends_with("�") {
133 return Ok(String::new());
135 }
136 self.prefix_offset = 0;
137 return Ok(text);
138 }
139
140 let prefix_text = self.tokenizer.decode(
142 &self.token_ids[self.prefix_offset..old_read_offset],
143 self.skip_special_tokens,
144 )?;
145
146 let new_text = self.tokenizer.decode(
148 &self.token_ids[self.prefix_offset..],
149 self.skip_special_tokens,
150 )?;
151
152 let mut prefix_text_len = prefix_text.len();
154 while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
155 prefix_text_len -= 1;
156 }
157
158 if new_text.len() > prefix_text.len() {
159 if new_text.ends_with("�") {
160 return Ok(String::new());
162 } else {
163 let incremental_text = new_text[prefix_text_len..].to_string().replace("�", "");
165 self.prefix_offset = old_read_offset;
166 return Ok(incremental_text);
167 }
168 }
169
170 Ok(String::new())
171 }
172
173 #[inline]
175 pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
176 &self.tokenizer
177 }
178
179 #[inline]
181 pub fn token_ids(&self) -> &[TokenIdType] {
182 &self.token_ids
183 }
184
185 pub fn text(&self) -> Result<String> {
187 self.tokenizer
188 .decode(&self.token_ids, self.skip_special_tokens)
189 }
190
191 #[inline]
193 pub fn prefix_offset(&self) -> usize {
194 self.prefix_offset
195 }
196
197 #[inline]
199 pub fn read_offset(&self) -> usize {
200 self.read_offset
201 }
202
203 #[inline]
205 pub fn skip_special_tokens(&self) -> bool {
206 self.skip_special_tokens
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use crate::{mock::MockTokenizer, *};
213
214 #[test]
215 fn test_sequence_new() {
216 let tokenizer = Arc::new(MockTokenizer::new());
217 let seq = Sequence::new(tokenizer);
218 assert!(seq.is_empty());
219 assert_eq!(seq.len(), 0);
220 }
221
222 #[test]
223 fn test_sequence_append_text() {
224 let tokenizer = Arc::new(MockTokenizer::new());
225 let mut seq = Sequence::new(tokenizer);
226
227 seq.append_text("Hello", false).unwrap();
228 assert!(!seq.is_empty());
229 assert!(!seq.is_empty());
230
231 let text = seq.text().unwrap();
232 assert_eq!(text, "Hello");
233 }
234
235 #[test]
236 fn test_sequence_append_token() {
237 let tokenizer = Arc::new(MockTokenizer::new());
238 let mut seq = Sequence::new(tokenizer.clone());
239
240 let text1 = seq.append_token(1).unwrap();
242 assert_eq!(text1, "Hello");
243
244 let text2 = seq.append_token(2).unwrap();
247 assert_eq!(text2, " world");
249
250 assert_eq!(seq.text().unwrap(), "Hello world");
251 }
252
253 #[test]
254 fn test_sequence_clear() {
255 let tokenizer = Arc::new(MockTokenizer::new());
256 let mut seq = Sequence::new(tokenizer);
257
258 seq.append_text("Hello world", false).unwrap();
259 assert!(!seq.is_empty());
260
261 seq.clear();
262 assert!(seq.is_empty());
263 assert_eq!(seq.len(), 0);
264 assert_eq!(seq.prefix_offset(), 0);
265 assert_eq!(seq.read_offset(), 0);
266 }
267
268 #[test]
269 fn test_sequence_debug() {
270 let tokenizer = Arc::new(MockTokenizer::new());
271 let mut seq = Sequence::new(tokenizer);
272
273 seq.append_text("Test", false).unwrap();
274 let debug_str = format!("{seq:?}");
275 assert!(debug_str.contains("Sequence"));
276 assert!(debug_str.contains("token count"));
277 }
278}