tool_parser/parsers/
mistral.rs1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use serde_json::Value;
4
5use crate::{
6 errors::{ParserError, ParserResult},
7 parsers::helpers,
8 partial_json::PartialJson,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall},
11};
12
13pub struct MistralParser {
20 partial_json: PartialJson,
22
23 buffer: String,
25
26 prev_tool_call_arr: Vec<Value>,
28
29 current_tool_id: i32,
31
32 current_tool_name_sent: bool,
34
35 streamed_args_for_tool: Vec<String>,
37
38 bot_token: &'static str,
40 eot_token: &'static str,
41 tool_call_separator: &'static str,
42
43 array_closed: bool,
45}
46
47impl MistralParser {
48 pub fn new() -> Self {
50 Self {
51 partial_json: PartialJson::default(),
52 buffer: String::new(),
53 prev_tool_call_arr: Vec::new(),
54 current_tool_id: -1,
55 current_tool_name_sent: false,
56 streamed_args_for_tool: Vec::new(),
57 bot_token: "[TOOL_CALLS] [",
58 eot_token: "]",
59 tool_call_separator: ", ",
60 array_closed: false,
61 }
62 }
63
64 fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
65 const BOT_TOKEN: &str = "[TOOL_CALLS] [";
66
67 let start_idx = text.find(BOT_TOKEN)?;
69
70 let json_start = start_idx + BOT_TOKEN.len() - 1;
73
74 let mut bracket_count = 0;
75 let mut in_string = false;
76 let mut escape_next = false;
77
78 let bytes = text.as_bytes();
79
80 for i in json_start..text.len() {
81 let char = bytes[i];
82
83 if escape_next {
84 escape_next = false;
85 continue;
86 }
87
88 if char == b'\\' {
89 escape_next = true;
90 continue;
91 }
92
93 if char == b'"' && !escape_next {
94 in_string = !in_string;
95 continue;
96 }
97
98 if !in_string {
99 if char == b'[' {
100 bracket_count += 1;
101 } else if char == b']' {
102 bracket_count -= 1;
103 if bracket_count == 0 {
104 return Some((start_idx, &text[json_start..=i]));
106 }
107 }
108 }
109 }
110
111 None
113 }
114
115 fn parse_json_array(&self, json_str: &str) -> ParserResult<Vec<ToolCall>> {
117 let value: Value = serde_json::from_str(json_str)
118 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
119
120 let mut tools = Vec::new();
121
122 if let Value::Array(arr) = value {
123 for item in arr.iter() {
124 if let Some(tool) = self.parse_single_object(item)? {
125 tools.push(tool);
126 }
127 }
128 } else {
129 if let Some(tool) = self.parse_single_object(&value)? {
131 tools.push(tool);
132 }
133 }
134
135 Ok(tools)
136 }
137
138 fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
140 let name = obj.get("name").and_then(|v| v.as_str());
141
142 if let Some(name) = name {
143 let empty_obj = Value::Object(serde_json::Map::new());
145 let args = obj.get("arguments").unwrap_or(&empty_obj);
146
147 let arguments = serde_json::to_string(args)
149 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
150
151 Ok(Some(ToolCall {
152 function: FunctionCall {
153 name: name.to_string(),
154 arguments,
155 },
156 }))
157 } else {
158 Ok(None)
159 }
160 }
161}
162
163impl Default for MistralParser {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[async_trait]
170impl ToolParser for MistralParser {
171 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
172 if !self.has_tool_markers(text) {
174 return Ok((text.to_string(), vec![]));
175 }
176
177 if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
179 let normal_text_before = if start_idx > 0 {
181 text[..start_idx].to_string()
182 } else {
183 String::new()
184 };
185
186 match self.parse_json_array(json_array) {
187 Ok(tools) => Ok((normal_text_before, tools)),
188 Err(e) => {
189 tracing::debug!("Failed to parse tool call: {}", e);
191 Ok((text.to_string(), vec![]))
192 }
193 }
194 } else {
195 Ok((text.to_string(), vec![]))
197 }
198 }
199
200 async fn parse_incremental(
201 &mut self,
202 chunk: &str,
203 tools: &[Tool],
204 ) -> ParserResult<StreamingParseResult> {
205 self.buffer.push_str(chunk);
207 let current_text = &self.buffer.clone();
208
209 let has_tool_start = self.has_tool_markers(current_text)
211 || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
212
213 if !has_tool_start {
214 if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
216 let mut normal_text = self.buffer.clone();
217 self.buffer.clear();
218
219 if !self.array_closed
222 && self.current_tool_id > 0
223 && normal_text.starts_with(self.eot_token)
224 {
225 normal_text = normal_text
226 .strip_prefix(self.eot_token)
227 .unwrap()
228 .to_string();
229 self.array_closed = true;
230 }
231
232 return Ok(StreamingParseResult {
233 normal_text,
234 calls: vec![],
235 });
236 } else {
237 return Ok(StreamingParseResult::default());
239 }
240 }
241
242 let tool_indices = helpers::get_tool_indices(tools);
244
245 let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
247 pos + self.bot_token.len()
248 } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
249 self.tool_call_separator.len()
250 } else {
251 0
252 };
253
254 helpers::handle_json_tool_streaming(
255 current_text,
256 start_idx,
257 &mut self.partial_json,
258 &tool_indices,
259 &mut self.buffer,
260 &mut self.current_tool_id,
261 &mut self.current_tool_name_sent,
262 &mut self.streamed_args_for_tool,
263 &mut self.prev_tool_call_arr,
264 )
265 }
266
267 fn has_tool_markers(&self, text: &str) -> bool {
268 text.contains("[TOOL_CALLS]")
269 }
270
271 fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::types::ToolCallItem>> {
272 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
273 }
274
275 fn reset(&mut self) {
276 helpers::reset_parser_state(
277 &mut self.buffer,
278 &mut self.prev_tool_call_arr,
279 &mut self.current_tool_id,
280 &mut self.current_tool_name_sent,
281 &mut self.streamed_args_for_tool,
282 );
283 self.array_closed = false;
284 }
285}