1use std::collections::HashMap;
31use std::time::Duration;
32
33use tokio::sync::{Mutex, oneshot};
34
35use crate::types::{ContentBlock, RuntimeError};
36
37const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
39
40pub struct ToolParkingLot {
49 pending: Mutex<HashMap<String, oneshot::Sender<Vec<ContentBlock>>>>,
51 timeout: Duration,
53}
54
55impl ToolParkingLot {
56 pub fn new(timeout: Duration) -> Self {
63 Self { pending: Mutex::new(HashMap::new()), timeout }
64 }
65
66 pub fn with_default_timeout() -> Self {
68 Self::new(DEFAULT_TIMEOUT)
69 }
70
71 pub async fn park(&self, tool_use_id: &str) -> Result<Vec<ContentBlock>, RuntimeError> {
82 let (tx, rx) = oneshot::channel();
83
84 {
85 let mut pending = self.pending.lock().await;
86 pending.insert(tool_use_id.to_string(), tx);
87 }
88
89 match tokio::time::timeout(self.timeout, rx).await {
90 Ok(Ok(content)) => Ok(content),
91 Ok(Err(_recv_error)) => {
92 let mut pending = self.pending.lock().await;
94 pending.remove(tool_use_id);
95 Err(RuntimeError::internal(format!(
96 "parking channel closed unexpectedly for tool_use_id: {tool_use_id}"
97 )))
98 }
99 Err(_timeout) => {
100 let mut pending = self.pending.lock().await;
102 pending.remove(tool_use_id);
103 Err(RuntimeError::tool_timeout(tool_use_id, self.timeout.as_secs()))
104 }
105 }
106 }
107
108 pub async fn deliver(
118 &self,
119 tool_use_id: &str,
120 content: Vec<ContentBlock>,
121 ) -> Result<(), RuntimeError> {
122 let tx = {
123 let mut pending = self.pending.lock().await;
124 pending.remove(tool_use_id)
125 };
126
127 match tx {
128 Some(sender) => {
129 let _ = sender.send(content);
131 Ok(())
132 }
133 None => Err(RuntimeError::NotFound {
134 session_id: format!("no pending tool call: {tool_use_id}"),
135 }),
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use std::sync::Arc;
143
144 use super::*;
145
146 #[tokio::test]
147 async fn test_successful_park_and_deliver() {
148 let lot = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
149 let tool_id = "tool_use_123";
150
151 let lot_clone = Arc::clone(&lot);
152 let park_handle = tokio::spawn(async move { lot_clone.park(tool_id).await });
153
154 tokio::time::sleep(Duration::from_millis(10)).await;
156
157 let content = vec![ContentBlock::Text { text: "result data".to_string() }];
158 lot.deliver(tool_id, content).await.unwrap();
159
160 let result = park_handle.await.unwrap().unwrap();
161 assert_eq!(result.len(), 1);
162 match &result[0] {
163 ContentBlock::Text { text } => assert_eq!(text, "result data"),
164 _ => panic!("expected Text variant"),
165 }
166 }
167
168 #[tokio::test]
169 async fn test_timeout_returns_tool_timeout_error() {
170 let lot = ToolParkingLot::new(Duration::from_millis(50));
171 let tool_id = "tool_use_timeout";
172
173 let result = lot.park(tool_id).await;
174
175 assert!(result.is_err());
176 let err = result.unwrap_err();
177 match err {
178 RuntimeError::ToolTimeout { tool_use_id, timeout_secs } => {
179 assert_eq!(tool_use_id, tool_id);
180 assert_eq!(timeout_secs, 0);
182 }
183 other => panic!("expected ToolTimeout, got: {other}"),
184 }
185 }
186
187 #[tokio::test]
188 async fn test_deliver_to_unknown_id_returns_not_found() {
189 let lot = ToolParkingLot::new(Duration::from_secs(5));
190
191 let result = lot
192 .deliver("nonexistent_id", vec![ContentBlock::Text { text: "hello".to_string() }])
193 .await;
194
195 assert!(result.is_err());
196 let err = result.unwrap_err();
197 match err {
198 RuntimeError::NotFound { session_id } => {
199 assert!(session_id.contains("nonexistent_id"));
200 }
201 other => panic!("expected NotFound, got: {other}"),
202 }
203 }
204
205 #[tokio::test]
206 async fn test_default_timeout_is_five_minutes() {
207 let lot = ToolParkingLot::with_default_timeout();
208 assert_eq!(lot.timeout, Duration::from_secs(300));
209 }
210
211 #[tokio::test]
212 async fn test_multiple_concurrent_parks() {
213 let lot = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
214
215 let lot_a = Arc::clone(&lot);
216 let handle_a = tokio::spawn(async move { lot_a.park("tool_a").await });
217
218 let lot_b = Arc::clone(&lot);
219 let handle_b = tokio::spawn(async move { lot_b.park("tool_b").await });
220
221 tokio::time::sleep(Duration::from_millis(10)).await;
222
223 lot.deliver("tool_b", vec![ContentBlock::Text { text: "b_result".to_string() }])
224 .await
225 .unwrap();
226
227 lot.deliver("tool_a", vec![ContentBlock::Text { text: "a_result".to_string() }])
228 .await
229 .unwrap();
230
231 let result_a = handle_a.await.unwrap().unwrap();
232 let result_b = handle_b.await.unwrap().unwrap();
233
234 match &result_a[0] {
235 ContentBlock::Text { text } => assert_eq!(text, "a_result"),
236 _ => panic!("expected Text"),
237 }
238 match &result_b[0] {
239 ContentBlock::Text { text } => assert_eq!(text, "b_result"),
240 _ => panic!("expected Text"),
241 }
242 }
243
244 #[tokio::test]
245 async fn test_deliver_after_timeout_returns_not_found() {
246 let lot = ToolParkingLot::new(Duration::from_millis(20));
247 let tool_id = "tool_expired";
248
249 let _ = lot.park(tool_id).await;
251
252 let result =
254 lot.deliver(tool_id, vec![ContentBlock::Text { text: "late".to_string() }]).await;
255
256 assert!(result.is_err());
257 match result.unwrap_err() {
258 RuntimeError::NotFound { .. } => {}
259 other => panic!("expected NotFound, got: {other}"),
260 }
261 }
262}