use std::{collections::HashSet, sync::Arc};
use aho_corasick::AhoCorasick;
use anyhow::Result;
use crate::{
sequence::Sequence,
traits::{self, TokenIdType},
};
#[derive(Debug, Clone, PartialEq)]
pub enum SequenceDecoderOutput {
Text(String),
Held,
Stopped,
StoppedWithText(String),
}
#[derive(Debug, Clone, Default)]
pub struct StopSequenceConfig {
pub stop_tokens: HashSet<TokenIdType>,
pub stop_sequences: Vec<String>,
pub visible_stop_tokens: HashSet<TokenIdType>,
pub visible_stop_sequences: Vec<String>,
}
impl StopSequenceConfig {
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
self.stop_tokens.insert(token_id);
self
}
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.stop_sequences.push(sequence.into());
self
}
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.visible_stop_tokens.insert(token_id);
self
}
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.visible_stop_sequences.push(sequence.into());
self
}
}
pub struct StopSequenceDecoder {
sequence: Sequence,
config: StopSequenceConfig,
aho_corasick: Option<AhoCorasick>,
visible_boundary_idx: usize,
jail_buffer: String,
stopped: bool,
}
impl StopSequenceDecoder {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
) -> Self {
let mut patterns: Vec<&str> = config
.stop_sequences
.iter()
.filter(|s| !s.is_empty())
.map(|s| s.as_str())
.collect();
let visible_boundary_idx = patterns.len();
patterns.extend(
config
.visible_stop_sequences
.iter()
.filter(|s| !s.is_empty())
.map(|s| s.as_str()),
);
let aho_corasick = if patterns.is_empty() {
None
} else {
#[expect(
clippy::expect_used,
reason = "AhoCorasick::new with pre-filtered non-empty &str patterns is practically infallible"
)]
Some(AhoCorasick::new(patterns).expect("Failed to build Aho-Corasick automaton"))
};
StopSequenceDecoder {
sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
config,
aho_corasick,
visible_boundary_idx,
jail_buffer: String::new(),
stopped: false,
}
}
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
if self.stopped {
return Ok(SequenceDecoderOutput::Stopped);
}
if self.config.stop_tokens.contains(&token_id) {
self.stopped = true;
if !self.jail_buffer.is_empty() {
return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
&mut self.jail_buffer,
)));
}
return Ok(SequenceDecoderOutput::Stopped);
}
if self.config.visible_stop_tokens.contains(&token_id) {
self.stopped = true;
let stop_text = self
.sequence
.tokenizer()
.decode(&[token_id], self.sequence.skip_special_tokens())?;
let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
let new_text = self.sequence.append_token(token_id)?;
self.jail_buffer.push_str(&new_text);
if let Some(ac) = &self.aho_corasick {
if let Some(mat) = ac.find(&self.jail_buffer) {
self.stopped = true;
let is_visible = mat.pattern().as_usize() >= self.visible_boundary_idx;
if is_visible {
let output = self.jail_buffer[..mat.end()].to_string();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
} else {
let output = self.jail_buffer[..mat.start()].to_string();
self.jail_buffer.clear();
return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped
} else {
SequenceDecoderOutput::StoppedWithText(output)
});
}
}
}
let buffer_len = self.jail_buffer.len();
let mut best_split_pos: Option<usize> = None;
for stop_seq in self
.config
.stop_sequences
.iter()
.chain(&self.config.visible_stop_sequences)
{
let stop_len = stop_seq.len();
if stop_len <= 1 || buffer_len == 0 {
continue;
}
let max_len = buffer_len.min(stop_len - 1);
for len in (1..=max_len).rev() {
let suffix_start = buffer_len - len;
if !self.jail_buffer.is_char_boundary(suffix_start) {
continue;
}
let suffix = &self.jail_buffer[suffix_start..];
if stop_seq.starts_with(suffix)
&& best_split_pos.is_none_or(|current| suffix_start < current)
{
best_split_pos = Some(suffix_start);
break;
}
}
}
if let Some(split_pos) = best_split_pos {
let suffix = self.jail_buffer.split_off(split_pos);
let to_output = std::mem::replace(&mut self.jail_buffer, suffix);
if to_output.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(to_output))
}
} else {
let output = std::mem::take(&mut self.jail_buffer);
if output.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(output))
}
}
}
pub fn process_tokens(
&mut self,
token_ids: &[TokenIdType],
) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::with_capacity(token_ids.len());
for &token_id in token_ids {
outputs.push(self.process_token(token_id)?);
}
Ok(outputs)
}
pub fn flush(&mut self) -> SequenceDecoderOutput {
if self.jail_buffer.is_empty() {
SequenceDecoderOutput::Text(String::new())
} else {
SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
}
}
pub fn is_stopped(&self) -> bool {
self.stopped
}
pub fn reset(&mut self) {
self.jail_buffer.clear();
self.sequence.clear();
self.stopped = false;
}
}
pub struct StopSequenceDecoderBuilder {
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
}
impl StopSequenceDecoderBuilder {
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
StopSequenceDecoderBuilder {
tokenizer,
config: StopSequenceConfig::default(),
skip_special_tokens: true,
}
}
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.stop_tokens.insert(token_id);
self
}
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.stop_sequences.push(sequence.into());
self
}
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.visible_stop_tokens.insert(token_id);
self
}
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.visible_stop_sequences.push(sequence.into());
self
}
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
self.skip_special_tokens = skip;
self
}
pub fn build(self) -> StopSequenceDecoder {
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::StopSequenceDecoderBuilder;
use crate::{
mock::MockTokenizer, SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder,
};
#[test]
fn test_stop_token_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1).unwrap(); assert!(matches!(result, SequenceDecoderOutput::Text(_)));
let result = decoder.process_token(999).unwrap(); assert_eq!(result, SequenceDecoderOutput::Stopped);
let result = decoder.process_token(2).unwrap();
assert_eq!(result, SequenceDecoderOutput::Stopped);
}
#[test]
fn test_visible_stop_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(999).unwrap();
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
}
#[test]
fn test_builder_pattern() {
let tokenizer = Arc::new(MockTokenizer::new());
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
.stop_token(999)
.stop_sequence("STOP")
.visible_stop_token(1000)
.skip_special_tokens(true)
.build();
assert!(!decoder.is_stopped());
}
#[test]
fn test_incremental_decoding_no_repetition() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let mut outputs = Vec::new();
let result = decoder.process_token(1).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
let result = decoder.process_token(3).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
assert_eq!(outputs.len(), 3);
for i in 0..outputs.len() {
for j in i + 1..outputs.len() {
assert!(!outputs[j].contains(&outputs[i]));
}
}
}
#[test]
fn test_stop_sequence_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("test");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap(); decoder.process_token(2).unwrap();
let result = decoder.process_token(3).unwrap();
assert!(matches!(
result,
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
));
}
#[test]
fn test_flush_after_partial() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap();
let result = decoder.flush();
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_reset_functionality() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap();
decoder.process_token(999).unwrap();
assert!(decoder.is_stopped());
decoder.reset();
assert!(!decoder.is_stopped());
let result = decoder.process_token(2).unwrap();
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_visible_stop_sequence() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap();
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::StoppedWithText(text) = result {
assert!(text.contains("world"));
} else {
panic!("Expected StoppedWithText with visible stop sequence");
}
}
#[test]
fn test_multiple_tokens_processing() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
assert_eq!(results.len(), 3);
for result in results {
assert!(matches!(
result,
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
));
}
}
#[test]
fn test_stop_sequence_spanning_multiple_tokens() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result1 = decoder.process_token(1).unwrap();
assert!(
matches!(result1, SequenceDecoderOutput::Held),
"Expected Held while jail buffer is a prefix of the stop sequence, got {result1:?}"
);
assert!(
!decoder.is_stopped(),
"Decoder should not be stopped after a partial match"
);
let result2 = decoder.process_token(2).unwrap();
assert_eq!(
result2,
SequenceDecoderOutput::Stopped,
"Expected Stopped when jail buffer matches the hidden stop sequence"
);
assert!(
decoder.is_stopped(),
"Decoder should be stopped after the full stop sequence match"
);
let result3 = decoder.process_token(3).unwrap();
assert_eq!(result3, SequenceDecoderOutput::Stopped);
}
#[test]
fn test_visible_stop_sequence_spanning_multiple_tokens() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_sequence("Hello world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result1 = decoder.process_token(1).unwrap();
assert!(
matches!(result1, SequenceDecoderOutput::Held),
"Expected Held for partial visible stop sequence match, got {result1:?}"
);
let result2 = decoder.process_token(2).unwrap();
match &result2 {
SequenceDecoderOutput::StoppedWithText(text) => {
assert!(
text.contains("Hello world"),
"Visible stop output should contain the full stop sequence, got: {text:?}"
);
}
other => panic!("Expected StoppedWithText for visible stop sequence, got {other:?}"),
}
assert!(decoder.is_stopped());
}
#[test]
fn test_stop_sequence_spanning_tokens_with_preceding_text() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result1 = decoder.process_token(3).unwrap();
assert!(
matches!(result1, SequenceDecoderOutput::Text(_)),
"Expected Text for token with no stop sequence overlap, got {result1:?}"
);
let result2 = decoder.process_token(1).unwrap();
match &result2 {
SequenceDecoderOutput::Text(text) => {
assert!(
!text.contains("Hello"),
"Partially-matched 'Hello' should be jailed, not emitted. Got: {text:?}"
);
}
SequenceDecoderOutput::Held => {
}
other => panic!("Expected Text (prefix before partial match) or Held, got {other:?}"),
}
let result3 = decoder.process_token(2).unwrap();
assert!(
matches!(
result3,
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
),
"Expected Stopped or StoppedWithText when stop sequence completes, got {result3:?}"
);
assert!(decoder.is_stopped());
}
#[test]
fn test_utf8_multibyte_character_boundaries() {
use crate::mock::MockTokenizer;
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(" ×");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1); assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_delta_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("Δ");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_degree_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("°");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_triangle_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
let result = decoder.process_token(3);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_en_dash_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(" –");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
let result = decoder.process_token(3);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_various_characters() {
let test_cases = vec![
("×", "multiplication sign - 2 bytes"),
("Δ", "Greek Delta - 2 bytes"),
("°", "degree sign - 2 bytes"),
("∆", "increment - 3 bytes"),
("–", "en dash - 3 bytes"),
("€", "euro sign - 3 bytes"),
("中", "Chinese character - 3 bytes"),
("🚀", "rocket emoji - 4 bytes"),
("💡", "lightbulb emoji - 4 bytes"),
];
for (stop_char, description) in test_cases {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
for token_id in 1..=5 {
let result = decoder.process_token(token_id);
assert!(
result.is_ok(),
"Failed on {description} with token {token_id}"
);
}
}
}
}