1use std::{collections::HashSet, sync::Arc};
2
3use aho_corasick::AhoCorasick;
4use anyhow::Result;
5
6use crate::{
7 sequence::Sequence,
8 traits::{self, TokenIdType},
9};
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum SequenceDecoderOutput {
14 Text(String),
16 Held,
18 Stopped,
20 StoppedWithText(String),
22}
23
24#[derive(Debug, Clone, Default)]
26pub struct StopSequenceConfig {
27 pub stop_tokens: HashSet<TokenIdType>,
29 pub stop_sequences: Vec<String>,
31 pub visible_stop_tokens: HashSet<TokenIdType>,
33 pub visible_stop_sequences: Vec<String>,
35}
36
37impl StopSequenceConfig {
38 pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
40 self.stop_tokens.insert(token_id);
41 self
42 }
43
44 pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
46 self.stop_sequences.push(sequence.into());
47 self
48 }
49
50 pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
52 self.visible_stop_tokens.insert(token_id);
53 self
54 }
55
56 pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
58 self.visible_stop_sequences.push(sequence.into());
59 self
60 }
61}
62
63pub struct StopSequenceDecoder {
65 sequence: Sequence,
67 config: StopSequenceConfig,
68 aho_corasick: Option<AhoCorasick>,
70 visible_boundary_idx: usize,
73 jail_buffer: String,
75 stopped: bool,
77}
78
79impl StopSequenceDecoder {
80 pub fn new(
82 tokenizer: Arc<dyn traits::Tokenizer>,
83 config: StopSequenceConfig,
84 skip_special_tokens: bool,
85 ) -> Self {
86 let mut patterns: Vec<&str> = config
89 .stop_sequences
90 .iter()
91 .filter(|s| !s.is_empty())
92 .map(|s| s.as_str())
93 .collect();
94 let visible_boundary_idx = patterns.len();
95 patterns.extend(
96 config
97 .visible_stop_sequences
98 .iter()
99 .filter(|s| !s.is_empty())
100 .map(|s| s.as_str()),
101 );
102
103 let aho_corasick = if patterns.is_empty() {
104 None
105 } else {
106 #[expect(
109 clippy::expect_used,
110 reason = "AhoCorasick::new with pre-filtered non-empty &str patterns is practically infallible"
111 )]
112 Some(AhoCorasick::new(patterns).expect("Failed to build Aho-Corasick automaton"))
113 };
114
115 StopSequenceDecoder {
116 sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
117 config,
118 aho_corasick,
119 visible_boundary_idx,
120 jail_buffer: String::new(),
121 stopped: false,
122 }
123 }
124
125 pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
127 if self.stopped {
128 return Ok(SequenceDecoderOutput::Stopped);
129 }
130
131 if self.config.stop_tokens.contains(&token_id) {
133 self.stopped = true;
134
135 if !self.jail_buffer.is_empty() {
137 return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
138 &mut self.jail_buffer,
139 )));
140 }
141 return Ok(SequenceDecoderOutput::Stopped);
142 }
143
144 if self.config.visible_stop_tokens.contains(&token_id) {
145 self.stopped = true;
146
147 let stop_text = self
149 .sequence
150 .tokenizer()
151 .decode(&[token_id], self.sequence.skip_special_tokens())?;
152 let output = format!("{}{}", self.jail_buffer, stop_text);
153 self.jail_buffer.clear();
154 return Ok(SequenceDecoderOutput::StoppedWithText(output));
155 }
156
157 let new_text = self.sequence.append_token(token_id)?;
159
160 self.jail_buffer.push_str(&new_text);
161
162 if let Some(ac) = &self.aho_corasick {
164 if let Some(mat) = ac.find(&self.jail_buffer) {
165 self.stopped = true;
166 let is_visible = mat.pattern().as_usize() >= self.visible_boundary_idx;
167
168 if is_visible {
169 let output = self.jail_buffer[..mat.end()].to_string();
171 self.jail_buffer.clear();
172 return Ok(SequenceDecoderOutput::StoppedWithText(output));
173 } else {
174 let output = self.jail_buffer[..mat.start()].to_string();
176 self.jail_buffer.clear();
177 return Ok(if output.is_empty() {
178 SequenceDecoderOutput::Stopped
179 } else {
180 SequenceDecoderOutput::StoppedWithText(output)
181 });
182 }
183 }
184 }
185
186 let buffer_len = self.jail_buffer.len();
189 let mut best_split_pos: Option<usize> = None;
190
191 for stop_seq in self
192 .config
193 .stop_sequences
194 .iter()
195 .chain(&self.config.visible_stop_sequences)
196 {
197 let stop_len = stop_seq.len();
198
199 if stop_len <= 1 || buffer_len == 0 {
200 continue;
201 }
202
203 let max_len = buffer_len.min(stop_len - 1);
204
205 for len in (1..=max_len).rev() {
206 let suffix_start = buffer_len - len;
207
208 if !self.jail_buffer.is_char_boundary(suffix_start) {
209 continue;
210 }
211
212 let suffix = &self.jail_buffer[suffix_start..];
213
214 if stop_seq.starts_with(suffix)
215 && best_split_pos.is_none_or(|current| suffix_start < current)
216 {
217 best_split_pos = Some(suffix_start);
218 break;
219 }
220 }
221 }
222
223 if let Some(split_pos) = best_split_pos {
224 let suffix = self.jail_buffer.split_off(split_pos);
228 let to_output = std::mem::replace(&mut self.jail_buffer, suffix);
229
230 if to_output.is_empty() {
231 Ok(SequenceDecoderOutput::Held)
232 } else {
233 Ok(SequenceDecoderOutput::Text(to_output))
234 }
235 } else {
236 let output = std::mem::take(&mut self.jail_buffer);
238 if output.is_empty() {
239 Ok(SequenceDecoderOutput::Held)
240 } else {
241 Ok(SequenceDecoderOutput::Text(output))
242 }
243 }
244 }
245
246 pub fn process_tokens(
248 &mut self,
249 token_ids: &[TokenIdType],
250 ) -> Result<Vec<SequenceDecoderOutput>> {
251 let mut outputs = Vec::with_capacity(token_ids.len());
253 for &token_id in token_ids {
254 outputs.push(self.process_token(token_id)?);
255 }
256 Ok(outputs)
257 }
258
259 pub fn flush(&mut self) -> SequenceDecoderOutput {
261 if self.jail_buffer.is_empty() {
262 SequenceDecoderOutput::Text(String::new())
263 } else {
264 SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
266 }
267 }
268
269 pub fn is_stopped(&self) -> bool {
271 self.stopped
272 }
273
274 pub fn reset(&mut self) {
276 self.jail_buffer.clear();
277 self.sequence.clear();
278 self.stopped = false;
279 }
280}
281
282pub struct StopSequenceDecoderBuilder {
284 tokenizer: Arc<dyn traits::Tokenizer>,
285 config: StopSequenceConfig,
286 skip_special_tokens: bool,
287}
288
289impl StopSequenceDecoderBuilder {
290 pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
291 StopSequenceDecoderBuilder {
292 tokenizer,
293 config: StopSequenceConfig::default(),
294 skip_special_tokens: true,
295 }
296 }
297
298 pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
299 self.config.stop_tokens.insert(token_id);
300 self
301 }
302
303 pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
304 self.config.stop_sequences.push(sequence.into());
305 self
306 }
307
308 pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
309 self.config.visible_stop_tokens.insert(token_id);
310 self
311 }
312
313 pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
314 self.config.visible_stop_sequences.push(sequence.into());
315 self
316 }
317
318 pub fn skip_special_tokens(mut self, skip: bool) -> Self {
319 self.skip_special_tokens = skip;
320 self
321 }
322
323 pub fn build(self) -> StopSequenceDecoder {
324 StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use std::sync::Arc;
331
332 use super::StopSequenceDecoderBuilder;
333 use crate::{
334 mock::MockTokenizer, SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder,
335 };
336
337 #[test]
338 fn test_stop_token_detection() {
339 let tokenizer = Arc::new(MockTokenizer::new());
340 let config = StopSequenceConfig::default().with_stop_token(999); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
343
344 let result = decoder.process_token(1).unwrap(); assert!(matches!(result, SequenceDecoderOutput::Text(_)));
347
348 let result = decoder.process_token(999).unwrap(); assert_eq!(result, SequenceDecoderOutput::Stopped);
351
352 let result = decoder.process_token(2).unwrap();
354 assert_eq!(result, SequenceDecoderOutput::Stopped);
355 }
356
357 #[test]
358 fn test_visible_stop_token() {
359 let tokenizer = Arc::new(MockTokenizer::new());
360 let config = StopSequenceConfig::default().with_visible_stop_token(999);
361
362 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
363
364 let result = decoder.process_token(999).unwrap();
365 assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
366 }
367
368 #[test]
369 fn test_builder_pattern() {
370 let tokenizer = Arc::new(MockTokenizer::new());
371
372 let decoder = StopSequenceDecoderBuilder::new(tokenizer)
373 .stop_token(999)
374 .stop_sequence("STOP")
375 .visible_stop_token(1000)
376 .skip_special_tokens(true)
377 .build();
378
379 assert!(!decoder.is_stopped());
380 }
381
382 #[test]
383 fn test_incremental_decoding_no_repetition() {
384 let tokenizer = Arc::new(MockTokenizer::new());
386 let config = StopSequenceConfig::default();
387 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
388
389 let mut outputs = Vec::new();
391
392 let result = decoder.process_token(1).unwrap();
394 if let SequenceDecoderOutput::Text(text) = result {
395 outputs.push(text.clone());
396 }
397
398 let result = decoder.process_token(2).unwrap();
400 if let SequenceDecoderOutput::Text(text) = result {
401 outputs.push(text.clone());
402 }
403
404 let result = decoder.process_token(3).unwrap();
406 if let SequenceDecoderOutput::Text(text) = result {
407 outputs.push(text.clone());
408 }
409
410 assert_eq!(outputs.len(), 3);
413
414 for i in 0..outputs.len() {
415 for j in i + 1..outputs.len() {
416 assert!(!outputs[j].contains(&outputs[i]));
418 }
419 }
420 }
421
422 #[test]
423 fn test_stop_sequence_detection() {
424 let tokenizer = Arc::new(MockTokenizer::new());
425 let config = StopSequenceConfig::default().with_stop_sequence("test");
426 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
427
428 decoder.process_token(1).unwrap(); decoder.process_token(2).unwrap(); let result = decoder.process_token(3).unwrap(); assert!(matches!(
437 result,
438 SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
439 ));
440 }
441
442 #[test]
443 fn test_flush_after_partial() {
444 let tokenizer = Arc::new(MockTokenizer::new());
445 let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
446 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
447
448 decoder.process_token(1).unwrap(); let result = decoder.flush();
453
454 assert!(matches!(result, SequenceDecoderOutput::Text(_)));
456 }
457
458 #[test]
459 fn test_reset_functionality() {
460 let tokenizer = Arc::new(MockTokenizer::new());
461 let config = StopSequenceConfig::default().with_stop_token(999);
462 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
463
464 decoder.process_token(1).unwrap();
466 decoder.process_token(999).unwrap();
467 assert!(decoder.is_stopped());
468
469 decoder.reset();
471 assert!(!decoder.is_stopped());
472
473 let result = decoder.process_token(2).unwrap();
475 assert!(matches!(result, SequenceDecoderOutput::Text(_)));
476 }
477
478 #[test]
479 fn test_visible_stop_sequence() {
480 let tokenizer = Arc::new(MockTokenizer::new());
481 let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
482 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
483
484 decoder.process_token(1).unwrap();
486
487 let result = decoder.process_token(2).unwrap();
489
490 if let SequenceDecoderOutput::StoppedWithText(text) = result {
491 assert!(text.contains("world"));
493 } else {
494 panic!("Expected StoppedWithText with visible stop sequence");
495 }
496 }
497
498 #[test]
499 fn test_multiple_tokens_processing() {
500 let tokenizer = Arc::new(MockTokenizer::new());
501 let config = StopSequenceConfig::default();
502 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
503
504 let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
506
507 assert_eq!(results.len(), 3);
509
510 for result in results {
512 assert!(matches!(
513 result,
514 SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
515 ));
516 }
517 }
518
519 #[test]
529 fn test_stop_sequence_spanning_multiple_tokens() {
530 let tokenizer = Arc::new(MockTokenizer::new());
531
532 let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
534 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
535
536 let result1 = decoder.process_token(1).unwrap();
540 assert!(
541 matches!(result1, SequenceDecoderOutput::Held),
542 "Expected Held while jail buffer is a prefix of the stop sequence, got {result1:?}"
543 );
544 assert!(
545 !decoder.is_stopped(),
546 "Decoder should not be stopped after a partial match"
547 );
548
549 let result2 = decoder.process_token(2).unwrap();
552 assert_eq!(
553 result2,
554 SequenceDecoderOutput::Stopped,
555 "Expected Stopped when jail buffer matches the hidden stop sequence"
556 );
557 assert!(
558 decoder.is_stopped(),
559 "Decoder should be stopped after the full stop sequence match"
560 );
561
562 let result3 = decoder.process_token(3).unwrap();
564 assert_eq!(result3, SequenceDecoderOutput::Stopped);
565 }
566
567 #[test]
571 fn test_visible_stop_sequence_spanning_multiple_tokens() {
572 let tokenizer = Arc::new(MockTokenizer::new());
573
574 let config = StopSequenceConfig::default().with_visible_stop_sequence("Hello world");
575 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
576
577 let result1 = decoder.process_token(1).unwrap();
579 assert!(
580 matches!(result1, SequenceDecoderOutput::Held),
581 "Expected Held for partial visible stop sequence match, got {result1:?}"
582 );
583
584 let result2 = decoder.process_token(2).unwrap();
586 match &result2 {
587 SequenceDecoderOutput::StoppedWithText(text) => {
588 assert!(
589 text.contains("Hello world"),
590 "Visible stop output should contain the full stop sequence, got: {text:?}"
591 );
592 }
593 other => panic!("Expected StoppedWithText for visible stop sequence, got {other:?}"),
594 }
595 assert!(decoder.is_stopped());
596 }
597
598 #[test]
609 fn test_stop_sequence_spanning_tokens_with_preceding_text() {
610 let tokenizer = Arc::new(MockTokenizer::new());
611
612 let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
613 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
614
615 let result1 = decoder.process_token(3).unwrap();
617 assert!(
618 matches!(result1, SequenceDecoderOutput::Text(_)),
619 "Expected Text for token with no stop sequence overlap, got {result1:?}"
620 );
621
622 let result2 = decoder.process_token(1).unwrap();
627 match &result2 {
628 SequenceDecoderOutput::Text(text) => {
629 assert!(
631 !text.contains("Hello"),
632 "Partially-matched 'Hello' should be jailed, not emitted. Got: {text:?}"
633 );
634 }
635 SequenceDecoderOutput::Held => {
636 }
638 other => panic!("Expected Text (prefix before partial match) or Held, got {other:?}"),
639 }
640
641 let result3 = decoder.process_token(2).unwrap();
643 assert!(
644 matches!(
645 result3,
646 SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
647 ),
648 "Expected Stopped or StoppedWithText when stop sequence completes, got {result3:?}"
649 );
650 assert!(decoder.is_stopped());
651 }
652
653 #[test]
654 fn test_utf8_multibyte_character_boundaries() {
655 use crate::mock::MockTokenizer;
659
660 let tokenizer = Arc::new(MockTokenizer::new());
661
662 let config = StopSequenceConfig::default().with_stop_sequence(" ×");
664
665 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
666
667 let result = decoder.process_token(1); assert!(result.is_ok());
674
675 let result = decoder.process_token(2);
677 assert!(result.is_ok());
678 }
679
680 #[test]
681 fn test_utf8_multibyte_delta_character() {
682 let tokenizer = Arc::new(MockTokenizer::new());
685 let config = StopSequenceConfig::default().with_stop_sequence("Δ");
686
687 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
688
689 let result = decoder.process_token(1);
691 assert!(result.is_ok());
692 let result = decoder.process_token(2);
693 assert!(result.is_ok());
694 }
695
696 #[test]
697 fn test_utf8_multibyte_degree_character() {
698 let tokenizer = Arc::new(MockTokenizer::new());
701 let config = StopSequenceConfig::default().with_stop_sequence("°");
702
703 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
704
705 let result = decoder.process_token(1);
707 assert!(result.is_ok());
708 let result = decoder.process_token(2);
709 assert!(result.is_ok());
710 }
711
712 #[test]
713 fn test_utf8_multibyte_triangle_character() {
714 let tokenizer = Arc::new(MockTokenizer::new());
717 let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
718
719 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
720
721 let result = decoder.process_token(1);
723 assert!(result.is_ok());
724 let result = decoder.process_token(2);
725 assert!(result.is_ok());
726 let result = decoder.process_token(3);
727 assert!(result.is_ok());
728 }
729
730 #[test]
731 fn test_utf8_multibyte_en_dash_character() {
732 let tokenizer = Arc::new(MockTokenizer::new());
735 let config = StopSequenceConfig::default().with_stop_sequence(" –");
736
737 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
738
739 let result = decoder.process_token(1);
741 assert!(result.is_ok());
742 let result = decoder.process_token(2);
743 assert!(result.is_ok());
744 let result = decoder.process_token(3);
745 assert!(result.is_ok());
746 }
747
748 #[test]
749 fn test_utf8_multibyte_various_characters() {
750 let test_cases = vec![
753 ("×", "multiplication sign - 2 bytes"),
754 ("Δ", "Greek Delta - 2 bytes"),
755 ("°", "degree sign - 2 bytes"),
756 ("∆", "increment - 3 bytes"),
757 ("–", "en dash - 3 bytes"),
758 ("€", "euro sign - 3 bytes"),
759 ("中", "Chinese character - 3 bytes"),
760 ("🚀", "rocket emoji - 4 bytes"),
761 ("💡", "lightbulb emoji - 4 bytes"),
762 ];
763
764 for (stop_char, description) in test_cases {
765 let tokenizer = Arc::new(MockTokenizer::new());
766 let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
767
768 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
769
770 for token_id in 1..=5 {
772 let result = decoder.process_token(token_id);
773 assert!(
774 result.is_ok(),
775 "Failed on {description} with token {token_id}"
776 );
777 }
778 }
779 }
780}