use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::{Mutex, oneshot};
use crate::types::{ContentBlock, RuntimeError};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
pub struct ToolParkingLot {
pending: Mutex<HashMap<String, oneshot::Sender<Vec<ContentBlock>>>>,
timeout: Duration,
}
impl ToolParkingLot {
pub fn new(timeout: Duration) -> Self {
Self { pending: Mutex::new(HashMap::new()), timeout }
}
pub fn with_default_timeout() -> Self {
Self::new(DEFAULT_TIMEOUT)
}
pub async fn park(&self, tool_use_id: &str) -> Result<Vec<ContentBlock>, RuntimeError> {
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(tool_use_id.to_string(), tx);
}
match tokio::time::timeout(self.timeout, rx).await {
Ok(Ok(content)) => Ok(content),
Ok(Err(_recv_error)) => {
let mut pending = self.pending.lock().await;
pending.remove(tool_use_id);
Err(RuntimeError::internal(format!(
"parking channel closed unexpectedly for tool_use_id: {tool_use_id}"
)))
}
Err(_timeout) => {
let mut pending = self.pending.lock().await;
pending.remove(tool_use_id);
Err(RuntimeError::tool_timeout(tool_use_id, self.timeout.as_secs()))
}
}
}
pub async fn deliver(
&self,
tool_use_id: &str,
content: Vec<ContentBlock>,
) -> Result<(), RuntimeError> {
let tx = {
let mut pending = self.pending.lock().await;
pending.remove(tool_use_id)
};
match tx {
Some(sender) => {
let _ = sender.send(content);
Ok(())
}
None => Err(RuntimeError::NotFound {
session_id: format!("no pending tool call: {tool_use_id}"),
}),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[tokio::test]
async fn test_successful_park_and_deliver() {
let lot = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
let tool_id = "tool_use_123";
let lot_clone = Arc::clone(&lot);
let park_handle = tokio::spawn(async move { lot_clone.park(tool_id).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let content = vec![ContentBlock::Text { text: "result data".to_string() }];
lot.deliver(tool_id, content).await.unwrap();
let result = park_handle.await.unwrap().unwrap();
assert_eq!(result.len(), 1);
match &result[0] {
ContentBlock::Text { text } => assert_eq!(text, "result data"),
_ => panic!("expected Text variant"),
}
}
#[tokio::test]
async fn test_timeout_returns_tool_timeout_error() {
let lot = ToolParkingLot::new(Duration::from_millis(50));
let tool_id = "tool_use_timeout";
let result = lot.park(tool_id).await;
assert!(result.is_err());
let err = result.unwrap_err();
match err {
RuntimeError::ToolTimeout { tool_use_id, timeout_secs } => {
assert_eq!(tool_use_id, tool_id);
assert_eq!(timeout_secs, 0);
}
other => panic!("expected ToolTimeout, got: {other}"),
}
}
#[tokio::test]
async fn test_deliver_to_unknown_id_returns_not_found() {
let lot = ToolParkingLot::new(Duration::from_secs(5));
let result = lot
.deliver("nonexistent_id", vec![ContentBlock::Text { text: "hello".to_string() }])
.await;
assert!(result.is_err());
let err = result.unwrap_err();
match err {
RuntimeError::NotFound { session_id } => {
assert!(session_id.contains("nonexistent_id"));
}
other => panic!("expected NotFound, got: {other}"),
}
}
#[tokio::test]
async fn test_default_timeout_is_five_minutes() {
let lot = ToolParkingLot::with_default_timeout();
assert_eq!(lot.timeout, Duration::from_secs(300));
}
#[tokio::test]
async fn test_multiple_concurrent_parks() {
let lot = Arc::new(ToolParkingLot::new(Duration::from_secs(5)));
let lot_a = Arc::clone(&lot);
let handle_a = tokio::spawn(async move { lot_a.park("tool_a").await });
let lot_b = Arc::clone(&lot);
let handle_b = tokio::spawn(async move { lot_b.park("tool_b").await });
tokio::time::sleep(Duration::from_millis(10)).await;
lot.deliver("tool_b", vec![ContentBlock::Text { text: "b_result".to_string() }])
.await
.unwrap();
lot.deliver("tool_a", vec![ContentBlock::Text { text: "a_result".to_string() }])
.await
.unwrap();
let result_a = handle_a.await.unwrap().unwrap();
let result_b = handle_b.await.unwrap().unwrap();
match &result_a[0] {
ContentBlock::Text { text } => assert_eq!(text, "a_result"),
_ => panic!("expected Text"),
}
match &result_b[0] {
ContentBlock::Text { text } => assert_eq!(text, "b_result"),
_ => panic!("expected Text"),
}
}
#[tokio::test]
async fn test_deliver_after_timeout_returns_not_found() {
let lot = ToolParkingLot::new(Duration::from_millis(20));
let tool_id = "tool_expired";
let _ = lot.park(tool_id).await;
let result =
lot.deliver(tool_id, vec![ContentBlock::Text { text: "late".to_string() }]).await;
assert!(result.is_err());
match result.unwrap_err() {
RuntimeError::NotFound { .. } => {}
other => panic!("expected NotFound, got: {other}"),
}
}
}