1use async_trait::async_trait;
28use serde_json::{Map, Value};
29
30use crate::registry::KernelError;
31use crate::registry::ToolRegistry;
32use crate::tool::ToolName;
33use crate::trace::{DispatchTrace, DispatchTraceEvent, TracedAction, TracedOutcome};
34
35#[derive(Debug, Clone, PartialEq)]
39pub struct ToolInvocation {
40 pub name: ToolName,
42 pub args: Value,
44}
45
46impl ToolInvocation {
47 pub fn new(name: impl Into<ToolName>, args: Value) -> Result<Self, KernelError> {
49 let name = name.into();
50 if name.trim().is_empty() {
51 return Err(KernelError::NormalizerFailed(
52 "empty tool name in structured tool call".into(),
53 ));
54 }
55 validate_identifier("tool name", &name)?;
56 Ok(Self { name, args })
57 }
58
59 pub async fn dispatch(&self, tools: &ToolRegistry) -> Result<Value, KernelError> {
61 tools.invoke(&self.name, self.args.clone()).await
62 }
63}
64
65#[derive(Debug, Clone, PartialEq)]
67pub struct ToolInvocationResult {
68 pub invocation: ToolInvocation,
70 pub output: Value,
72}
73
74#[derive(Debug, Clone, PartialEq)]
76pub enum ToolDispatchAction {
77 Continue,
79 Skip {
81 output: Value,
83 reason: Option<String>,
85 },
86 Terminate { reason: String },
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum ToolInvocationOutcome {
93 Completed,
95 Skipped {
97 reason: Option<String>,
99 },
100}
101
102#[async_trait]
109pub trait ToolDispatchHook: Send + Sync {
110 async fn before_invocation(
114 &self,
115 _invocation: &ToolInvocation,
116 ) -> Result<ToolDispatchAction, KernelError> {
117 Ok(ToolDispatchAction::Continue)
118 }
119
120 async fn after_invocation(&self, _result: &ToolInvocationResult) -> Result<(), KernelError> {
122 Ok(())
123 }
124
125 async fn after_invocation_with_outcome(
131 &self,
132 result: &ToolInvocationResult,
133 _outcome: &ToolInvocationOutcome,
134 ) -> Result<(), KernelError> {
135 self.after_invocation(result).await
136 }
137
138 async fn on_invocation_error(
142 &self,
143 _invocation: &ToolInvocation,
144 _error: &KernelError,
145 ) -> Result<(), KernelError> {
146 Ok(())
147 }
148}
149
150pub async fn dispatch_tool_invocations(
157 tools: &ToolRegistry,
158 invocations: &[ToolInvocation],
159) -> Result<Vec<ToolInvocationResult>, KernelError> {
160 dispatch_tool_invocations_with_hooks(tools, invocations, &[]).await
161}
162
163pub async fn dispatch_tool_invocations_with_hooks(
170 tools: &ToolRegistry,
171 invocations: &[ToolInvocation],
172 hooks: &[&dyn ToolDispatchHook],
173) -> Result<Vec<ToolInvocationResult>, KernelError> {
174 dispatch_inner(tools, invocations, hooks, None).await
175}
176
177pub async fn dispatch_tool_invocations_with_trace(
185 tools: &ToolRegistry,
186 invocations: &[ToolInvocation],
187 hooks: &[&dyn ToolDispatchHook],
188 trace: &DispatchTrace,
189) -> Result<Vec<ToolInvocationResult>, KernelError> {
190 dispatch_inner(tools, invocations, hooks, Some(trace)).await
191}
192
193async fn dispatch_inner(
194 tools: &ToolRegistry,
195 invocations: &[ToolInvocation],
196 hooks: &[&dyn ToolDispatchHook],
197 trace: Option<&DispatchTrace>,
198) -> Result<Vec<ToolInvocationResult>, KernelError> {
199 let mut results = Vec::with_capacity(invocations.len());
200
201 for (invocation_index, invocation) in invocations.iter().enumerate() {
202 let mut action = ToolDispatchAction::Continue;
203 let mut observed: usize = 0;
209 let mut before_err: Option<(usize, KernelError)> = None;
210 for (hook_index, hook) in hooks.iter().enumerate() {
211 match hook.before_invocation(invocation).await {
212 Ok(next) => {
213 observed += 1;
214 if let Some(trace) = trace {
215 trace.push(DispatchTraceEvent::HookBefore {
216 invocation_index,
217 hook_index,
218 decision: TracedAction::from(&next),
219 });
220 }
221 action = next;
222 if !matches!(action, ToolDispatchAction::Continue) {
223 break;
224 }
225 }
226 Err(error) => {
227 before_err = Some((hook_index, error));
228 break;
229 }
230 }
231 }
232 if let Some((hook_index, error)) = before_err {
233 if let Some(trace) = trace {
234 trace.push(DispatchTraceEvent::HookBeforeError {
235 invocation_index,
236 hook_index,
237 message: error.to_string(),
238 });
239 }
240 notify_invocation_error_subset(
241 hooks,
242 observed,
243 invocation,
244 &error,
245 trace,
246 invocation_index,
247 )
248 .await?;
249 if let Some(trace) = trace {
250 trace.push(DispatchTraceEvent::InvocationOutcome {
251 invocation_index,
252 outcome: TracedOutcome::Failed {
253 message: error.to_string(),
254 },
255 });
256 }
257 return Err(error);
258 }
259
260 let (output, outcome) = match action {
261 ToolDispatchAction::Continue => match invocation.dispatch(tools).await {
262 Ok(output) => (output, ToolInvocationOutcome::Completed),
263 Err(error) => {
264 notify_invocation_error(hooks, invocation, &error, trace, invocation_index)
265 .await?;
266 if let Some(trace) = trace {
267 trace.push(DispatchTraceEvent::InvocationOutcome {
268 invocation_index,
269 outcome: TracedOutcome::Failed {
270 message: error.to_string(),
271 },
272 });
273 }
274 return Err(error);
275 }
276 },
277 ToolDispatchAction::Skip { output, reason } => {
278 (output, ToolInvocationOutcome::Skipped { reason })
279 }
280 ToolDispatchAction::Terminate { reason } => {
281 let error = KernelError::ToolDispatchTerminated(reason.clone());
282 notify_invocation_error(hooks, invocation, &error, trace, invocation_index).await?;
283 if let Some(trace) = trace {
284 trace.push(DispatchTraceEvent::InvocationOutcome {
285 invocation_index,
286 outcome: TracedOutcome::Terminated { reason },
287 });
288 }
289 return Err(error);
290 }
291 };
292
293 let result = ToolInvocationResult {
294 invocation: invocation.clone(),
295 output,
296 };
297
298 for (hook_index, hook) in hooks.iter().enumerate() {
299 hook.after_invocation_with_outcome(&result, &outcome)
300 .await?;
301 if let Some(trace) = trace {
302 trace.push(DispatchTraceEvent::HookAfter {
303 invocation_index,
304 hook_index,
305 });
306 }
307 }
308
309 if let Some(trace) = trace {
310 let outcome_event = match &outcome {
311 ToolInvocationOutcome::Completed => TracedOutcome::Completed,
312 ToolInvocationOutcome::Skipped { reason } => TracedOutcome::Skipped {
313 reason: reason.clone(),
314 },
315 };
316 trace.push(DispatchTraceEvent::InvocationOutcome {
317 invocation_index,
318 outcome: outcome_event,
319 });
320 }
321
322 results.push(result);
323 }
324
325 Ok(results)
326}
327
328async fn notify_invocation_error(
329 hooks: &[&dyn ToolDispatchHook],
330 invocation: &ToolInvocation,
331 error: &KernelError,
332 trace: Option<&DispatchTrace>,
333 invocation_index: usize,
334) -> Result<(), KernelError> {
335 for (hook_index, hook) in hooks.iter().enumerate() {
336 hook.on_invocation_error(invocation, error).await?;
337 if let Some(trace) = trace {
338 trace.push(DispatchTraceEvent::HookCleanup {
339 invocation_index,
340 hook_index,
341 });
342 }
343 }
344 Ok(())
345}
346
347async fn notify_invocation_error_subset(
351 hooks: &[&dyn ToolDispatchHook],
352 upto: usize,
353 invocation: &ToolInvocation,
354 error: &KernelError,
355 trace: Option<&DispatchTrace>,
356 invocation_index: usize,
357) -> Result<(), KernelError> {
358 for (hook_index, hook) in hooks.iter().take(upto).enumerate() {
359 hook.on_invocation_error(invocation, error).await?;
360 if let Some(trace) = trace {
361 trace.push(DispatchTraceEvent::HookCleanup {
362 invocation_index,
363 hook_index,
364 });
365 }
366 }
367 Ok(())
368}
369
370pub trait ToolCallNormalizer: Send + Sync {
385 fn normalize(&self, raw: &str) -> Result<Vec<ToolInvocation>, KernelError>;
387
388 fn is_applicable(&self, raw: &str) -> bool;
390}
391
392#[derive(Debug, Clone, Default)]
404pub struct StructuredToolCallNormalizer;
405
406impl StructuredToolCallNormalizer {
407 pub fn normalize_openai_responses(value: &Value) -> Result<Vec<ToolInvocation>, KernelError> {
410 match value {
411 Value::Object(object) => {
412 if let Some(output) = object.get("output") {
413 return normalize_responses_output(output);
414 }
415 if is_responses_function_call(object) {
416 return parse_responses_function_call(object).map(|call| vec![call]);
417 }
418 Ok(Vec::new())
419 }
420 Value::Array(items) => items
421 .iter()
422 .map(normalize_responses_output_item)
423 .collect::<Result<Vec<_>, _>>()
424 .map(flatten_invocations),
425 _ => Ok(Vec::new()),
426 }
427 }
428
429 pub fn normalize_openai_chat_completions(
432 value: &Value,
433 ) -> Result<Vec<ToolInvocation>, KernelError> {
434 match value {
435 Value::Object(object) => {
436 if let Some(choices) = object.get("choices") {
437 return normalize_chat_choices(choices);
438 }
439 if let Some(tool_calls) = object.get("tool_calls") {
440 return normalize_chat_tool_calls(tool_calls);
441 }
442 if is_chat_tool_call(object) {
443 return parse_chat_tool_call(object).map(|call| vec![call]);
444 }
445 Ok(Vec::new())
446 }
447 Value::Array(items) => normalize_chat_tool_calls_array(items),
448 _ => Ok(Vec::new()),
449 }
450 }
451
452 pub fn normalize(value: &Value) -> Result<Vec<ToolInvocation>, KernelError> {
458 let mut invocations = Self::normalize_openai_responses(value)?;
459 invocations.extend(Self::normalize_openai_chat_completions(value)?);
460 Ok(invocations)
461 }
462}
463
464const LFM_START: &str = "<|tool_call_start|>";
467const LFM_END: &str = "<|tool_call_end|>";
468
469#[derive(Debug, Clone, Default)]
493pub struct LfmNormalizer;
494
495impl ToolCallNormalizer for LfmNormalizer {
496 fn is_applicable(&self, raw: &str) -> bool {
497 raw.contains(LFM_START)
498 }
499
500 fn normalize(&self, raw: &str) -> Result<Vec<ToolInvocation>, KernelError> {
501 let mut results = Vec::new();
502 let mut remaining = raw;
503
504 while let Some(block_start) = remaining.find(LFM_START) {
505 let after_start = remaining
507 .get(block_start + LFM_START.len()..)
508 .ok_or_else(|| KernelError::NormalizerFailed("LFM: start marker overrun".into()))?;
509
510 let block_end = after_start.find(LFM_END).ok_or_else(|| {
511 KernelError::NormalizerFailed("LFM: unclosed <|tool_call_start|> marker".into())
512 })?;
513
514 let block = after_start.get(..block_end).ok_or_else(|| {
515 KernelError::NormalizerFailed("LFM: block slice out of bounds".into())
516 })?;
517
518 remaining = after_start.get(block_end + LFM_END.len()..).unwrap_or("");
520
521 let calls = parse_lfm_block(block)?;
522 results.extend(calls);
523 }
524
525 Ok(results)
526 }
527}
528
529fn normalize_responses_output(value: &Value) -> Result<Vec<ToolInvocation>, KernelError> {
532 match value {
533 Value::Array(items) => items
534 .iter()
535 .map(normalize_responses_output_item)
536 .collect::<Result<Vec<_>, _>>()
537 .map(flatten_invocations),
538 Value::Object(object) if is_responses_function_call(object) => {
539 parse_responses_function_call(object).map(|call| vec![call])
540 }
541 _ => Ok(Vec::new()),
542 }
543}
544
545fn normalize_responses_output_item(value: &Value) -> Result<Vec<ToolInvocation>, KernelError> {
546 match value {
547 Value::Object(object) if is_responses_function_call(object) => {
548 parse_responses_function_call(object).map(|call| vec![call])
549 }
550 _ => Ok(Vec::new()),
551 }
552}
553
554fn is_responses_function_call(object: &Map<String, Value>) -> bool {
555 object
556 .get("type")
557 .and_then(Value::as_str)
558 .is_some_and(|kind| kind == "function_call")
559}
560
561fn parse_responses_function_call(
562 object: &Map<String, Value>,
563) -> Result<ToolInvocation, KernelError> {
564 let name = required_string_field(object, "name", "OpenAI Responses function_call")?;
565 let args = object
566 .get("arguments")
567 .map(parse_standard_arguments)
568 .transpose()?
569 .unwrap_or_else(|| Value::Object(Map::new()));
570 ToolInvocation::new(name, args)
571}
572
573fn normalize_chat_choices(value: &Value) -> Result<Vec<ToolInvocation>, KernelError> {
574 let choices = value.as_array().ok_or_else(|| {
575 KernelError::NormalizerFailed("OpenAI Chat Completions choices must be an array".into())
576 })?;
577
578 let mut invocations = Vec::new();
579 for choice in choices {
580 let Some(message) = choice.get("message") else {
581 continue;
582 };
583 invocations
584 .extend(StructuredToolCallNormalizer::normalize_openai_chat_completions(message)?);
585 }
586
587 Ok(invocations)
588}
589
590fn normalize_chat_tool_calls(value: &Value) -> Result<Vec<ToolInvocation>, KernelError> {
591 match value {
592 Value::Array(items) => normalize_chat_tool_calls_array(items),
593 Value::Object(object) if is_chat_tool_call(object) => {
594 parse_chat_tool_call(object).map(|call| vec![call])
595 }
596 _ => Ok(Vec::new()),
597 }
598}
599
600fn normalize_chat_tool_calls_array(items: &[Value]) -> Result<Vec<ToolInvocation>, KernelError> {
601 items
602 .iter()
603 .map(|item| match item {
604 Value::Object(object) if is_chat_tool_call(object) => parse_chat_tool_call(object),
605 Value::Object(_) => Err(KernelError::NormalizerFailed(
606 "OpenAI Chat Completions tool call missing function payload".into(),
607 )),
608 _ => Err(KernelError::NormalizerFailed(
609 "OpenAI Chat Completions tool call must be an object".into(),
610 )),
611 })
612 .collect()
613}
614
615fn is_chat_tool_call(object: &Map<String, Value>) -> bool {
616 object.get("function").is_some()
617}
618
619fn parse_chat_tool_call(object: &Map<String, Value>) -> Result<ToolInvocation, KernelError> {
620 let function = object
621 .get("function")
622 .and_then(Value::as_object)
623 .ok_or_else(|| {
624 KernelError::NormalizerFailed(
625 "OpenAI Chat Completions tool call missing function object".into(),
626 )
627 })?;
628 let name = required_string_field(function, "name", "OpenAI Chat Completions function")?;
629 let args = function
630 .get("arguments")
631 .map(parse_standard_arguments)
632 .transpose()?
633 .unwrap_or_else(|| Value::Object(Map::new()));
634
635 ToolInvocation::new(name, args)
636}
637
638fn parse_standard_arguments(value: &Value) -> Result<Value, KernelError> {
639 match value {
640 Value::String(raw) => {
641 let trimmed = raw.trim();
642 if trimmed.is_empty() {
643 return Ok(Value::Object(Map::new()));
644 }
645 serde_json::from_str(trimmed).map_err(|err| {
646 KernelError::NormalizerFailed(format!(
647 "failed to parse standard tool-call arguments JSON: {err}"
648 ))
649 })
650 }
651 Value::Null => Ok(Value::Object(Map::new())),
652 other => Ok(other.clone()),
653 }
654}
655
656fn required_string_field(
657 object: &Map<String, Value>,
658 field: &str,
659 context: &str,
660) -> Result<String, KernelError> {
661 object
662 .get(field)
663 .and_then(Value::as_str)
664 .map(ToOwned::to_owned)
665 .ok_or_else(|| KernelError::NormalizerFailed(format!("{context} missing `{field}` string")))
666}
667
668fn flatten_invocations(nested: Vec<Vec<ToolInvocation>>) -> Vec<ToolInvocation> {
669 nested.into_iter().flatten().collect()
670}
671
672fn parse_lfm_block(block: &str) -> Result<Vec<ToolInvocation>, KernelError> {
676 let block = block.trim();
677 let inner = block
679 .strip_prefix('[')
680 .and_then(|s| s.strip_suffix(']'))
681 .unwrap_or(block);
682
683 split_top_level(inner, ',')
684 .into_iter()
685 .filter(|s| !s.trim().is_empty())
686 .map(|s| parse_lfm_call(s.trim()))
687 .collect()
688}
689
690fn parse_lfm_call(expr: &str) -> Result<ToolInvocation, KernelError> {
692 let (name_raw, rest) = expr.split_once('(').ok_or_else(|| {
693 KernelError::NormalizerFailed(format!("LFM: expected '(' in call: {expr:?}"))
694 })?;
695
696 let name = name_raw.trim().to_string();
697 if name.is_empty() {
698 return Err(KernelError::NormalizerFailed(
699 "LFM: empty tool name in call expression".into(),
700 ));
701 }
702 validate_identifier("tool name", &name)?;
703
704 let (kwargs_str, trailing) = rest.rsplit_once(')').ok_or_else(|| {
706 KernelError::NormalizerFailed(format!("LFM: missing closing ')' in: {expr:?}"))
707 })?;
708 if !trailing.trim().is_empty() {
709 return Err(KernelError::NormalizerFailed(format!(
710 "LFM: trailing content after call expression: {trailing:?}"
711 )));
712 }
713
714 let args = parse_kwargs(kwargs_str)?;
715 Ok(ToolInvocation { name, args })
716}
717
718fn parse_kwargs(s: &str) -> Result<Value, KernelError> {
720 let s = s.trim();
721 if s.is_empty() {
722 return Ok(Value::Object(Map::new()));
723 }
724
725 let mut map = Map::new();
726 for pair in split_top_level(s, ',') {
727 let pair = pair.trim();
728 if pair.is_empty() {
729 continue;
730 }
731 let (key_raw, val_raw) = pair.split_once('=').ok_or_else(|| {
732 KernelError::NormalizerFailed(format!("LFM: kwarg without '=': {pair:?}"))
733 })?;
734 let key = key_raw.trim().to_string();
735 if key.is_empty() {
736 return Err(KernelError::NormalizerFailed(
737 "LFM: empty kwarg name".into(),
738 ));
739 }
740 validate_identifier("kwarg name", &key)?;
741 if map.contains_key(&key) {
742 return Err(KernelError::NormalizerFailed(format!(
743 "LFM: duplicate kwarg: {key}"
744 )));
745 }
746 let val = parse_value(val_raw.trim())?;
747 map.insert(key, val);
748 }
749
750 Ok(Value::Object(map))
751}
752
753fn parse_value(s: &str) -> Result<Value, KernelError> {
759 let s = s.trim();
760
761 if s.is_empty() {
762 return Ok(Value::String(String::new()));
763 }
764
765 if let Some(inner) = s.strip_prefix('\'').and_then(|t| t.strip_suffix('\'')) {
767 return Ok(Value::String(
768 inner.replace("\\'", "'").replace("\\\"", "\""),
769 ));
770 }
771 if s.starts_with('\'') {
772 return Err(KernelError::NormalizerFailed(
773 "LFM: unterminated single-quoted string".into(),
774 ));
775 }
776 if let Some(inner) = s.strip_prefix('"').and_then(|t| t.strip_suffix('"')) {
778 return Ok(Value::String(
779 inner.replace("\\'", "'").replace("\\\"", "\""),
780 ));
781 }
782 if s.starts_with('"') {
783 return Err(KernelError::NormalizerFailed(
784 "LFM: unterminated double-quoted string".into(),
785 ));
786 }
787 if s == "True" {
789 return Ok(Value::Bool(true));
790 }
791 if s == "False" {
792 return Ok(Value::Bool(false));
793 }
794 if s == "None" || s == "null" {
796 return Ok(Value::Null);
797 }
798 if let Some(inner) = s.strip_prefix('[').and_then(|t| t.strip_suffix(']')) {
800 return parse_array(inner);
801 }
802 if s.starts_with('[') {
803 return Err(KernelError::NormalizerFailed(
804 "LFM: unterminated list literal".into(),
805 ));
806 }
807 if let Some(inner) = s.strip_prefix('{').and_then(|t| t.strip_suffix('}')) {
809 return parse_object(inner);
810 }
811 if s.starts_with('{') {
812 return Err(KernelError::NormalizerFailed(
813 "LFM: unterminated object literal".into(),
814 ));
815 }
816 if let Ok(n) = s.parse::<i64>() {
818 return Ok(Value::Number(n.into()));
819 }
820 if let Ok(f) = s.parse::<f64>() {
822 let num = serde_json::Number::from_f64(f).ok_or_else(|| {
823 KernelError::NormalizerFailed(format!("LFM: non-finite float in argument: {s:?}"))
824 })?;
825 return Ok(Value::Number(num));
826 }
827 Ok(Value::String(s.to_string()))
829}
830
831fn parse_array(inner: &str) -> Result<Value, KernelError> {
832 let inner = inner.trim();
833 if inner.is_empty() {
834 return Ok(Value::Array(Vec::new()));
835 }
836
837 let values = split_top_level(inner, ',')
838 .into_iter()
839 .filter(|part| !part.trim().is_empty())
840 .map(|part| parse_value(part.trim()))
841 .collect::<Result<Vec<_>, _>>()?;
842
843 Ok(Value::Array(values))
844}
845
846fn parse_object(inner: &str) -> Result<Value, KernelError> {
847 let inner = inner.trim();
848 if inner.is_empty() {
849 return Ok(Value::Object(Map::new()));
850 }
851
852 let mut map = Map::new();
853 for entry in split_top_level(inner, ',') {
854 let entry = entry.trim();
855 if entry.is_empty() {
856 continue;
857 }
858
859 let (key_raw, value_raw) = split_once_top_level(entry, ':').ok_or_else(|| {
860 KernelError::NormalizerFailed(format!("LFM: object entry without ':': {entry:?}"))
861 })?;
862 let key = parse_object_key(key_raw.trim())?;
863 if map.contains_key(&key) {
864 return Err(KernelError::NormalizerFailed(format!(
865 "LFM: duplicate object key: {key}"
866 )));
867 }
868
869 map.insert(key, parse_value(value_raw.trim())?);
870 }
871
872 Ok(Value::Object(map))
873}
874
875fn parse_object_key(raw: &str) -> Result<String, KernelError> {
876 match parse_value(raw)? {
877 Value::String(key) => Ok(key),
878 _ => Err(KernelError::NormalizerFailed(format!(
879 "LFM: object key must be a string: {raw:?}"
880 ))),
881 }
882}
883
884fn validate_identifier(kind: &str, value: &str) -> Result<(), KernelError> {
888 let valid = value
889 .chars()
890 .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | '.'));
891
892 if valid {
893 return Ok(());
894 }
895
896 Err(KernelError::NormalizerFailed(format!(
897 "invalid {kind}: {value:?}"
898 )))
899}
900
901fn split_top_level(s: &str, delim: char) -> Vec<&str> {
905 let mut parts: Vec<&str> = Vec::new();
906 let mut depth: usize = 0;
907 let mut in_sq = false;
908 let mut in_dq = false;
909 let mut escape_next = false;
910 let mut start = 0usize;
911
912 for (i, ch) in s.char_indices() {
913 if escape_next {
914 escape_next = false;
915 continue;
916 }
917 if ch == '\\' && (in_sq || in_dq) {
918 escape_next = true;
919 continue;
920 }
921 if in_sq {
922 if ch == '\'' {
923 in_sq = false;
924 }
925 continue;
926 }
927 if in_dq {
928 if ch == '"' {
929 in_dq = false;
930 }
931 continue;
932 }
933 match ch {
934 '\'' => in_sq = true,
935 '"' => in_dq = true,
936 '(' | '[' | '{' => depth = depth.saturating_add(1),
937 ')' | ']' | '}' => depth = depth.saturating_sub(1),
938 c if c == delim && depth == 0 => {
939 parts.push(s.get(start..i).unwrap_or(""));
941 start = i + ch.len_utf8();
942 }
943 _ => {}
944 }
945 }
946 parts.push(s.get(start..).unwrap_or(""));
947 parts
948}
949
950fn split_once_top_level(s: &str, delim: char) -> Option<(&str, &str)> {
951 split_index_top_level(s, delim).map(|idx| {
952 let left = s.get(..idx).unwrap_or("");
953 let right = s.get(idx + delim.len_utf8()..).unwrap_or("");
954 (left, right)
955 })
956}
957
958fn split_index_top_level(s: &str, delim: char) -> Option<usize> {
959 let mut depth: usize = 0;
960 let mut in_sq = false;
961 let mut in_dq = false;
962 let mut escape_next = false;
963
964 for (i, ch) in s.char_indices() {
965 if escape_next {
966 escape_next = false;
967 continue;
968 }
969 if ch == '\\' && (in_sq || in_dq) {
970 escape_next = true;
971 continue;
972 }
973 if in_sq {
974 if ch == '\'' {
975 in_sq = false;
976 }
977 continue;
978 }
979 if in_dq {
980 if ch == '"' {
981 in_dq = false;
982 }
983 continue;
984 }
985 match ch {
986 '\'' => in_sq = true,
987 '"' => in_dq = true,
988 '(' | '[' | '{' => depth = depth.saturating_add(1),
989 ')' | ']' | '}' => depth = depth.saturating_sub(1),
990 c if c == delim && depth == 0 => return Some(i),
991 _ => {}
992 }
993 }
994
995 None
996}
997
998#[cfg(test)]
1001mod tests {
1002 use super::*;
1003 use crate::{LocalTool, ToolRegistry, ToolSchema};
1004 use serde_json::json;
1005 use std::sync::Arc;
1006
1007 #[test]
1010 fn not_applicable_for_plain_text() {
1011 assert!(!LfmNormalizer.is_applicable("hello world"));
1012 }
1013
1014 #[test]
1015 fn applicable_when_start_marker_present() {
1016 assert!(
1017 LfmNormalizer
1018 .is_applicable("<|tool_call_start|>[get_weather(city='Berlin')]<|tool_call_end|>")
1019 );
1020 }
1021
1022 #[test]
1025 fn plain_text_returns_empty() {
1026 let calls = LfmNormalizer
1027 .normalize("The weather in Berlin is sunny.")
1028 .unwrap();
1029 assert!(calls.is_empty());
1030 }
1031
1032 #[test]
1033 fn single_call_string_arg() {
1034 let raw = "<|tool_call_start|>[get_weather(city='Berlin')]<|tool_call_end|>";
1035 let calls = LfmNormalizer.normalize(raw).unwrap();
1036 assert_eq!(calls.len(), 1);
1037 assert_eq!(calls[0].name, "get_weather");
1038 assert_eq!(calls[0].args, json!({"city": "Berlin"}));
1039 }
1040
1041 #[test]
1042 fn single_call_multiple_args() {
1043 let raw = "<|tool_call_start|>[search(query='rust async', limit=10)]<|tool_call_end|>";
1044 let calls = LfmNormalizer.normalize(raw).unwrap();
1045 assert_eq!(calls.len(), 1);
1046 assert_eq!(calls[0].name, "search");
1047 assert_eq!(calls[0].args, json!({"query": "rust async", "limit": 10}));
1048 }
1049
1050 #[test]
1051 fn single_call_no_args() {
1052 let raw = "<|tool_call_start|>[list_tools()]<|tool_call_end|>";
1053 let calls = LfmNormalizer.normalize(raw).unwrap();
1054 assert_eq!(calls.len(), 1);
1055 assert_eq!(calls[0].name, "list_tools");
1056 assert_eq!(calls[0].args, json!({}));
1057 }
1058
1059 #[test]
1060 fn multiple_calls_in_one_block() {
1061 let raw = "<|tool_call_start|>[get_weather(city='Berlin'), get_time(zone='UTC')]<|tool_call_end|>";
1062 let calls = LfmNormalizer.normalize(raw).unwrap();
1063 assert_eq!(calls.len(), 2);
1064 assert_eq!(calls[0].name, "get_weather");
1065 assert_eq!(calls[0].args, json!({"city": "Berlin"}));
1066 assert_eq!(calls[1].name, "get_time");
1067 assert_eq!(calls[1].args, json!({"zone": "UTC"}));
1068 }
1069
1070 #[test]
1071 fn multiple_blocks_in_one_message() {
1072 let raw = concat!(
1073 "<|tool_call_start|>[step_one(x=1)]<|tool_call_end|>",
1074 " some text ",
1075 "<|tool_call_start|>[step_two(y=2)]<|tool_call_end|>",
1076 );
1077 let calls = LfmNormalizer.normalize(raw).unwrap();
1078 assert_eq!(calls.len(), 2);
1079 assert_eq!(calls[0].name, "step_one");
1080 assert_eq!(calls[1].name, "step_two");
1081 }
1082
1083 #[test]
1084 fn block_without_brackets_is_parsed() {
1085 let raw = "<|tool_call_start|>ping(target='8.8.8.8')<|tool_call_end|>";
1087 let calls = LfmNormalizer.normalize(raw).unwrap();
1088 assert_eq!(calls.len(), 1);
1089 assert_eq!(calls[0].name, "ping");
1090 assert_eq!(calls[0].args, json!({"target": "8.8.8.8"}));
1091 }
1092
1093 #[test]
1096 fn integer_arg() {
1097 let raw = "<|tool_call_start|>[set_limit(n=42)]<|tool_call_end|>";
1098 let calls = LfmNormalizer.normalize(raw).unwrap();
1099 assert_eq!(calls[0].args, json!({"n": 42}));
1100 }
1101
1102 #[test]
1103 fn float_arg() {
1104 let raw = "<|tool_call_start|>[set_temp(t=0.7)]<|tool_call_end|>";
1105 let calls = LfmNormalizer.normalize(raw).unwrap();
1106 assert_eq!(calls[0].args["t"].as_f64().unwrap(), 0.7);
1107 }
1108
1109 #[test]
1110 fn boolean_args() {
1111 let raw = "<|tool_call_start|>[configure(verbose=True, strict=False)]<|tool_call_end|>";
1112 let calls = LfmNormalizer.normalize(raw).unwrap();
1113 assert_eq!(calls[0].args, json!({"verbose": true, "strict": false}));
1114 }
1115
1116 #[test]
1117 fn null_args() {
1118 let raw = "<|tool_call_start|>[reset(ctx=None)]<|tool_call_end|>";
1119 let calls = LfmNormalizer.normalize(raw).unwrap();
1120 assert_eq!(calls[0].args, json!({"ctx": null}));
1121 }
1122
1123 #[test]
1124 fn double_quoted_string_arg() {
1125 let raw = r#"<|tool_call_start|>[greet(name="world")]<|tool_call_end|>"#;
1126 let calls = LfmNormalizer.normalize(raw).unwrap();
1127 assert_eq!(calls[0].args, json!({"name": "world"}));
1128 }
1129
1130 #[test]
1131 fn nested_list_and_object_args() {
1132 let raw = "<|tool_call_start|>[plan(items=['a,b', 'c'], meta={'city': 'Berlin', 'coords': [52.52, 13.405], 'active': True})]<|tool_call_end|>";
1133 let calls = LfmNormalizer.normalize(raw).unwrap();
1134 assert_eq!(calls.len(), 1);
1135 assert_eq!(
1136 calls[0].args,
1137 json!({
1138 "items": ["a,b", "c"],
1139 "meta": {
1140 "city": "Berlin",
1141 "coords": [52.52, 13.405],
1142 "active": true
1143 }
1144 })
1145 );
1146 }
1147
1148 #[test]
1149 fn openai_responses_function_call_item() {
1150 let value = json!({
1151 "type": "function_call",
1152 "id": "fc_123",
1153 "call_id": "call_123",
1154 "name": "get_weather",
1155 "arguments": "{\"city\":\"Berlin\"}",
1156 "status": "completed"
1157 });
1158
1159 let calls = StructuredToolCallNormalizer::normalize_openai_responses(&value).unwrap();
1160 assert_eq!(calls.len(), 1);
1161 assert_eq!(calls[0].name, "get_weather");
1162 assert_eq!(calls[0].args, json!({"city": "Berlin"}));
1163 }
1164
1165 #[test]
1166 fn openai_responses_full_response() {
1167 let value = json!({
1168 "id": "resp_123",
1169 "output": [
1170 { "type": "message", "content": [] },
1171 {
1172 "type": "function_call",
1173 "id": "fc_123",
1174 "call_id": "call_123",
1175 "name": "search.docs",
1176 "arguments": {"query": "tool calls"},
1177 "status": "completed"
1178 }
1179 ]
1180 });
1181
1182 let calls = StructuredToolCallNormalizer::normalize_openai_responses(&value).unwrap();
1183 assert_eq!(calls.len(), 1);
1184 assert_eq!(calls[0].name, "search.docs");
1185 assert_eq!(calls[0].args, json!({"query": "tool calls"}));
1186 }
1187
1188 #[test]
1189 fn openai_chat_completions_tool_calls() {
1190 let value = json!({
1191 "choices": [{
1192 "message": {
1193 "role": "assistant",
1194 "content": null,
1195 "tool_calls": [{
1196 "id": "call_123",
1197 "type": "function",
1198 "function": {
1199 "name": "get_weather",
1200 "arguments": "{\"city\":\"Berlin\"}"
1201 }
1202 }]
1203 }
1204 }]
1205 });
1206
1207 let calls =
1208 StructuredToolCallNormalizer::normalize_openai_chat_completions(&value).unwrap();
1209 assert_eq!(calls.len(), 1);
1210 assert_eq!(calls[0].name, "get_weather");
1211 assert_eq!(calls[0].args, json!({"city": "Berlin"}));
1212 }
1213
1214 #[test]
1215 fn structured_normalizer_aggregates_supported_shapes() {
1216 let responses_value = json!({
1217 "output": [{
1218 "type": "function_call",
1219 "name": "first",
1220 "arguments": "{}"
1221 }]
1222 });
1223 let chat_value = json!({
1224 "tool_calls": [{
1225 "function": {
1226 "name": "second",
1227 "arguments": {"ok": true}
1228 }
1229 }]
1230 });
1231
1232 let responses_calls = StructuredToolCallNormalizer::normalize(&responses_value).unwrap();
1233 let chat_calls = StructuredToolCallNormalizer::normalize(&chat_value).unwrap();
1234
1235 assert_eq!(responses_calls[0].name, "first");
1236 assert_eq!(chat_calls[0].name, "second");
1237 assert_eq!(chat_calls[0].args, json!({"ok": true}));
1238 }
1239
1240 #[test]
1243 fn unclosed_marker_returns_error() {
1244 let raw = "<|tool_call_start|>[get_weather(city='Berlin')]";
1245 let err = LfmNormalizer.normalize(raw).unwrap_err();
1246 let msg = err.to_string();
1247 assert!(msg.contains("unclosed"), "expected 'unclosed' in: {msg}");
1248 }
1249
1250 #[test]
1251 fn missing_paren_returns_error() {
1252 let raw = "<|tool_call_start|>[not_a_call]<|tool_call_end|>";
1254 let err = LfmNormalizer.normalize(raw).unwrap_err();
1255 let msg = err.to_string();
1256 assert!(msg.contains("expected '('"), "got: {msg}");
1257 }
1258
1259 #[test]
1260 fn kwarg_without_equals_returns_error() {
1261 let raw = "<|tool_call_start|>[fn(badarg)]<|tool_call_end|>";
1262 let err = LfmNormalizer.normalize(raw).unwrap_err();
1263 let msg = err.to_string();
1264 assert!(msg.contains("kwarg without '='"), "got: {msg}");
1265 }
1266
1267 #[test]
1268 fn invalid_tool_name_returns_error() {
1269 let raw = "<|tool_call_start|>[bad/name(arg=1)]<|tool_call_end|>";
1270 let err = LfmNormalizer.normalize(raw).unwrap_err();
1271 let msg = err.to_string();
1272 assert!(msg.contains("invalid tool name"), "got: {msg}");
1273 }
1274
1275 #[test]
1276 fn empty_kwarg_name_returns_error() {
1277 let raw = "<|tool_call_start|>[fn(=1)]<|tool_call_end|>";
1278 let err = LfmNormalizer.normalize(raw).unwrap_err();
1279 let msg = err.to_string();
1280 assert!(msg.contains("empty kwarg name"), "got: {msg}");
1281 }
1282
1283 #[test]
1284 fn duplicate_kwarg_returns_error() {
1285 let raw = "<|tool_call_start|>[fn(city='Berlin', city='Paris')]<|tool_call_end|>";
1286 let err = LfmNormalizer.normalize(raw).unwrap_err();
1287 let msg = err.to_string();
1288 assert!(msg.contains("duplicate kwarg"), "got: {msg}");
1289 }
1290
1291 #[test]
1292 fn malformed_standard_arguments_return_error() {
1293 let value = json!({
1294 "type": "function_call",
1295 "name": "bad_args",
1296 "arguments": "{not json}"
1297 });
1298
1299 let err = StructuredToolCallNormalizer::normalize_openai_responses(&value).unwrap_err();
1300 let msg = err.to_string();
1301 assert!(msg.contains("arguments JSON"), "got: {msg}");
1302 }
1303
1304 #[test]
1305 fn trailing_call_content_returns_error() {
1306 let raw = "<|tool_call_start|>[fn(arg=1) extra]<|tool_call_end|>";
1307 let err = LfmNormalizer.normalize(raw).unwrap_err();
1308 let msg = err.to_string();
1309 assert!(msg.contains("trailing content"), "got: {msg}");
1310 }
1311
1312 #[test]
1313 fn unterminated_nested_literal_returns_error() {
1314 let raw = "<|tool_call_start|>[fn(items=['a', 'b')]<|tool_call_end|>";
1315 let err = LfmNormalizer.normalize(raw).unwrap_err();
1316 let msg = err.to_string();
1317 assert!(msg.contains("unterminated list"), "got: {msg}");
1318 }
1319
1320 #[tokio::test]
1321 async fn dispatch_invocations_runs_tools_in_order() {
1322 let tools = ToolRegistry::new();
1323 tools.register(Arc::new(LocalTool::new(
1324 ToolSchema {
1325 name: "echo".into(),
1326 description: "echoes args".into(),
1327 args_schema: json!({"type": "object"}),
1328 result_schema: json!({"type": "object"}),
1329 },
1330 |args| async move { Ok(json!({"seen": args})) },
1331 )));
1332
1333 let invocations = LfmNormalizer
1334 .normalize("<|tool_call_start|>[echo(value={'nested': [1, 2]})]<|tool_call_end|>")
1335 .unwrap();
1336 let results = dispatch_tool_invocations(&tools, &invocations)
1337 .await
1338 .unwrap();
1339
1340 assert_eq!(results.len(), 1);
1341 assert_eq!(results[0].invocation.name, "echo");
1342 assert_eq!(
1343 results[0].output,
1344 json!({"seen": {"value": {"nested": [1, 2]}}})
1345 );
1346 }
1347
1348 #[test]
1351 fn split_respects_parens() {
1352 let parts = split_top_level("fn(a, b), fn2(c)", ',');
1354 assert_eq!(parts, vec!["fn(a, b)", " fn2(c)"]);
1355 }
1356
1357 #[test]
1358 fn split_respects_single_quotes() {
1359 let parts = split_top_level("a='x,y', b=2", ',');
1360 assert_eq!(parts, vec!["a='x,y'", " b=2"]);
1361 }
1362
1363 #[test]
1364 fn split_respects_nested_arrays_and_objects() {
1365 let parts = split_top_level("a=[1, 2], b={'x': 'y,z'}, c=3", ',');
1366 assert_eq!(parts, vec!["a=[1, 2]", " b={'x': 'y,z'}", " c=3"]);
1367 }
1368}