Skip to main content

atomcode_tuix/
think.rs

1// crates/atomcode-tuix/src/think.rs
2pub const THINK_BUF_MAX: usize = 64 * 1024; // 64KB cap to prevent OOM
3
4/// Streaming stripper for `<think>...</think>` and `<thinking>...</thinking>`
5/// blocks. Case-insensitive, attribute-tolerant, safe across feed boundaries
6/// even when a tag is split mid-UTF-8 sequence.
7pub struct ThinkStripper {
8    /// Holds text that MIGHT be the start of a tag but we can't yet tell.
9    /// Only contains bytes at char boundaries.
10    carry: String,
11    /// When true, we're inside a <think*> block and waiting for </>.
12    inside: bool,
13}
14
15impl Default for ThinkStripper {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl ThinkStripper {
22    pub fn new() -> Self {
23        Self {
24            carry: String::new(),
25            inside: false,
26        }
27    }
28
29    pub fn buffered_bytes(&self) -> usize {
30        self.carry.len()
31    }
32
33    /// Reset to the pristine state. Call between turns — otherwise an
34    /// unclosed `<think>` from a previous turn (model got cancelled, got
35    /// an error mid-stream, switched provider, etc.) leaves `inside=true`
36    /// and silently swallows every TextDelta of the next turn. Symptom:
37    /// user sees blank assistant bubbles even though the provider returned
38    /// normal text.
39    pub fn reset(&mut self) {
40        self.carry.clear();
41        self.inside = false;
42    }
43
44    /// Feed a chunk, return the visible portion (outside of think blocks).
45    pub fn feed(&mut self, delta: &str) -> String {
46        // Enforce cap: if carry + delta would exceed THINK_BUF_MAX,
47        // we flush the carry as-is (giving up on partial-tag detection)
48        // so memory stays bounded.
49        if self.carry.len() + delta.len() > THINK_BUF_MAX {
50            let mut flushed = std::mem::take(&mut self.carry);
51            flushed.push_str(delta);
52            if self.inside {
53                return String::new(); // still in block, discard
54            }
55            return flushed;
56        }
57
58        self.carry.push_str(delta);
59        let mut out = String::new();
60        self.drain_into(&mut out);
61        out
62    }
63
64    fn drain_into(&mut self, out: &mut String) {
65        loop {
66            if self.inside {
67                // Look for closing tag.
68                match find_close_tag(&self.carry) {
69                    Some((_close_start, close_end)) => {
70                        self.carry.drain(..close_end);
71                        self.inside = false;
72                        // continue loop to scan rest
73                    }
74                    None => {
75                        // keep everything buffered; we might still see </...>
76                        // but DO drop all but trailing 11 chars (len of "</thinking>")
77                        // to keep memory bounded under streaming.
78                        let keep = 11.min(self.carry.len());
79                        let drop_end = self.carry.len() - keep;
80                        let safe = prev_boundary(&self.carry, drop_end);
81                        self.carry.drain(..safe);
82                        return;
83                    }
84                }
85            } else {
86                match find_open_tag(&self.carry) {
87                    TagScan::None => {
88                        out.push_str(&self.carry);
89                        self.carry.clear();
90                        return;
91                    }
92                    TagScan::Complete { start, end } => {
93                        out.push_str(&self.carry[..start]);
94                        self.carry.drain(..end);
95                        self.inside = true;
96                    }
97                    TagScan::PartialAt(pos) => {
98                        out.push_str(&self.carry[..pos]);
99                        self.carry.drain(..pos);
100                        return;
101                    }
102                }
103            }
104        }
105    }
106}
107
108enum TagScan {
109    None,
110    Complete { start: usize, end: usize },
111    PartialAt(usize),
112}
113
114/// Scan `s` for a `<think>` or `<thinking>` open tag (case-insensitive,
115/// tolerant of attributes).
116fn find_open_tag(s: &str) -> TagScan {
117    let mut search_start = 0;
118    while let Some(lt) = s[search_start..].find('<') {
119        let abs = search_start + lt;
120        let rest = &s[abs..];
121        if let Some(end) = parse_open_tag(rest) {
122            return TagScan::Complete {
123                start: abs,
124                end: abs + end,
125            };
126        }
127        // Could it be a partial prefix of an open tag?
128        let lower: String = rest.chars().map(|c| c.to_ascii_lowercase()).collect();
129        let could_be_partial = lower.len() < 9 && "<thinking".starts_with(lower.as_str())
130            || lower.len() < 6 && "<think".starts_with(lower.as_str())
131            || lower.starts_with("<think") && !lower.contains('>')
132            || lower.starts_with("<thinking") && !lower.contains('>');
133        if could_be_partial {
134            return TagScan::PartialAt(abs);
135        }
136        search_start = abs + 1;
137    }
138    TagScan::None
139}
140
141/// If `s` starts with a complete `<think...>` or `<thinking...>` open tag,
142/// returns byte end position (just after `>`).
143///
144/// NOTE: attribute values are matched by finding the next literal `>`.
145/// A quoted attribute containing a literal `>` (e.g. `<think foo="a>b">`)
146/// would confuse this parser. LLM-generated think tags in practice carry
147/// no meaningful attributes, so this trade-off is intentional.
148fn parse_open_tag(s: &str) -> Option<usize> {
149    if !s.starts_with('<') {
150        return None;
151    }
152    let lower_head: String = s.chars().take(10).map(|c| c.to_ascii_lowercase()).collect();
153    let name_end = if lower_head.starts_with("<thinking") {
154        9
155    } else if lower_head.starts_with("<think") {
156        6
157    } else {
158        return None;
159    };
160    let after = &s[name_end..];
161    let first = after.chars().next()?;
162    if first == '>' {
163        return Some(name_end + 1);
164    }
165    if first.is_ascii_whitespace() {
166        if let Some(gt) = after.find('>') {
167            return Some(name_end + gt + 1);
168        }
169    }
170    None
171}
172
173/// Find the next `</think>` or `</thinking>` close tag (case-insensitive).
174fn find_close_tag(s: &str) -> Option<(usize, usize)> {
175    let lower: String = s.chars().map(|c| c.to_ascii_lowercase()).collect();
176    let p1 = lower
177        .find("</thinking>")
178        .map(|p| (p, p + "</thinking>".len()));
179    let p2 = lower.find("</think>").map(|p| (p, p + "</think>".len()));
180    match (p1, p2) {
181        (Some(a), Some(b)) => Some(if a.0 < b.0 { a } else { b }),
182        (Some(a), None) => Some(a),
183        (None, Some(b)) => Some(b),
184        (None, None) => None,
185    }
186}
187
188fn prev_boundary(s: &str, mut idx: usize) -> usize {
189    while idx > 0 && !s.is_char_boundary(idx) {
190        idx -= 1;
191    }
192    idx
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn no_tags_passes_through() {
201        let mut s = ThinkStripper::new();
202        assert_eq!(s.feed("hello world"), "hello world");
203    }
204
205    #[test]
206    fn complete_block_in_one_feed() {
207        let mut s = ThinkStripper::new();
208        assert_eq!(s.feed("a<think>secret</think>b"), "ab");
209    }
210
211    #[test]
212    fn tag_split_across_feeds() {
213        let mut s = ThinkStripper::new();
214        assert_eq!(s.feed("hello <thi"), "hello ");
215        assert_eq!(s.feed("nk>secret</think> world"), " world");
216    }
217
218    #[test]
219    fn utf8_boundary_at_feed_edge_no_panic() {
220        let mut s = ThinkStripper::new();
221        assert_eq!(s.feed("abc<thi"), "abc");
222        assert_eq!(s.feed("nk>密</think>你好"), "你好");
223    }
224
225    #[test]
226    fn case_insensitive_tag() {
227        let mut s = ThinkStripper::new();
228        assert_eq!(s.feed("<THINK>a</THINK>b"), "b");
229        let mut s2 = ThinkStripper::new();
230        assert_eq!(s2.feed("<Think>a</Think>b"), "b");
231    }
232
233    #[test]
234    fn thinking_tag_also_stripped() {
235        let mut s = ThinkStripper::new();
236        assert_eq!(s.feed("<thinking>a</thinking>b"), "b");
237    }
238
239    #[test]
240    fn tag_with_attributes() {
241        let mut s = ThinkStripper::new();
242        assert_eq!(s.feed("<think key=\"v\">a</think>b"), "b");
243    }
244
245    #[test]
246    fn unclosed_block_capped_at_buf_limit() {
247        let mut s = ThinkStripper::new();
248        let junk = "x".repeat(100_000);
249        let input = format!("<think>{}", junk);
250        let _ = s.feed(&input);
251        assert!(s.buffered_bytes() <= THINK_BUF_MAX);
252    }
253
254    #[test]
255    fn literal_angle_bracket_outside_tag_preserved() {
256        let mut s = ThinkStripper::new();
257        assert_eq!(s.feed("a < b > c"), "a < b > c");
258    }
259
260    #[test]
261    fn multiple_blocks() {
262        let mut s = ThinkStripper::new();
263        assert_eq!(s.feed("a<think>x</think>b<think>y</think>c"), "abc");
264    }
265
266    /// Regression: an unclosed `<think>` from a previous turn would leave
267    /// `inside=true` and swallow the entire next turn's text. The real-world
268    /// trigger is a provider switch mid-turn (e.g. GLM → Kimi): GLM embeds
269    /// thinking as `<think>…</think>` in content, Kimi routes it through
270    /// `reasoning_content` (plain content with no `<think>` tag). If the
271    /// GLM turn cancels with an open tag and no one calls `reset()`, every
272    /// Kimi TextDelta afterward disappears — user sees blank assistant
273    /// bubbles while datalog shows the LLM actually returned text.
274    #[test]
275    fn reset_clears_stuck_inside_state() {
276        let mut s = ThinkStripper::new();
277        // Turn 1: opens a think block but never closes it (stream ended /
278        // got cancelled).
279        let _ = s.feed("prefix <think>still thinking when we got cut");
280        // Turn 2 from a different provider that doesn't use <think> tags.
281        // Without reset, this text gets swallowed.
282        assert_eq!(
283            s.feed("hello from the next model"),
284            "",
285            "without reset, text leaks through the stuck inside=true state",
286        );
287        // Apply the fix.
288        s.reset();
289        // Turn 3: now the stripper is pristine and plain text passes.
290        assert_eq!(
291            s.feed("hello from the next model"),
292            "hello from the next model"
293        );
294    }
295
296    #[test]
297    fn reset_from_pristine_state_is_a_noop() {
298        // Calling reset on a fresh stripper shouldn't break subsequent feeds.
299        let mut s = ThinkStripper::new();
300        s.reset();
301        assert_eq!(s.feed("plain text"), "plain text");
302    }
303
304    #[test]
305    fn reset_clears_partial_carry_at_feed_boundary() {
306        // A feed that ends mid-tag leaves bytes in `carry` awaiting the
307        // rest of the tag. Reset should flush that too so the next turn
308        // starts clean.
309        let mut s = ThinkStripper::new();
310        assert_eq!(s.feed("hello <thi"), "hello "); // carry now holds "<thi"
311        assert!(s.buffered_bytes() > 0);
312        s.reset();
313        assert_eq!(s.buffered_bytes(), 0);
314        // Without reset the next feed would try to complete the tag;
315        // with reset, "<think>" is treated as the start of a new block.
316        assert_eq!(s.feed("not a tag: <3"), "not a tag: <3");
317    }
318}