reasoning_parser/parsers/
base.rs1use crate::traits::{ParseError, ParserConfig, ParserResult, ReasoningParser};
5
6#[derive(Debug, Clone)]
11pub struct BaseReasoningParser {
12 config: ParserConfig,
13 in_reasoning: bool,
14 buffer: String,
15 stripped_think_start: bool,
16 model_type: String,
17}
18
19impl BaseReasoningParser {
20 pub fn new(config: ParserConfig) -> Self {
22 let in_reasoning = config.initial_in_reasoning;
23 Self {
24 config,
25 in_reasoning,
26 buffer: String::new(),
27 stripped_think_start: false,
28 model_type: "base".to_string(),
29 }
30 }
31
32 pub fn with_model_type(mut self, model_type: String) -> Self {
34 self.model_type = model_type;
35 self
36 }
37
38 fn is_partial_token(&self, text: &str) -> bool {
40 (self.config.think_start_token.starts_with(text) && self.config.think_start_token != text)
41 || (self.config.think_end_token.starts_with(text)
42 && self.config.think_end_token != text)
43 }
44}
45
46impl ReasoningParser for BaseReasoningParser {
47 fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError> {
48 if text.len() > self.config.max_buffer_size {
50 return Err(ParseError::BufferOverflow(text.len()));
51 }
52
53 let in_reasoning = self.in_reasoning || text.contains(&self.config.think_start_token);
54
55 if !in_reasoning {
56 return Ok(ParserResult::normal(text.to_string()));
57 }
58
59 let processed_text = text
61 .replace(&self.config.think_start_token, "")
62 .trim()
63 .to_string();
64
65 if !processed_text.contains(&self.config.think_end_token) {
66 return Ok(ParserResult::reasoning(processed_text));
68 }
69
70 let splits: Vec<&str> = processed_text
72 .splitn(2, &self.config.think_end_token)
73 .collect();
74 let reasoning_text = splits.first().unwrap_or(&"").to_string();
75 let normal_text = splits
76 .get(1)
77 .map(|s| s.trim().to_string())
78 .unwrap_or_default();
79
80 Ok(ParserResult::new(normal_text, reasoning_text))
81 }
82
83 fn parse_reasoning_streaming_incremental(
84 &mut self,
85 text: &str,
86 ) -> Result<ParserResult, ParseError> {
87 if self.buffer.len() + text.len() > self.config.max_buffer_size {
89 return Err(ParseError::BufferOverflow(self.buffer.len() + text.len()));
90 }
91
92 self.buffer.push_str(text);
94 let mut current_text = self.buffer.clone();
95
96 if self.is_partial_token(¤t_text) {
98 return Ok(ParserResult::default());
99 }
100
101 if !self.stripped_think_start && current_text.contains(&self.config.think_start_token) {
103 current_text = current_text.replace(&self.config.think_start_token, "");
104 self.buffer = current_text.clone();
105 self.stripped_think_start = true;
106 self.in_reasoning = true;
107 }
108
109 let think_end_idx = if self.in_reasoning {
111 current_text
112 .find(&self.config.think_end_token)
113 .unwrap_or(current_text.len())
114 } else {
115 current_text.len()
116 };
117
118 if self.in_reasoning && think_end_idx < current_text.len() {
119 let reasoning_text = ¤t_text[..think_end_idx];
120 self.buffer.clear();
121 self.in_reasoning = false;
122 let start_idx = think_end_idx + self.config.think_end_token.len();
123 let normal_text = if start_idx < current_text.len() {
124 ¤t_text[start_idx..]
125 } else {
126 ""
127 };
128 return Ok(ParserResult::new(
129 normal_text.to_string(),
130 reasoning_text.trim().to_string(),
131 ));
132 }
133
134 if self.in_reasoning && self.config.stream_reasoning {
136 let reasoning_text = current_text;
138 self.buffer.clear();
139 Ok(ParserResult::reasoning(reasoning_text))
140 } else if !self.in_reasoning {
141 let normal_text = current_text;
145 self.buffer.clear();
146 Ok(ParserResult::normal(normal_text))
147 } else {
148 Ok(ParserResult::default())
150 }
151 }
152
153 fn reset(&mut self) {
154 self.in_reasoning = self.config.initial_in_reasoning;
155 self.buffer.clear();
156 self.stripped_think_start = false;
157 }
158
159 fn model_type(&self) -> &str {
160 &self.model_type
161 }
162
163 fn is_in_reasoning(&self) -> bool {
164 self.in_reasoning
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 fn create_test_parser(
173 initial_in_reasoning: bool,
174 stream_reasoning: bool,
175 ) -> BaseReasoningParser {
176 let config = ParserConfig {
177 think_start_token: "<think>".to_string(),
178 think_end_token: "</think>".to_string(),
179 stream_reasoning,
180 max_buffer_size: 65536,
181 initial_in_reasoning,
182 };
183 BaseReasoningParser::new(config)
184 }
185
186 #[test]
187 fn test_detect_and_parse_reasoning() {
188 let mut parser = create_test_parser(false, true);
189 let result = parser
190 .detect_and_parse_reasoning("<think>with reasoning</think> and more text.")
191 .unwrap();
192 assert_eq!(result.normal_text, "and more text.");
193 assert_eq!(result.reasoning_text, "with reasoning");
194 }
195
196 #[test]
197 fn test_detect_and_parse_no_reasoning() {
198 let mut parser = create_test_parser(false, true);
199 let result = parser
200 .detect_and_parse_reasoning("This is a test without reasoning.")
201 .unwrap();
202 assert_eq!(result.normal_text, "This is a test without reasoning.");
203 assert_eq!(result.reasoning_text, "");
204 }
205
206 #[test]
207 fn test_detect_and_parse_truncated_reasoning() {
208 let mut parser = create_test_parser(false, true);
209 let result = parser
210 .detect_and_parse_reasoning("<think>with truncated reasoning")
211 .unwrap();
212 assert_eq!(result.normal_text, "");
213 assert_eq!(result.reasoning_text, "with truncated reasoning");
214 }
215
216 #[test]
217 fn test_parse_streaming_partial_token() {
218 let mut parser = create_test_parser(false, true);
219 let result = parser
220 .parse_reasoning_streaming_incremental("<thi")
221 .unwrap();
222 assert_eq!(result.normal_text, "");
223 assert_eq!(result.reasoning_text, "");
224 }
225
226 #[test]
227 fn test_parse_streaming_complete() {
228 let mut parser = create_test_parser(false, true);
229 let result = parser
230 .parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.")
231 .unwrap();
232 assert_eq!(result.normal_text, " and more text.");
233 assert_eq!(result.reasoning_text, "with reasoning");
234 }
235
236 #[test]
237 fn test_parse_streaming_no_end_token() {
238 let mut parser = create_test_parser(true, true);
239 let result = parser
240 .parse_reasoning_streaming_incremental("<think>with reasoning")
241 .unwrap();
242 assert_eq!(result.normal_text, "");
243 assert_eq!(result.reasoning_text, "with reasoning");
244 }
245
246 #[test]
247 fn test_initial_in_reasoning_true() {
248 let mut parser = create_test_parser(true, true);
250 let result = parser
251 .detect_and_parse_reasoning("no think tags here")
252 .unwrap();
253 assert_eq!(result.normal_text, "");
254 assert_eq!(result.reasoning_text, "no think tags here");
255 }
256
257 #[test]
258 fn test_buffer_loss_bug_fix() {
259 let mut parser = create_test_parser(false, true);
261
262 let result1 = parser.parse_reasoning_streaming_incremental("</").unwrap();
264 assert_eq!(result1.normal_text, "");
265 assert_eq!(result1.reasoning_text, "");
266
267 let result2 = parser
270 .parse_reasoning_streaming_incremental("answer")
271 .unwrap();
272 assert_eq!(result2.normal_text, "</answer");
273 assert_eq!(result2.reasoning_text, "");
274 }
275
276 #[test]
277 fn test_streaming_with_stream_reasoning_enabled() {
278 let mut parser = create_test_parser(false, true);
279
280 let result1 = parser
282 .parse_reasoning_streaming_incremental("<think>reasoning ")
283 .unwrap();
284 assert_eq!(result1.normal_text, "");
285 assert_eq!(result1.reasoning_text, "reasoning ");
286
287 let result2 = parser
289 .parse_reasoning_streaming_incremental("content ")
290 .unwrap();
291 assert_eq!(result2.normal_text, "");
292 assert_eq!(result2.reasoning_text, "content ");
293
294 let result3 = parser
296 .parse_reasoning_streaming_incremental("more</think> normal")
297 .unwrap();
298 assert_eq!(result3.normal_text, " normal");
299 assert_eq!(result3.reasoning_text, "more");
300 }
301
302 #[test]
303 fn test_reset_state() {
304 let mut parser = create_test_parser(false, true);
305
306 parser
308 .parse_reasoning_streaming_incremental("<think>reasoning</think> normal")
309 .unwrap();
310
311 parser.reset();
313 assert!(!parser.in_reasoning);
314 assert!(parser.buffer.is_empty());
315 assert!(!parser.stripped_think_start);
316 }
317
318 #[test]
319 fn test_buffer_overflow_detect_and_parse() {
320 let config = ParserConfig {
321 max_buffer_size: 10, ..Default::default()
323 };
324 let mut parser = BaseReasoningParser::new(config);
325
326 let large_text = "a".repeat(20);
327 let result = parser.detect_and_parse_reasoning(&large_text);
328
329 assert!(result.is_err());
330 match result {
331 Err(ParseError::BufferOverflow(size)) => {
332 assert_eq!(size, 20);
333 }
334 _ => panic!("Expected BufferOverflow error"),
335 }
336 }
337
338 #[test]
339 fn test_buffer_overflow_streaming() {
340 let config = ParserConfig {
341 max_buffer_size: 10, ..Default::default()
343 };
344 let mut parser = BaseReasoningParser::new(config);
345
346 let result1 = parser.parse_reasoning_streaming_incremental("<thi");
348 assert!(result1.is_ok());
349 assert_eq!(result1.unwrap().normal_text, "");
350
351 let result2 = parser.parse_reasoning_streaming_incremental("this_is_too_large");
354 assert!(result2.is_err());
355 match result2 {
356 Err(ParseError::BufferOverflow(size)) => {
357 assert_eq!(size, 21); }
359 _ => panic!("Expected BufferOverflow error"),
360 }
361 }
362}