1use std::{collections::HashSet, sync::Arc};
2
3use anyhow::Result;
4
5use crate::{
6 sequence::Sequence,
7 traits::{self, TokenIdType},
8};
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum SequenceDecoderOutput {
13 Text(String),
15 Held,
17 Stopped,
19 StoppedWithText(String),
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct StopSequenceConfig {
26 pub stop_tokens: HashSet<TokenIdType>,
28 pub stop_sequences: Vec<String>,
30 pub visible_stop_tokens: HashSet<TokenIdType>,
32 pub visible_stop_sequences: Vec<String>,
34}
35
36impl StopSequenceConfig {
37 pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
39 self.stop_tokens.insert(token_id);
40 self
41 }
42
43 pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
45 self.stop_sequences.push(sequence.into());
46 self
47 }
48
49 pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
51 self.visible_stop_tokens.insert(token_id);
52 self
53 }
54
55 pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
57 self.visible_stop_sequences.push(sequence.into());
58 self
59 }
60}
61
62pub struct StopSequenceDecoder {
64 sequence: Sequence,
66 config: StopSequenceConfig,
67 jail_buffer: String,
69 stopped: bool,
71}
72
73impl StopSequenceDecoder {
74 pub fn new(
76 tokenizer: Arc<dyn traits::Tokenizer>,
77 config: StopSequenceConfig,
78 skip_special_tokens: bool,
79 ) -> Self {
80 StopSequenceDecoder {
81 sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
82 config,
83 jail_buffer: String::new(),
84 stopped: false,
85 }
86 }
87
88 pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
90 if self.stopped {
91 return Ok(SequenceDecoderOutput::Stopped);
92 }
93
94 if self.config.stop_tokens.contains(&token_id) {
96 self.stopped = true;
97
98 if !self.jail_buffer.is_empty() {
100 return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
101 &mut self.jail_buffer,
102 )));
103 }
104 return Ok(SequenceDecoderOutput::Stopped);
105 }
106
107 if self.config.visible_stop_tokens.contains(&token_id) {
108 self.stopped = true;
109
110 let stop_text = self
112 .sequence
113 .tokenizer()
114 .decode(&[token_id], self.sequence.skip_special_tokens())?;
115 let output = format!("{}{}", self.jail_buffer, stop_text);
116 self.jail_buffer.clear();
117 return Ok(SequenceDecoderOutput::StoppedWithText(output));
118 }
119
120 let new_text = self.sequence.append_token(token_id)?;
122
123 self.jail_buffer.push_str(&new_text);
124
125 for stop_seq in &self.config.stop_sequences {
127 if let Some(pos) = self.jail_buffer.find(stop_seq) {
128 self.stopped = true;
129 let output = self.jail_buffer[..pos].to_string();
130 self.jail_buffer.clear();
131 return Ok(if output.is_empty() {
132 SequenceDecoderOutput::Stopped
133 } else {
134 SequenceDecoderOutput::StoppedWithText(output)
135 });
136 }
137 }
138
139 for stop_seq in &self.config.visible_stop_sequences {
141 if let Some(pos) = self.jail_buffer.find(stop_seq) {
142 self.stopped = true;
143 let end_pos = pos + stop_seq.len();
144 let output = self.jail_buffer[..end_pos].to_string();
145 self.jail_buffer.clear();
146 return Ok(SequenceDecoderOutput::StoppedWithText(output));
147 }
148 }
149
150 let buffer_len = self.jail_buffer.len();
153 let mut best_split_pos: Option<usize> = None;
154
155 for stop_seq in self
156 .config
157 .stop_sequences
158 .iter()
159 .chain(&self.config.visible_stop_sequences)
160 {
161 let stop_len = stop_seq.len();
162
163 if stop_len <= 1 || buffer_len == 0 {
164 continue;
165 }
166
167 let max_len = buffer_len.min(stop_len - 1);
168
169 for len in (1..=max_len).rev() {
170 let suffix_start = buffer_len - len;
171
172 if !self.jail_buffer.is_char_boundary(suffix_start) {
173 continue;
174 }
175
176 let suffix = &self.jail_buffer[suffix_start..];
177
178 if stop_seq.starts_with(suffix)
179 && best_split_pos.is_none_or(|current| suffix_start < current)
180 {
181 best_split_pos = Some(suffix_start);
182 break;
183 }
184 }
185 }
186
187 if let Some(split_pos) = best_split_pos {
188 let suffix = self.jail_buffer.split_off(split_pos);
192 let to_output = std::mem::replace(&mut self.jail_buffer, suffix);
193
194 if to_output.is_empty() {
195 Ok(SequenceDecoderOutput::Held)
196 } else {
197 Ok(SequenceDecoderOutput::Text(to_output))
198 }
199 } else {
200 let output = std::mem::take(&mut self.jail_buffer);
202 if output.is_empty() {
203 Ok(SequenceDecoderOutput::Held)
204 } else {
205 Ok(SequenceDecoderOutput::Text(output))
206 }
207 }
208 }
209
210 pub fn process_tokens(
212 &mut self,
213 token_ids: &[TokenIdType],
214 ) -> Result<Vec<SequenceDecoderOutput>> {
215 let mut outputs = Vec::with_capacity(token_ids.len());
217 for &token_id in token_ids {
218 outputs.push(self.process_token(token_id)?);
219 }
220 Ok(outputs)
221 }
222
223 pub fn flush(&mut self) -> SequenceDecoderOutput {
225 if !self.jail_buffer.is_empty() {
226 SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
228 } else {
229 SequenceDecoderOutput::Text(String::new())
230 }
231 }
232
233 pub fn is_stopped(&self) -> bool {
235 self.stopped
236 }
237
238 pub fn reset(&mut self) {
240 self.jail_buffer.clear();
241 self.sequence.clear();
242 self.stopped = false;
243 }
244}
245
246pub struct StopSequenceDecoderBuilder {
248 tokenizer: Arc<dyn traits::Tokenizer>,
249 config: StopSequenceConfig,
250 skip_special_tokens: bool,
251}
252
253impl StopSequenceDecoderBuilder {
254 pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
255 StopSequenceDecoderBuilder {
256 tokenizer,
257 config: StopSequenceConfig::default(),
258 skip_special_tokens: true,
259 }
260 }
261
262 pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
263 self.config.stop_tokens.insert(token_id);
264 self
265 }
266
267 pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
268 self.config.stop_sequences.push(sequence.into());
269 self
270 }
271
272 pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
273 self.config.visible_stop_tokens.insert(token_id);
274 self
275 }
276
277 pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
278 self.config.visible_stop_sequences.push(sequence.into());
279 self
280 }
281
282 pub fn skip_special_tokens(mut self, skip: bool) -> Self {
283 self.skip_special_tokens = skip;
284 self
285 }
286
287 pub fn build(self) -> StopSequenceDecoder {
288 StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use std::sync::Arc;
295
296 use super::StopSequenceDecoderBuilder;
297 use crate::{
298 mock::MockTokenizer, SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder,
299 };
300
301 #[test]
302 fn test_stop_token_detection() {
303 let tokenizer = Arc::new(MockTokenizer::new());
304 let config = StopSequenceConfig::default().with_stop_token(999); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
307
308 let result = decoder.process_token(1).unwrap(); assert!(matches!(result, SequenceDecoderOutput::Text(_)));
311
312 let result = decoder.process_token(999).unwrap(); assert_eq!(result, SequenceDecoderOutput::Stopped);
315
316 let result = decoder.process_token(2).unwrap();
318 assert_eq!(result, SequenceDecoderOutput::Stopped);
319 }
320
321 #[test]
322 fn test_visible_stop_token() {
323 let tokenizer = Arc::new(MockTokenizer::new());
324 let config = StopSequenceConfig::default().with_visible_stop_token(999);
325
326 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
327
328 let result = decoder.process_token(999).unwrap();
329 assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
330 }
331
332 #[test]
333 fn test_builder_pattern() {
334 let tokenizer = Arc::new(MockTokenizer::new());
335
336 let decoder = StopSequenceDecoderBuilder::new(tokenizer)
337 .stop_token(999)
338 .stop_sequence("STOP")
339 .visible_stop_token(1000)
340 .skip_special_tokens(true)
341 .build();
342
343 assert!(!decoder.is_stopped());
344 }
345
346 #[test]
347 fn test_incremental_decoding_no_repetition() {
348 let tokenizer = Arc::new(MockTokenizer::new());
350 let config = StopSequenceConfig::default();
351 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
352
353 let mut outputs = Vec::new();
355
356 let result = decoder.process_token(1).unwrap();
358 if let SequenceDecoderOutput::Text(text) = result {
359 outputs.push(text.clone());
360 }
361
362 let result = decoder.process_token(2).unwrap();
364 if let SequenceDecoderOutput::Text(text) = result {
365 outputs.push(text.clone());
366 }
367
368 let result = decoder.process_token(3).unwrap();
370 if let SequenceDecoderOutput::Text(text) = result {
371 outputs.push(text.clone());
372 }
373
374 assert_eq!(outputs.len(), 3);
377
378 for i in 0..outputs.len() {
379 for j in i + 1..outputs.len() {
380 assert!(!outputs[j].contains(&outputs[i]));
382 }
383 }
384 }
385
386 #[test]
387 fn test_stop_sequence_detection() {
388 let tokenizer = Arc::new(MockTokenizer::new());
389 let config = StopSequenceConfig::default().with_stop_sequence("test");
390 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
391
392 decoder.process_token(1).unwrap(); decoder.process_token(2).unwrap(); let result = decoder.process_token(3).unwrap(); assert!(matches!(
401 result,
402 SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
403 ));
404 }
405
406 #[test]
407 fn test_flush_after_partial() {
408 let tokenizer = Arc::new(MockTokenizer::new());
409 let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
410 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
411
412 decoder.process_token(1).unwrap(); let result = decoder.flush();
417
418 assert!(matches!(result, SequenceDecoderOutput::Text(_)));
420 }
421
422 #[test]
423 fn test_reset_functionality() {
424 let tokenizer = Arc::new(MockTokenizer::new());
425 let config = StopSequenceConfig::default().with_stop_token(999);
426 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
427
428 decoder.process_token(1).unwrap();
430 decoder.process_token(999).unwrap();
431 assert!(decoder.is_stopped());
432
433 decoder.reset();
435 assert!(!decoder.is_stopped());
436
437 let result = decoder.process_token(2).unwrap();
439 assert!(matches!(result, SequenceDecoderOutput::Text(_)));
440 }
441
442 #[test]
443 fn test_visible_stop_sequence() {
444 let tokenizer = Arc::new(MockTokenizer::new());
445 let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
446 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
447
448 decoder.process_token(1).unwrap();
450
451 let result = decoder.process_token(2).unwrap();
453
454 if let SequenceDecoderOutput::StoppedWithText(text) = result {
455 assert!(text.contains("world"));
457 } else {
458 panic!("Expected StoppedWithText with visible stop sequence");
459 }
460 }
461
462 #[test]
463 fn test_multiple_tokens_processing() {
464 let tokenizer = Arc::new(MockTokenizer::new());
465 let config = StopSequenceConfig::default();
466 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
467
468 let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
470
471 assert_eq!(results.len(), 3);
473
474 for result in results {
476 assert!(matches!(
477 result,
478 SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
479 ));
480 }
481 }
482
483 #[test]
484 fn test_utf8_multibyte_character_boundaries() {
485 use crate::mock::MockTokenizer;
489
490 let tokenizer = Arc::new(MockTokenizer::new());
491
492 let config = StopSequenceConfig::default().with_stop_sequence(" ×");
494
495 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
496
497 let result = decoder.process_token(1); assert!(result.is_ok());
504
505 let result = decoder.process_token(2);
507 assert!(result.is_ok());
508 }
509
510 #[test]
511 fn test_utf8_multibyte_delta_character() {
512 let tokenizer = Arc::new(MockTokenizer::new());
515 let config = StopSequenceConfig::default().with_stop_sequence("Δ");
516
517 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
518
519 let result = decoder.process_token(1);
521 assert!(result.is_ok());
522 let result = decoder.process_token(2);
523 assert!(result.is_ok());
524 }
525
526 #[test]
527 fn test_utf8_multibyte_degree_character() {
528 let tokenizer = Arc::new(MockTokenizer::new());
531 let config = StopSequenceConfig::default().with_stop_sequence("°");
532
533 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
534
535 let result = decoder.process_token(1);
537 assert!(result.is_ok());
538 let result = decoder.process_token(2);
539 assert!(result.is_ok());
540 }
541
542 #[test]
543 fn test_utf8_multibyte_triangle_character() {
544 let tokenizer = Arc::new(MockTokenizer::new());
547 let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
548
549 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
550
551 let result = decoder.process_token(1);
553 assert!(result.is_ok());
554 let result = decoder.process_token(2);
555 assert!(result.is_ok());
556 let result = decoder.process_token(3);
557 assert!(result.is_ok());
558 }
559
560 #[test]
561 fn test_utf8_multibyte_en_dash_character() {
562 let tokenizer = Arc::new(MockTokenizer::new());
565 let config = StopSequenceConfig::default().with_stop_sequence(" –");
566
567 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
568
569 let result = decoder.process_token(1);
571 assert!(result.is_ok());
572 let result = decoder.process_token(2);
573 assert!(result.is_ok());
574 let result = decoder.process_token(3);
575 assert!(result.is_ok());
576 }
577
578 #[test]
579 fn test_utf8_multibyte_various_characters() {
580 let test_cases = vec![
583 ("×", "multiplication sign - 2 bytes"),
584 ("Δ", "Greek Delta - 2 bytes"),
585 ("°", "degree sign - 2 bytes"),
586 ("∆", "increment - 3 bytes"),
587 ("–", "en dash - 3 bytes"),
588 ("€", "euro sign - 3 bytes"),
589 ("中", "Chinese character - 3 bytes"),
590 ("🚀", "rocket emoji - 4 bytes"),
591 ("💡", "lightbulb emoji - 4 bytes"),
592 ];
593
594 for (stop_char, description) in test_cases {
595 let tokenizer = Arc::new(MockTokenizer::new());
596 let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
597
598 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
599
600 for token_id in 1..=5 {
602 let result = decoder.process_token(token_id);
603 assert!(
604 result.is_ok(),
605 "Failed on {} with token {}",
606 description,
607 token_id
608 );
609 }
610 }
611 }
612}