use std::{collections::HashSet, sync::Arc};
use anyhow::Result;
use crate::{
sequence::Sequence,
traits::{self, TokenIdType},
};
#[derive(Debug, Clone, PartialEq)]
pub enum SequenceDecoderOutput {
Text(String),
Held,
Stopped,
StoppedWithText(String),
}
#[derive(Debug, Clone, Default)]
pub struct StopSequenceConfig {
pub stop_tokens: HashSet<TokenIdType>,
pub stop_sequences: Vec<String>,
pub visible_stop_tokens: HashSet<TokenIdType>,
pub visible_stop_sequences: Vec<String>,
}
impl StopSequenceConfig {
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
self.stop_tokens.insert(token_id);
self
}
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.stop_sequences.push(sequence.into());
self
}
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.visible_stop_tokens.insert(token_id);
self
}
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.visible_stop_sequences.push(sequence.into());
self
}
}
pub struct StopSequenceDecoder {
sequence: Sequence,
config: StopSequenceConfig,
jail_buffer: String,
stopped: bool,
}
impl StopSequenceDecoder {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
) -> Self {
StopSequenceDecoder {
sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
config,
jail_buffer: String::new(),
stopped: false,
}
}
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
if self.stopped {
return Ok(SequenceDecoderOutput::Stopped);
}
if self.config.stop_tokens.contains(&token_id) {
self.stopped = true;
if !self.jail_buffer.is_empty() {
return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
&mut self.jail_buffer,
)));
}
return Ok(SequenceDecoderOutput::Stopped);
}
if self.config.visible_stop_tokens.contains(&token_id) {
self.stopped = true;
let stop_text = self
.sequence
.tokenizer()
.decode(&[token_id], self.sequence.skip_special_tokens())?;
let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
let new_text = self.sequence.append_token(token_id)?;
self.jail_buffer.push_str(&new_text);
for stop_seq in &self.config.stop_sequences {
if let Some(pos) = self.jail_buffer.find(stop_seq) {
self.stopped = true;
let output = self.jail_buffer[..pos].to_string();
self.jail_buffer.clear();
return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped
} else {
SequenceDecoderOutput::StoppedWithText(output)
});
}
}
for stop_seq in &self.config.visible_stop_sequences {
if let Some(pos) = self.jail_buffer.find(stop_seq) {
self.stopped = true;
let end_pos = pos + stop_seq.len();
let output = self.jail_buffer[..end_pos].to_string();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
}
let buffer_len = self.jail_buffer.len();
let mut best_split_pos: Option<usize> = None;
for stop_seq in self
.config
.stop_sequences
.iter()
.chain(&self.config.visible_stop_sequences)
{
let stop_len = stop_seq.len();
if stop_len <= 1 || buffer_len == 0 {
continue;
}
let max_len = buffer_len.min(stop_len - 1);
for len in (1..=max_len).rev() {
let suffix_start = buffer_len - len;
if !self.jail_buffer.is_char_boundary(suffix_start) {
continue;
}
let suffix = &self.jail_buffer[suffix_start..];
if stop_seq.starts_with(suffix)
&& best_split_pos.is_none_or(|current| suffix_start < current)
{
best_split_pos = Some(suffix_start);
break;
}
}
}
if let Some(split_pos) = best_split_pos {
let suffix = self.jail_buffer.split_off(split_pos);
let to_output = std::mem::replace(&mut self.jail_buffer, suffix);
if to_output.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(to_output))
}
} else {
let output = std::mem::take(&mut self.jail_buffer);
if output.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(output))
}
}
}
pub fn process_tokens(
&mut self,
token_ids: &[TokenIdType],
) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::with_capacity(token_ids.len());
for &token_id in token_ids {
outputs.push(self.process_token(token_id)?);
}
Ok(outputs)
}
pub fn flush(&mut self) -> SequenceDecoderOutput {
if !self.jail_buffer.is_empty() {
SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
} else {
SequenceDecoderOutput::Text(String::new())
}
}
pub fn is_stopped(&self) -> bool {
self.stopped
}
pub fn reset(&mut self) {
self.jail_buffer.clear();
self.sequence.clear();
self.stopped = false;
}
}
pub struct StopSequenceDecoderBuilder {
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
}
impl StopSequenceDecoderBuilder {
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
StopSequenceDecoderBuilder {
tokenizer,
config: StopSequenceConfig::default(),
skip_special_tokens: true,
}
}
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.stop_tokens.insert(token_id);
self
}
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.stop_sequences.push(sequence.into());
self
}
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.visible_stop_tokens.insert(token_id);
self
}
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.visible_stop_sequences.push(sequence.into());
self
}
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
self.skip_special_tokens = skip;
self
}
pub fn build(self) -> StopSequenceDecoder {
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::StopSequenceDecoderBuilder;
use crate::{
mock::MockTokenizer, SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder,
};
#[test]
fn test_stop_token_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1).unwrap(); assert!(matches!(result, SequenceDecoderOutput::Text(_)));
let result = decoder.process_token(999).unwrap(); assert_eq!(result, SequenceDecoderOutput::Stopped);
let result = decoder.process_token(2).unwrap();
assert_eq!(result, SequenceDecoderOutput::Stopped);
}
#[test]
fn test_visible_stop_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(999).unwrap();
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
}
#[test]
fn test_builder_pattern() {
let tokenizer = Arc::new(MockTokenizer::new());
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
.stop_token(999)
.stop_sequence("STOP")
.visible_stop_token(1000)
.skip_special_tokens(true)
.build();
assert!(!decoder.is_stopped());
}
#[test]
fn test_incremental_decoding_no_repetition() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let mut outputs = Vec::new();
let result = decoder.process_token(1).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
let result = decoder.process_token(3).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
assert_eq!(outputs.len(), 3);
for i in 0..outputs.len() {
for j in i + 1..outputs.len() {
assert!(!outputs[j].contains(&outputs[i]));
}
}
}
#[test]
fn test_stop_sequence_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("test");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap(); decoder.process_token(2).unwrap();
let result = decoder.process_token(3).unwrap();
assert!(matches!(
result,
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
));
}
#[test]
fn test_flush_after_partial() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap();
let result = decoder.flush();
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_reset_functionality() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap();
decoder.process_token(999).unwrap();
assert!(decoder.is_stopped());
decoder.reset();
assert!(!decoder.is_stopped());
let result = decoder.process_token(2).unwrap();
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_visible_stop_sequence() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
decoder.process_token(1).unwrap();
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::StoppedWithText(text) = result {
assert!(text.contains("world"));
} else {
panic!("Expected StoppedWithText with visible stop sequence");
}
}
#[test]
fn test_multiple_tokens_processing() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
assert_eq!(results.len(), 3);
for result in results {
assert!(matches!(
result,
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
));
}
}
#[test]
fn test_utf8_multibyte_character_boundaries() {
use crate::mock::MockTokenizer;
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(" ×");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1); assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_delta_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("Δ");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_degree_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("°");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_triangle_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
let result = decoder.process_token(3);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_en_dash_character() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(" –");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(1);
assert!(result.is_ok());
let result = decoder.process_token(2);
assert!(result.is_ok());
let result = decoder.process_token(3);
assert!(result.is_ok());
}
#[test]
fn test_utf8_multibyte_various_characters() {
let test_cases = vec![
("×", "multiplication sign - 2 bytes"),
("Δ", "Greek Delta - 2 bytes"),
("°", "degree sign - 2 bytes"),
("∆", "increment - 3 bytes"),
("–", "en dash - 3 bytes"),
("€", "euro sign - 3 bytes"),
("中", "Chinese character - 3 bytes"),
("🚀", "rocket emoji - 4 bytes"),
("💡", "lightbulb emoji - 4 bytes"),
];
for (stop_char, description) in test_cases {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
for token_id in 1..=5 {
let result = decoder.process_token(token_id);
assert!(
result.is_ok(),
"Failed on {} with token {}",
description,
token_id
);
}
}
}
}