1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use openai_protocol::common::Tool;
5use regex::Regex;
6use serde_json::Value;
7
8use crate::{
9 errors::{ParserError, ParserResult},
10 parsers::helpers,
11 traits::ToolParser,
12 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13};
14
15pub struct Step3Parser {
25 tool_call_extractor: Regex,
27 invoke_extractor: Regex,
29 param_extractor: Regex,
31
32 buffer: String,
34
35 bot_token: &'static str,
37 eot_token: &'static str,
38 tool_call_begin: &'static str,
39 tool_call_end: &'static str,
40 tool_sep: &'static str,
41
42 in_tool_block: bool,
44 tool_block_finished: bool,
45 current_function_name: String,
46 current_parameters: serde_json::Map<String, Value>,
47 in_tool_call: bool,
48 function_name_sent: bool,
49
50 prev_tool_call_arr: Vec<Value>,
52 current_tool_id: i32,
53 streamed_args_for_tool: Vec<String>,
54}
55
56impl Step3Parser {
57 pub fn new() -> Self {
59 let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>";
61 let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
62
63 let invoke_pattern = r#"(?s)<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>"#;
65 let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
66
67 let param_pattern = r#"(?s)<steptml:parameter name="([^"]+)">(.+?)</steptml:parameter>"#;
69 let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
70
71 Self {
72 tool_call_extractor,
73 invoke_extractor,
74 param_extractor,
75
76 buffer: String::new(),
77
78 bot_token: "<|tool_calls_begin|>",
79 eot_token: "<|tool_calls_end|>",
80 tool_call_begin: "<|tool_call_begin|>",
81 tool_call_end: "<|tool_call_end|>",
82 tool_sep: "<|tool_sep|>",
83
84 in_tool_block: false,
86 tool_block_finished: false,
87 current_function_name: String::new(),
88 current_parameters: serde_json::Map::new(),
89 in_tool_call: false,
90 function_name_sent: false,
91
92 prev_tool_call_arr: Vec::new(),
94 current_tool_id: -1,
95 streamed_args_for_tool: Vec::new(),
96 }
97 }
98
99 fn reset_streaming_state(&mut self) {
101 self.in_tool_call = false;
102 self.function_name_sent = false;
103 self.current_function_name.clear();
104 self.current_parameters.clear();
105 }
106
107 fn parse_partial_tool_call(
109 &mut self,
110 tool_indices: &HashMap<String, usize>,
111 ) -> ParserResult<StreamingParseResult> {
112 let mut calls = Vec::new();
113
114 if !self.buffer.contains(self.tool_sep) {
116 return Ok(StreamingParseResult {
117 normal_text: String::new(),
118 calls,
119 });
120 }
121
122 let buffer_clone = self.buffer.clone();
124 let parts: Vec<&str> = buffer_clone.splitn(2, self.tool_sep).collect();
125 if parts.len() != 2 {
126 return Ok(StreamingParseResult {
127 normal_text: String::new(),
128 calls,
129 });
130 }
131
132 let type_part = parts[0].trim();
133 let invoke_part = parts[1];
134
135 if type_part != "function" {
137 self.reset_streaming_state();
139 return Ok(StreamingParseResult {
140 normal_text: String::new(),
141 calls,
142 });
143 }
144
145 if !self.function_name_sent {
147 if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
148 let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
149
150 if tool_indices.contains_key(func_name) {
152 self.current_function_name = func_name.to_string();
153 self.function_name_sent = true;
154
155 if self.current_tool_id == -1 {
157 self.current_tool_id = 0;
158 }
159
160 helpers::ensure_capacity(
162 self.current_tool_id,
163 &mut self.prev_tool_call_arr,
164 &mut self.streamed_args_for_tool,
165 );
166
167 let tool_id = self.current_tool_id as usize;
169 self.prev_tool_call_arr[tool_id] = serde_json::json!({
170 "name": func_name,
171 "arguments": {},
172 });
173
174 calls.push(ToolCallItem {
176 tool_index: self.current_tool_id as usize,
177 name: Some(func_name.to_string()),
178 parameters: String::new(),
179 });
180 } else {
181 tracing::debug!("Invalid function name: {}", func_name);
183 self.reset_streaming_state();
184 return Ok(StreamingParseResult {
185 normal_text: String::new(),
186 calls,
187 });
188 }
189 } else {
190 return Ok(StreamingParseResult {
192 normal_text: String::new(),
193 calls,
194 });
195 }
196 }
197
198 if self.function_name_sent {
200 let mut new_params = serde_json::Map::new();
202 for capture in self.param_extractor.captures_iter(invoke_part) {
203 let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
204 let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
205
206 let param_value =
208 if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
209 json_val
210 } else {
211 if param_value_str == "true" || param_value_str == "True" {
213 Value::Bool(true)
214 } else if param_value_str == "false" || param_value_str == "False" {
215 Value::Bool(false)
216 } else if param_value_str == "null" || param_value_str == "None" {
217 Value::Null
218 } else if let Ok(num) = param_value_str.parse::<i64>() {
219 Value::Number(num.into())
220 } else if let Ok(num) = param_value_str.parse::<f64>() {
221 if let Some(n) = serde_json::Number::from_f64(num) {
222 Value::Number(n)
223 } else {
224 Value::String(param_value_str.to_string())
225 }
226 } else {
227 Value::String(param_value_str.to_string())
228 }
229 };
230
231 new_params.insert(param_name.to_string(), param_value);
232 }
233
234 if new_params != self.current_parameters {
236 let diff = if self.current_parameters.is_empty() {
238 let params_content =
240 serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
241 if params_content.len() > 2 {
242 params_content[..params_content.len() - 1].to_string()
244 } else {
245 "{".to_string()
246 }
247 } else {
248 let old_json = serde_json::to_string(&self.current_parameters)
250 .unwrap_or_else(|_| "{}".to_string());
251 let new_json =
252 serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
253
254 let old_without_brace = &old_json[..old_json.len() - 1];
256 let new_without_brace = &new_json[..new_json.len() - 1];
257
258 new_without_brace
260 .strip_prefix(old_without_brace)
261 .map(|s| s.to_string())
262 .unwrap_or_default()
263 };
264
265 if !diff.is_empty() {
266 calls.push(ToolCallItem {
267 tool_index: self.current_tool_id as usize,
268 name: None,
269 parameters: diff.clone(),
270 });
271 let tool_id = self.current_tool_id as usize;
272 if tool_id < self.streamed_args_for_tool.len() {
273 self.streamed_args_for_tool[tool_id].push_str(&diff);
274 }
275 }
276
277 self.current_parameters = new_params.clone();
279 let tool_id = self.current_tool_id as usize;
280 if tool_id < self.prev_tool_call_arr.len() {
281 if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
282 obj.insert("arguments".to_string(), Value::Object(new_params));
283 }
284 }
285 }
286
287 if self.buffer.contains(self.tool_call_end) {
289 let tool_id = self.current_tool_id as usize;
291 if tool_id < self.streamed_args_for_tool.len()
292 && !self.streamed_args_for_tool[tool_id].is_empty()
293 {
294 calls.push(ToolCallItem {
295 tool_index: self.current_tool_id as usize,
296 name: None,
297 parameters: "}".to_string(),
298 });
299 self.streamed_args_for_tool[tool_id].push('}');
300 }
301
302 if let Some(end_idx) = self.buffer.find(self.tool_call_end) {
304 self.buffer = self.buffer[end_idx + self.tool_call_end.len()..].to_string();
306 }
307
308 self.reset_streaming_state();
310 self.current_tool_id += 1;
311 }
312 }
313
314 Ok(StreamingParseResult {
315 normal_text: String::new(),
316 calls,
317 })
318 }
319
320 fn parse_steptml_parameters(
322 &self,
323 params_text: &str,
324 ) -> ParserResult<serde_json::Map<String, Value>> {
325 let mut parameters = serde_json::Map::new();
326
327 for capture in self.param_extractor.captures_iter(params_text) {
328 let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
329 let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
330
331 let param_value = if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
333 json_val
334 } else {
335 if param_value_str == "true" || param_value_str == "True" {
337 Value::Bool(true)
338 } else if param_value_str == "false" || param_value_str == "False" {
339 Value::Bool(false)
340 } else if param_value_str == "null" || param_value_str == "None" {
341 Value::Null
342 } else if let Ok(num) = param_value_str.parse::<i64>() {
343 Value::Number(num.into())
344 } else if let Ok(num) = param_value_str.parse::<f64>() {
345 if let Some(n) = serde_json::Number::from_f64(num) {
346 Value::Number(n)
347 } else {
348 Value::String(param_value_str.to_string())
349 }
350 } else {
351 Value::String(param_value_str.to_string())
352 }
353 };
354
355 parameters.insert(param_name.to_string(), param_value);
356 }
357
358 Ok(parameters)
359 }
360
361 fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
363 if !block.contains("function") || !block.contains("<|tool_sep|>") {
365 return Ok(None);
366 }
367
368 let parts: Vec<&str> = block.split("<|tool_sep|>").collect();
370 if parts.len() != 2 {
371 return Ok(None);
372 }
373
374 if !parts[0].contains("function") {
376 return Ok(None);
377 }
378
379 let invoke_part = parts[1];
380
381 if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
383 let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
384
385 if func_name.is_empty() {
387 return Ok(None);
388 }
389
390 let params_text = captures.get(2).map_or("", |m| m.as_str());
391
392 let parameters = self.parse_steptml_parameters(params_text)?;
394
395 let arguments_str = serde_json::to_string(¶meters)
396 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
397
398 Ok(Some(ToolCall {
399 function: FunctionCall {
400 name: func_name.to_string(),
401 arguments: arguments_str,
402 },
403 }))
404 } else {
405 Ok(None)
406 }
407 }
408}
409
410impl Default for Step3Parser {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[async_trait]
417impl ToolParser for Step3Parser {
418 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
419 if !self.has_tool_markers(text) {
420 return Ok((text.to_string(), vec![]));
421 }
422
423 let idx = text.find("<|tool_calls_begin|>").unwrap();
425 let normal_text = text[..idx].to_string();
426
427 let mut tools = Vec::new();
429 for mat in self.tool_call_extractor.find_iter(text) {
430 match self.parse_tool_call(mat.as_str()) {
431 Ok(Some(tool)) => tools.push(tool),
432 Ok(None) => continue,
433 Err(e) => {
434 tracing::debug!("Failed to parse tool call: {}", e);
435 continue;
436 }
437 }
438 }
439
440 if tools.is_empty() {
442 return Ok((text.to_string(), vec![]));
443 }
444
445 Ok((normal_text, tools))
446 }
447
448 async fn parse_incremental(
449 &mut self,
450 chunk: &str,
451 tools: &[Tool],
452 ) -> ParserResult<StreamingParseResult> {
453 self.buffer.push_str(chunk);
454
455 let tool_indices = helpers::get_tool_indices(tools);
457
458 if self.tool_block_finished {
460 let normal_text = std::mem::take(&mut self.buffer);
461 return Ok(StreamingParseResult {
462 normal_text,
463 calls: vec![],
464 });
465 }
466
467 if !self.in_tool_block {
469 if self.buffer.contains(self.bot_token) {
470 let idx = self.buffer.find(self.bot_token).unwrap();
471 let normal_text = self.buffer[..idx].to_string();
472 self.buffer = self.buffer[idx + self.bot_token.len()..].to_string();
473 self.in_tool_block = true;
474 return Ok(StreamingParseResult {
475 normal_text,
476 calls: vec![],
477 });
478 } else {
479 if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_some() {
481 return Ok(StreamingParseResult::default()); } else {
483 let normal_text = std::mem::take(&mut self.buffer);
484 return Ok(StreamingParseResult {
485 normal_text,
486 calls: vec![],
487 });
488 }
489 }
490 }
491
492 let mut calls = Vec::new();
494
495 if self.buffer.contains(self.eot_token) {
497 let idx = self.buffer.find(self.eot_token).unwrap();
498
499 if self.in_tool_call {
501 let before_eot = &self.buffer[..idx];
503 if before_eot.contains(self.tool_call_end) {
504 let result = self.parse_partial_tool_call(&tool_indices)?;
506 calls.extend(result.calls);
507 } else {
508 tracing::warn!("Tool block ended with incomplete tool call");
510 }
511 }
512
513 let remaining = self.buffer[idx + self.eot_token.len()..].to_string();
514 self.buffer.clear();
515 self.tool_block_finished = true;
516
517 self.reset_streaming_state();
519
520 return Ok(StreamingParseResult {
521 normal_text: remaining,
522 calls,
523 });
524 }
525
526 if !self.in_tool_call {
528 if self.buffer.contains(self.tool_call_begin) {
529 let idx = self.buffer.find(self.tool_call_begin).unwrap();
530 self.buffer = self.buffer[idx + self.tool_call_begin.len()..].to_string();
532 self.in_tool_call = true;
533 self.function_name_sent = false;
534 self.current_function_name.clear();
535 self.current_parameters.clear();
536 } else {
538 return Ok(StreamingParseResult::default());
540 }
541 }
542
543 if self.in_tool_call {
545 return self.parse_partial_tool_call(&tool_indices);
546 }
547
548 Ok(StreamingParseResult::default())
549 }
550
551 fn has_tool_markers(&self, text: &str) -> bool {
552 text.contains(self.bot_token)
553 }
554
555 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
556 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
557 }
558
559 fn reset(&mut self) {
560 self.buffer.clear();
562 self.prev_tool_call_arr.clear();
563 self.current_tool_id = -1;
564 self.streamed_args_for_tool.clear();
565
566 self.in_tool_block = false;
568 self.tool_block_finished = false;
569 self.current_function_name.clear();
570 self.current_parameters.clear();
571 self.in_tool_call = false;
572 self.function_name_sent = false;
573 }
574}