tool_parser/parsers/
step3.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use openai_protocol::common::Tool;
5use regex::Regex;
6use serde_json::Value;
7
8use crate::{
9    errors::{ParserError, ParserResult},
10    parsers::helpers,
11    traits::ToolParser,
12    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13};
14
15/// Step3 format parser for tool calls
16///
17/// Handles the Step3 specific format with steptml XML:
18/// `<|tool_calls_begin|><|tool_call_begin|>function<|tool_sep|><steptml:invoke name="{name}"><steptml:parameter name="{k}">{v}</steptml:parameter></steptml:invoke><|tool_call_end|><|tool_calls_end|>`
19///
20/// Features:
21/// - Unicode token delimiters
22/// - StepTML XML format for invocations
23/// - Support for multiple sequential tool calls
24pub struct Step3Parser {
25    /// Regex for extracting tool call blocks
26    tool_call_extractor: Regex,
27    /// Regex for extracting steptml invocations
28    invoke_extractor: Regex,
29    /// Regex for extracting parameters
30    param_extractor: Regex,
31
32    /// Buffer for accumulating chunks
33    buffer: String,
34
35    /// Token configuration
36    bot_token: &'static str,
37    eot_token: &'static str,
38    tool_call_begin: &'static str,
39    tool_call_end: &'static str,
40    tool_sep: &'static str,
41
42    /// Streaming state variables (mirrors Python's Step3Detector)
43    in_tool_block: bool,
44    tool_block_finished: bool,
45    current_function_name: String,
46    current_parameters: serde_json::Map<String, Value>,
47    in_tool_call: bool,
48    function_name_sent: bool,
49
50    /// Standard state machine fields
51    prev_tool_call_arr: Vec<Value>,
52    current_tool_id: i32,
53    streamed_args_for_tool: Vec<String>,
54}
55
56impl Step3Parser {
57    /// Create a new Step3 parser
58    pub fn new() -> Self {
59        // Pattern for individual tool calls
60        let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>";
61        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
62
63        // Pattern for steptml invocations
64        let invoke_pattern = r#"(?s)<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>"#;
65        let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
66
67        // Pattern for steptml parameters - using non-greedy match for values to handle < characters
68        let param_pattern = r#"(?s)<steptml:parameter name="([^"]+)">(.+?)</steptml:parameter>"#;
69        let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
70
71        Self {
72            tool_call_extractor,
73            invoke_extractor,
74            param_extractor,
75
76            buffer: String::new(),
77
78            bot_token: "<|tool_calls_begin|>",
79            eot_token: "<|tool_calls_end|>",
80            tool_call_begin: "<|tool_call_begin|>",
81            tool_call_end: "<|tool_call_end|>",
82            tool_sep: "<|tool_sep|>",
83
84            // Streaming state variables
85            in_tool_block: false,
86            tool_block_finished: false,
87            current_function_name: String::new(),
88            current_parameters: serde_json::Map::new(),
89            in_tool_call: false,
90            function_name_sent: false,
91
92            // Standard state machine fields
93            prev_tool_call_arr: Vec::new(),
94            current_tool_id: -1,
95            streamed_args_for_tool: Vec::new(),
96        }
97    }
98
99    /// Reset streaming state for the next tool call
100    fn reset_streaming_state(&mut self) {
101        self.in_tool_call = false;
102        self.function_name_sent = false;
103        self.current_function_name.clear();
104        self.current_parameters.clear();
105    }
106
107    /// Parse partial tool call for streaming scenarios (mirrors Python's _parse_partial_tool_call)
108    fn parse_partial_tool_call(
109        &mut self,
110        tool_indices: &HashMap<String, usize>,
111    ) -> ParserResult<StreamingParseResult> {
112        let mut calls = Vec::new();
113
114        // Check if we have tool_sep (means we're past the type declaration)
115        if !self.buffer.contains(self.tool_sep) {
116            return Ok(StreamingParseResult {
117                normal_text: String::new(),
118                calls,
119            });
120        }
121
122        // Clone the buffer to avoid borrow conflicts
123        let buffer_clone = self.buffer.clone();
124        let parts: Vec<&str> = buffer_clone.splitn(2, self.tool_sep).collect();
125        if parts.len() != 2 {
126            return Ok(StreamingParseResult {
127                normal_text: String::new(),
128                calls,
129            });
130        }
131
132        let type_part = parts[0].trim();
133        let invoke_part = parts[1];
134
135        // Check if it's a function type
136        if type_part != "function" {
137            // Invalid tool type, skip this tool call
138            self.reset_streaming_state();
139            return Ok(StreamingParseResult {
140                normal_text: String::new(),
141                calls,
142            });
143        }
144
145        // Try to extract function name if not sent yet
146        if !self.function_name_sent {
147            if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
148                let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
149
150                // Validate function name
151                if tool_indices.contains_key(func_name) {
152                    self.current_function_name = func_name.to_string();
153                    self.function_name_sent = true;
154
155                    // Initialize tool tracking
156                    if self.current_tool_id == -1 {
157                        self.current_tool_id = 0;
158                    }
159
160                    // Ensure tracking arrays are large enough
161                    helpers::ensure_capacity(
162                        self.current_tool_id,
163                        &mut self.prev_tool_call_arr,
164                        &mut self.streamed_args_for_tool,
165                    );
166
167                    // Store tool call info
168                    let tool_id = self.current_tool_id as usize;
169                    self.prev_tool_call_arr[tool_id] = serde_json::json!({
170                        "name": func_name,
171                        "arguments": {},
172                    });
173
174                    // Send tool name with empty parameters
175                    calls.push(ToolCallItem {
176                        tool_index: self.current_tool_id as usize,
177                        name: Some(func_name.to_string()),
178                        parameters: String::new(),
179                    });
180                } else {
181                    // Invalid function name
182                    tracing::debug!("Invalid function name: {}", func_name);
183                    self.reset_streaming_state();
184                    return Ok(StreamingParseResult {
185                        normal_text: String::new(),
186                        calls,
187                    });
188                }
189            } else {
190                // Function name not complete yet
191                return Ok(StreamingParseResult {
192                    normal_text: String::new(),
193                    calls,
194                });
195            }
196        }
197
198        // Parse parameters incrementally
199        if self.function_name_sent {
200            // Extract all complete parameters
201            let mut new_params = serde_json::Map::new();
202            for capture in self.param_extractor.captures_iter(invoke_part) {
203                let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
204                let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
205
206                // Try to parse the value as JSON first, fallback to string
207                let param_value =
208                    if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
209                        json_val
210                    } else {
211                        // Try parsing as Python literal
212                        if param_value_str == "true" || param_value_str == "True" {
213                            Value::Bool(true)
214                        } else if param_value_str == "false" || param_value_str == "False" {
215                            Value::Bool(false)
216                        } else if param_value_str == "null" || param_value_str == "None" {
217                            Value::Null
218                        } else if let Ok(num) = param_value_str.parse::<i64>() {
219                            Value::Number(num.into())
220                        } else if let Ok(num) = param_value_str.parse::<f64>() {
221                            if let Some(n) = serde_json::Number::from_f64(num) {
222                                Value::Number(n)
223                            } else {
224                                Value::String(param_value_str.to_string())
225                            }
226                        } else {
227                            Value::String(param_value_str.to_string())
228                        }
229                    };
230
231                new_params.insert(param_name.to_string(), param_value);
232            }
233
234            // Check if we have new parameters to stream
235            if new_params != self.current_parameters {
236                // Build the JSON content without the closing brace for streaming
237                let diff = if self.current_parameters.is_empty() {
238                    // First parameters - send opening brace and content
239                    let params_content =
240                        serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
241                    if params_content.len() > 2 {
242                        // Send everything except the closing brace
243                        params_content[..params_content.len() - 1].to_string()
244                    } else {
245                        "{".to_string()
246                    }
247                } else {
248                    // Subsequent parameters - calculate the incremental diff
249                    let old_json = serde_json::to_string(&self.current_parameters)
250                        .unwrap_or_else(|_| "{}".to_string());
251                    let new_json =
252                        serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
253
254                    // Remove closing braces for comparison
255                    let old_without_brace = &old_json[..old_json.len() - 1];
256                    let new_without_brace = &new_json[..new_json.len() - 1];
257
258                    // The new content should extend the old content
259                    new_without_brace
260                        .strip_prefix(old_without_brace)
261                        .map(|s| s.to_string())
262                        .unwrap_or_default()
263                };
264
265                if !diff.is_empty() {
266                    calls.push(ToolCallItem {
267                        tool_index: self.current_tool_id as usize,
268                        name: None,
269                        parameters: diff.clone(),
270                    });
271                    let tool_id = self.current_tool_id as usize;
272                    if tool_id < self.streamed_args_for_tool.len() {
273                        self.streamed_args_for_tool[tool_id].push_str(&diff);
274                    }
275                }
276
277                // Update current state
278                self.current_parameters = new_params.clone();
279                let tool_id = self.current_tool_id as usize;
280                if tool_id < self.prev_tool_call_arr.len() {
281                    if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
282                        obj.insert("arguments".to_string(), Value::Object(new_params));
283                    }
284                }
285            }
286
287            // Check if tool call is complete
288            if self.buffer.contains(self.tool_call_end) {
289                // Send closing brace if we've sent any parameters
290                let tool_id = self.current_tool_id as usize;
291                if tool_id < self.streamed_args_for_tool.len()
292                    && !self.streamed_args_for_tool[tool_id].is_empty()
293                {
294                    calls.push(ToolCallItem {
295                        tool_index: self.current_tool_id as usize,
296                        name: None,
297                        parameters: "}".to_string(),
298                    });
299                    self.streamed_args_for_tool[tool_id].push('}');
300                }
301
302                // Find the end position
303                if let Some(end_idx) = self.buffer.find(self.tool_call_end) {
304                    // Remove the processed tool call from buffer
305                    self.buffer = self.buffer[end_idx + self.tool_call_end.len()..].to_string();
306                }
307
308                // Reset state for next tool call
309                self.reset_streaming_state();
310                self.current_tool_id += 1;
311            }
312        }
313
314        Ok(StreamingParseResult {
315            normal_text: String::new(),
316            calls,
317        })
318    }
319
320    /// Parse parameters from steptml format
321    fn parse_steptml_parameters(
322        &self,
323        params_text: &str,
324    ) -> ParserResult<serde_json::Map<String, Value>> {
325        let mut parameters = serde_json::Map::new();
326
327        for capture in self.param_extractor.captures_iter(params_text) {
328            let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
329            let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
330
331            // Try to parse the value as JSON first, fallback to string
332            let param_value = if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
333                json_val
334            } else {
335                // Try parsing as Python literal
336                if param_value_str == "true" || param_value_str == "True" {
337                    Value::Bool(true)
338                } else if param_value_str == "false" || param_value_str == "False" {
339                    Value::Bool(false)
340                } else if param_value_str == "null" || param_value_str == "None" {
341                    Value::Null
342                } else if let Ok(num) = param_value_str.parse::<i64>() {
343                    Value::Number(num.into())
344                } else if let Ok(num) = param_value_str.parse::<f64>() {
345                    if let Some(n) = serde_json::Number::from_f64(num) {
346                        Value::Number(n)
347                    } else {
348                        Value::String(param_value_str.to_string())
349                    }
350                } else {
351                    Value::String(param_value_str.to_string())
352                }
353            };
354
355            parameters.insert(param_name.to_string(), param_value);
356        }
357
358        Ok(parameters)
359    }
360
361    /// Parse a single tool call block
362    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
363        // Check if it contains function marker and tool separator
364        if !block.contains("function") || !block.contains("<|tool_sep|>") {
365            return Ok(None);
366        }
367
368        // Split by tool separator
369        let parts: Vec<&str> = block.split("<|tool_sep|>").collect();
370        if parts.len() != 2 {
371            return Ok(None);
372        }
373
374        // Check if it's a function type
375        if !parts[0].contains("function") {
376            return Ok(None);
377        }
378
379        let invoke_part = parts[1];
380
381        // Extract steptml invoke
382        if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
383            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
384
385            // Validate function name is not empty
386            if func_name.is_empty() {
387                return Ok(None);
388            }
389
390            let params_text = captures.get(2).map_or("", |m| m.as_str());
391
392            // Parse parameters
393            let parameters = self.parse_steptml_parameters(params_text)?;
394
395            let arguments_str = serde_json::to_string(&parameters)
396                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
397
398            Ok(Some(ToolCall {
399                function: FunctionCall {
400                    name: func_name.to_string(),
401                    arguments: arguments_str,
402                },
403            }))
404        } else {
405            Ok(None)
406        }
407    }
408}
409
410impl Default for Step3Parser {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416#[async_trait]
417impl ToolParser for Step3Parser {
418    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
419        if !self.has_tool_markers(text) {
420            return Ok((text.to_string(), vec![]));
421        }
422
423        // Find where tool calls begin
424        let idx = text.find("<|tool_calls_begin|>").unwrap();
425        let normal_text = text[..idx].to_string();
426
427        // Extract tool calls
428        let mut tools = Vec::new();
429        for mat in self.tool_call_extractor.find_iter(text) {
430            match self.parse_tool_call(mat.as_str()) {
431                Ok(Some(tool)) => tools.push(tool),
432                Ok(None) => continue,
433                Err(e) => {
434                    tracing::debug!("Failed to parse tool call: {}", e);
435                    continue;
436                }
437            }
438        }
439
440        // If no tools were successfully parsed despite having markers, return entire text as fallback
441        if tools.is_empty() {
442            return Ok((text.to_string(), vec![]));
443        }
444
445        Ok((normal_text, tools))
446    }
447
448    async fn parse_incremental(
449        &mut self,
450        chunk: &str,
451        tools: &[Tool],
452    ) -> ParserResult<StreamingParseResult> {
453        self.buffer.push_str(chunk);
454
455        // Build tool indices for validation
456        let tool_indices = helpers::get_tool_indices(tools);
457
458        // Stage 1: If we've finished the tool block, everything is normal text
459        if self.tool_block_finished {
460            let normal_text = std::mem::take(&mut self.buffer);
461            return Ok(StreamingParseResult {
462                normal_text,
463                calls: vec![],
464            });
465        }
466
467        // Stage 2: Check if tool block hasn't started yet
468        if !self.in_tool_block {
469            if self.buffer.contains(self.bot_token) {
470                let idx = self.buffer.find(self.bot_token).unwrap();
471                let normal_text = self.buffer[..idx].to_string();
472                self.buffer = self.buffer[idx + self.bot_token.len()..].to_string();
473                self.in_tool_block = true;
474                return Ok(StreamingParseResult {
475                    normal_text,
476                    calls: vec![],
477                });
478            } else {
479                // Check if we might have a partial bot_token
480                if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_some() {
481                    return Ok(StreamingParseResult::default()); // Wait for more text
482                } else {
483                    let normal_text = std::mem::take(&mut self.buffer);
484                    return Ok(StreamingParseResult {
485                        normal_text,
486                        calls: vec![],
487                    });
488                }
489            }
490        }
491
492        // We're inside the tool block
493        let mut calls = Vec::new();
494
495        // Stage 3: Check if tool block is ending
496        if self.buffer.contains(self.eot_token) {
497            let idx = self.buffer.find(self.eot_token).unwrap();
498
499            // If we're in the middle of a tool call, we need to handle it
500            if self.in_tool_call {
501                // The buffer before eot_token might contain the end of the current tool call
502                let before_eot = &self.buffer[..idx];
503                if before_eot.contains(self.tool_call_end) {
504                    // Parse this final tool call
505                    let result = self.parse_partial_tool_call(&tool_indices)?;
506                    calls.extend(result.calls);
507                } else {
508                    // Incomplete tool call - log warning
509                    tracing::warn!("Tool block ended with incomplete tool call");
510                }
511            }
512
513            let remaining = self.buffer[idx + self.eot_token.len()..].to_string();
514            self.buffer.clear();
515            self.tool_block_finished = true;
516
517            // Reset any partial tool call state
518            self.reset_streaming_state();
519
520            return Ok(StreamingParseResult {
521                normal_text: remaining,
522                calls,
523            });
524        }
525
526        // Stage 4: Check if we're in a tool call or need to start one
527        if !self.in_tool_call {
528            if self.buffer.contains(self.tool_call_begin) {
529                let idx = self.buffer.find(self.tool_call_begin).unwrap();
530                // Remove any content before tool call begin (shouldn't happen but be safe)
531                self.buffer = self.buffer[idx + self.tool_call_begin.len()..].to_string();
532                self.in_tool_call = true;
533                self.function_name_sent = false;
534                self.current_function_name.clear();
535                self.current_parameters.clear();
536                // Fall through to parse the partial tool call
537            } else {
538                // Wait for tool call to begin
539                return Ok(StreamingParseResult::default());
540            }
541        }
542
543        // Stage 5: Parse partial tool call
544        if self.in_tool_call {
545            return self.parse_partial_tool_call(&tool_indices);
546        }
547
548        Ok(StreamingParseResult::default())
549    }
550
551    fn has_tool_markers(&self, text: &str) -> bool {
552        text.contains(self.bot_token)
553    }
554
555    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
556        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
557    }
558
559    fn reset(&mut self) {
560        // Reset standard state
561        self.buffer.clear();
562        self.prev_tool_call_arr.clear();
563        self.current_tool_id = -1;
564        self.streamed_args_for_tool.clear();
565
566        // Reset Step3-specific fields
567        self.in_tool_block = false;
568        self.tool_block_finished = false;
569        self.current_function_name.clear();
570        self.current_parameters.clear();
571        self.in_tool_call = false;
572        self.function_name_sent = false;
573    }
574}