llm_tokenizer/
stop.rs

1use std::{collections::HashSet, sync::Arc};
2
3use anyhow::Result;
4
5use crate::{
6    sequence::Sequence,
7    traits::{self, TokenIdType},
8};
9
10/// Output from the sequence decoder
11#[derive(Debug, Clone, PartialEq)]
12pub enum SequenceDecoderOutput {
13    /// Normal text output
14    Text(String),
15    /// Text is being held due to partial stop sequence match
16    Held,
17    /// Stop sequence matched (hidden - not included in output)
18    Stopped,
19    /// Stop sequence matched with text (visible - included in output)
20    StoppedWithText(String),
21}
22
23/// Configuration for stop sequences
24#[derive(Debug, Clone, Default)]
25pub struct StopSequenceConfig {
26    /// Token IDs that trigger a stop
27    pub stop_tokens: HashSet<TokenIdType>,
28    /// String sequences that trigger a stop
29    pub stop_sequences: Vec<String>,
30    /// Token IDs for visible stops (included in output)
31    pub visible_stop_tokens: HashSet<TokenIdType>,
32    /// String sequences for visible stops (included in output)
33    pub visible_stop_sequences: Vec<String>,
34}
35
36impl StopSequenceConfig {
37    /// Builder pattern - add a stop token
38    pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
39        self.stop_tokens.insert(token_id);
40        self
41    }
42
43    /// Builder pattern - add a stop sequence
44    pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
45        self.stop_sequences.push(sequence.into());
46        self
47    }
48
49    /// Builder pattern - add a visible stop token
50    pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
51        self.visible_stop_tokens.insert(token_id);
52        self
53    }
54
55    /// Builder pattern - add a visible stop sequence
56    pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
57        self.visible_stop_sequences.push(sequence.into());
58        self
59    }
60}
61
62/// Decoder that handles stop sequences
63pub struct StopSequenceDecoder {
64    /// Sequence for incremental decoding (replaces token_buffer + offsets)
65    sequence: Sequence,
66    config: StopSequenceConfig,
67    /// Buffer for partial matches (the "jail")
68    jail_buffer: String,
69    /// Whether we've stopped
70    stopped: bool,
71}
72
73impl StopSequenceDecoder {
74    /// Create a new stop sequence decoder
75    pub fn new(
76        tokenizer: Arc<dyn traits::Tokenizer>,
77        config: StopSequenceConfig,
78        skip_special_tokens: bool,
79    ) -> Self {
80        StopSequenceDecoder {
81            sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
82            config,
83            jail_buffer: String::new(),
84            stopped: false,
85        }
86    }
87
88    /// Process a single token
89    pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
90        if self.stopped {
91            return Ok(SequenceDecoderOutput::Stopped);
92        }
93
94        // Check for token-level stops first
95        if self.config.stop_tokens.contains(&token_id) {
96            self.stopped = true;
97
98            // Flush any jailed text before stopping - use mem::take to avoid clone
99            if !self.jail_buffer.is_empty() {
100                return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
101                    &mut self.jail_buffer,
102                )));
103            }
104            return Ok(SequenceDecoderOutput::Stopped);
105        }
106
107        if self.config.visible_stop_tokens.contains(&token_id) {
108            self.stopped = true;
109
110            // Include jailed text plus the stop token
111            let stop_text = self
112                .sequence
113                .tokenizer()
114                .decode(&[token_id], self.sequence.skip_special_tokens())?;
115            let output = format!("{}{}", self.jail_buffer, stop_text);
116            self.jail_buffer.clear();
117            return Ok(SequenceDecoderOutput::StoppedWithText(output));
118        }
119
120        // Use Sequence for incremental decoding
121        let new_text = self.sequence.append_token(token_id)?;
122
123        self.jail_buffer.push_str(&new_text);
124
125        // Check for hidden stop sequences
126        for stop_seq in &self.config.stop_sequences {
127            if let Some(pos) = self.jail_buffer.find(stop_seq) {
128                self.stopped = true;
129                let output = self.jail_buffer[..pos].to_string();
130                self.jail_buffer.clear();
131                return Ok(if output.is_empty() {
132                    SequenceDecoderOutput::Stopped
133                } else {
134                    SequenceDecoderOutput::StoppedWithText(output)
135                });
136            }
137        }
138
139        // Check for visible stop sequences
140        for stop_seq in &self.config.visible_stop_sequences {
141            if let Some(pos) = self.jail_buffer.find(stop_seq) {
142                self.stopped = true;
143                let end_pos = pos + stop_seq.len();
144                let output = self.jail_buffer[..end_pos].to_string();
145                self.jail_buffer.clear();
146                return Ok(SequenceDecoderOutput::StoppedWithText(output));
147            }
148        }
149
150        // Check for partial matches: is the end of jail_buffer the start of any stop_seq?
151        // This handles stop sequences split across tokens
152        let buffer_len = self.jail_buffer.len();
153        let mut best_split_pos: Option<usize> = None;
154
155        for stop_seq in self
156            .config
157            .stop_sequences
158            .iter()
159            .chain(&self.config.visible_stop_sequences)
160        {
161            let stop_len = stop_seq.len();
162
163            if stop_len <= 1 || buffer_len == 0 {
164                continue;
165            }
166
167            let max_len = buffer_len.min(stop_len - 1);
168
169            for len in (1..=max_len).rev() {
170                let suffix_start = buffer_len - len;
171
172                if !self.jail_buffer.is_char_boundary(suffix_start) {
173                    continue;
174                }
175
176                let suffix = &self.jail_buffer[suffix_start..];
177
178                if stop_seq.starts_with(suffix)
179                    && best_split_pos.is_none_or(|current| suffix_start < current)
180                {
181                    best_split_pos = Some(suffix_start);
182                    break;
183                }
184            }
185        }
186
187        if let Some(split_pos) = best_split_pos {
188            // Hold the partial match, flush the rest
189            // Use split_off for zero-copy: keeps [0..split_pos] in place, returns [split_pos..]
190            // Then swap so we output the prefix and keep the suffix
191            let suffix = self.jail_buffer.split_off(split_pos);
192            let to_output = std::mem::replace(&mut self.jail_buffer, suffix);
193
194            if to_output.is_empty() {
195                Ok(SequenceDecoderOutput::Held)
196            } else {
197                Ok(SequenceDecoderOutput::Text(to_output))
198            }
199        } else {
200            // No partial matches - flush everything
201            let output = std::mem::take(&mut self.jail_buffer);
202            if output.is_empty() {
203                Ok(SequenceDecoderOutput::Held)
204            } else {
205                Ok(SequenceDecoderOutput::Text(output))
206            }
207        }
208    }
209
210    /// Process multiple tokens
211    pub fn process_tokens(
212        &mut self,
213        token_ids: &[TokenIdType],
214    ) -> Result<Vec<SequenceDecoderOutput>> {
215        // Pre-allocate with exact capacity to avoid reallocations
216        let mut outputs = Vec::with_capacity(token_ids.len());
217        for &token_id in token_ids {
218            outputs.push(self.process_token(token_id)?);
219        }
220        Ok(outputs)
221    }
222
223    /// Flush any held text
224    pub fn flush(&mut self) -> SequenceDecoderOutput {
225        if !self.jail_buffer.is_empty() {
226            // Use mem::take to avoid clone - transfers ownership and leaves empty string
227            SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
228        } else {
229            SequenceDecoderOutput::Text(String::new())
230        }
231    }
232
233    /// Check if decoding has stopped
234    pub fn is_stopped(&self) -> bool {
235        self.stopped
236    }
237
238    /// Reset the decoder state
239    pub fn reset(&mut self) {
240        self.jail_buffer.clear();
241        self.sequence.clear();
242        self.stopped = false;
243    }
244}
245
246/// Builder for StopSequenceDecoder
247pub struct StopSequenceDecoderBuilder {
248    tokenizer: Arc<dyn traits::Tokenizer>,
249    config: StopSequenceConfig,
250    skip_special_tokens: bool,
251}
252
253impl StopSequenceDecoderBuilder {
254    pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
255        StopSequenceDecoderBuilder {
256            tokenizer,
257            config: StopSequenceConfig::default(),
258            skip_special_tokens: true,
259        }
260    }
261
262    pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
263        self.config.stop_tokens.insert(token_id);
264        self
265    }
266
267    pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
268        self.config.stop_sequences.push(sequence.into());
269        self
270    }
271
272    pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
273        self.config.visible_stop_tokens.insert(token_id);
274        self
275    }
276
277    pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
278        self.config.visible_stop_sequences.push(sequence.into());
279        self
280    }
281
282    pub fn skip_special_tokens(mut self, skip: bool) -> Self {
283        self.skip_special_tokens = skip;
284        self
285    }
286
287    pub fn build(self) -> StopSequenceDecoder {
288        StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use std::sync::Arc;
295
296    use super::StopSequenceDecoderBuilder;
297    use crate::{
298        mock::MockTokenizer, SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder,
299    };
300
301    #[test]
302    fn test_stop_token_detection() {
303        let tokenizer = Arc::new(MockTokenizer::new());
304        let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
305
306        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
307
308        // Process tokens before stop
309        let result = decoder.process_token(1).unwrap(); // "Hello"
310        assert!(matches!(result, SequenceDecoderOutput::Text(_)));
311
312        // Process stop token
313        let result = decoder.process_token(999).unwrap(); // <eos>
314        assert_eq!(result, SequenceDecoderOutput::Stopped);
315
316        // Further tokens should also return Stopped
317        let result = decoder.process_token(2).unwrap();
318        assert_eq!(result, SequenceDecoderOutput::Stopped);
319    }
320
321    #[test]
322    fn test_visible_stop_token() {
323        let tokenizer = Arc::new(MockTokenizer::new());
324        let config = StopSequenceConfig::default().with_visible_stop_token(999);
325
326        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
327
328        let result = decoder.process_token(999).unwrap();
329        assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
330    }
331
332    #[test]
333    fn test_builder_pattern() {
334        let tokenizer = Arc::new(MockTokenizer::new());
335
336        let decoder = StopSequenceDecoderBuilder::new(tokenizer)
337            .stop_token(999)
338            .stop_sequence("STOP")
339            .visible_stop_token(1000)
340            .skip_special_tokens(true)
341            .build();
342
343        assert!(!decoder.is_stopped());
344    }
345
346    #[test]
347    fn test_incremental_decoding_no_repetition() {
348        // This test verifies the critical fix: no repeated output
349        let tokenizer = Arc::new(MockTokenizer::new());
350        let config = StopSequenceConfig::default();
351        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
352
353        // Process tokens one by one and collect outputs
354        let mut outputs = Vec::new();
355
356        // Token 1: "Hello"
357        let result = decoder.process_token(1).unwrap();
358        if let SequenceDecoderOutput::Text(text) = result {
359            outputs.push(text.clone());
360        }
361
362        // Token 2: "world"
363        let result = decoder.process_token(2).unwrap();
364        if let SequenceDecoderOutput::Text(text) = result {
365            outputs.push(text.clone());
366        }
367
368        // Token 3: "test"
369        let result = decoder.process_token(3).unwrap();
370        if let SequenceDecoderOutput::Text(text) = result {
371            outputs.push(text.clone());
372        }
373
374        // CRITICAL: Each output should be unique (no accumulation)
375        // The fix ensures we only output NEW text, not accumulated text
376        assert_eq!(outputs.len(), 3);
377
378        for i in 0..outputs.len() {
379            for j in i + 1..outputs.len() {
380                // No output should contain another (no accumulation)
381                assert!(!outputs[j].contains(&outputs[i]));
382            }
383        }
384    }
385
386    #[test]
387    fn test_stop_sequence_detection() {
388        let tokenizer = Arc::new(MockTokenizer::new());
389        let config = StopSequenceConfig::default().with_stop_sequence("test");
390        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
391
392        // Process "Hello world"
393        decoder.process_token(1).unwrap(); // "Hello"
394        decoder.process_token(2).unwrap(); // "world"
395
396        // Process "test" which should trigger stop
397        let result = decoder.process_token(3).unwrap(); // "test"
398
399        // Should stop when we hit "test"
400        assert!(matches!(
401            result,
402            SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
403        ));
404    }
405
406    #[test]
407    fn test_flush_after_partial() {
408        let tokenizer = Arc::new(MockTokenizer::new());
409        let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
410        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
411
412        // Process a token
413        decoder.process_token(1).unwrap(); // "Hello"
414
415        // Flush should return any remaining text in jail
416        let result = decoder.flush();
417
418        // After processing, flush should work
419        assert!(matches!(result, SequenceDecoderOutput::Text(_)));
420    }
421
422    #[test]
423    fn test_reset_functionality() {
424        let tokenizer = Arc::new(MockTokenizer::new());
425        let config = StopSequenceConfig::default().with_stop_token(999);
426        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
427
428        // Process and stop
429        decoder.process_token(1).unwrap();
430        decoder.process_token(999).unwrap();
431        assert!(decoder.is_stopped());
432
433        // Reset should clear everything
434        decoder.reset();
435        assert!(!decoder.is_stopped());
436
437        // Should be able to process again
438        let result = decoder.process_token(2).unwrap();
439        assert!(matches!(result, SequenceDecoderOutput::Text(_)));
440    }
441
442    #[test]
443    fn test_visible_stop_sequence() {
444        let tokenizer = Arc::new(MockTokenizer::new());
445        let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
446        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
447
448        // Process "Hello"
449        decoder.process_token(1).unwrap();
450
451        // Process "world" - should include it in output
452        let result = decoder.process_token(2).unwrap();
453
454        if let SequenceDecoderOutput::StoppedWithText(text) = result {
455            // Should include "world" in the output
456            assert!(text.contains("world"));
457        } else {
458            panic!("Expected StoppedWithText with visible stop sequence");
459        }
460    }
461
462    #[test]
463    fn test_multiple_tokens_processing() {
464        let tokenizer = Arc::new(MockTokenizer::new());
465        let config = StopSequenceConfig::default();
466        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
467
468        // Process multiple tokens at once
469        let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
470
471        // Should get results for each token
472        assert_eq!(results.len(), 3);
473
474        // Each result should be Text (no stops configured)
475        for result in results {
476            assert!(matches!(
477                result,
478                SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
479            ));
480        }
481    }
482
483    #[test]
484    fn test_utf8_multibyte_character_boundaries() {
485        // This test verifies the fix for the UTF-8 boundary panic
486        // The panic occurred when trying to slice jail_buffer at a byte index
487        // that was in the middle of a multi-byte UTF-8 character (e.g., '×')
488        use crate::mock::MockTokenizer;
489
490        let tokenizer = Arc::new(MockTokenizer::new());
491
492        // Configure stop sequence with a multi-byte character
493        let config = StopSequenceConfig::default().with_stop_sequence(" ×");
494
495        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
496
497        // Simulate the scenario: jail_buffer will contain " ×" (space + multiplication sign)
498        // The '×' character is UTF-8 encoded as bytes [0xC3, 0x97] (2 bytes)
499        // When checking for partial matches, we must not slice in the middle of these bytes
500
501        // This should not panic - the fix ensures we only slice at char boundaries
502        let result = decoder.process_token(1); // Will add some text to jail_buffer
503        assert!(result.is_ok());
504
505        // Even with multi-byte UTF-8 characters in the buffer, processing should work
506        let result = decoder.process_token(2);
507        assert!(result.is_ok());
508    }
509
510    #[test]
511    fn test_utf8_multibyte_delta_character() {
512        // Test for: byte index 1 is not a char boundary; it is inside 'Δ' (bytes 0..2) of `Δ`
513        // 'Δ' (U+0394 GREEK CAPITAL LETTER DELTA) is encoded as [0xCE, 0x94] (2 bytes)
514        let tokenizer = Arc::new(MockTokenizer::new());
515        let config = StopSequenceConfig::default().with_stop_sequence("Δ");
516
517        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
518
519        // Process tokens - should not panic when checking partial matches
520        let result = decoder.process_token(1);
521        assert!(result.is_ok());
522        let result = decoder.process_token(2);
523        assert!(result.is_ok());
524    }
525
526    #[test]
527    fn test_utf8_multibyte_degree_character() {
528        // Test for: byte index 1 is not a char boundary; it is inside '°' (bytes 0..2) of `°`
529        // '°' (U+00B0 DEGREE SIGN) is encoded as [0xC2, 0xB0] (2 bytes)
530        let tokenizer = Arc::new(MockTokenizer::new());
531        let config = StopSequenceConfig::default().with_stop_sequence("°");
532
533        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
534
535        // Process tokens - should not panic when checking partial matches
536        let result = decoder.process_token(1);
537        assert!(result.is_ok());
538        let result = decoder.process_token(2);
539        assert!(result.is_ok());
540    }
541
542    #[test]
543    fn test_utf8_multibyte_triangle_character() {
544        // Test for: byte index 4 is not a char boundary; it is inside '∆' (bytes 2..5) of ` (∆`
545        // '∆' (U+2206 INCREMENT) is encoded as [0xE2, 0x88, 0x86] (3 bytes)
546        let tokenizer = Arc::new(MockTokenizer::new());
547        let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
548
549        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
550
551        // Process tokens - should not panic when checking partial matches
552        let result = decoder.process_token(1);
553        assert!(result.is_ok());
554        let result = decoder.process_token(2);
555        assert!(result.is_ok());
556        let result = decoder.process_token(3);
557        assert!(result.is_ok());
558    }
559
560    #[test]
561    fn test_utf8_multibyte_en_dash_character() {
562        // Test for: byte index 3 is not a char boundary; it is inside '–' (bytes 1..4) of ` –`
563        // '–' (U+2013 EN DASH) is encoded as [0xE2, 0x80, 0x93] (3 bytes)
564        let tokenizer = Arc::new(MockTokenizer::new());
565        let config = StopSequenceConfig::default().with_stop_sequence(" –");
566
567        let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
568
569        // Process tokens - should not panic when checking partial matches
570        let result = decoder.process_token(1);
571        assert!(result.is_ok());
572        let result = decoder.process_token(2);
573        assert!(result.is_ok());
574        let result = decoder.process_token(3);
575        assert!(result.is_ok());
576    }
577
578    #[test]
579    fn test_utf8_multibyte_various_characters() {
580        // Comprehensive test with multiple multi-byte UTF-8 characters
581        // Tests 2-byte, 3-byte, and 4-byte UTF-8 sequences
582        let test_cases = vec![
583            ("×", "multiplication sign - 2 bytes"),
584            ("Δ", "Greek Delta - 2 bytes"),
585            ("°", "degree sign - 2 bytes"),
586            ("∆", "increment - 3 bytes"),
587            ("–", "en dash - 3 bytes"),
588            ("€", "euro sign - 3 bytes"),
589            ("中", "Chinese character - 3 bytes"),
590            ("🚀", "rocket emoji - 4 bytes"),
591            ("💡", "lightbulb emoji - 4 bytes"),
592        ];
593
594        for (stop_char, description) in test_cases {
595            let tokenizer = Arc::new(MockTokenizer::new());
596            let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
597
598            let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
599
600            // Process multiple tokens - should not panic
601            for token_id in 1..=5 {
602                let result = decoder.process_token(token_id);
603                assert!(
604                    result.is_ok(),
605                    "Failed on {} with token {}",
606                    description,
607                    token_id
608                );
609            }
610        }
611    }
612}