1use std::collections::HashMap;
9use std::sync::Arc;
10
11use motosan_agent_tool::{Tool, ToolContext, ToolResult};
12use tokio::task::JoinHandle;
13
14use crate::llm::ToolCallItem;
15
16pub struct StreamingToolExecutor {
22 pending: Vec<(ToolCallItem, JoinHandle<ToolResult>)>,
23}
24
25impl StreamingToolExecutor {
26 pub fn new() -> Self {
28 Self {
29 pending: Vec::new(),
30 }
31 }
32
33 pub fn submit(
38 &mut self,
39 item: ToolCallItem,
40 tool_map: &HashMap<String, Arc<dyn Tool>>,
41 timeout: Option<std::time::Duration>,
42 ctx: &ToolContext,
43 ) {
44 let tool = tool_map.get(&item.name).cloned();
45 let name = item.name.clone();
46 let args = item.args.clone();
47 let ctx = ctx.clone();
48
49 let handle = tokio::spawn(async move {
50 let fut = async {
51 if let Some(tool) = tool {
52 tool.call(args, &ctx).await
53 } else {
54 ToolResult::error(format!("unknown tool: {name}"))
55 }
56 };
57 if let Some(dur) = timeout {
58 match tokio::time::timeout(dur, fut).await {
59 Ok(result) => result,
60 Err(_) => ToolResult::error(format!("tool '{name}' timed out after {dur:?}")),
61 }
62 } else {
63 fut.await
64 }
65 });
66
67 self.pending.push((item, handle));
68 }
69
70 pub fn has_pending(&self) -> bool {
72 !self.pending.is_empty()
73 }
74
75 pub async fn collect(self) -> (Vec<ToolCallItem>, Vec<ToolResult>) {
80 let mut items = Vec::with_capacity(self.pending.len());
81 let mut results = Vec::with_capacity(self.pending.len());
82
83 for (item, handle) in self.pending {
84 let result = handle
85 .await
86 .unwrap_or_else(|e| ToolResult::error(format!("tool task panicked: {e}")));
87 items.push(item);
88 results.push(result);
89 }
90
91 (items, results)
92 }
93}
94
95impl Default for StreamingToolExecutor {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use motosan_agent_tool::ToolDef;
105 use std::sync::atomic::{AtomicBool, Ordering};
106
107 struct TimestampTool {
109 name: String,
110 started: Arc<AtomicBool>,
111 result: String,
112 }
113
114 impl TimestampTool {
115 fn new(name: &str, result: &str, started: Arc<AtomicBool>) -> Self {
116 Self {
117 name: name.to_string(),
118 started,
119 result: result.to_string(),
120 }
121 }
122 }
123
124 impl Tool for TimestampTool {
125 fn def(&self) -> ToolDef {
126 ToolDef {
127 name: self.name.clone(),
128 description: "test tool".into(),
129 input_schema: serde_json::json!({
130 "type": "object",
131 "properties": {},
132 "required": []
133 }),
134 }
135 }
136
137 fn call(
138 &self,
139 _args: serde_json::Value,
140 _ctx: &ToolContext,
141 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult> + Send + '_>> {
142 let started = self.started.clone();
143 let result = self.result.clone();
144 Box::pin(async move {
145 started.store(true, Ordering::SeqCst);
146 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
148 ToolResult::text(result)
149 })
150 }
151 }
152
153 #[tokio::test]
154 async fn submit_and_collect_returns_results_in_order() {
155 let started_a = Arc::new(AtomicBool::new(false));
156 let started_b = Arc::new(AtomicBool::new(false));
157
158 let tool_a: Arc<dyn Tool> =
159 Arc::new(TimestampTool::new("tool_a", "result_a", started_a.clone()));
160 let tool_b: Arc<dyn Tool> =
161 Arc::new(TimestampTool::new("tool_b", "result_b", started_b.clone()));
162
163 let mut tool_map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
164 tool_map.insert("tool_a".to_string(), tool_a);
165 tool_map.insert("tool_b".to_string(), tool_b);
166
167 let ctx = ToolContext::default();
168 let mut executor = StreamingToolExecutor::new();
169
170 let item_a = ToolCallItem {
171 id: "call_1".to_string(),
172 name: "tool_a".to_string(),
173 args: serde_json::json!({}),
174 };
175 let item_b = ToolCallItem {
176 id: "call_2".to_string(),
177 name: "tool_b".to_string(),
178 args: serde_json::json!({}),
179 };
180
181 executor.submit(item_a, &tool_map, None, &ctx);
182 executor.submit(item_b, &tool_map, None, &ctx);
183
184 assert!(executor.has_pending());
185
186 let (items, results) = executor.collect().await;
187
188 assert_eq!(items.len(), 2);
189 assert_eq!(items[0].id, "call_1");
190 assert_eq!(items[1].id, "call_2");
191 assert_eq!(results.len(), 2);
192 assert!(started_a.load(Ordering::SeqCst));
194 assert!(started_b.load(Ordering::SeqCst));
195 }
196
197 #[tokio::test]
198 async fn unknown_tool_returns_error_result() {
199 let tool_map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
200 let ctx = ToolContext::default();
201 let mut executor = StreamingToolExecutor::new();
202
203 let item = ToolCallItem {
204 id: "call_1".to_string(),
205 name: "nonexistent".to_string(),
206 args: serde_json::json!({}),
207 };
208
209 executor.submit(item, &tool_map, None, &ctx);
210 let (items, results) = executor.collect().await;
211
212 assert_eq!(items.len(), 1);
213 assert_eq!(results.len(), 1);
214 let text = format!("{:?}", results[0]);
216 assert!(text.contains("unknown tool"), "got: {text}");
217 }
218
219 #[tokio::test]
220 async fn tools_start_executing_immediately_after_submit() {
221 let started = Arc::new(AtomicBool::new(false));
222 let tool: Arc<dyn Tool> =
223 Arc::new(TimestampTool::new("slow_tool", "done", started.clone()));
224
225 let mut tool_map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
226 tool_map.insert("slow_tool".to_string(), tool);
227
228 let ctx = ToolContext::default();
229 let mut executor = StreamingToolExecutor::new();
230
231 let item = ToolCallItem {
232 id: "call_1".to_string(),
233 name: "slow_tool".to_string(),
234 args: serde_json::json!({}),
235 };
236
237 executor.submit(item, &tool_map, None, &ctx);
238
239 tokio::task::yield_now().await;
241 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
242
243 assert!(
245 started.load(Ordering::SeqCst),
246 "tool should start executing immediately after submit, before collect()"
247 );
248
249 let (_, results) = executor.collect().await;
250 assert_eq!(results.len(), 1);
251 }
252
253 #[tokio::test]
254 async fn empty_executor_collects_nothing() {
255 let executor = StreamingToolExecutor::new();
256 assert!(!executor.has_pending());
257
258 let (items, results) = executor.collect().await;
259 assert!(items.is_empty());
260 assert!(results.is_empty());
261 }
262}