Skip to main content

adk_managed/
parking.rs

1//! Custom tool parking lot for the managed agent runtime.
2//!
3//! When the agent emits a `custom_tool_use` event, the session loop parks
4//! execution until the client delivers a result (or timeout elapses). This
5//! module implements that channel-based wait mechanism.
6//!
7//! # Architecture
8//!
9//! Internally, each parked tool call gets a `tokio::sync::oneshot` channel.
10//! The [`ToolParkingLot::park`] method creates the sender, stores it, and awaits
11//! the receiver with a timeout. [`ToolParkingLot::deliver`] looks up the sender
12//! and pushes the content through.
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use std::time::Duration;
18//! use adk_managed::parking::ToolParkingLot;
19//! use adk_managed::types::ContentBlock;
20//!
21//! let lot = ToolParkingLot::new(Duration::from_secs(300));
22//!
23//! // In the session loop task:
24//! let result = lot.park("tool_use_abc").await?;
25//!
26//! // In the event dispatch task:
27//! lot.deliver("tool_use_abc", vec![ContentBlock::Text { text: "done".into() }]).await?;
28//! ```
29
30use std::collections::HashMap;
31use std::time::Duration;
32
33use tokio::sync::{Mutex, oneshot};
34
35use crate::types::{ContentBlock, RuntimeError};
36
37/// Default timeout for parked tool calls (5 minutes).
38const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
39
40/// A parking lot for custom tool calls awaiting client-delivered results.
41///
42/// The session loop calls [`park`](Self::park) when `agent.custom_tool_use` is emitted,
43/// blocking until either:
44/// 1. The client sends `user.custom_tool_result` and the runtime calls [`deliver`](Self::deliver)
45/// 2. The configured timeout elapses, returning [`RuntimeError::ToolTimeout`]
46///
47/// Thread-safe: the internal map is protected by a [`Mutex`].
48pub struct ToolParkingLot {
49    /// Pending tool calls: tool_use_id → sender that will deliver the result.
50    pending: Mutex<HashMap<String, oneshot::Sender<Vec<ContentBlock>>>>,
51    /// How long to wait before timing out a parked call.
52    timeout: Duration,
53}
54
55impl ToolParkingLot {
56    /// Create a new parking lot with the specified timeout.
57    ///
58    /// # Arguments
59    ///
60    /// * `timeout` - Maximum duration to wait for a tool result before returning
61    ///   [`RuntimeError::ToolTimeout`].
62    pub fn new(timeout: Duration) -> Self {
63        Self { pending: Mutex::new(HashMap::new()), timeout }
64    }
65
66    /// Create a new parking lot with the default timeout (5 minutes).
67    pub fn with_default_timeout() -> Self {
68        Self::new(DEFAULT_TIMEOUT)
69    }
70
71    /// Park the session loop, waiting for a custom tool result.
72    ///
73    /// Creates a oneshot channel, stores the sender under `tool_use_id`, and
74    /// awaits the receiver. Returns the content blocks when delivered, or
75    /// [`RuntimeError::ToolTimeout`] if the timeout elapses.
76    ///
77    /// # Errors
78    ///
79    /// - [`RuntimeError::ToolTimeout`] if no result is delivered within the timeout.
80    /// - [`RuntimeError::Internal`] if the sender is dropped unexpectedly.
81    pub async fn park(&self, tool_use_id: &str) -> Result<Vec<ContentBlock>, RuntimeError> {
82        let (tx, rx) = oneshot::channel();
83
84        {
85            let mut pending = self.pending.lock().await;
86            pending.insert(tool_use_id.to_string(), tx);
87        }
88
89        match tokio::time::timeout(self.timeout, rx).await {
90            Ok(Ok(content)) => Ok(content),
91            Ok(Err(_recv_error)) => {
92                // Sender was dropped without sending — clean up and report internal error.
93                let mut pending = self.pending.lock().await;
94                pending.remove(tool_use_id);
95                Err(RuntimeError::internal(format!(
96                    "parking channel closed unexpectedly for tool_use_id: {tool_use_id}"
97                )))
98            }
99            Err(_timeout) => {
100                // Timeout elapsed — remove the pending entry and return timeout error.
101                let mut pending = self.pending.lock().await;
102                pending.remove(tool_use_id);
103                Err(RuntimeError::tool_timeout(tool_use_id, self.timeout.as_secs()))
104            }
105        }
106    }
107
108    /// Deliver a result to a parked tool call.
109    ///
110    /// Looks up the sender by `tool_use_id` and sends the content. Returns an
111    /// error if no pending call with this ID exists (e.g., it already timed out
112    /// or was never parked).
113    ///
114    /// # Errors
115    ///
116    /// - [`RuntimeError::NotFound`] if no pending call exists for the given ID.
117    pub async fn deliver(
118        &self,
119        tool_use_id: &str,
120        content: Vec<ContentBlock>,
121    ) -> Result<(), RuntimeError> {
122        let tx = {
123            let mut pending = self.pending.lock().await;
124            pending.remove(tool_use_id)
125        };
126
127        match tx {
128            Some(sender) => {
129                // If the receiver was already dropped (e.g., task cancelled), ignore the error.
130                let _ = sender.send(content);
131                Ok(())
132            }
133            None => Err(RuntimeError::NotFound {
134                session_id: format!("no pending tool call: {tool_use_id}"),
135            }),
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use std::sync::Arc;
143
144    use super::*;
145
146    #[tokio::test]
147    async fn test_successful_park_and_deliver() {
148        let lot = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
149        let tool_id = "tool_use_123";
150
151        let lot_clone = Arc::clone(&lot);
152        let park_handle = tokio::spawn(async move { lot_clone.park(tool_id).await });
153
154        // Give the park task a moment to register.
155        tokio::time::sleep(Duration::from_millis(10)).await;
156
157        let content = vec![ContentBlock::Text { text: "result data".to_string() }];
158        lot.deliver(tool_id, content).await.unwrap();
159
160        let result = park_handle.await.unwrap().unwrap();
161        assert_eq!(result.len(), 1);
162        match &result[0] {
163            ContentBlock::Text { text } => assert_eq!(text, "result data"),
164            _ => panic!("expected Text variant"),
165        }
166    }
167
168    #[tokio::test]
169    async fn test_timeout_returns_tool_timeout_error() {
170        let lot = ToolParkingLot::new(Duration::from_millis(50));
171        let tool_id = "tool_use_timeout";
172
173        let result = lot.park(tool_id).await;
174
175        assert!(result.is_err());
176        let err = result.unwrap_err();
177        match err {
178            RuntimeError::ToolTimeout { tool_use_id, timeout_secs } => {
179                assert_eq!(tool_use_id, tool_id);
180                // Duration is 50ms, which rounds down to 0 seconds.
181                assert_eq!(timeout_secs, 0);
182            }
183            other => panic!("expected ToolTimeout, got: {other}"),
184        }
185    }
186
187    #[tokio::test]
188    async fn test_deliver_to_unknown_id_returns_not_found() {
189        let lot = ToolParkingLot::new(Duration::from_secs(5));
190
191        let result = lot
192            .deliver("nonexistent_id", vec![ContentBlock::Text { text: "hello".to_string() }])
193            .await;
194
195        assert!(result.is_err());
196        let err = result.unwrap_err();
197        match err {
198            RuntimeError::NotFound { session_id } => {
199                assert!(session_id.contains("nonexistent_id"));
200            }
201            other => panic!("expected NotFound, got: {other}"),
202        }
203    }
204
205    #[tokio::test]
206    async fn test_default_timeout_is_five_minutes() {
207        let lot = ToolParkingLot::with_default_timeout();
208        assert_eq!(lot.timeout, Duration::from_secs(300));
209    }
210
211    #[tokio::test]
212    async fn test_multiple_concurrent_parks() {
213        let lot = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
214
215        let lot_a = Arc::clone(&lot);
216        let handle_a = tokio::spawn(async move { lot_a.park("tool_a").await });
217
218        let lot_b = Arc::clone(&lot);
219        let handle_b = tokio::spawn(async move { lot_b.park("tool_b").await });
220
221        tokio::time::sleep(Duration::from_millis(10)).await;
222
223        lot.deliver("tool_b", vec![ContentBlock::Text { text: "b_result".to_string() }])
224            .await
225            .unwrap();
226
227        lot.deliver("tool_a", vec![ContentBlock::Text { text: "a_result".to_string() }])
228            .await
229            .unwrap();
230
231        let result_a = handle_a.await.unwrap().unwrap();
232        let result_b = handle_b.await.unwrap().unwrap();
233
234        match &result_a[0] {
235            ContentBlock::Text { text } => assert_eq!(text, "a_result"),
236            _ => panic!("expected Text"),
237        }
238        match &result_b[0] {
239            ContentBlock::Text { text } => assert_eq!(text, "b_result"),
240            _ => panic!("expected Text"),
241        }
242    }
243
244    #[tokio::test]
245    async fn test_deliver_after_timeout_returns_not_found() {
246        let lot = ToolParkingLot::new(Duration::from_millis(20));
247        let tool_id = "tool_expired";
248
249        // Park and let it timeout.
250        let _ = lot.park(tool_id).await;
251
252        // Now try to deliver — should fail because the entry was cleaned up.
253        let result =
254            lot.deliver(tool_id, vec![ContentBlock::Text { text: "late".to_string() }]).await;
255
256        assert!(result.is_err());
257        match result.unwrap_err() {
258            RuntimeError::NotFound { .. } => {}
259            other => panic!("expected NotFound, got: {other}"),
260        }
261    }
262}