1use std::sync::Arc;
4use std::time::Duration;
5
6use forge_error::DispatchError;
7use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
8use serde_json::Value;
9
10pub struct TimeoutDispatcher {
12 inner: Arc<dyn ToolDispatcher>,
13 timeout: Duration,
14 server_name: String,
15}
16
17impl TimeoutDispatcher {
18 pub fn new(
20 inner: Arc<dyn ToolDispatcher>,
21 timeout: Duration,
22 server_name: impl Into<String>,
23 ) -> Self {
24 Self {
25 inner,
26 timeout,
27 server_name: server_name.into(),
28 }
29 }
30}
31
32#[async_trait::async_trait]
33impl ToolDispatcher for TimeoutDispatcher {
34 #[tracing::instrument(skip(self, args), fields(server, tool))]
35 async fn call_tool(
36 &self,
37 server: &str,
38 tool: &str,
39 args: Value,
40 ) -> Result<Value, DispatchError> {
41 match tokio::time::timeout(self.timeout, self.inner.call_tool(server, tool, args)).await {
42 Ok(result) => result,
43 Err(_elapsed) => Err(DispatchError::Timeout {
44 server: self.server_name.clone(),
45 timeout_ms: self.timeout.as_millis() as u64,
46 }),
47 }
48 }
49}
50
51pub struct TimeoutResourceDispatcher {
53 inner: Arc<dyn ResourceDispatcher>,
54 timeout: Duration,
55 server_name: String,
56}
57
58impl TimeoutResourceDispatcher {
59 pub fn new(
61 inner: Arc<dyn ResourceDispatcher>,
62 timeout: Duration,
63 server_name: impl Into<String>,
64 ) -> Self {
65 Self {
66 inner,
67 timeout,
68 server_name: server_name.into(),
69 }
70 }
71}
72
73#[async_trait::async_trait]
74impl ResourceDispatcher for TimeoutResourceDispatcher {
75 #[tracing::instrument(skip(self), fields(server, uri))]
76 async fn read_resource(
77 &self,
78 server: &str,
79 uri: &str,
80 ) -> Result<serde_json::Value, DispatchError> {
81 match tokio::time::timeout(self.timeout, self.inner.read_resource(server, uri)).await {
82 Ok(result) => result,
83 Err(_elapsed) => Err(DispatchError::Timeout {
84 server: self.server_name.clone(),
85 timeout_ms: self.timeout.as_millis() as u64,
86 }),
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 struct InstantDispatcher;
96
97 #[async_trait::async_trait]
98 impl ToolDispatcher for InstantDispatcher {
99 async fn call_tool(
100 &self,
101 _server: &str,
102 tool: &str,
103 _args: Value,
104 ) -> Result<Value, DispatchError> {
105 Ok(serde_json::json!({"tool": tool, "status": "ok"}))
106 }
107 }
108
109 struct SlowDispatcher {
110 delay: Duration,
111 }
112
113 #[async_trait::async_trait]
114 impl ToolDispatcher for SlowDispatcher {
115 async fn call_tool(
116 &self,
117 _server: &str,
118 _tool: &str,
119 _args: Value,
120 ) -> Result<Value, DispatchError> {
121 tokio::time::sleep(self.delay).await;
122 Ok(serde_json::json!({"status": "ok"}))
123 }
124 }
125
126 struct FailingDispatcher;
127
128 #[async_trait::async_trait]
129 impl ToolDispatcher for FailingDispatcher {
130 async fn call_tool(
131 &self,
132 _server: &str,
133 _tool: &str,
134 _args: Value,
135 ) -> Result<Value, DispatchError> {
136 Err(DispatchError::Internal(anyhow::anyhow!("inner error")))
137 }
138 }
139
140 #[tokio::test]
141 async fn fast_call_passes_through() {
142 let inner = Arc::new(InstantDispatcher);
143 let td = TimeoutDispatcher::new(inner, Duration::from_secs(5), "test-server");
144 let result = td
145 .call_tool("test-server", "echo", serde_json::json!({}))
146 .await;
147 assert!(result.is_ok());
148 assert_eq!(result.unwrap()["tool"], "echo");
149 }
150
151 #[tokio::test]
152 async fn slow_call_times_out() {
153 let inner = Arc::new(SlowDispatcher {
154 delay: Duration::from_secs(10),
155 });
156 let td = TimeoutDispatcher::new(inner, Duration::from_millis(50), "slow-server");
157 let result = td
158 .call_tool("slow-server", "scan", serde_json::json!({}))
159 .await;
160 assert!(result.is_err());
161 assert!(matches!(result.unwrap_err(), DispatchError::Timeout { .. }));
162 }
163
164 #[tokio::test]
165 async fn timeout_error_message_contains_context() {
166 let inner = Arc::new(SlowDispatcher {
167 delay: Duration::from_secs(10),
168 });
169 let td = TimeoutDispatcher::new(inner, Duration::from_millis(50), "narsil");
170 let err = td
171 .call_tool("narsil", "symbols.find", serde_json::json!({}))
172 .await
173 .unwrap_err();
174 assert!(matches!(err, DispatchError::Timeout { ref server, .. } if server == "narsil"));
175 let msg = err.to_string();
176 assert!(msg.contains("timeout"), "should mention timeout: {msg}");
177 assert!(msg.contains("narsil"), "should mention server name: {msg}");
178 }
179
180 struct InstantResourceDispatcher;
183
184 #[async_trait::async_trait]
185 impl ResourceDispatcher for InstantResourceDispatcher {
186 async fn read_resource(
187 &self,
188 _server: &str,
189 uri: &str,
190 ) -> Result<serde_json::Value, DispatchError> {
191 Ok(serde_json::json!({"uri": uri}))
192 }
193 }
194
195 struct SlowResourceDispatcher;
196
197 #[async_trait::async_trait]
198 impl ResourceDispatcher for SlowResourceDispatcher {
199 async fn read_resource(
200 &self,
201 _server: &str,
202 _uri: &str,
203 ) -> Result<serde_json::Value, DispatchError> {
204 tokio::time::sleep(Duration::from_secs(10)).await;
205 Ok(serde_json::json!({}))
206 }
207 }
208
209 #[tokio::test]
210 async fn rs_c07_timeout_wraps_resource_reads() {
211 let fast = TimeoutResourceDispatcher::new(
213 Arc::new(InstantResourceDispatcher),
214 Duration::from_secs(5),
215 "fast-server",
216 );
217 let result = fast.read_resource("fast-server", "file:///log").await;
218 assert!(result.is_ok());
219
220 let slow = TimeoutResourceDispatcher::new(
222 Arc::new(SlowResourceDispatcher),
223 Duration::from_millis(50),
224 "slow-server",
225 );
226 let result = slow.read_resource("slow-server", "file:///log").await;
227 assert!(matches!(result, Err(DispatchError::Timeout { .. })));
228 }
229
230 #[tokio::test]
231 async fn inner_error_preserved() {
232 let inner = Arc::new(FailingDispatcher);
233 let td = TimeoutDispatcher::new(inner, Duration::from_secs(5), "test");
234 let err = td
235 .call_tool("test", "tool", serde_json::json!({}))
236 .await
237 .unwrap_err();
238 assert!(matches!(err, DispatchError::Internal(_)));
239 }
240}