Skip to main content

construct/agent/
dispatcher.rs

1use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
2use crate::tools::{Tool, ToolSpec};
3use serde_json::Value;
4use std::fmt::Write;
5
6#[derive(Debug, Clone)]
7pub struct ParsedToolCall {
8    pub name: String,
9    pub arguments: Value,
10    pub tool_call_id: Option<String>,
11}
12
13#[derive(Debug, Clone)]
14pub struct ToolExecutionResult {
15    pub name: String,
16    pub output: String,
17    pub success: bool,
18    pub tool_call_id: Option<String>,
19}
20
21pub trait ToolDispatcher: Send + Sync {
22    fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
23    fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
24    fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
25    fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
26    fn should_send_tool_specs(&self) -> bool;
27}
28
29#[derive(Default)]
30pub struct XmlToolDispatcher;
31
32impl XmlToolDispatcher {
33    fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
34        // Strip `<think>...</think>` blocks before parsing tool calls.
35        // Qwen and other reasoning models may embed chain-of-thought inline.
36        let cleaned = Self::strip_think_tags(response);
37        let mut text_parts = Vec::new();
38        let mut calls = Vec::new();
39        let mut remaining = cleaned.as_str();
40
41        while let Some(start) = remaining.find("<tool_call>") {
42            let before = &remaining[..start];
43            if !before.trim().is_empty() {
44                text_parts.push(before.trim().to_string());
45            }
46
47            if let Some(end) = remaining[start..].find("</tool_call>") {
48                let inner = &remaining[start + 11..start + end];
49                match serde_json::from_str::<Value>(inner.trim()) {
50                    Ok(parsed) => {
51                        let name = parsed
52                            .get("name")
53                            .and_then(Value::as_str)
54                            .unwrap_or("")
55                            .to_string();
56                        if name.is_empty() {
57                            remaining = &remaining[start + end + 12..];
58                            continue;
59                        }
60                        let arguments = parsed
61                            .get("arguments")
62                            .cloned()
63                            .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
64                        calls.push(ParsedToolCall {
65                            name,
66                            arguments,
67                            tool_call_id: None,
68                        });
69                    }
70                    Err(e) => {
71                        tracing::warn!("Malformed <tool_call> JSON: {e}");
72                    }
73                }
74                remaining = &remaining[start + end + 12..];
75            } else {
76                break;
77            }
78        }
79
80        if !remaining.trim().is_empty() {
81            text_parts.push(remaining.trim().to_string());
82        }
83
84        (text_parts.join("\n"), calls)
85    }
86
87    /// Remove `<think>...</think>` blocks from model output.
88    fn strip_think_tags(s: &str) -> String {
89        let mut result = String::with_capacity(s.len());
90        let mut rest = s;
91        loop {
92            if let Some(start) = rest.find("<think>") {
93                result.push_str(&rest[..start]);
94                if let Some(end) = rest[start..].find("</think>") {
95                    rest = &rest[start + end + "</think>".len()..];
96                } else {
97                    break;
98                }
99            } else {
100                result.push_str(rest);
101                break;
102            }
103        }
104        result
105    }
106
107    pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
108        tools.iter().map(|tool| tool.spec()).collect()
109    }
110}
111
112impl ToolDispatcher for XmlToolDispatcher {
113    fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
114        let text = response.text_or_empty();
115        Self::parse_xml_tool_calls(text)
116    }
117
118    fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
119        let mut content = String::new();
120        for result in results {
121            let status = if result.success { "ok" } else { "error" };
122            let _ = writeln!(
123                content,
124                "<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
125                result.name, status, result.output
126            );
127        }
128        ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
129    }
130
131    fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
132        let mut instructions = String::new();
133        instructions.push_str("## Tool Use Protocol\n\n");
134        instructions
135            .push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
136        instructions.push_str(
137            "```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
138        );
139
140        instructions
141    }
142
143    fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
144        history
145            .iter()
146            .flat_map(|msg| match msg {
147                ConversationMessage::Chat(chat) => vec![chat.clone()],
148                ConversationMessage::AssistantToolCalls { text, .. } => {
149                    vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
150                }
151                ConversationMessage::ToolResults(results) => {
152                    let mut content = String::new();
153                    for result in results {
154                        let _ = writeln!(
155                            content,
156                            "<tool_result id=\"{}\">\n{}\n</tool_result>",
157                            result.tool_call_id, result.content
158                        );
159                    }
160                    vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
161                }
162            })
163            .collect()
164    }
165
166    fn should_send_tool_specs(&self) -> bool {
167        false
168    }
169}
170
171pub struct NativeToolDispatcher;
172
173impl ToolDispatcher for NativeToolDispatcher {
174    fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
175        let text = response.text.clone().unwrap_or_default();
176        let calls = response
177            .tool_calls
178            .iter()
179            .map(|tc| ParsedToolCall {
180                name: tc.name.clone(),
181                arguments: {
182                    let raw = tc.arguments.trim();
183                    if raw.is_empty() {
184                        Value::Object(serde_json::Map::new())
185                    } else {
186                        serde_json::from_str(raw).unwrap_or_else(|e| {
187                            tracing::warn!(
188                                tool = %tc.name,
189                                error = %e,
190                                "Failed to parse native tool call arguments as JSON; defaulting to empty object"
191                            );
192                            Value::Object(serde_json::Map::new())
193                        })
194                    }
195                },
196                tool_call_id: Some(tc.id.clone()),
197            })
198            .collect();
199        (text, calls)
200    }
201
202    fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
203        let messages = results
204            .iter()
205            .map(|result| ToolResultMessage {
206                tool_call_id: result
207                    .tool_call_id
208                    .clone()
209                    .unwrap_or_else(|| "unknown".to_string()),
210                content: result.output.clone(),
211            })
212            .collect();
213        ConversationMessage::ToolResults(messages)
214    }
215
216    fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
217        String::new()
218    }
219
220    fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
221        history
222            .iter()
223            .flat_map(|msg| match msg {
224                ConversationMessage::Chat(chat) => vec![chat.clone()],
225                ConversationMessage::AssistantToolCalls {
226                    text,
227                    tool_calls,
228                    reasoning_content,
229                } => {
230                    let mut payload = serde_json::json!({
231                        "content": text,
232                        "tool_calls": tool_calls,
233                    });
234                    if let Some(rc) = reasoning_content {
235                        payload["reasoning_content"] = serde_json::json!(rc);
236                    }
237                    vec![ChatMessage::assistant(payload.to_string())]
238                }
239                ConversationMessage::ToolResults(results) => results
240                    .iter()
241                    .map(|result| {
242                        ChatMessage::tool(
243                            serde_json::json!({
244                                "tool_call_id": result.tool_call_id,
245                                "content": result.content,
246                            })
247                            .to_string(),
248                        )
249                    })
250                    .collect(),
251            })
252            .collect()
253    }
254
255    fn should_send_tool_specs(&self) -> bool {
256        true
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn xml_dispatcher_parses_tool_calls() {
266        let response = ChatResponse {
267            text: Some(
268                "Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
269                    .into(),
270            ),
271            tool_calls: vec![],
272            usage: None,
273            reasoning_content: None,
274        };
275        let dispatcher = XmlToolDispatcher;
276        let (_, calls) = dispatcher.parse_response(&response);
277        assert_eq!(calls.len(), 1);
278        assert_eq!(calls[0].name, "shell");
279    }
280
281    #[test]
282    fn xml_dispatcher_strips_think_before_tool_call() {
283        let response = ChatResponse {
284            text: Some(
285                "<think>I should list files</think>\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
286                    .into(),
287            ),
288            tool_calls: vec![],
289            usage: None,
290            reasoning_content: None,
291        };
292        let dispatcher = XmlToolDispatcher;
293        let (text, calls) = dispatcher.parse_response(&response);
294        assert_eq!(calls.len(), 1);
295        assert_eq!(calls[0].name, "shell");
296        assert!(
297            !text.contains("<think>"),
298            "think tags should be stripped from text"
299        );
300    }
301
302    #[test]
303    fn xml_dispatcher_think_only_returns_no_calls() {
304        let response = ChatResponse {
305            text: Some("<think>Just thinking</think>".into()),
306            tool_calls: vec![],
307            usage: None,
308            reasoning_content: None,
309        };
310        let dispatcher = XmlToolDispatcher;
311        let (_, calls) = dispatcher.parse_response(&response);
312        assert!(calls.is_empty());
313    }
314
315    #[test]
316    fn native_dispatcher_roundtrip() {
317        let response = ChatResponse {
318            text: Some("ok".into()),
319            tool_calls: vec![crate::providers::ToolCall {
320                id: "tc1".into(),
321                name: "file_read".into(),
322                arguments: "{\"path\":\"a.txt\"}".into(),
323            }],
324            usage: None,
325            reasoning_content: None,
326        };
327        let dispatcher = NativeToolDispatcher;
328        let (_, calls) = dispatcher.parse_response(&response);
329        assert_eq!(calls.len(), 1);
330        assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
331
332        let msg = dispatcher.format_results(&[ToolExecutionResult {
333            name: "file_read".into(),
334            output: "hello".into(),
335            success: true,
336            tool_call_id: Some("tc1".into()),
337        }]);
338        match msg {
339            ConversationMessage::ToolResults(results) => {
340                assert_eq!(results.len(), 1);
341                assert_eq!(results[0].tool_call_id, "tc1");
342            }
343            _ => panic!("expected tool results"),
344        }
345    }
346
347    #[test]
348    fn xml_format_results_contains_tool_result_tags() {
349        let dispatcher = XmlToolDispatcher;
350        let msg = dispatcher.format_results(&[ToolExecutionResult {
351            name: "shell".into(),
352            output: "ok".into(),
353            success: true,
354            tool_call_id: None,
355        }]);
356        let rendered = match msg {
357            ConversationMessage::Chat(chat) => chat.content,
358            _ => String::new(),
359        };
360        assert!(rendered.contains("<tool_result"));
361        assert!(rendered.contains("shell"));
362    }
363
364    #[test]
365    fn native_format_results_keeps_tool_call_id() {
366        let dispatcher = NativeToolDispatcher;
367        let msg = dispatcher.format_results(&[ToolExecutionResult {
368            name: "shell".into(),
369            output: "ok".into(),
370            success: true,
371            tool_call_id: Some("tc-1".into()),
372        }]);
373
374        match msg {
375            ConversationMessage::ToolResults(results) => {
376                assert_eq!(results.len(), 1);
377                assert_eq!(results[0].tool_call_id, "tc-1");
378            }
379            _ => panic!("expected ToolResults variant"),
380        }
381    }
382
383    // ═══════════════════════════════════════════════════════════════════════
384    // reasoning_content pass-through tests
385    // ═══════════════════════════════════════════════════════════════════════
386
387    #[test]
388    fn native_to_provider_messages_includes_reasoning_content() {
389        let dispatcher = NativeToolDispatcher;
390        let history = vec![ConversationMessage::AssistantToolCalls {
391            text: Some("answer".into()),
392            tool_calls: vec![crate::providers::ToolCall {
393                id: "tc_1".into(),
394                name: "shell".into(),
395                arguments: "{}".into(),
396            }],
397            reasoning_content: Some("thinking step".into()),
398        }];
399
400        let messages = dispatcher.to_provider_messages(&history);
401        assert_eq!(messages.len(), 1);
402        assert_eq!(messages[0].role, "assistant");
403
404        let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
405        assert_eq!(payload["reasoning_content"].as_str(), Some("thinking step"));
406        assert_eq!(payload["content"].as_str(), Some("answer"));
407        assert!(payload["tool_calls"].is_array());
408    }
409
410    #[test]
411    fn native_to_provider_messages_omits_reasoning_content_when_none() {
412        let dispatcher = NativeToolDispatcher;
413        let history = vec![ConversationMessage::AssistantToolCalls {
414            text: Some("answer".into()),
415            tool_calls: vec![crate::providers::ToolCall {
416                id: "tc_1".into(),
417                name: "shell".into(),
418                arguments: "{}".into(),
419            }],
420            reasoning_content: None,
421        }];
422
423        let messages = dispatcher.to_provider_messages(&history);
424        assert_eq!(messages.len(), 1);
425
426        let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
427        assert!(payload.get("reasoning_content").is_none());
428    }
429
430    #[test]
431    fn xml_to_provider_messages_ignores_reasoning_content() {
432        let dispatcher = XmlToolDispatcher;
433        let history = vec![ConversationMessage::AssistantToolCalls {
434            text: Some("answer".into()),
435            tool_calls: vec![crate::providers::ToolCall {
436                id: "tc_1".into(),
437                name: "shell".into(),
438                arguments: "{}".into(),
439            }],
440            reasoning_content: Some("should be ignored".into()),
441        }];
442
443        let messages = dispatcher.to_provider_messages(&history);
444        assert_eq!(messages.len(), 1);
445        assert_eq!(messages[0].role, "assistant");
446        // XmlToolDispatcher returns text only, not JSON payload
447        assert_eq!(messages[0].content, "answer");
448        assert!(!messages[0].content.contains("reasoning_content"));
449    }
450}