1use std::collections::HashMap;
2
3use openai_protocol::common::Tool;
4use serde::de::{Deserialize, IgnoredAny};
5use serde_json::{de::Deserializer, Value};
6
7use crate::{
8 errors::{ParserError, ParserResult},
9 types::{StreamingParseResult, ToolCallItem},
10};
11
12pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
14 tools
15 .iter()
16 .enumerate()
17 .map(|(i, tool)| (tool.function.name.clone(), i))
18 .collect()
19}
20
21pub fn find_common_prefix(s1: &str, s2: &str) -> String {
24 s1.chars()
25 .zip(s2.chars())
26 .take_while(|(c1, c2)| c1 == c2)
27 .map(|(c1, _)| c1)
28 .collect()
29}
30
31pub fn get_unstreamed_args(
35 prev_tool_call_arr: &[Value],
36 streamed_args_for_tool: &[String],
37) -> Option<Vec<ToolCallItem>> {
38 if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
40 return None;
41 }
42
43 let tool_index = prev_tool_call_arr.len() - 1;
45 if tool_index >= streamed_args_for_tool.len() {
46 return None;
47 }
48
49 let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
51 let expected_str = serde_json::to_string(expected_args).ok()?;
52 let actual_str = &streamed_args_for_tool[tool_index];
53
54 let remaining = if expected_str.starts_with(actual_str) {
56 &expected_str[actual_str.len()..]
57 } else {
58 return None;
59 };
60
61 if remaining.is_empty() {
62 return None;
63 }
64
65 Some(vec![ToolCallItem {
67 tool_index,
68 name: None, parameters: remaining.to_string(),
70 }])
71}
72
73pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
76 if buffer.is_empty() || token.is_empty() {
77 return None;
78 }
79
80 (1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
81}
82
83pub fn reset_current_tool_state(
87 buffer: &mut String,
88 current_tool_name_sent: &mut bool,
89 streamed_args_for_tool: &mut Vec<String>,
90 prev_tool_call_arr: &[Value],
91) {
92 buffer.clear();
93 *current_tool_name_sent = false;
94
95 if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
98 streamed_args_for_tool.pop();
99 }
100}
101
102pub fn reset_parser_state(
105 buffer: &mut String,
106 prev_tool_call_arr: &mut Vec<Value>,
107 current_tool_id: &mut i32,
108 current_tool_name_sent: &mut bool,
109 streamed_args_for_tool: &mut Vec<String>,
110) {
111 buffer.clear();
112 prev_tool_call_arr.clear();
113 *current_tool_id = -1;
114 *current_tool_name_sent = false;
115 streamed_args_for_tool.clear();
116}
117
118pub fn ensure_capacity(
120 current_tool_id: i32,
121 prev_tool_call_arr: &mut Vec<Value>,
122 streamed_args_for_tool: &mut Vec<String>,
123) {
124 if current_tool_id < 0 {
125 return;
126 }
127 let needed = (current_tool_id + 1) as usize;
128
129 if prev_tool_call_arr.len() < needed {
130 prev_tool_call_arr.resize_with(needed, || Value::Null);
131 }
132 if streamed_args_for_tool.len() < needed {
133 streamed_args_for_tool.resize_with(needed, String::new);
134 }
135}
136
137pub fn is_complete_json(input: &str) -> bool {
139 let mut de = Deserializer::from_str(input);
140 IgnoredAny::deserialize(&mut de).is_ok() && de.end().is_ok()
141}
142
143pub fn normalize_arguments_field(mut obj: Value) -> Value {
153 if obj.get("arguments").is_none() {
154 if let Some(params) = obj.get("parameters").cloned() {
155 if let Value::Object(ref mut map) = obj {
156 map.insert("arguments".to_string(), params);
157 }
158 }
159 }
160 obj
161}
162
163pub fn normalize_name_field(mut obj: Value) -> Value {
173 if obj.get("name").is_none() {
174 if let Some(tool_name) = obj.get("tool_name").cloned() {
175 if let Value::Object(ref mut map) = obj {
176 map.insert("name".to_string(), tool_name);
177 }
178 }
179 }
180 obj
181}
182
183pub fn normalize_tool_call_fields(obj: Value) -> Value {
189 let obj = normalize_name_field(obj);
190 normalize_arguments_field(obj)
191}
192
193#[expect(clippy::too_many_arguments)]
219pub(crate) fn handle_json_tool_streaming(
220 current_text: &str,
221 start_idx: usize,
222 partial_json: &mut crate::partial_json::PartialJson,
223 tool_indices: &HashMap<String, usize>,
224 buffer: &mut String,
225 current_tool_id: &mut i32,
226 current_tool_name_sent: &mut bool,
227 streamed_args_for_tool: &mut Vec<String>,
228 prev_tool_call_arr: &mut Vec<Value>,
229) -> ParserResult<StreamingParseResult> {
230 if start_idx >= current_text.len() {
232 return Ok(StreamingParseResult::default());
233 }
234
235 let json_str = ¤t_text[start_idx..];
237
238 let allow_partial_strings = *current_tool_name_sent;
241
242 let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) {
244 Ok(result) => result,
245 Err(_) => {
246 return Ok(StreamingParseResult::default());
247 }
248 };
249
250 let safe_end_idx = if json_str.is_char_boundary(end_idx) {
253 end_idx
254 } else {
255 (0..end_idx)
257 .rev()
258 .find(|&i| json_str.is_char_boundary(i))
259 .unwrap_or(0)
260 };
261 let is_complete = is_complete_json(&json_str[..safe_end_idx]);
262
263 let current_tool_call = normalize_tool_call_fields(obj);
266
267 if let Some(name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
269 if !tool_indices.contains_key(name) {
270 tracing::debug!("Invalid tool name '{}' - skipping", name);
272 reset_current_tool_state(
273 buffer,
274 current_tool_name_sent,
275 streamed_args_for_tool,
276 prev_tool_call_arr,
277 );
278 return Ok(StreamingParseResult::default());
279 }
280 }
281
282 let mut result = StreamingParseResult::default();
283
284 if !*current_tool_name_sent {
286 if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
287 if tool_indices.contains_key(function_name) {
288 if *current_tool_id == -1 {
290 *current_tool_id = 0;
291 streamed_args_for_tool.push(String::new());
292 } else if *current_tool_id as usize >= streamed_args_for_tool.len() {
293 ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
295 }
296
297 *current_tool_name_sent = true;
299 result.calls.push(ToolCallItem {
300 tool_index: *current_tool_id as usize,
301 name: Some(function_name.to_string()),
302 parameters: String::new(),
303 });
304 }
305 }
306 }
307 else if let Some(cur_arguments) = current_tool_call.get("arguments") {
309 let tool_id = *current_tool_id as usize;
310 let sent = streamed_args_for_tool
311 .get(tool_id)
312 .map(|s| s.len())
313 .unwrap_or(0);
314 let cur_args_json = serde_json::to_string(cur_arguments)
315 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
316
317 let prev_arguments = if tool_id < prev_tool_call_arr.len() {
319 prev_tool_call_arr[tool_id].get("arguments")
320 } else {
321 None
322 };
323
324 let mut argument_diff = None;
326
327 if is_complete {
328 argument_diff = if sent < cur_args_json.len() {
331 Some(cur_args_json[sent..].to_string())
332 } else {
333 Some(String::new())
334 };
335 } else if let Some(prev_args) = prev_arguments {
336 let prev_args_json = serde_json::to_string(prev_args)
337 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
338
339 if cur_args_json != prev_args_json {
340 let prefix = find_common_prefix(&prev_args_json, &cur_args_json);
341 argument_diff = if sent < prefix.len() {
342 Some(prefix[sent..].to_string())
343 } else {
344 Some(String::new())
345 };
346 }
347 }
348
349 if let Some(diff) = argument_diff {
351 if !diff.is_empty() {
352 if tool_id < streamed_args_for_tool.len() {
353 streamed_args_for_tool[tool_id].push_str(&diff);
354 }
355 result.calls.push(ToolCallItem {
356 tool_index: tool_id,
357 name: None,
358 parameters: diff,
359 });
360 }
361 }
362
363 if *current_tool_id >= 0 {
365 ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
366
367 if tool_id < prev_tool_call_arr.len() {
368 prev_tool_call_arr[tool_id] = current_tool_call;
369 }
370 }
371
372 if is_complete {
374 *buffer = current_text[start_idx + end_idx..].to_string();
375 *current_tool_name_sent = false;
376 *current_tool_id += 1;
377 }
378 }
379
380 Ok(result)
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_ends_with_partial_token() {
389 assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
390 assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
391 assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
392 assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
393 assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
394 }
395
396 #[test]
397 fn test_reset_current_tool_state() {
398 let mut buffer = String::from("partial json");
399 let mut current_tool_name_sent = true;
400 let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
401 let prev_tools = vec![serde_json::json!({"name": "tool0"})];
402
403 reset_current_tool_state(
404 &mut buffer,
405 &mut current_tool_name_sent,
406 &mut streamed_args,
407 &prev_tools,
408 );
409
410 assert_eq!(buffer, "");
411 assert!(!current_tool_name_sent);
412 assert_eq!(streamed_args.len(), 1); assert_eq!(streamed_args[0], "tool0_args");
414 }
415
416 #[test]
417 fn test_reset_current_tool_state_no_pop_when_synced() {
418 let mut buffer = String::from("partial json");
419 let mut current_tool_name_sent = true;
420 let mut streamed_args = vec!["tool0_args".to_string()];
421 let prev_tools = vec![serde_json::json!({"name": "tool0"})];
422
423 reset_current_tool_state(
424 &mut buffer,
425 &mut current_tool_name_sent,
426 &mut streamed_args,
427 &prev_tools,
428 );
429
430 assert_eq!(buffer, "");
431 assert!(!current_tool_name_sent);
432 assert_eq!(streamed_args.len(), 1); }
434
435 #[test]
436 fn test_reset_parser_state() {
437 let mut buffer = String::from("some buffer");
438 let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
439 let mut current_tool_id = 5;
440 let mut current_tool_name_sent = true;
441 let mut streamed_args = vec!["args".to_string()];
442
443 reset_parser_state(
444 &mut buffer,
445 &mut prev_tools,
446 &mut current_tool_id,
447 &mut current_tool_name_sent,
448 &mut streamed_args,
449 );
450
451 assert_eq!(buffer, "");
452 assert_eq!(prev_tools.len(), 0);
453 assert_eq!(current_tool_id, -1);
454 assert!(!current_tool_name_sent);
455 assert_eq!(streamed_args.len(), 0);
456 }
457
458 #[test]
459 fn test_ensure_capacity() {
460 let mut prev_tools = vec![];
461 let mut streamed_args = vec![];
462
463 ensure_capacity(2, &mut prev_tools, &mut streamed_args);
464
465 assert_eq!(prev_tools.len(), 3);
466 assert_eq!(streamed_args.len(), 3);
467 assert_eq!(prev_tools[0], Value::Null);
468 assert_eq!(streamed_args[0], "");
469 }
470
471 #[test]
472 fn test_ensure_capacity_negative_id() {
473 let mut prev_tools = vec![];
474 let mut streamed_args = vec![];
475
476 ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
477
478 assert_eq!(prev_tools.len(), 0);
480 assert_eq!(streamed_args.len(), 0);
481 }
482
483 #[test]
484 fn test_is_complete_json() {
485 assert!(is_complete_json(r#"{"name": "test"}"#));
486 assert!(is_complete_json("[1, 2, 3]"));
487 assert!(is_complete_json("42"));
488 assert!(is_complete_json("true"));
489 assert!(!is_complete_json(r#"{"name": "#));
490 assert!(!is_complete_json("[1, 2,"));
491 }
492
493 #[test]
494 fn test_normalize_arguments_field() {
495 let obj = serde_json::json!({
497 "name": "test",
498 "parameters": {"key": "value"}
499 });
500 let normalized = normalize_arguments_field(obj);
501 assert_eq!(
502 normalized.get("arguments").unwrap(),
503 &serde_json::json!({"key": "value"})
504 );
505
506 let obj = serde_json::json!({
508 "name": "test",
509 "arguments": {"key": "value"}
510 });
511 let normalized = normalize_arguments_field(obj.clone());
512 assert_eq!(normalized, obj);
513
514 let obj = serde_json::json!({"name": "test"});
516 let normalized = normalize_arguments_field(obj.clone());
517 assert_eq!(normalized, obj);
518 }
519
520 #[test]
521 fn test_normalize_name_field() {
522 let obj = serde_json::json!({
524 "tool_name": "search",
525 "parameters": {"query": "test"}
526 });
527 let normalized = normalize_name_field(obj);
528 assert_eq!(normalized.get("name").unwrap(), "search");
529
530 let obj = serde_json::json!({
532 "name": "test",
533 "arguments": {"key": "value"}
534 });
535 let normalized = normalize_name_field(obj.clone());
536 assert_eq!(normalized, obj);
537
538 let obj = serde_json::json!({
540 "tool_name": "cohere_name",
541 "name": "standard_name",
542 "parameters": {}
543 });
544 let normalized = normalize_name_field(obj);
545 assert_eq!(normalized.get("name").unwrap(), "standard_name");
546
547 let obj = serde_json::json!({"parameters": {}});
549 let normalized = normalize_name_field(obj.clone());
550 assert!(normalized.get("name").is_none());
551 }
552
553 #[test]
554 fn test_normalize_tool_call_fields() {
555 let obj = serde_json::json!({
557 "tool_name": "search",
558 "parameters": {"query": "rust programming"}
559 });
560 let normalized = normalize_tool_call_fields(obj);
561 assert_eq!(normalized.get("name").unwrap(), "search");
562 assert_eq!(
563 normalized.get("arguments").unwrap(),
564 &serde_json::json!({"query": "rust programming"})
565 );
566
567 let obj = serde_json::json!({
569 "name": "test",
570 "arguments": {"key": "value"}
571 });
572 let normalized = normalize_tool_call_fields(obj.clone());
573 assert_eq!(normalized, obj);
574
575 let obj = serde_json::json!({
577 "name": "test",
578 "parameters": {"key": "value"}
579 });
580 let normalized = normalize_tool_call_fields(obj);
581 assert_eq!(normalized.get("name").unwrap(), "test");
582 assert_eq!(
583 normalized.get("arguments").unwrap(),
584 &serde_json::json!({"key": "value"})
585 );
586 }
587}