codetether_agent/session/helper/
markup.rs1use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
2use regex::Regex;
3use serde_json::Value;
4use std::sync::OnceLock;
5use uuid::Uuid;
6
7pub fn tool_call_markup_re() -> &'static Regex {
8 static TOOL_CALL_RE: OnceLock<Regex> = OnceLock::new();
9 TOOL_CALL_RE.get_or_init(|| {
10 Regex::new(r"(?s)<tool_call>\s*(?:```(?:json)?\s*)?(\{.*?\})(?:\s*```)?\s*</tool_call>")
11 .expect("tool_call regex must compile")
12 })
13}
14
15pub fn extract_markup_tool_calls(text: &str) -> (String, Vec<(String, String)>) {
16 let mut calls = Vec::new();
17 let re = tool_call_markup_re();
18
19 for capture in re.captures_iter(text) {
20 let Some(block) = capture.get(1).map(|m| m.as_str()) else {
21 continue;
22 };
23 let Ok(payload) = serde_json::from_str::<Value>(block) else {
24 continue;
25 };
26 let Some(name) = payload.get("name").and_then(Value::as_str) else {
27 continue;
28 };
29 let arguments = payload
30 .get("arguments")
31 .or_else(|| payload.get("args"))
32 .or_else(|| payload.get("input"))
33 .cloned()
34 .map(|v| serde_json::to_string(&v).unwrap_or_else(|_| "{}".to_string()))
35 .unwrap_or_else(|| {
36 if let Some(obj) = payload.as_object() {
39 let params: serde_json::Map<String, Value> = obj
40 .iter()
41 .filter(|(k, _)| *k != "name")
42 .map(|(k, v)| (k.clone(), v.clone()))
43 .collect();
44 serde_json::to_string(&Value::Object(params))
45 .unwrap_or_else(|_| "{}".to_string())
46 } else {
47 "{}".to_string()
48 }
49 });
50 calls.push((name.to_string(), arguments));
51 }
52
53 let cleaned = re.replace_all(text, "").into_owned();
54 (cleaned, calls)
55}
56
57pub fn normalize_textual_tool_calls(
58 mut response: CompletionResponse,
59 tools: &[ToolDefinition],
60) -> CompletionResponse {
61 if response
62 .message
63 .content
64 .iter()
65 .any(|p| matches!(p, ContentPart::ToolCall { .. }))
66 {
67 return response;
68 }
69
70 if tools.is_empty() {
71 return response;
72 }
73
74 let mut rewritten = Vec::with_capacity(response.message.content.len());
75 let mut parsed_calls: Vec<(String, String)> = Vec::new();
76 let allowed_tools: std::collections::HashSet<&str> =
77 tools.iter().map(|t| t.name.as_str()).collect();
78
79 for part in response.message.content {
80 match part {
81 ContentPart::Text { text } => {
82 let (cleaned, calls) = extract_markup_tool_calls(&text);
83 for (name, arguments) in calls {
84 if allowed_tools.contains(name.as_str()) {
85 parsed_calls.push((name, arguments));
86 } else {
87 tracing::warn!(tool = %name, "Ignoring unknown <tool_call> tool name");
88 }
89 }
90
91 if !cleaned.trim().is_empty() {
92 rewritten.push(ContentPart::Text {
93 text: cleaned.trim().to_string(),
94 });
95 }
96 }
97 other => rewritten.push(other),
98 }
99 }
100
101 if parsed_calls.is_empty() {
102 response.message.content = rewritten;
103 return response;
104 }
105
106 for (name, arguments) in parsed_calls {
107 rewritten.push(ContentPart::ToolCall {
108 id: Uuid::new_v4().to_string(),
109 name,
110 arguments,
111 thought_signature: None,
112 });
113 }
114
115 response.message.content = rewritten;
116 response.finish_reason = FinishReason::ToolCalls;
117 response
118}