1use async_trait::async_trait;
31use openai_protocol::common::Tool;
32use serde_json::Value;
33
34use crate::{
35 errors::{ParserError, ParserResult},
36 parsers::helpers,
37 partial_json::PartialJson,
38 traits::ToolParser,
39 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
40};
41
42const START_ACTION: &str = "<|START_ACTION|>";
43const END_ACTION: &str = "<|END_ACTION|>";
44const START_RESPONSE: &str = "<|START_RESPONSE|>";
45const END_RESPONSE: &str = "<|END_RESPONSE|>";
46const START_TEXT: &str = "<|START_TEXT|>";
47const END_TEXT: &str = "<|END_TEXT|>";
48
49#[derive(Debug, Clone, Copy, PartialEq)]
51enum ParseState {
52 Text,
54 InAction,
56}
57
58pub struct CohereParser {
63 state: ParseState,
65
66 partial_json: PartialJson,
68
69 buffer: String,
71
72 prev_tool_call_arr: Vec<Value>,
74
75 current_tool_id: i32,
77
78 current_tool_name_sent: bool,
80
81 streamed_args_for_tool: Vec<String>,
83}
84
85impl CohereParser {
86 pub fn new() -> Self {
88 Self {
89 state: ParseState::Text,
90 partial_json: PartialJson::default(),
91 buffer: String::new(),
92 prev_tool_call_arr: Vec::new(),
93 current_tool_id: -1,
94 current_tool_name_sent: false,
95 streamed_args_for_tool: Vec::new(),
96 }
97 }
98
99 fn clean_text(text: &str) -> String {
101 text.replace(START_RESPONSE, "")
102 .replace(END_RESPONSE, "")
103 .replace(START_TEXT, "")
104 .replace(END_TEXT, "")
105 }
106
107 fn convert_tool_call(json_str: &str) -> ParserResult<Vec<ToolCall>> {
109 let value: Value = serde_json::from_str(json_str.trim())
110 .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {e}")))?;
111
112 let tools = match value {
113 Value::Array(arr) => arr,
114 single => vec![single],
115 };
116
117 tools
118 .into_iter()
119 .filter_map(|tool| {
120 let name = tool
122 .get("tool_name")
123 .and_then(|v| v.as_str())
124 .or_else(|| tool.get("name").and_then(|v| v.as_str()))?;
125
126 let parameters = tool
128 .get("parameters")
129 .or_else(|| tool.get("arguments"))
130 .map(|v| v.to_string())
131 .unwrap_or_else(|| "{}".to_string());
132
133 Some(Ok(ToolCall {
134 function: FunctionCall {
135 name: name.to_string(),
136 arguments: parameters,
137 },
138 }))
139 })
140 .collect()
141 }
142
143 fn extract_action_json(text: &str) -> Option<(usize, &str, usize)> {
145 let start_idx = text.find(START_ACTION)?;
146 let json_start = start_idx + START_ACTION.len();
147
148 if let Some(end_offset) = text[json_start..].find(END_ACTION) {
149 let json_str = &text[json_start..json_start + end_offset];
150 Some((
151 start_idx,
152 json_str,
153 json_start + end_offset + END_ACTION.len(),
154 ))
155 } else {
156 None
158 }
159 }
160}
161
162impl Default for CohereParser {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168#[async_trait]
169impl ToolParser for CohereParser {
170 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
171 if !self.has_tool_markers(text) {
173 let cleaned = Self::clean_text(text);
174 return Ok((cleaned.trim().to_string(), vec![]));
175 }
176
177 let mut normal_text = String::new();
178 let mut tool_calls = Vec::new();
179 let mut remaining = text;
180
181 while let Some((start_idx, json_str, end_idx)) = Self::extract_action_json(remaining) {
182 normal_text.push_str(&remaining[..start_idx]);
184
185 match Self::convert_tool_call(json_str) {
187 Ok(calls) => tool_calls.extend(calls),
188 Err(e) => {
189 tracing::debug!("Failed to parse Cohere tool call: {}", e);
190 }
191 }
192
193 remaining = &remaining[end_idx..];
194 }
195
196 normal_text.push_str(remaining);
198
199 let cleaned_text = Self::clean_text(&normal_text);
201
202 Ok((cleaned_text.trim().to_string(), tool_calls))
203 }
204
205 async fn parse_incremental(
206 &mut self,
207 chunk: &str,
208 tools: &[Tool],
209 ) -> ParserResult<StreamingParseResult> {
210 self.buffer.push_str(chunk);
211
212 match self.state {
213 ParseState::Text => {
214 let start_pos = self.buffer.find(START_ACTION);
216 if let Some(pos) = start_pos {
217 let text_before = Self::clean_text(&self.buffer[..pos]);
219
220 self.state = ParseState::InAction;
222 self.buffer.drain(..pos + START_ACTION.len());
223
224 return Ok(StreamingParseResult {
225 normal_text: text_before,
226 calls: vec![],
227 });
228 }
229
230 if helpers::ends_with_partial_token(&self.buffer, START_ACTION).is_some() {
232 return Ok(StreamingParseResult::default());
234 }
235
236 let cleaned = Self::clean_text(&self.buffer);
238 self.buffer.clear();
239 Ok(StreamingParseResult {
240 normal_text: cleaned,
241 calls: vec![],
242 })
243 }
244
245 ParseState::InAction => {
246 if let Some(pos) = self.buffer.find(END_ACTION) {
248 let json_content = self.buffer[..pos].to_string();
250
251 let tool_indices = helpers::get_tool_indices(tools);
253
254 let mut temp_buffer = String::new();
256
257 let result = helpers::handle_json_tool_streaming(
259 &json_content,
260 0,
261 &mut self.partial_json,
262 &tool_indices,
263 &mut temp_buffer,
264 &mut self.current_tool_id,
265 &mut self.current_tool_name_sent,
266 &mut self.streamed_args_for_tool,
267 &mut self.prev_tool_call_arr,
268 )?;
269
270 self.buffer.drain(..pos + END_ACTION.len());
272 self.state = ParseState::Text;
273
274 return Ok(result);
275 }
276
277 Ok(StreamingParseResult::default())
281 }
282 }
283 }
284
285 fn has_tool_markers(&self, text: &str) -> bool {
286 text.contains(START_ACTION) || text.contains(END_ACTION)
287 }
288
289 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
290 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
291 }
292
293 fn reset(&mut self) {
294 self.state = ParseState::Text;
295 helpers::reset_parser_state(
296 &mut self.buffer,
297 &mut self.prev_tool_call_arr,
298 &mut self.current_tool_id,
299 &mut self.current_tool_name_sent,
300 &mut self.streamed_args_for_tool,
301 );
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[tokio::test]
310 async fn test_single_tool_call() {
311 let parser = CohereParser::new();
312 let input = r#"<|START_RESPONSE|>Let me search for that.<|END_RESPONSE|>
313<|START_ACTION|>
314{"tool_name": "search", "parameters": {"query": "rust programming"}}
315<|END_ACTION|>"#;
316
317 let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
318 assert_eq!(tools.len(), 1);
319 assert_eq!(normal_text, "Let me search for that.");
320 assert_eq!(tools[0].function.name, "search");
321
322 let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
323 assert_eq!(args["query"], "rust programming");
324 }
325
326 #[tokio::test]
327 async fn test_multiple_tool_calls_array() {
328 let parser = CohereParser::new();
329 let input = r#"<|START_ACTION|>
330[
331 {"tool_name": "search", "parameters": {"query": "rust"}},
332 {"tool_name": "get_weather", "parameters": {"city": "Paris"}}
333]
334<|END_ACTION|>"#;
335
336 let (_, tools) = parser.parse_complete(input).await.unwrap();
337 assert_eq!(tools.len(), 2);
338 assert_eq!(tools[0].function.name, "search");
339 assert_eq!(tools[1].function.name, "get_weather");
340 }
341
342 #[tokio::test]
343 async fn test_no_tool_calls() {
344 let parser = CohereParser::new();
345 let input = "<|START_RESPONSE|>Hello, how can I help?<|END_RESPONSE|>";
346
347 let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
348 assert_eq!(tools.len(), 0);
349 assert_eq!(normal_text, "Hello, how can I help?");
350 }
351
352 #[tokio::test]
353 async fn test_has_tool_markers() {
354 let parser = CohereParser::new();
355
356 assert!(parser.has_tool_markers("<|START_ACTION|>"));
357 assert!(parser.has_tool_markers("<|END_ACTION|>"));
358 assert!(parser.has_tool_markers("Some text <|START_ACTION|> more"));
359 assert!(!parser.has_tool_markers("Just plain text"));
360 assert!(!parser.has_tool_markers("[TOOL_CALLS]")); }
362
363 #[tokio::test]
364 async fn test_empty_parameters() {
365 let parser = CohereParser::new();
366 let input = r#"<|START_ACTION|>{"tool_name": "ping"}<|END_ACTION|>"#;
367
368 let (_, tools) = parser.parse_complete(input).await.unwrap();
369 assert_eq!(tools.len(), 1);
370 assert_eq!(tools[0].function.name, "ping");
371 assert_eq!(tools[0].function.arguments, "{}");
372 }
373
374 #[tokio::test]
375 async fn test_nested_json() {
376 let parser = CohereParser::new();
377 let input = r#"<|START_ACTION|>
378{"tool_name": "process", "parameters": {"config": {"nested": {"value": [1, 2, 3]}}}}
379<|END_ACTION|>"#;
380
381 let (_, tools) = parser.parse_complete(input).await.unwrap();
382 assert_eq!(tools.len(), 1);
383
384 let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
385 assert_eq!(
386 args["config"]["nested"]["value"],
387 serde_json::json!([1, 2, 3])
388 );
389 }
390
391 #[tokio::test]
392 async fn test_text_markers_cleaned() {
393 let parser = CohereParser::new();
394 let input = r#"<|START_TEXT|>Some intro<|END_TEXT|>
395<|START_ACTION|>{"tool_name": "test", "parameters": {}}<|END_ACTION|>
396<|START_TEXT|>Conclusion<|END_TEXT|>"#;
397
398 let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
399 assert_eq!(tools.len(), 1);
400 assert!(normal_text.contains("Some intro"));
401 assert!(normal_text.contains("Conclusion"));
402 assert!(!normal_text.contains("<|START_TEXT|>"));
403 assert!(!normal_text.contains("<|END_TEXT|>"));
404 }
405
406 #[tokio::test]
407 async fn test_malformed_json() {
408 let parser = CohereParser::new();
409 let input = r#"<|START_ACTION|>{"tool_name": invalid}<|END_ACTION|>"#;
410
411 let (_, tools) = parser.parse_complete(input).await.unwrap();
412 assert_eq!(tools.len(), 0);
414 }
415}