Skip to main content

reasoning_parser/parsers/
base.rs

1// Base implementation of reasoning parser that handles common logic
2// for detecting and extracting reasoning blocks from text.
3
4use crate::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
5
6/// Base reasoning parser implementation.
7///
8/// This parser handles the common logic for detecting reasoning blocks
9/// delimited by start and end tokens (e.g., <think> and </think>).
10#[derive(Debug, Clone)]
11pub struct BaseReasoningParser {
12    config: ParserConfig,
13    in_reasoning: bool,
14    buffer: String,
15    stripped_think_start: bool,
16    model_type: String,
17}
18
19impl BaseReasoningParser {
20    /// Create a new BaseReasoningParser with the given configuration.
21    pub fn new(config: ParserConfig) -> Self {
22        let in_reasoning = config.always_in_reasoning;
23        Self {
24            config,
25            in_reasoning,
26            buffer: String::new(),
27            stripped_think_start: false,
28            model_type: "base".to_string(),
29        }
30    }
31
32    /// Create with custom model type identifier.
33    pub fn with_model_type(mut self, model_type: String) -> Self {
34        self.model_type = model_type;
35        self
36    }
37
38    /// Check if the current buffer is a prefix of one of the tokens.
39    fn is_partial_token(&self, text: &str) -> bool {
40        (self.config.think_start_token.starts_with(text) && self.config.think_start_token != text)
41            || (self.config.think_end_token.starts_with(text)
42                && self.config.think_end_token != text)
43    }
44}
45
46impl ReasoningParser for BaseReasoningParser {
47    fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
48        // Check input size against buffer limit
49        if text.len() > self.config.max_buffer_size {
50            return Err(ParseError::BufferOverflow(text.len()));
51        }
52
53        let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token);
54
55        if !in_reasoning {
56            return Ok(ParserResult::normal(text.to_string()));
57        }
58
59        // The text is considered to be in a reasoning block.
60        let processed_text = text
61            .replace(&self.config.think_start_token, "")
62            .trim()
63            .to_string();
64
65        if !processed_text.contains(&self.config.think_end_token) {
66            // Assume reasoning was truncated before end token
67            return Ok(ParserResult::reasoning(processed_text));
68        }
69
70        // Extract reasoning content
71        let splits: Vec<&str> = processed_text
72            .splitn(2, &self.config.think_end_token)
73            .collect();
74        let reasoning_text = (*splits.first().unwrap_or(&"")).to_string();
75        let normal_text = splits
76            .get(1)
77            .map(|s| s.trim().to_string())
78            .unwrap_or_default();
79
80        Ok(ParserResult::new(normal_text, reasoning_text))
81    }
82
83    fn parse_reasoning_streaming_incremental(
84        &mut self,
85        text: &str,
86    ) -> Result<ParserResult, ParseError> {
87        // Check if adding this text would exceed buffer limit
88        if self.buffer.len() + text.len() > self.config.max_buffer_size {
89            return Err(ParseError::BufferOverflow(self.buffer.len() + text.len()));
90        }
91
92        // Incrementally parse the streaming text
93        self.buffer.push_str(text);
94        let mut current_text = self.buffer.clone();
95
96        // If the current text is a prefix of a token, keep buffering
97        if self.is_partial_token(&current_text) {
98            return Ok(ParserResult::default());
99        }
100
101        // Strip start token if present
102        if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) {
103            current_text = current_text.replace(&self.config.think_start_token, "");
104            self.buffer.clone_from(&current_text);
105            self.stripped_think_start = true;
106            self.in_reasoning = true;
107        }
108
109        // Handle end of reasoning block
110        let think_end_idx = if self.in_reasoning {
111            current_text
112                .find(&self.config.think_end_token)
113                .unwrap_or(current_text.len())
114        } else {
115            current_text.len()
116        };
117
118        if self.in_reasoning && think_end_idx < current_text.len() {
119            let reasoning_text = &current_text[..think_end_idx];
120            self.buffer.clear();
121            self.in_reasoning = false;
122            let start_idx = think_end_idx + self.config.think_end_token.len();
123            let normal_text = if start_idx < current_text.len() {
124                &current_text[start_idx..]
125            } else {
126                ""
127            };
128            return Ok(ParserResult::new(
129                normal_text.to_string(),
130                reasoning_text.trim().to_string(),
131            ));
132        }
133
134        // Continue with reasoning content
135        if self.in_reasoning && self.config.stream_reasoning {
136            // Stream the content immediately
137            let reasoning_text = current_text;
138            self.buffer.clear();
139            Ok(ParserResult::reasoning(reasoning_text))
140        } else if !self.in_reasoning {
141            // If we're not in a reasoning block, return as normal text
142            // CRITICAL FIX: Return current_text (with buffer) not just text
143            // This prevents buffer loss when partial tokens are followed by normal text
144            let normal_text = current_text;
145            self.buffer.clear();
146            Ok(ParserResult::normal(normal_text))
147        } else {
148            // If we are in a reasoning block but no end token is found, buffer it
149            Ok(ParserResult::default())
150        }
151    }
152
153    fn reset(&mut self) {
154        self.in_reasoning = self.config.always_in_reasoning;
155        self.buffer.clear();
156        self.stripped_think_start = false;
157    }
158
159    fn mark_reasoning_started(&mut self) {
160        self.in_reasoning = true;
161    }
162
163    fn mark_think_start_stripped(&mut self) {
164        self.stripped_think_start = true;
165    }
166
167    fn model_type(&self) -> &str {
168        &self.model_type
169    }
170
171    fn is_in_reasoning(&self) -> bool {
172        self.in_reasoning
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::traits::DEFAULT_MAX_BUFFER_SIZE;
180
181    fn create_test_parser(
182        always_in_reasoning: bool,
183        stream_reasoning: bool,
184    ) -> BaseReasoningParser {
185        let config = ParserConfig {
186            think_start_token: "<think>".to_string(),
187            think_end_token: "</think>".to_string(),
188            stream_reasoning,
189            max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
190            always_in_reasoning,
191        };
192        BaseReasoningParser::new(config)
193    }
194
195    #[test]
196    fn test_detect_and_parse_reasoning() {
197        let mut parser = create_test_parser(false, true);
198        let result = parser
199            .detect_and_parse_reasoning("<think>with reasoning</think> and more text.")
200            .unwrap();
201        assert_eq!(result.normal_text, "and more text.");
202        assert_eq!(result.reasoning_text, "with reasoning");
203    }
204
205    #[test]
206    fn test_detect_and_parse_no_reasoning() {
207        let mut parser = create_test_parser(false, true);
208        let result = parser
209            .detect_and_parse_reasoning("This is a test without reasoning.")
210            .unwrap();
211        assert_eq!(result.normal_text, "This is a test without reasoning.");
212        assert_eq!(result.reasoning_text, "");
213    }
214
215    #[test]
216    fn test_detect_and_parse_truncated_reasoning() {
217        let mut parser = create_test_parser(false, true);
218        let result = parser
219            .detect_and_parse_reasoning("<think>with truncated reasoning")
220            .unwrap();
221        assert_eq!(result.normal_text, "");
222        assert_eq!(result.reasoning_text, "with truncated reasoning");
223    }
224
225    #[test]
226    fn test_parse_streaming_partial_token() {
227        let mut parser = create_test_parser(false, true);
228        let result = parser
229            .parse_reasoning_streaming_incremental("<thi")
230            .unwrap();
231        assert_eq!(result.normal_text, "");
232        assert_eq!(result.reasoning_text, "");
233    }
234
235    #[test]
236    fn test_parse_streaming_complete() {
237        let mut parser = create_test_parser(false, true);
238        let result = parser
239            .parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.")
240            .unwrap();
241        assert_eq!(result.normal_text, " and more text.");
242        assert_eq!(result.reasoning_text, "with reasoning");
243    }
244
245    #[test]
246    fn test_parse_streaming_no_end_token() {
247        let mut parser = create_test_parser(true, true);
248        let result = parser
249            .parse_reasoning_streaming_incremental("<think>with reasoning")
250            .unwrap();
251        assert_eq!(result.normal_text, "");
252        assert_eq!(result.reasoning_text, "with reasoning");
253    }
254
255    #[test]
256    fn test_always_in_reasoning_true() {
257        // Parser starts with in_reasoning=true (like DeepSeek-R1)
258        let mut parser = create_test_parser(true, true);
259        let result = parser
260            .detect_and_parse_reasoning("no think tags here")
261            .unwrap();
262        assert_eq!(result.normal_text, "");
263        assert_eq!(result.reasoning_text, "no think tags here");
264    }
265
266    #[test]
267    fn test_buffer_loss_bug_fix() {
268        // Critical test for buffer preservation
269        let mut parser = create_test_parser(false, true);
270
271        // Step 1: Send partial end tag when not in reasoning mode
272        let result1 = parser.parse_reasoning_streaming_incremental("</").unwrap();
273        assert_eq!(result1.normal_text, "");
274        assert_eq!(result1.reasoning_text, "");
275
276        // Step 2: Send normal text that doesn't complete the end tag
277        // Must return "</answer" not just "answer"
278        let result2 = parser
279            .parse_reasoning_streaming_incremental("answer")
280            .unwrap();
281        assert_eq!(result2.normal_text, "</answer");
282        assert_eq!(result2.reasoning_text, "");
283    }
284
285    #[test]
286    fn test_streaming_with_stream_reasoning_enabled() {
287        let mut parser = create_test_parser(false, true);
288
289        // Start reasoning block
290        let result1 = parser
291            .parse_reasoning_streaming_incremental("<think>reasoning ")
292            .unwrap();
293        assert_eq!(result1.normal_text, "");
294        assert_eq!(result1.reasoning_text, "reasoning ");
295
296        // Continue streaming reasoning
297        let result2 = parser
298            .parse_reasoning_streaming_incremental("content ")
299            .unwrap();
300        assert_eq!(result2.normal_text, "");
301        assert_eq!(result2.reasoning_text, "content ");
302
303        // End reasoning block
304        let result3 = parser
305            .parse_reasoning_streaming_incremental("more</think> normal")
306            .unwrap();
307        assert_eq!(result3.normal_text, " normal");
308        assert_eq!(result3.reasoning_text, "more");
309    }
310
311    #[test]
312    fn test_reset_state() {
313        let mut parser = create_test_parser(false, true);
314
315        // Process some text
316        parser
317            .parse_reasoning_streaming_incremental("<think>reasoning</think> normal")
318            .unwrap();
319
320        // Reset and verify state
321        parser.reset();
322        assert!(!parser.in_reasoning);
323        assert!(parser.buffer.is_empty());
324        assert!(!parser.stripped_think_start);
325    }
326
327    #[test]
328    fn test_buffer_overflow_detect_and_parse() {
329        let config = ParserConfig {
330            max_buffer_size: 10, // Set a very small buffer
331            ..Default::default()
332        };
333        let mut parser = BaseReasoningParser::new(config);
334
335        let large_text = "a".repeat(20);
336        let result = parser.detect_and_parse_reasoning(&large_text);
337
338        assert!(result.is_err());
339        match result {
340            Err(ParseError::BufferOverflow(size)) => {
341                assert_eq!(size, 20);
342            }
343            _ => panic!("Expected BufferOverflow error"),
344        }
345    }
346
347    #[test]
348    fn test_buffer_overflow_streaming() {
349        let config = ParserConfig {
350            max_buffer_size: 10, // Set a very small buffer
351            ..Default::default()
352        };
353        let mut parser = BaseReasoningParser::new(config);
354
355        // Send a partial token that will be buffered
356        let result1 = parser.parse_reasoning_streaming_incremental("<thi");
357        assert!(result1.is_ok());
358        assert_eq!(result1.unwrap().normal_text, "");
359
360        // Second chunk would exceed buffer
361        // Buffer has "<thi" (4 chars) + "this_is_too_large" (17 chars) = 21 total
362        let result2 = parser.parse_reasoning_streaming_incremental("this_is_too_large");
363        assert!(result2.is_err());
364        match result2 {
365            Err(ParseError::BufferOverflow(size)) => {
366                assert_eq!(size, 21); // 4 + 17
367            }
368            _ => panic!("Expected BufferOverflow error"),
369        }
370    }
371}