Skip to main content

tool_parser/parsers/
deepseek.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// DeepSeek V3 format parser for tool calls
14///
15/// Handles the DeepSeek V3 specific format that uses Unicode tokens:
16/// `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|><|tool▁calls▁end|>`
17///
18/// Features:
19/// - Unicode token delimiters
20/// - JSON arguments in code blocks
21/// - Support for multiple sequential tool calls
22///
23/// Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
24pub struct DeepSeekParser {
25    /// Regex for extracting complete tool calls
26    tool_call_extractor: Regex,
27    /// Regex for extracting function details
28    func_detail_extractor: Regex,
29    /// Regex for matching partial tool calls during streaming
30    partial_tool_call_regex: Regex,
31    /// Regex pattern for removing completed tool calls from buffer
32    tool_call_end_pattern: Regex,
33
34    /// Buffer for accumulating incomplete patterns across chunks
35    buffer: String,
36
37    /// Stores complete tool call info (name and arguments) for each tool being parsed
38    prev_tool_call_arr: Vec<Value>,
39
40    /// Index of currently streaming tool call (-1 means no active tool)
41    current_tool_id: i32,
42
43    /// Flag for whether current tool's name has been sent to client
44    current_tool_name_sent: bool,
45
46    /// Tracks raw JSON string content streamed to client for each tool's arguments
47    streamed_args_for_tool: Vec<String>,
48}
49
50impl DeepSeekParser {
51    /// Create a new DeepSeek parser
52    #[expect(
53        clippy::expect_used,
54        reason = "regex patterns are compile-time string literals"
55    )]
56    pub fn new() -> Self {
57        // Use (?s) flag for DOTALL mode to handle newlines
58        let tool_call_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
59        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
60
61        let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
62        let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
63
64        // Partial pattern for streaming - uses .* (greedy) not .*? to match all partial content
65        let partial_pattern = r"(?s)<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)";
66        let partial_tool_call_regex = Regex::new(partial_pattern).expect("Valid regex pattern");
67
68        // Pattern for removing completed tool calls
69        let end_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
70        let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
71
72        Self {
73            tool_call_extractor,
74            func_detail_extractor,
75            partial_tool_call_regex,
76            tool_call_end_pattern,
77            buffer: String::new(),
78            prev_tool_call_arr: Vec::new(),
79            current_tool_id: -1,
80            current_tool_name_sent: false,
81            streamed_args_for_tool: Vec::new(),
82        }
83    }
84
85    /// Parse a single tool call block - throws error if parsing fails
86    fn parse_tool_call(&self, block: &str) -> ParserResult<ToolCall> {
87        let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
88            ParserError::ParsingFailed("Failed to match tool call pattern".to_string())
89        })?;
90
91        // Get function type (should be "function")
92        let func_type = captures.get(1).map_or("", |m| m.as_str());
93        if func_type != "function" {
94            return Err(ParserError::ParsingFailed(format!(
95                "Invalid function type: {func_type}"
96            )));
97        }
98
99        // Get function name
100        let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
101        if func_name.is_empty() {
102            return Err(ParserError::ParsingFailed(
103                "Empty function name".to_string(),
104            ));
105        }
106
107        // Get JSON arguments
108        let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
109
110        // Parse JSON arguments
111        let value = serde_json::from_str::<Value>(json_args)
112            .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {e}")))?;
113
114        // Create arguments object
115        let args = if value.is_object() {
116            value
117        } else {
118            // If not an object, wrap it
119            serde_json::json!({ "value": value })
120        };
121
122        let arguments =
123            serde_json::to_string(&args).map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
124
125        Ok(ToolCall {
126            function: FunctionCall {
127                name: func_name.to_string(),
128                arguments,
129            },
130        })
131    }
132}
133
134impl Default for DeepSeekParser {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140#[async_trait]
141impl ToolParser for DeepSeekParser {
142    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
143        if !self.has_tool_markers(text) {
144            return Ok((text.to_string(), vec![]));
145        }
146
147        // Find where tool calls begin
148        // Safe: has_tool_markers() already confirmed the marker exists
149        let idx = text
150            .find("<|tool▁calls▁begin|>")
151            .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
152        let normal_text = text[..idx].to_string();
153
154        // Try to extract tool calls, log warnings for failures
155        let mut tools = Vec::new();
156        for mat in self.tool_call_extractor.find_iter(text) {
157            match self.parse_tool_call(mat.as_str()) {
158                Ok(tool) => tools.push(tool),
159                Err(e) => {
160                    tracing::debug!("Failed to parse tool call: {}", e);
161                    continue;
162                }
163            }
164        }
165
166        // If no tools were successfully parsed despite having markers, return entire text as fallback
167        if tools.is_empty() {
168            return Ok((text.to_string(), vec![]));
169        }
170
171        Ok((normal_text, tools))
172    }
173
174    async fn parse_incremental(
175        &mut self,
176        chunk: &str,
177        tools: &[Tool],
178    ) -> ParserResult<StreamingParseResult> {
179        self.buffer.push_str(chunk);
180        let current_text = &self.buffer.clone();
181
182        // Check if we have a tool call (either the start token or individual tool call)
183        let has_tool_call =
184            self.has_tool_markers(current_text) || current_text.contains("<|tool▁call▁begin|>");
185
186        if !has_tool_call {
187            // No tool markers detected - return all buffered content as normal text
188            // Strip out end tokens if present
189            let mut normal_text = std::mem::take(&mut self.buffer);
190            for e_token in ["<|tool▁calls▁end|>", "```", "<|tool▁call▁end|>"] {
191                normal_text = normal_text.replace(e_token, "");
192            }
193            return Ok(StreamingParseResult {
194                normal_text,
195                calls: vec![],
196            });
197        }
198
199        // Build tool indices for validation
200        let tool_indices = helpers::get_tool_indices(tools);
201
202        let mut calls: Vec<ToolCallItem> = Vec::new();
203
204        // Try to match the partial tool call pattern
205        if let Some(captures) = self.partial_tool_call_regex.captures(current_text) {
206            let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
207            let func_args_raw = captures.get(3).map_or("", |m| m.as_str()).trim();
208
209            // Validate tool name
210            if !tool_indices.contains_key(func_name) {
211                // Invalid tool name - skip this tool, preserve indexing for next tool
212                tracing::debug!("Invalid tool name '{}' - skipping", func_name);
213                helpers::reset_current_tool_state(
214                    &mut self.buffer,
215                    &mut self.current_tool_name_sent,
216                    &mut self.streamed_args_for_tool,
217                    &self.prev_tool_call_arr,
218                );
219                return Ok(StreamingParseResult::default());
220            }
221
222            // Initialize state if this is the first tool call
223            if self.current_tool_id == -1 {
224                self.current_tool_id = 0;
225                self.prev_tool_call_arr = Vec::new();
226                self.streamed_args_for_tool = vec![String::new()];
227            }
228
229            // Ensure we have enough entries in our tracking arrays
230            helpers::ensure_capacity(
231                self.current_tool_id,
232                &mut self.prev_tool_call_arr,
233                &mut self.streamed_args_for_tool,
234            );
235
236            // Send tool name if not sent yet
237            if self.current_tool_name_sent {
238                // Compute incremental diff
239                let tool_id = self.current_tool_id as usize;
240                let last_sent = self
241                    .streamed_args_for_tool
242                    .get(tool_id)
243                    .map(|s| s.as_str())
244                    .unwrap_or("");
245
246                let argument_diff = func_args_raw
247                    .strip_prefix(last_sent)
248                    .unwrap_or(func_args_raw);
249
250                if !argument_diff.is_empty() {
251                    calls.push(ToolCallItem {
252                        tool_index: tool_id,
253                        name: None,
254                        parameters: argument_diff.to_string(),
255                    });
256                    if tool_id < self.streamed_args_for_tool.len() {
257                        self.streamed_args_for_tool[tool_id].push_str(argument_diff);
258                    }
259                }
260
261                // Check if JSON is complete
262                if helpers::is_complete_json(func_args_raw) {
263                    // Update the stored arguments
264                    if let Ok(parsed_args) = serde_json::from_str::<Value>(func_args_raw) {
265                        let tool_id = self.current_tool_id as usize;
266                        if tool_id < self.prev_tool_call_arr.len() {
267                            if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
268                                obj.insert("arguments".to_string(), parsed_args);
269                            }
270                        }
271                    }
272
273                    // Find the end of the current tool call and remove only that part from buffer
274                    if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
275                        // Remove the completed tool call from buffer, keep any remaining content
276                        self.buffer = current_text[mat.end()..].to_string();
277                    } else {
278                        self.buffer.clear();
279                    }
280
281                    let result = StreamingParseResult {
282                        normal_text: String::new(),
283                        calls,
284                    };
285
286                    self.current_tool_id += 1;
287                    self.current_tool_name_sent = false;
288                    return Ok(result);
289                }
290            } else {
291                calls.push(ToolCallItem {
292                    tool_index: self.current_tool_id as usize,
293                    name: Some(func_name.to_string()),
294                    parameters: String::new(),
295                });
296                self.current_tool_name_sent = true;
297
298                // Store the tool call info for serving layer completions endpoint
299                let tool_id = self.current_tool_id as usize;
300                if self.prev_tool_call_arr.len() <= tool_id {
301                    self.prev_tool_call_arr
302                        .resize_with(tool_id + 1, || Value::Null);
303                }
304                self.prev_tool_call_arr[tool_id] = serde_json::json!({
305                    "name": func_name,
306                    "arguments": {},
307                });
308            }
309        }
310
311        Ok(StreamingParseResult {
312            normal_text: String::new(),
313            calls,
314        })
315    }
316
317    fn has_tool_markers(&self, text: &str) -> bool {
318        text.contains("<|tool▁calls▁begin|>")
319    }
320
321    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
322        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
323    }
324
325    fn reset(&mut self) {
326        self.buffer.clear();
327        self.prev_tool_call_arr.clear();
328        self.current_tool_id = -1;
329        self.current_tool_name_sent = false;
330        self.streamed_args_for_tool.clear();
331    }
332}