1use std::{collections::HashSet, sync::Arc};
2
3use aho_corasick::{AhoCorasick, Input};
4use anyhow::Result;
5
6use crate::{
7 sequence::Sequence,
8 traits::{self, TokenIdType},
9};
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum SequenceDecoderOutput {
14 Text(String),
16 Held,
18 Stopped,
20 StoppedWithText(String),
22}
23
24#[derive(Debug, Clone, Default)]
26pub struct StopSequenceConfig {
27 pub stop_tokens: HashSet<TokenIdType>,
29 pub stop_sequences: Vec<String>,
31 pub visible_stop_tokens: HashSet<TokenIdType>,
33 pub visible_stop_sequences: Vec<String>,
35}
36
37impl StopSequenceConfig {
38 pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
40 self.stop_tokens.insert(token_id);
41 self
42 }
43
44 pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
46 self.stop_sequences.push(sequence.into());
47 self
48 }
49
50 pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
52 self.visible_stop_tokens.insert(token_id);
53 self
54 }
55
56 pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
58 self.visible_stop_sequences.push(sequence.into());
59 self
60 }
61}
62
63pub struct StopSequenceDecoder {
65 sequence: Sequence,
67 config: StopSequenceConfig,
68 aho_corasick: Option<AhoCorasick>,
70 visible_boundary_idx: usize,
73 jail_buffer: String,
75 jail_max_bytes: usize,
78 stopped: bool,
80 token_only: bool,
83}
84
85impl StopSequenceDecoder {
86 pub fn new(
88 tokenizer: Arc<dyn traits::Tokenizer>,
89 config: StopSequenceConfig,
90 skip_special_tokens: bool,
91 ) -> Self {
92 let mut patterns: Vec<&str> = config
95 .stop_sequences
96 .iter()
97 .filter(|s| !s.is_empty())
98 .map(|s| s.as_str())
99 .collect();
100 let visible_boundary_idx = patterns.len();
101 patterns.extend(
102 config
103 .visible_stop_sequences
104 .iter()
105 .filter(|s| !s.is_empty())
106 .map(|s| s.as_str()),
107 );
108
109 let jail_max_bytes = config
113 .stop_sequences
114 .iter()
115 .chain(&config.visible_stop_sequences)
116 .map(|s| s.len())
117 .max()
118 .unwrap_or(0);
119
120 let aho_corasick = if patterns.is_empty() {
121 None
122 } else {
123 #[expect(
126 clippy::expect_used,
127 reason = "AhoCorasick::new with pre-filtered non-empty &str patterns is practically infallible"
128 )]
129 Some(AhoCorasick::new(patterns).expect("Failed to build Aho-Corasick automaton"))
130 };
131
132 let token_only = aho_corasick.is_none();
133
134 StopSequenceDecoder {
135 sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
136 config,
137 aho_corasick,
138 visible_boundary_idx,
139 jail_buffer: String::new(),
140 jail_max_bytes,
141 stopped: false,
142 token_only,
143 }
144 }
145
146 pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
148 if self.stopped {
149 return Ok(SequenceDecoderOutput::Stopped);
150 }
151
152 if self.config.stop_tokens.contains(&token_id) {
154 self.stopped = true;
155
156 if !self.jail_buffer.is_empty() {
158 return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
159 &mut self.jail_buffer,
160 )));
161 }
162 return Ok(SequenceDecoderOutput::Stopped);
163 }
164
165 if self.config.visible_stop_tokens.contains(&token_id) {
166 self.stopped = true;
167
168 let stop_text = self
170 .sequence
171 .tokenizer()
172 .decode(&[token_id], self.sequence.skip_special_tokens())?;
173 let output = format!("{}{}", self.jail_buffer, stop_text);
174 self.jail_buffer.clear();
175 return Ok(SequenceDecoderOutput::StoppedWithText(output));
176 }
177
178 let new_text = self.sequence.append_token(token_id)?;
180
181 if self.token_only {
184 if new_text.is_empty() {
185 return Ok(SequenceDecoderOutput::Held);
186 }
187 return Ok(SequenceDecoderOutput::Text(new_text));
188 }
189
190 let old_len = self.jail_buffer.len();
191 self.jail_buffer.push_str(&new_text);
192
193 if let Some(ac) = &self.aho_corasick {
198 let search_start = if old_len >= self.jail_max_bytes {
199 let raw = old_len + 1 - self.jail_max_bytes;
201 let mut start = raw;
202 while start < self.jail_buffer.len() && !self.jail_buffer.is_char_boundary(start) {
203 start += 1;
204 }
205 start
206 } else {
207 0
208 };
209
210 let input = Input::new(&self.jail_buffer).span(search_start..self.jail_buffer.len());
211 if let Some(mat) = ac.find(input) {
212 self.stopped = true;
213 let is_visible = mat.pattern().as_usize() >= self.visible_boundary_idx;
214
215 if is_visible {
216 let output = self.jail_buffer[..mat.end()].to_string();
218 self.jail_buffer.clear();
219 return Ok(SequenceDecoderOutput::StoppedWithText(output));
220 } else {
221 let output = self.jail_buffer[..mat.start()].to_string();
223 self.jail_buffer.clear();
224 return Ok(if output.is_empty() {
225 SequenceDecoderOutput::Stopped
226 } else {
227 SequenceDecoderOutput::StoppedWithText(output)
228 });
229 }
230 }
231 }
232
233 if self.jail_buffer.len() > self.jail_max_bytes {
236 let mut drain_to = self.jail_buffer.len() - self.jail_max_bytes;
239 while drain_to > 0 && !self.jail_buffer.is_char_boundary(drain_to) {
240 drain_to -= 1;
242 }
243
244 if drain_to > 0 {
245 let suffix = self.jail_buffer.split_off(drain_to);
246 let to_output = std::mem::replace(&mut self.jail_buffer, suffix);
247 return Ok(SequenceDecoderOutput::Text(to_output));
248 }
249 }
250
251 Ok(SequenceDecoderOutput::Held)
253 }
254
255 pub fn process_tokens(
259 &mut self,
260 token_ids: &[TokenIdType],
261 ) -> Result<Vec<SequenceDecoderOutput>> {
262 let mut outputs = Vec::with_capacity(token_ids.len());
263 for &token_id in token_ids {
264 let output = self.process_token(token_id)?;
265 let done = matches!(
266 output,
267 SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
268 );
269 outputs.push(output);
270 if done {
271 break;
272 }
273 }
274 Ok(outputs)
275 }
276
277 pub fn flush(&mut self) -> SequenceDecoderOutput {
279 if self.jail_buffer.is_empty() {
280 SequenceDecoderOutput::Held
281 } else {
282 SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
284 }
285 }
286
287 pub fn is_stopped(&self) -> bool {
289 self.stopped
290 }
291
292 pub fn reset(&mut self) {
294 self.jail_buffer.clear();
295 self.sequence.clear();
296 self.stopped = false;
297 }
298}
299
300pub struct StopSequenceDecoderBuilder {
302 tokenizer: Arc<dyn traits::Tokenizer>,
303 config: StopSequenceConfig,
304 skip_special_tokens: bool,
305}
306
307impl StopSequenceDecoderBuilder {
308 pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
309 StopSequenceDecoderBuilder {
310 tokenizer,
311 config: StopSequenceConfig::default(),
312 skip_special_tokens: true,
313 }
314 }
315
316 pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
317 self.config.stop_tokens.insert(token_id);
318 self
319 }
320
321 pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
322 self.config.stop_sequences.push(sequence.into());
323 self
324 }
325
326 pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
327 self.config.visible_stop_tokens.insert(token_id);
328 self
329 }
330
331 pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
332 self.config.visible_stop_sequences.push(sequence.into());
333 self
334 }
335
336 pub fn skip_special_tokens(mut self, skip: bool) -> Self {
337 self.skip_special_tokens = skip;
338 self
339 }
340
341 pub fn build(self) -> StopSequenceDecoder {
342 StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use std::sync::Arc;
349
350 use super::StopSequenceDecoderBuilder;
351 use crate::{
352 mock::MockTokenizer, SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder,
353 };
354
355 #[test]
356 fn test_stop_token_detection() {
357 let tokenizer = Arc::new(MockTokenizer::new());
358 let config = StopSequenceConfig::default().with_stop_token(999); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
361
362 let result = decoder.process_token(1).unwrap(); assert!(matches!(result, SequenceDecoderOutput::Text(_)));
365
366 let result = decoder.process_token(999).unwrap(); assert_eq!(result, SequenceDecoderOutput::Stopped);
369
370 let result = decoder.process_token(2).unwrap();
372 assert_eq!(result, SequenceDecoderOutput::Stopped);
373 }
374
375 #[test]
376 fn test_visible_stop_token() {
377 let tokenizer = Arc::new(MockTokenizer::new());
378 let config = StopSequenceConfig::default().with_visible_stop_token(999);
379
380 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
381
382 let result = decoder.process_token(999).unwrap();
383 assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
384 }
385
386 #[test]
387 fn test_builder_pattern() {
388 let tokenizer = Arc::new(MockTokenizer::new());
389
390 let decoder = StopSequenceDecoderBuilder::new(tokenizer)
391 .stop_token(999)
392 .stop_sequence("STOP")
393 .visible_stop_token(1000)
394 .skip_special_tokens(true)
395 .build();
396
397 assert!(!decoder.is_stopped());
398 }
399
400 #[test]
401 fn test_incremental_decoding_no_repetition() {
402 let tokenizer = Arc::new(MockTokenizer::new());
404 let config = StopSequenceConfig::default();
405 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
406
407 let mut outputs = Vec::new();
409
410 let result = decoder.process_token(1).unwrap();
412 if let SequenceDecoderOutput::Text(text) = result {
413 outputs.push(text.clone());
414 }
415
416 let result = decoder.process_token(2).unwrap();
418 if let SequenceDecoderOutput::Text(text) = result {
419 outputs.push(text.clone());
420 }
421
422 let result = decoder.process_token(3).unwrap();
424 if let SequenceDecoderOutput::Text(text) = result {
425 outputs.push(text.clone());
426 }
427
428 assert_eq!(outputs.len(), 3);
431
432 for i in 0..outputs.len() {
433 for j in i + 1..outputs.len() {
434 assert!(!outputs[j].contains(&outputs[i]));
436 }
437 }
438 }
439
440 #[test]
441 fn test_stop_sequence_detection() {
442 let tokenizer = Arc::new(MockTokenizer::new());
443 let config = StopSequenceConfig::default().with_stop_sequence("test");
444 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
445
446 decoder.process_token(1).unwrap(); decoder.process_token(2).unwrap(); let result = decoder.process_token(3).unwrap(); assert!(matches!(
455 result,
456 SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
457 ));
458 }
459
460 #[test]
461 fn test_flush_after_partial() {
462 let tokenizer = Arc::new(MockTokenizer::new());
463 let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
464 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
465
466 decoder.process_token(1).unwrap(); let result = decoder.flush();
471
472 assert!(matches!(result, SequenceDecoderOutput::Text(_)));
474 }
475
476 #[test]
477 fn test_reset_functionality() {
478 let tokenizer = Arc::new(MockTokenizer::new());
479 let config = StopSequenceConfig::default().with_stop_token(999);
480 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
481
482 decoder.process_token(1).unwrap();
484 decoder.process_token(999).unwrap();
485 assert!(decoder.is_stopped());
486
487 decoder.reset();
489 assert!(!decoder.is_stopped());
490
491 let result = decoder.process_token(2).unwrap();
493 assert!(matches!(result, SequenceDecoderOutput::Text(_)));
494 }
495
496 #[test]
497 fn test_visible_stop_sequence() {
498 let tokenizer = Arc::new(MockTokenizer::new());
499 let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
500 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
501
502 decoder.process_token(1).unwrap();
504
505 let result = decoder.process_token(2).unwrap();
507
508 if let SequenceDecoderOutput::StoppedWithText(text) = result {
509 assert!(text.contains("world"));
511 } else {
512 panic!("Expected StoppedWithText with visible stop sequence");
513 }
514 }
515
516 #[test]
517 fn test_multiple_tokens_processing() {
518 let tokenizer = Arc::new(MockTokenizer::new());
519 let config = StopSequenceConfig::default();
520 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
521
522 let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
524
525 assert_eq!(results.len(), 3);
527
528 for result in results {
530 assert!(matches!(
531 result,
532 SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
533 ));
534 }
535 }
536
537 #[test]
547 fn test_stop_sequence_spanning_multiple_tokens() {
548 let tokenizer = Arc::new(MockTokenizer::new());
549
550 let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
552 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
553
554 let result1 = decoder.process_token(1).unwrap();
558 assert!(
559 matches!(result1, SequenceDecoderOutput::Held),
560 "Expected Held while jail buffer is a prefix of the stop sequence, got {result1:?}"
561 );
562 assert!(
563 !decoder.is_stopped(),
564 "Decoder should not be stopped after a partial match"
565 );
566
567 let result2 = decoder.process_token(2).unwrap();
570 assert_eq!(
571 result2,
572 SequenceDecoderOutput::Stopped,
573 "Expected Stopped when jail buffer matches the hidden stop sequence"
574 );
575 assert!(
576 decoder.is_stopped(),
577 "Decoder should be stopped after the full stop sequence match"
578 );
579
580 let result3 = decoder.process_token(3).unwrap();
582 assert_eq!(result3, SequenceDecoderOutput::Stopped);
583 }
584
585 #[test]
589 fn test_visible_stop_sequence_spanning_multiple_tokens() {
590 let tokenizer = Arc::new(MockTokenizer::new());
591
592 let config = StopSequenceConfig::default().with_visible_stop_sequence("Hello world");
593 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
594
595 let result1 = decoder.process_token(1).unwrap();
597 assert!(
598 matches!(result1, SequenceDecoderOutput::Held),
599 "Expected Held for partial visible stop sequence match, got {result1:?}"
600 );
601
602 let result2 = decoder.process_token(2).unwrap();
604 match &result2 {
605 SequenceDecoderOutput::StoppedWithText(text) => {
606 assert!(
607 text.contains("Hello world"),
608 "Visible stop output should contain the full stop sequence, got: {text:?}"
609 );
610 }
611 other => panic!("Expected StoppedWithText for visible stop sequence, got {other:?}"),
612 }
613 assert!(decoder.is_stopped());
614 }
615
616 #[test]
629 fn test_stop_sequence_spanning_tokens_with_preceding_text() {
630 let tokenizer = Arc::new(MockTokenizer::new());
631
632 let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
633 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
634
635 let result1 = decoder.process_token(3).unwrap();
637 assert!(
638 matches!(result1, SequenceDecoderOutput::Held),
639 "Expected Held for token within jail window, got {result1:?}"
640 );
641
642 let result2 = decoder.process_token(1).unwrap();
644 assert!(
645 matches!(result2, SequenceDecoderOutput::Held),
646 "Expected Held for token within jail window, got {result2:?}"
647 );
648
649 let result3 = decoder.process_token(2).unwrap();
652 assert!(
653 matches!(
654 result3,
655 SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
656 ),
657 "Expected Stopped or StoppedWithText when stop sequence completes, got {result3:?}"
658 );
659 assert!(decoder.is_stopped());
660
661 if let SequenceDecoderOutput::StoppedWithText(text) = &result3 {
663 assert!(
664 !text.contains("Hello world"),
665 "Hidden stop sequence should not appear in output, got: {text:?}"
666 );
667 }
668 }
669
670 #[test]
671 fn test_utf8_multibyte_character_boundaries() {
672 use crate::mock::MockTokenizer;
676
677 let tokenizer = Arc::new(MockTokenizer::new());
678
679 let config = StopSequenceConfig::default().with_stop_sequence(" ×");
681
682 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
683
684 let result = decoder.process_token(1); assert!(result.is_ok());
691
692 let result = decoder.process_token(2);
694 assert!(result.is_ok());
695 }
696
697 #[test]
698 fn test_utf8_multibyte_delta_character() {
699 let tokenizer = Arc::new(MockTokenizer::new());
702 let config = StopSequenceConfig::default().with_stop_sequence("Δ");
703
704 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
705
706 let result = decoder.process_token(1);
708 assert!(result.is_ok());
709 let result = decoder.process_token(2);
710 assert!(result.is_ok());
711 }
712
713 #[test]
714 fn test_utf8_multibyte_degree_character() {
715 let tokenizer = Arc::new(MockTokenizer::new());
718 let config = StopSequenceConfig::default().with_stop_sequence("°");
719
720 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
721
722 let result = decoder.process_token(1);
724 assert!(result.is_ok());
725 let result = decoder.process_token(2);
726 assert!(result.is_ok());
727 }
728
729 #[test]
730 fn test_utf8_multibyte_triangle_character() {
731 let tokenizer = Arc::new(MockTokenizer::new());
734 let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
735
736 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
737
738 let result = decoder.process_token(1);
740 assert!(result.is_ok());
741 let result = decoder.process_token(2);
742 assert!(result.is_ok());
743 let result = decoder.process_token(3);
744 assert!(result.is_ok());
745 }
746
747 #[test]
748 fn test_utf8_multibyte_en_dash_character() {
749 let tokenizer = Arc::new(MockTokenizer::new());
752 let config = StopSequenceConfig::default().with_stop_sequence(" –");
753
754 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
755
756 let result = decoder.process_token(1);
758 assert!(result.is_ok());
759 let result = decoder.process_token(2);
760 assert!(result.is_ok());
761 let result = decoder.process_token(3);
762 assert!(result.is_ok());
763 }
764
765 #[test]
766 fn test_utf8_multibyte_various_characters() {
767 let test_cases = vec![
770 ("×", "multiplication sign - 2 bytes"),
771 ("Δ", "Greek Delta - 2 bytes"),
772 ("°", "degree sign - 2 bytes"),
773 ("∆", "increment - 3 bytes"),
774 ("–", "en dash - 3 bytes"),
775 ("€", "euro sign - 3 bytes"),
776 ("中", "Chinese character - 3 bytes"),
777 ("🚀", "rocket emoji - 4 bytes"),
778 ("💡", "lightbulb emoji - 4 bytes"),
779 ];
780
781 for (stop_char, description) in test_cases {
782 let tokenizer = Arc::new(MockTokenizer::new());
783 let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
784
785 let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
786
787 for token_id in 1..=5 {
789 let result = decoder.process_token(token_id);
790 assert!(
791 result.is_ok(),
792 "Failed on {description} with token {token_id}"
793 );
794 }
795 }
796 }
797}