reasoning_parser/parsers/
base.rs

1// Base implementation of reasoning parser that handles common logic
2// for detecting and extracting reasoning blocks from text.
3
4use crate::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
5
6/// Base reasoning parser implementation.
7///
8/// This parser handles the common logic for detecting reasoning blocks
9/// delimited by start and end tokens (e.g., <think> and </think>).
10#[derive(Debug, Clone)]
11pub struct BaseReasoningParser {
12    config: ParserConfig,
13    in_reasoning: bool,
14    buffer: String,
15    stripped_think_start: bool,
16    model_type: String,
17}
18
19impl BaseReasoningParser {
20    /// Create a new BaseReasoningParser with the given configuration.
21    pub fn new(config: ParserConfig) -> Self {
22        let in_reasoning = config.initial_in_reasoning;
23        Self {
24            config,
25            in_reasoning,
26            buffer: String::new(),
27            stripped_think_start: false,
28            model_type: "base".to_string(),
29        }
30    }
31
32    /// Create with custom model type identifier.
33    pub fn with_model_type(mut self, model_type: String) -> Self {
34        self.model_type = model_type;
35        self
36    }
37
38    /// Check if the current buffer is a prefix of one of the tokens.
39    fn is_partial_token(&self, text: &str) -> bool {
40        (self.config.think_start_token.starts_with(text) && self.config.think_start_token != text)
41            || (self.config.think_end_token.starts_with(text)
42                && self.config.think_end_token != text)
43    }
44}
45
46impl ReasoningParser for BaseReasoningParser {
47    fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
48        // Check input size against buffer limit
49        if text.len() > self.config.max_buffer_size {
50            return Err(ParseError::BufferOverflow(text.len()));
51        }
52
53        let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token);
54
55        if !in_reasoning {
56            return Ok(ParserResult::normal(text.to_string()));
57        }
58
59        // The text is considered to be in a reasoning block.
60        let processed_text = text
61            .replace(&self.config.think_start_token, "")
62            .trim()
63            .to_string();
64
65        if !processed_text.contains(&self.config.think_end_token) {
66            // Assume reasoning was truncated before end token
67            return Ok(ParserResult::reasoning(processed_text));
68        }
69
70        // Extract reasoning content
71        let splits: Vec<&str> = processed_text
72            .splitn(2, &self.config.think_end_token)
73            .collect();
74        let reasoning_text = splits.first().unwrap_or(&"").to_string();
75        let normal_text = splits
76            .get(1)
77            .map(|s| s.trim().to_string())
78            .unwrap_or_default();
79
80        Ok(ParserResult::new(normal_text, reasoning_text))
81    }
82
83    fn parse_reasoning_streaming_incremental(
84        &mut self,
85        text: &str,
86    ) -> Result<ParserResult, ParseError> {
87        // Check if adding this text would exceed buffer limit
88        if self.buffer.len() + text.len() > self.config.max_buffer_size {
89            return Err(ParseError::BufferOverflow(self.buffer.len() + text.len()));
90        }
91
92        // Incrementally parse the streaming text
93        self.buffer.push_str(text);
94        let mut current_text = self.buffer.clone();
95
96        // If the current text is a prefix of a token, keep buffering
97        if self.is_partial_token(&current_text) {
98            return Ok(ParserResult::default());
99        }
100
101        // Strip start token if present
102        if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) {
103            current_text = current_text.replace(&self.config.think_start_token, "");
104            self.buffer = current_text.clone();
105            self.stripped_think_start = true;
106            self.in_reasoning = true;
107        }
108
109        // Handle end of reasoning block
110        let think_end_idx = if self.in_reasoning {
111            current_text
112                .find(&self.config.think_end_token)
113                .unwrap_or(current_text.len())
114        } else {
115            current_text.len()
116        };
117
118        if self.in_reasoning && think_end_idx < current_text.len() {
119            let reasoning_text = &current_text[..think_end_idx];
120            self.buffer.clear();
121            self.in_reasoning = false;
122            let start_idx = think_end_idx + self.config.think_end_token.len();
123            let normal_text = if start_idx < current_text.len() {
124                &current_text[start_idx..]
125            } else {
126                ""
127            };
128            return Ok(ParserResult::new(
129                normal_text.to_string(),
130                reasoning_text.trim().to_string(),
131            ));
132        }
133
134        // Continue with reasoning content
135        if self.in_reasoning && self.config.stream_reasoning {
136            // Stream the content immediately
137            let reasoning_text = current_text;
138            self.buffer.clear();
139            Ok(ParserResult::reasoning(reasoning_text))
140        } else if !self.in_reasoning {
141            // If we're not in a reasoning block, return as normal text
142            // CRITICAL FIX: Return current_text (with buffer) not just text
143            // This prevents buffer loss when partial tokens are followed by normal text
144            let normal_text = current_text;
145            self.buffer.clear();
146            Ok(ParserResult::normal(normal_text))
147        } else {
148            // If we are in a reasoning block but no end token is found, buffer it
149            Ok(ParserResult::default())
150        }
151    }
152
153    fn reset(&mut self) {
154        self.in_reasoning = self.config.initial_in_reasoning;
155        self.buffer.clear();
156        self.stripped_think_start = false;
157    }
158
159    fn model_type(&self) -> &str {
160        &self.model_type
161    }
162
163    fn is_in_reasoning(&self) -> bool {
164        self.in_reasoning
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    fn create_test_parser(
173        initial_in_reasoning: bool,
174        stream_reasoning: bool,
175    ) -> BaseReasoningParser {
176        let config = ParserConfig {
177            think_start_token: "<think>".to_string(),
178            think_end_token: "</think>".to_string(),
179            stream_reasoning,
180            max_buffer_size: 65536,
181            initial_in_reasoning,
182        };
183        BaseReasoningParser::new(config)
184    }
185
186    #[test]
187    fn test_detect_and_parse_reasoning() {
188        let mut parser = create_test_parser(false, true);
189        let result = parser
190            .detect_and_parse_reasoning("<think>with reasoning</think> and more text.")
191            .unwrap();
192        assert_eq!(result.normal_text, "and more text.");
193        assert_eq!(result.reasoning_text, "with reasoning");
194    }
195
196    #[test]
197    fn test_detect_and_parse_no_reasoning() {
198        let mut parser = create_test_parser(false, true);
199        let result = parser
200            .detect_and_parse_reasoning("This is a test without reasoning.")
201            .unwrap();
202        assert_eq!(result.normal_text, "This is a test without reasoning.");
203        assert_eq!(result.reasoning_text, "");
204    }
205
206    #[test]
207    fn test_detect_and_parse_truncated_reasoning() {
208        let mut parser = create_test_parser(false, true);
209        let result = parser
210            .detect_and_parse_reasoning("<think>with truncated reasoning")
211            .unwrap();
212        assert_eq!(result.normal_text, "");
213        assert_eq!(result.reasoning_text, "with truncated reasoning");
214    }
215
216    #[test]
217    fn test_parse_streaming_partial_token() {
218        let mut parser = create_test_parser(false, true);
219        let result = parser
220            .parse_reasoning_streaming_incremental("<thi")
221            .unwrap();
222        assert_eq!(result.normal_text, "");
223        assert_eq!(result.reasoning_text, "");
224    }
225
226    #[test]
227    fn test_parse_streaming_complete() {
228        let mut parser = create_test_parser(false, true);
229        let result = parser
230            .parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.")
231            .unwrap();
232        assert_eq!(result.normal_text, " and more text.");
233        assert_eq!(result.reasoning_text, "with reasoning");
234    }
235
236    #[test]
237    fn test_parse_streaming_no_end_token() {
238        let mut parser = create_test_parser(true, true);
239        let result = parser
240            .parse_reasoning_streaming_incremental("<think>with reasoning")
241            .unwrap();
242        assert_eq!(result.normal_text, "");
243        assert_eq!(result.reasoning_text, "with reasoning");
244    }
245
246    #[test]
247    fn test_initial_in_reasoning_true() {
248        // Parser starts with in_reasoning=true (like DeepSeek-R1)
249        let mut parser = create_test_parser(true, true);
250        let result = parser
251            .detect_and_parse_reasoning("no think tags here")
252            .unwrap();
253        assert_eq!(result.normal_text, "");
254        assert_eq!(result.reasoning_text, "no think tags here");
255    }
256
257    #[test]
258    fn test_buffer_loss_bug_fix() {
259        // Critical test for buffer preservation
260        let mut parser = create_test_parser(false, true);
261
262        // Step 1: Send partial end tag when not in reasoning mode
263        let result1 = parser.parse_reasoning_streaming_incremental("</").unwrap();
264        assert_eq!(result1.normal_text, "");
265        assert_eq!(result1.reasoning_text, "");
266
267        // Step 2: Send normal text that doesn't complete the end tag
268        // Must return "</answer" not just "answer"
269        let result2 = parser
270            .parse_reasoning_streaming_incremental("answer")
271            .unwrap();
272        assert_eq!(result2.normal_text, "</answer");
273        assert_eq!(result2.reasoning_text, "");
274    }
275
276    #[test]
277    fn test_streaming_with_stream_reasoning_enabled() {
278        let mut parser = create_test_parser(false, true);
279
280        // Start reasoning block
281        let result1 = parser
282            .parse_reasoning_streaming_incremental("<think>reasoning ")
283            .unwrap();
284        assert_eq!(result1.normal_text, "");
285        assert_eq!(result1.reasoning_text, "reasoning ");
286
287        // Continue streaming reasoning
288        let result2 = parser
289            .parse_reasoning_streaming_incremental("content ")
290            .unwrap();
291        assert_eq!(result2.normal_text, "");
292        assert_eq!(result2.reasoning_text, "content ");
293
294        // End reasoning block
295        let result3 = parser
296            .parse_reasoning_streaming_incremental("more</think> normal")
297            .unwrap();
298        assert_eq!(result3.normal_text, " normal");
299        assert_eq!(result3.reasoning_text, "more");
300    }
301
302    #[test]
303    fn test_reset_state() {
304        let mut parser = create_test_parser(false, true);
305
306        // Process some text
307        parser
308            .parse_reasoning_streaming_incremental("<think>reasoning</think> normal")
309            .unwrap();
310
311        // Reset and verify state
312        parser.reset();
313        assert!(!parser.in_reasoning);
314        assert!(parser.buffer.is_empty());
315        assert!(!parser.stripped_think_start);
316    }
317
318    #[test]
319    fn test_buffer_overflow_detect_and_parse() {
320        let config = ParserConfig {
321            max_buffer_size: 10, // Set a very small buffer
322            ..Default::default()
323        };
324        let mut parser = BaseReasoningParser::new(config);
325
326        let large_text = "a".repeat(20);
327        let result = parser.detect_and_parse_reasoning(&large_text);
328
329        assert!(result.is_err());
330        match result {
331            Err(ParseError::BufferOverflow(size)) => {
332                assert_eq!(size, 20);
333            }
334            _ => panic!("Expected BufferOverflow error"),
335        }
336    }
337
338    #[test]
339    fn test_buffer_overflow_streaming() {
340        let config = ParserConfig {
341            max_buffer_size: 10, // Set a very small buffer
342            ..Default::default()
343        };
344        let mut parser = BaseReasoningParser::new(config);
345
346        // Send a partial token that will be buffered
347        let result1 = parser.parse_reasoning_streaming_incremental("<thi");
348        assert!(result1.is_ok());
349        assert_eq!(result1.unwrap().normal_text, "");
350
351        // Second chunk would exceed buffer
352        // Buffer has "<thi" (4 chars) + "this_is_too_large" (17 chars) = 21 total
353        let result2 = parser.parse_reasoning_streaming_incremental("this_is_too_large");
354        assert!(result2.is_err());
355        match result2 {
356            Err(ParseError::BufferOverflow(size)) => {
357                assert_eq!(size, 21); // 4 + 17
358            }
359            _ => panic!("Expected BufferOverflow error"),
360        }
361    }
362}