Skip to main content

reasoning_parser/parsers/
base.rs

1// Base implementation of reasoning parser that handles common logic
2// for detecting and extracting reasoning blocks from text.
3
4use crate::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
5
6/// Base reasoning parser implementation.
7///
8/// This parser handles the common logic for detecting reasoning blocks
9/// delimited by start and end tokens (e.g., <think> and </think>).
10#[derive(Debug, Clone)]
11pub struct BaseReasoningParser {
12    config: ParserConfig,
13    in_reasoning: bool,
14    buffer: String,
15    stripped_think_start: bool,
16    model_type: String,
17}
18
19impl BaseReasoningParser {
20    /// Create a new BaseReasoningParser with the given configuration.
21    pub fn new(config: ParserConfig) -> Self {
22        let in_reasoning = config.initial_in_reasoning;
23        Self {
24            config,
25            in_reasoning,
26            buffer: String::new(),
27            stripped_think_start: false,
28            model_type: "base".to_string(),
29        }
30    }
31
32    /// Create with custom model type identifier.
33    pub fn with_model_type(mut self, model_type: String) -> Self {
34        self.model_type = model_type;
35        self
36    }
37
38    /// Check if the current buffer is a prefix of one of the tokens.
39    fn is_partial_token(&self, text: &str) -> bool {
40        (self.config.think_start_token.starts_with(text) && self.config.think_start_token != text)
41            || (self.config.think_end_token.starts_with(text)
42                && self.config.think_end_token != text)
43    }
44}
45
46impl ReasoningParser for BaseReasoningParser {
47    fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
48        // Check input size against buffer limit
49        if text.len() > self.config.max_buffer_size {
50            return Err(ParseError::BufferOverflow(text.len()));
51        }
52
53        let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token);
54
55        if !in_reasoning {
56            return Ok(ParserResult::normal(text.to_string()));
57        }
58
59        // The text is considered to be in a reasoning block.
60        let processed_text = text
61            .replace(&self.config.think_start_token, "")
62            .trim()
63            .to_string();
64
65        if !processed_text.contains(&self.config.think_end_token) {
66            // Assume reasoning was truncated before end token
67            return Ok(ParserResult::reasoning(processed_text));
68        }
69
70        // Extract reasoning content
71        let splits: Vec<&str> = processed_text
72            .splitn(2, &self.config.think_end_token)
73            .collect();
74        let reasoning_text = (*splits.first().unwrap_or(&"")).to_string();
75        let normal_text = splits
76            .get(1)
77            .map(|s| s.trim().to_string())
78            .unwrap_or_default();
79
80        Ok(ParserResult::new(normal_text, reasoning_text))
81    }
82
83    fn parse_reasoning_streaming_incremental(
84        &mut self,
85        text: &str,
86    ) -> Result<ParserResult, ParseError> {
87        // Check if adding this text would exceed buffer limit
88        if self.buffer.len() + text.len() > self.config.max_buffer_size {
89            return Err(ParseError::BufferOverflow(self.buffer.len() + text.len()));
90        }
91
92        // Incrementally parse the streaming text
93        self.buffer.push_str(text);
94        let mut current_text = self.buffer.clone();
95
96        // If the current text is a prefix of a token, keep buffering
97        if self.is_partial_token(&current_text) {
98            return Ok(ParserResult::default());
99        }
100
101        // Strip start token if present
102        if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) {
103            current_text = current_text.replace(&self.config.think_start_token, "");
104            self.buffer.clone_from(&current_text);
105            self.stripped_think_start = true;
106            self.in_reasoning = true;
107        }
108
109        // Handle end of reasoning block
110        let think_end_idx = if self.in_reasoning {
111            current_text
112                .find(&self.config.think_end_token)
113                .unwrap_or(current_text.len())
114        } else {
115            current_text.len()
116        };
117
118        if self.in_reasoning && think_end_idx < current_text.len() {
119            let reasoning_text = &current_text[..think_end_idx];
120            self.buffer.clear();
121            self.in_reasoning = false;
122            let start_idx = think_end_idx + self.config.think_end_token.len();
123            let normal_text = if start_idx < current_text.len() {
124                &current_text[start_idx..]
125            } else {
126                ""
127            };
128            return Ok(ParserResult::new(
129                normal_text.to_string(),
130                reasoning_text.trim().to_string(),
131            ));
132        }
133
134        // Continue with reasoning content
135        if self.in_reasoning && self.config.stream_reasoning {
136            // Stream the content immediately
137            let reasoning_text = current_text;
138            self.buffer.clear();
139            Ok(ParserResult::reasoning(reasoning_text))
140        } else if !self.in_reasoning {
141            // If we're not in a reasoning block, return as normal text
142            // CRITICAL FIX: Return current_text (with buffer) not just text
143            // This prevents buffer loss when partial tokens are followed by normal text
144            let normal_text = current_text;
145            self.buffer.clear();
146            Ok(ParserResult::normal(normal_text))
147        } else {
148            // If we are in a reasoning block but no end token is found, buffer it
149            Ok(ParserResult::default())
150        }
151    }
152
153    fn reset(&mut self) {
154        self.in_reasoning = self.config.initial_in_reasoning;
155        self.buffer.clear();
156        self.stripped_think_start = false;
157    }
158
159    fn model_type(&self) -> &str {
160        &self.model_type
161    }
162
163    fn is_in_reasoning(&self) -> bool {
164        self.in_reasoning
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::traits::DEFAULT_MAX_BUFFER_SIZE;
172
173    fn create_test_parser(
174        initial_in_reasoning: bool,
175        stream_reasoning: bool,
176    ) -> BaseReasoningParser {
177        let config = ParserConfig {
178            think_start_token: "<think>".to_string(),
179            think_end_token: "</think>".to_string(),
180            stream_reasoning,
181            max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
182            initial_in_reasoning,
183        };
184        BaseReasoningParser::new(config)
185    }
186
187    #[test]
188    fn test_detect_and_parse_reasoning() {
189        let mut parser = create_test_parser(false, true);
190        let result = parser
191            .detect_and_parse_reasoning("<think>with reasoning</think> and more text.")
192            .unwrap();
193        assert_eq!(result.normal_text, "and more text.");
194        assert_eq!(result.reasoning_text, "with reasoning");
195    }
196
197    #[test]
198    fn test_detect_and_parse_no_reasoning() {
199        let mut parser = create_test_parser(false, true);
200        let result = parser
201            .detect_and_parse_reasoning("This is a test without reasoning.")
202            .unwrap();
203        assert_eq!(result.normal_text, "This is a test without reasoning.");
204        assert_eq!(result.reasoning_text, "");
205    }
206
207    #[test]
208    fn test_detect_and_parse_truncated_reasoning() {
209        let mut parser = create_test_parser(false, true);
210        let result = parser
211            .detect_and_parse_reasoning("<think>with truncated reasoning")
212            .unwrap();
213        assert_eq!(result.normal_text, "");
214        assert_eq!(result.reasoning_text, "with truncated reasoning");
215    }
216
217    #[test]
218    fn test_parse_streaming_partial_token() {
219        let mut parser = create_test_parser(false, true);
220        let result = parser
221            .parse_reasoning_streaming_incremental("<thi")
222            .unwrap();
223        assert_eq!(result.normal_text, "");
224        assert_eq!(result.reasoning_text, "");
225    }
226
227    #[test]
228    fn test_parse_streaming_complete() {
229        let mut parser = create_test_parser(false, true);
230        let result = parser
231            .parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.")
232            .unwrap();
233        assert_eq!(result.normal_text, " and more text.");
234        assert_eq!(result.reasoning_text, "with reasoning");
235    }
236
237    #[test]
238    fn test_parse_streaming_no_end_token() {
239        let mut parser = create_test_parser(true, true);
240        let result = parser
241            .parse_reasoning_streaming_incremental("<think>with reasoning")
242            .unwrap();
243        assert_eq!(result.normal_text, "");
244        assert_eq!(result.reasoning_text, "with reasoning");
245    }
246
247    #[test]
248    fn test_initial_in_reasoning_true() {
249        // Parser starts with in_reasoning=true (like DeepSeek-R1)
250        let mut parser = create_test_parser(true, true);
251        let result = parser
252            .detect_and_parse_reasoning("no think tags here")
253            .unwrap();
254        assert_eq!(result.normal_text, "");
255        assert_eq!(result.reasoning_text, "no think tags here");
256    }
257
258    #[test]
259    fn test_buffer_loss_bug_fix() {
260        // Critical test for buffer preservation
261        let mut parser = create_test_parser(false, true);
262
263        // Step 1: Send partial end tag when not in reasoning mode
264        let result1 = parser.parse_reasoning_streaming_incremental("</").unwrap();
265        assert_eq!(result1.normal_text, "");
266        assert_eq!(result1.reasoning_text, "");
267
268        // Step 2: Send normal text that doesn't complete the end tag
269        // Must return "</answer" not just "answer"
270        let result2 = parser
271            .parse_reasoning_streaming_incremental("answer")
272            .unwrap();
273        assert_eq!(result2.normal_text, "</answer");
274        assert_eq!(result2.reasoning_text, "");
275    }
276
277    #[test]
278    fn test_streaming_with_stream_reasoning_enabled() {
279        let mut parser = create_test_parser(false, true);
280
281        // Start reasoning block
282        let result1 = parser
283            .parse_reasoning_streaming_incremental("<think>reasoning ")
284            .unwrap();
285        assert_eq!(result1.normal_text, "");
286        assert_eq!(result1.reasoning_text, "reasoning ");
287
288        // Continue streaming reasoning
289        let result2 = parser
290            .parse_reasoning_streaming_incremental("content ")
291            .unwrap();
292        assert_eq!(result2.normal_text, "");
293        assert_eq!(result2.reasoning_text, "content ");
294
295        // End reasoning block
296        let result3 = parser
297            .parse_reasoning_streaming_incremental("more</think> normal")
298            .unwrap();
299        assert_eq!(result3.normal_text, " normal");
300        assert_eq!(result3.reasoning_text, "more");
301    }
302
303    #[test]
304    fn test_reset_state() {
305        let mut parser = create_test_parser(false, true);
306
307        // Process some text
308        parser
309            .parse_reasoning_streaming_incremental("<think>reasoning</think> normal")
310            .unwrap();
311
312        // Reset and verify state
313        parser.reset();
314        assert!(!parser.in_reasoning);
315        assert!(parser.buffer.is_empty());
316        assert!(!parser.stripped_think_start);
317    }
318
319    #[test]
320    fn test_buffer_overflow_detect_and_parse() {
321        let config = ParserConfig {
322            max_buffer_size: 10, // Set a very small buffer
323            ..Default::default()
324        };
325        let mut parser = BaseReasoningParser::new(config);
326
327        let large_text = "a".repeat(20);
328        let result = parser.detect_and_parse_reasoning(&large_text);
329
330        assert!(result.is_err());
331        match result {
332            Err(ParseError::BufferOverflow(size)) => {
333                assert_eq!(size, 20);
334            }
335            _ => panic!("Expected BufferOverflow error"),
336        }
337    }
338
339    #[test]
340    fn test_buffer_overflow_streaming() {
341        let config = ParserConfig {
342            max_buffer_size: 10, // Set a very small buffer
343            ..Default::default()
344        };
345        let mut parser = BaseReasoningParser::new(config);
346
347        // Send a partial token that will be buffered
348        let result1 = parser.parse_reasoning_streaming_incremental("<thi");
349        assert!(result1.is_ok());
350        assert_eq!(result1.unwrap().normal_text, "");
351
352        // Second chunk would exceed buffer
353        // Buffer has "<thi" (4 chars) + "this_is_too_large" (17 chars) = 21 total
354        let result2 = parser.parse_reasoning_streaming_incremental("this_is_too_large");
355        assert!(result2.is_err());
356        match result2 {
357            Err(ParseError::BufferOverflow(size)) => {
358                assert_eq!(size, 21); // 4 + 17
359            }
360            _ => panic!("Expected BufferOverflow error"),
361        }
362    }
363}