pub(crate) const CHAT_STOP_PATTERNS: &[&str] = &[
"<|im_end|>",
"<|im_start|>",
"<|endoftext|>",
"</s>",
"<end_of_turn>",
];
pub(crate) const CHAT_STOP_PATTERNS_BROKEN: &[&str] =
&["|im_end|>", "|im_start|>", "|endoftext|>", "end_of_turn>"];
pub(crate) fn merge_stop_patterns<S: AsRef<str>>(user: &[S], extras: &[&str]) -> Vec<String> {
let mut out: Vec<String> = user.iter().map(|s| s.as_ref().to_string()).collect();
for e in extras {
if !out.iter().any(|s| s == e) {
out.push((*e).to_string());
}
}
out
}
pub(crate) fn strip_thinking_tags(text: &str) -> String {
let mut result = text.to_string();
while let Some(start) = result.find("<think>") {
if let Some(end) = result[start..].find("</think>") {
let end_absolute = start + end + "</think>".len();
result.replace_range(start..end_absolute, "");
} else {
result.truncate(start);
break;
}
}
result
}
pub(crate) fn truncate_at_first_stop<S: AsRef<str>>(text: &mut String, patterns: &[S]) -> bool {
let mut earliest: Option<usize> = None;
for p in patterns {
if let Some(pos) = text.find(p.as_ref()) {
earliest = Some(match earliest {
None => pos,
Some(cur) => cur.min(pos),
});
}
}
if let Some(pos) = earliest {
text.truncate(pos);
true
} else {
false
}
}
pub(crate) fn trim_partial_stop_suffix<S: AsRef<str>>(text: &mut String, patterns: &[S]) -> bool {
for pattern in patterns {
let p = pattern.as_ref();
for prefix_len in 1..p.len() {
let prefix = &p[..prefix_len];
if text.ends_with(prefix) {
text.truncate(text.len() - prefix_len);
return true;
}
}
}
false
}
pub(crate) struct StreamingTextFilter {
stop_patterns: Vec<String>,
cumulative_text: String,
last_emitted_len: usize,
inside_think_block: bool,
hit_stop_pattern: bool,
}
impl StreamingTextFilter {
pub fn new(stop_patterns: Vec<String>) -> Self {
Self {
stop_patterns,
cumulative_text: String::new(),
last_emitted_len: 0,
inside_think_block: false,
hit_stop_pattern: false,
}
}
pub fn is_stopped(&self) -> bool {
self.hit_stop_pattern
}
pub fn cumulative_emitted(&self) -> &str {
&self.cumulative_text[..self.last_emitted_len]
}
pub fn push(&mut self, chunk: &str) -> Option<String> {
if self.hit_stop_pattern {
return None;
}
self.cumulative_text.push_str(chunk);
if !self.inside_think_block
&& self.cumulative_text[self.last_emitted_len..].contains("<think>")
{
self.inside_think_block = true;
if let Some(pos) = self.cumulative_text.find("<think>") {
self.last_emitted_len = pos;
}
}
if self.inside_think_block {
if self.cumulative_text.contains("</think>") {
self.inside_think_block = false;
self.cumulative_text = strip_thinking_tags(&self.cumulative_text);
self.last_emitted_len = self.last_emitted_len.min(self.cumulative_text.len());
}
return None;
}
for pattern in &self.stop_patterns {
if self.cumulative_text.contains(pattern.as_str()) {
self.hit_stop_pattern = true;
if let Some(pos) = self.cumulative_text.find(pattern.as_str()) {
self.cumulative_text.truncate(pos);
}
return None;
}
}
let safe_end = find_potential_stop_start(&self.cumulative_text, &self.stop_patterns)
.unwrap_or(self.cumulative_text.len());
if safe_end > self.last_emitted_len {
let safe = self.cumulative_text[self.last_emitted_len..safe_end].to_string();
self.last_emitted_len = safe_end;
Some(safe)
} else {
None
}
}
}
fn find_potential_stop_start(text: &str, patterns: &[String]) -> Option<usize> {
for pattern in patterns {
for prefix_len in 1..=pattern.len() {
let prefix = &pattern[..prefix_len];
if text.ends_with(prefix) {
return Some(text.len() - prefix_len);
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_thinking_tags_removes_closed_blocks() {
assert_eq!(
strip_thinking_tags("before<think>hidden</think>after"),
"beforeafter"
);
}
#[test]
fn strip_thinking_tags_removes_multiple_blocks() {
assert_eq!(
strip_thinking_tags("a<think>x</think>b<think>y</think>c"),
"abc"
);
}
#[test]
fn strip_thinking_tags_truncates_unclosed_block() {
assert_eq!(
strip_thinking_tags("visible<think>still reasoning"),
"visible"
);
}
#[test]
fn strip_thinking_tags_passthrough_no_tags() {
assert_eq!(strip_thinking_tags("nothing to see"), "nothing to see");
}
#[test]
fn merge_stop_patterns_deduplicates() {
let user = ["<|im_end|>".to_string(), "CUSTOM".to_string()];
let got = merge_stop_patterns(&user, CHAT_STOP_PATTERNS);
assert_eq!(got[0], "<|im_end|>");
assert_eq!(got[1], "CUSTOM");
assert!(got.contains(&"<end_of_turn>".to_string()));
assert_eq!(got.iter().filter(|s| *s == "<|im_end|>").count(), 1);
}
#[test]
fn truncate_at_first_stop_picks_earliest() {
let mut text = String::from("hello <end_of_turn> world <|im_end|>");
let patterns = ["<|im_end|>", "<end_of_turn>"];
assert!(truncate_at_first_stop(&mut text, &patterns));
assert_eq!(text, "hello ");
}
#[test]
fn truncate_at_first_stop_no_match() {
let mut text = String::from("no stops here");
assert!(!truncate_at_first_stop(&mut text, &["<|im_end|>"]));
assert_eq!(text, "no stops here");
}
#[test]
fn trim_partial_stop_suffix_removes_partial_prefix() {
let mut text = String::from("response tail <|im_");
let patterns = ["<|im_end|>"];
assert!(trim_partial_stop_suffix(&mut text, &patterns));
assert_eq!(text, "response tail ");
}
#[test]
fn trim_partial_stop_suffix_ignores_clean_end() {
let mut text = String::from("clean response");
assert!(!trim_partial_stop_suffix(&mut text, &["<|im_end|>"]));
assert_eq!(text, "clean response");
}
#[test]
fn streaming_filter_emits_safe_chunks() {
let mut f = StreamingTextFilter::new(vec!["<|im_end|>".to_string()]);
assert_eq!(f.push("Hello "), Some("Hello ".to_string()));
assert_eq!(f.push("world"), Some("world".to_string()));
assert_eq!(f.cumulative_emitted(), "Hello world");
assert!(!f.is_stopped());
}
#[test]
fn streaming_filter_holds_back_partial_stop_prefix() {
let mut f = StreamingTextFilter::new(vec!["<|im_end|>".to_string()]);
assert_eq!(f.push("Hello "), Some("Hello ".to_string()));
assert_eq!(f.push("<|im_"), None);
assert_eq!(f.push("portant!"), Some("<|im_portant!".to_string()));
}
#[test]
fn streaming_filter_stops_on_complete_pattern() {
let mut f = StreamingTextFilter::new(vec!["<|im_end|>".to_string()]);
f.push("hello ");
f.push("<|im_end|>");
assert!(f.is_stopped());
assert_eq!(f.push(" ignored"), None);
}
#[test]
fn streaming_filter_suppresses_think_block() {
let mut f = StreamingTextFilter::new(vec![]);
assert_eq!(f.push("visible "), Some("visible ".to_string()));
assert_eq!(f.push("<think>"), None);
assert_eq!(f.push("reasoning"), None);
assert_eq!(f.push("</think>"), None);
assert_eq!(f.push("answer"), Some("answer".to_string()));
}
#[test]
fn streaming_filter_unclosed_think_stays_suppressed() {
let mut f = StreamingTextFilter::new(vec![]);
assert_eq!(f.push("<think>"), None);
assert_eq!(f.push("still reasoning"), None);
assert_eq!(f.push(" forever"), None);
assert!(!f.is_stopped());
assert_eq!(f.cumulative_emitted(), "");
}
#[test]
fn streaming_filter_think_block_swallows_preceding_unemitted_text() {
let mut f = StreamingTextFilter::new(vec![]);
assert_eq!(f.push("scratch<think>"), None);
assert_eq!(f.push("hidden</think>"), None);
assert_eq!(f.push("final"), Some("final".to_string()));
assert_eq!(f.cumulative_emitted(), "scratchfinal");
}
#[test]
fn streaming_filter_utf8_text_does_not_panic_on_ascii_stop_patterns() {
let mut f = StreamingTextFilter::new(vec!["<|im_end|>".to_string()]);
let _ = f.push("héllo ");
let _ = f.push("wörld");
let _ = f.push("<|im_");
let _ = f.push("end|>");
assert!(f.is_stopped());
}
}