Skip to main content

motosan_agent_loop/
streaming_executor.rs

1//! Streaming tool executor — starts executing tools as LLM streams tool_use blocks.
2//!
3//! Instead of waiting for the LLM to finish streaming before executing any tools,
4//! [`StreamingToolExecutor`] spawns tool execution tasks the moment a complete
5//! `tool_use` block arrives. This overlaps tool execution with ongoing LLM output,
6//! reducing wall-clock latency.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use motosan_agent_tool::{Tool, ToolContext, ToolResult};
12use tokio::task::JoinHandle;
13
14use crate::llm::ToolCallItem;
15
16/// Executes tools eagerly as they arrive from a streaming LLM response.
17///
18/// Call [`submit`](Self::submit) each time a complete `tool_use` block is received
19/// from the stream. After the stream ends, call [`collect`](Self::collect) to
20/// await all pending tasks and retrieve results in submission order.
21pub struct StreamingToolExecutor {
22    pending: Vec<(ToolCallItem, JoinHandle<ToolResult>)>,
23}
24
25impl StreamingToolExecutor {
26    /// Create a new executor with no pending tasks.
27    pub fn new() -> Self {
28        Self {
29            pending: Vec::new(),
30        }
31    }
32
33    /// Submit a tool for immediate execution.
34    ///
35    /// The tool is looked up in `tool_map` and spawned as a background task.
36    /// If the tool is not found, the task will resolve with an error result.
37    pub fn submit(
38        &mut self,
39        item: ToolCallItem,
40        tool_map: &HashMap<String, Arc<dyn Tool>>,
41        timeout: Option<std::time::Duration>,
42        ctx: &ToolContext,
43    ) {
44        let tool = tool_map.get(&item.name).cloned();
45        let name = item.name.clone();
46        let args = item.args.clone();
47        let ctx = ctx.clone();
48
49        let handle = tokio::spawn(async move {
50            let fut = async {
51                if let Some(tool) = tool {
52                    tool.call(args, &ctx).await
53                } else {
54                    ToolResult::error(format!("unknown tool: {name}"))
55                }
56            };
57            if let Some(dur) = timeout {
58                match tokio::time::timeout(dur, fut).await {
59                    Ok(result) => result,
60                    Err(_) => ToolResult::error(format!("tool '{name}' timed out after {dur:?}")),
61                }
62            } else {
63                fut.await
64            }
65        });
66
67        self.pending.push((item, handle));
68    }
69
70    /// Returns true if any tools have been submitted.
71    pub fn has_pending(&self) -> bool {
72        !self.pending.is_empty()
73    }
74
75    /// Await all submitted tasks and return `(items, results)` in submission order.
76    ///
77    /// Items and results are returned as separate vectors so callers can use
78    /// the existing `execute_and_record_tool_calls` pipeline unchanged.
79    pub async fn collect(self) -> (Vec<ToolCallItem>, Vec<ToolResult>) {
80        let mut items = Vec::with_capacity(self.pending.len());
81        let mut results = Vec::with_capacity(self.pending.len());
82
83        for (item, handle) in self.pending {
84            let result = handle
85                .await
86                .unwrap_or_else(|e| ToolResult::error(format!("tool task panicked: {e}")));
87            items.push(item);
88            results.push(result);
89        }
90
91        (items, results)
92    }
93}
94
95impl Default for StreamingToolExecutor {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use motosan_agent_tool::ToolDef;
105    use std::sync::atomic::{AtomicBool, Ordering};
106
107    /// A tool that records when it starts executing.
108    struct TimestampTool {
109        name: String,
110        started: Arc<AtomicBool>,
111        result: String,
112    }
113
114    impl TimestampTool {
115        fn new(name: &str, result: &str, started: Arc<AtomicBool>) -> Self {
116            Self {
117                name: name.to_string(),
118                started,
119                result: result.to_string(),
120            }
121        }
122    }
123
124    impl Tool for TimestampTool {
125        fn def(&self) -> ToolDef {
126            ToolDef {
127                name: self.name.clone(),
128                description: "test tool".into(),
129                input_schema: serde_json::json!({
130                    "type": "object",
131                    "properties": {},
132                    "required": []
133                }),
134            }
135        }
136
137        fn call(
138            &self,
139            _args: serde_json::Value,
140            _ctx: &ToolContext,
141        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult> + Send + '_>> {
142            let started = self.started.clone();
143            let result = self.result.clone();
144            Box::pin(async move {
145                started.store(true, Ordering::SeqCst);
146                // Small delay to simulate work
147                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
148                ToolResult::text(result)
149            })
150        }
151    }
152
153    #[tokio::test]
154    async fn submit_and_collect_returns_results_in_order() {
155        let started_a = Arc::new(AtomicBool::new(false));
156        let started_b = Arc::new(AtomicBool::new(false));
157
158        let tool_a: Arc<dyn Tool> =
159            Arc::new(TimestampTool::new("tool_a", "result_a", started_a.clone()));
160        let tool_b: Arc<dyn Tool> =
161            Arc::new(TimestampTool::new("tool_b", "result_b", started_b.clone()));
162
163        let mut tool_map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
164        tool_map.insert("tool_a".to_string(), tool_a);
165        tool_map.insert("tool_b".to_string(), tool_b);
166
167        let ctx = ToolContext::default();
168        let mut executor = StreamingToolExecutor::new();
169
170        let item_a = ToolCallItem {
171            id: "call_1".to_string(),
172            name: "tool_a".to_string(),
173            args: serde_json::json!({}),
174        };
175        let item_b = ToolCallItem {
176            id: "call_2".to_string(),
177            name: "tool_b".to_string(),
178            args: serde_json::json!({}),
179        };
180
181        executor.submit(item_a, &tool_map, None, &ctx);
182        executor.submit(item_b, &tool_map, None, &ctx);
183
184        assert!(executor.has_pending());
185
186        let (items, results) = executor.collect().await;
187
188        assert_eq!(items.len(), 2);
189        assert_eq!(items[0].id, "call_1");
190        assert_eq!(items[1].id, "call_2");
191        assert_eq!(results.len(), 2);
192        // Both tools should have started
193        assert!(started_a.load(Ordering::SeqCst));
194        assert!(started_b.load(Ordering::SeqCst));
195    }
196
197    #[tokio::test]
198    async fn unknown_tool_returns_error_result() {
199        let tool_map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
200        let ctx = ToolContext::default();
201        let mut executor = StreamingToolExecutor::new();
202
203        let item = ToolCallItem {
204            id: "call_1".to_string(),
205            name: "nonexistent".to_string(),
206            args: serde_json::json!({}),
207        };
208
209        executor.submit(item, &tool_map, None, &ctx);
210        let (items, results) = executor.collect().await;
211
212        assert_eq!(items.len(), 1);
213        assert_eq!(results.len(), 1);
214        // The result should be an error about unknown tool
215        let text = format!("{:?}", results[0]);
216        assert!(text.contains("unknown tool"), "got: {text}");
217    }
218
219    #[tokio::test]
220    async fn tools_start_executing_immediately_after_submit() {
221        let started = Arc::new(AtomicBool::new(false));
222        let tool: Arc<dyn Tool> =
223            Arc::new(TimestampTool::new("slow_tool", "done", started.clone()));
224
225        let mut tool_map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
226        tool_map.insert("slow_tool".to_string(), tool);
227
228        let ctx = ToolContext::default();
229        let mut executor = StreamingToolExecutor::new();
230
231        let item = ToolCallItem {
232            id: "call_1".to_string(),
233            name: "slow_tool".to_string(),
234            args: serde_json::json!({}),
235        };
236
237        executor.submit(item, &tool_map, None, &ctx);
238
239        // Yield to let the spawned task start
240        tokio::task::yield_now().await;
241        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
242
243        // Tool should have started executing before we call collect()
244        assert!(
245            started.load(Ordering::SeqCst),
246            "tool should start executing immediately after submit, before collect()"
247        );
248
249        let (_, results) = executor.collect().await;
250        assert_eq!(results.len(), 1);
251    }
252
253    #[tokio::test]
254    async fn empty_executor_collects_nothing() {
255        let executor = StreamingToolExecutor::new();
256        assert!(!executor.has_pending());
257
258        let (items, results) = executor.collect().await;
259        assert!(items.is_empty());
260        assert!(results.is_empty());
261    }
262}