1use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
2use crate::tools::{Tool, ToolSpec};
3use serde_json::Value;
4use std::fmt::Write;
5
6#[derive(Debug, Clone)]
7pub struct ParsedToolCall {
8 pub name: String,
9 pub arguments: Value,
10 pub tool_call_id: Option<String>,
11}
12
13#[derive(Debug, Clone)]
14pub struct ToolExecutionResult {
15 pub name: String,
16 pub output: String,
17 pub success: bool,
18 pub tool_call_id: Option<String>,
19}
20
21pub trait ToolDispatcher: Send + Sync {
22 fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
23 fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
24 fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
25 fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
26 fn should_send_tool_specs(&self) -> bool;
27}
28
29#[derive(Default)]
30pub struct XmlToolDispatcher;
31
32impl XmlToolDispatcher {
33 fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
34 let cleaned = Self::strip_think_tags(response);
37 let mut text_parts = Vec::new();
38 let mut calls = Vec::new();
39 let mut remaining = cleaned.as_str();
40
41 while let Some(start) = remaining.find("<tool_call>") {
42 let before = &remaining[..start];
43 if !before.trim().is_empty() {
44 text_parts.push(before.trim().to_string());
45 }
46
47 if let Some(end) = remaining[start..].find("</tool_call>") {
48 let inner = &remaining[start + 11..start + end];
49 match serde_json::from_str::<Value>(inner.trim()) {
50 Ok(parsed) => {
51 let name = parsed
52 .get("name")
53 .and_then(Value::as_str)
54 .unwrap_or("")
55 .to_string();
56 if name.is_empty() {
57 remaining = &remaining[start + end + 12..];
58 continue;
59 }
60 let arguments = parsed
61 .get("arguments")
62 .cloned()
63 .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
64 calls.push(ParsedToolCall {
65 name,
66 arguments,
67 tool_call_id: None,
68 });
69 }
70 Err(e) => {
71 tracing::warn!("Malformed <tool_call> JSON: {e}");
72 }
73 }
74 remaining = &remaining[start + end + 12..];
75 } else {
76 break;
77 }
78 }
79
80 if !remaining.trim().is_empty() {
81 text_parts.push(remaining.trim().to_string());
82 }
83
84 (text_parts.join("\n"), calls)
85 }
86
87 fn strip_think_tags(s: &str) -> String {
89 let mut result = String::with_capacity(s.len());
90 let mut rest = s;
91 loop {
92 if let Some(start) = rest.find("<think>") {
93 result.push_str(&rest[..start]);
94 if let Some(end) = rest[start..].find("</think>") {
95 rest = &rest[start + end + "</think>".len()..];
96 } else {
97 break;
98 }
99 } else {
100 result.push_str(rest);
101 break;
102 }
103 }
104 result
105 }
106
107 pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
108 tools.iter().map(|tool| tool.spec()).collect()
109 }
110}
111
112impl ToolDispatcher for XmlToolDispatcher {
113 fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
114 let text = response.text_or_empty();
115 Self::parse_xml_tool_calls(text)
116 }
117
118 fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
119 let mut content = String::new();
120 for result in results {
121 let status = if result.success { "ok" } else { "error" };
122 let _ = writeln!(
123 content,
124 "<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
125 result.name, status, result.output
126 );
127 }
128 ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
129 }
130
131 fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
132 let mut instructions = String::new();
133 instructions.push_str("## Tool Use Protocol\n\n");
134 instructions
135 .push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
136 instructions.push_str(
137 "```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
138 );
139
140 instructions
141 }
142
143 fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
144 history
145 .iter()
146 .flat_map(|msg| match msg {
147 ConversationMessage::Chat(chat) => vec![chat.clone()],
148 ConversationMessage::AssistantToolCalls { text, .. } => {
149 vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
150 }
151 ConversationMessage::ToolResults(results) => {
152 let mut content = String::new();
153 for result in results {
154 let _ = writeln!(
155 content,
156 "<tool_result id=\"{}\">\n{}\n</tool_result>",
157 result.tool_call_id, result.content
158 );
159 }
160 vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
161 }
162 })
163 .collect()
164 }
165
166 fn should_send_tool_specs(&self) -> bool {
167 false
168 }
169}
170
171pub struct NativeToolDispatcher;
172
173impl ToolDispatcher for NativeToolDispatcher {
174 fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
175 let text = response.text.clone().unwrap_or_default();
176 let calls = response
177 .tool_calls
178 .iter()
179 .map(|tc| ParsedToolCall {
180 name: tc.name.clone(),
181 arguments: {
182 let raw = tc.arguments.trim();
183 if raw.is_empty() {
184 Value::Object(serde_json::Map::new())
185 } else {
186 serde_json::from_str(raw).unwrap_or_else(|e| {
187 tracing::warn!(
188 tool = %tc.name,
189 error = %e,
190 "Failed to parse native tool call arguments as JSON; defaulting to empty object"
191 );
192 Value::Object(serde_json::Map::new())
193 })
194 }
195 },
196 tool_call_id: Some(tc.id.clone()),
197 })
198 .collect();
199 (text, calls)
200 }
201
202 fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
203 let messages = results
204 .iter()
205 .map(|result| ToolResultMessage {
206 tool_call_id: result
207 .tool_call_id
208 .clone()
209 .unwrap_or_else(|| "unknown".to_string()),
210 content: result.output.clone(),
211 })
212 .collect();
213 ConversationMessage::ToolResults(messages)
214 }
215
216 fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
217 String::new()
218 }
219
220 fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
221 history
222 .iter()
223 .flat_map(|msg| match msg {
224 ConversationMessage::Chat(chat) => vec![chat.clone()],
225 ConversationMessage::AssistantToolCalls {
226 text,
227 tool_calls,
228 reasoning_content,
229 } => {
230 let mut payload = serde_json::json!({
231 "content": text,
232 "tool_calls": tool_calls,
233 });
234 if let Some(rc) = reasoning_content {
235 payload["reasoning_content"] = serde_json::json!(rc);
236 }
237 vec![ChatMessage::assistant(payload.to_string())]
238 }
239 ConversationMessage::ToolResults(results) => results
240 .iter()
241 .map(|result| {
242 ChatMessage::tool(
243 serde_json::json!({
244 "tool_call_id": result.tool_call_id,
245 "content": result.content,
246 })
247 .to_string(),
248 )
249 })
250 .collect(),
251 })
252 .collect()
253 }
254
255 fn should_send_tool_specs(&self) -> bool {
256 true
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn xml_dispatcher_parses_tool_calls() {
266 let response = ChatResponse {
267 text: Some(
268 "Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
269 .into(),
270 ),
271 tool_calls: vec![],
272 usage: None,
273 reasoning_content: None,
274 };
275 let dispatcher = XmlToolDispatcher;
276 let (_, calls) = dispatcher.parse_response(&response);
277 assert_eq!(calls.len(), 1);
278 assert_eq!(calls[0].name, "shell");
279 }
280
281 #[test]
282 fn xml_dispatcher_strips_think_before_tool_call() {
283 let response = ChatResponse {
284 text: Some(
285 "<think>I should list files</think>\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
286 .into(),
287 ),
288 tool_calls: vec![],
289 usage: None,
290 reasoning_content: None,
291 };
292 let dispatcher = XmlToolDispatcher;
293 let (text, calls) = dispatcher.parse_response(&response);
294 assert_eq!(calls.len(), 1);
295 assert_eq!(calls[0].name, "shell");
296 assert!(
297 !text.contains("<think>"),
298 "think tags should be stripped from text"
299 );
300 }
301
302 #[test]
303 fn xml_dispatcher_think_only_returns_no_calls() {
304 let response = ChatResponse {
305 text: Some("<think>Just thinking</think>".into()),
306 tool_calls: vec![],
307 usage: None,
308 reasoning_content: None,
309 };
310 let dispatcher = XmlToolDispatcher;
311 let (_, calls) = dispatcher.parse_response(&response);
312 assert!(calls.is_empty());
313 }
314
315 #[test]
316 fn native_dispatcher_roundtrip() {
317 let response = ChatResponse {
318 text: Some("ok".into()),
319 tool_calls: vec![crate::providers::ToolCall {
320 id: "tc1".into(),
321 name: "file_read".into(),
322 arguments: "{\"path\":\"a.txt\"}".into(),
323 }],
324 usage: None,
325 reasoning_content: None,
326 };
327 let dispatcher = NativeToolDispatcher;
328 let (_, calls) = dispatcher.parse_response(&response);
329 assert_eq!(calls.len(), 1);
330 assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
331
332 let msg = dispatcher.format_results(&[ToolExecutionResult {
333 name: "file_read".into(),
334 output: "hello".into(),
335 success: true,
336 tool_call_id: Some("tc1".into()),
337 }]);
338 match msg {
339 ConversationMessage::ToolResults(results) => {
340 assert_eq!(results.len(), 1);
341 assert_eq!(results[0].tool_call_id, "tc1");
342 }
343 _ => panic!("expected tool results"),
344 }
345 }
346
347 #[test]
348 fn xml_format_results_contains_tool_result_tags() {
349 let dispatcher = XmlToolDispatcher;
350 let msg = dispatcher.format_results(&[ToolExecutionResult {
351 name: "shell".into(),
352 output: "ok".into(),
353 success: true,
354 tool_call_id: None,
355 }]);
356 let rendered = match msg {
357 ConversationMessage::Chat(chat) => chat.content,
358 _ => String::new(),
359 };
360 assert!(rendered.contains("<tool_result"));
361 assert!(rendered.contains("shell"));
362 }
363
364 #[test]
365 fn native_format_results_keeps_tool_call_id() {
366 let dispatcher = NativeToolDispatcher;
367 let msg = dispatcher.format_results(&[ToolExecutionResult {
368 name: "shell".into(),
369 output: "ok".into(),
370 success: true,
371 tool_call_id: Some("tc-1".into()),
372 }]);
373
374 match msg {
375 ConversationMessage::ToolResults(results) => {
376 assert_eq!(results.len(), 1);
377 assert_eq!(results[0].tool_call_id, "tc-1");
378 }
379 _ => panic!("expected ToolResults variant"),
380 }
381 }
382
383 #[test]
388 fn native_to_provider_messages_includes_reasoning_content() {
389 let dispatcher = NativeToolDispatcher;
390 let history = vec![ConversationMessage::AssistantToolCalls {
391 text: Some("answer".into()),
392 tool_calls: vec![crate::providers::ToolCall {
393 id: "tc_1".into(),
394 name: "shell".into(),
395 arguments: "{}".into(),
396 }],
397 reasoning_content: Some("thinking step".into()),
398 }];
399
400 let messages = dispatcher.to_provider_messages(&history);
401 assert_eq!(messages.len(), 1);
402 assert_eq!(messages[0].role, "assistant");
403
404 let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
405 assert_eq!(payload["reasoning_content"].as_str(), Some("thinking step"));
406 assert_eq!(payload["content"].as_str(), Some("answer"));
407 assert!(payload["tool_calls"].is_array());
408 }
409
410 #[test]
411 fn native_to_provider_messages_omits_reasoning_content_when_none() {
412 let dispatcher = NativeToolDispatcher;
413 let history = vec![ConversationMessage::AssistantToolCalls {
414 text: Some("answer".into()),
415 tool_calls: vec![crate::providers::ToolCall {
416 id: "tc_1".into(),
417 name: "shell".into(),
418 arguments: "{}".into(),
419 }],
420 reasoning_content: None,
421 }];
422
423 let messages = dispatcher.to_provider_messages(&history);
424 assert_eq!(messages.len(), 1);
425
426 let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
427 assert!(payload.get("reasoning_content").is_none());
428 }
429
430 #[test]
431 fn xml_to_provider_messages_ignores_reasoning_content() {
432 let dispatcher = XmlToolDispatcher;
433 let history = vec![ConversationMessage::AssistantToolCalls {
434 text: Some("answer".into()),
435 tool_calls: vec![crate::providers::ToolCall {
436 id: "tc_1".into(),
437 name: "shell".into(),
438 arguments: "{}".into(),
439 }],
440 reasoning_content: Some("should be ignored".into()),
441 }];
442
443 let messages = dispatcher.to_provider_messages(&history);
444 assert_eq!(messages.len(), 1);
445 assert_eq!(messages[0].role, "assistant");
446 assert_eq!(messages[0].content, "answer");
448 assert!(!messages[0].content.contains("reasoning_content"));
449 }
450}