Skip to main content

forge_client/
timeout.rs

1//! Per-server timeout wrapper for tool dispatchers.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use forge_error::DispatchError;
7use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
8use serde_json::Value;
9
10/// A [`ToolDispatcher`] that enforces a per-call timeout on the inner dispatcher.
11pub struct TimeoutDispatcher {
12    inner: Arc<dyn ToolDispatcher>,
13    timeout: Duration,
14    server_name: String,
15}
16
17impl TimeoutDispatcher {
18    /// Wrap an inner dispatcher with a per-call timeout.
19    pub fn new(
20        inner: Arc<dyn ToolDispatcher>,
21        timeout: Duration,
22        server_name: impl Into<String>,
23    ) -> Self {
24        Self {
25            inner,
26            timeout,
27            server_name: server_name.into(),
28        }
29    }
30}
31
32#[async_trait::async_trait]
33impl ToolDispatcher for TimeoutDispatcher {
34    #[tracing::instrument(skip(self, args), fields(server, tool))]
35    async fn call_tool(
36        &self,
37        server: &str,
38        tool: &str,
39        args: Value,
40    ) -> Result<Value, DispatchError> {
41        match tokio::time::timeout(self.timeout, self.inner.call_tool(server, tool, args)).await {
42            Ok(result) => result,
43            Err(_elapsed) => Err(DispatchError::Timeout {
44                server: self.server_name.clone(),
45                timeout_ms: self.timeout.as_millis() as u64,
46            }),
47        }
48    }
49}
50
51/// A [`ResourceDispatcher`] that enforces a per-call timeout on the inner dispatcher.
52pub struct TimeoutResourceDispatcher {
53    inner: Arc<dyn ResourceDispatcher>,
54    timeout: Duration,
55    server_name: String,
56}
57
58impl TimeoutResourceDispatcher {
59    /// Wrap an inner resource dispatcher with a per-call timeout.
60    pub fn new(
61        inner: Arc<dyn ResourceDispatcher>,
62        timeout: Duration,
63        server_name: impl Into<String>,
64    ) -> Self {
65        Self {
66            inner,
67            timeout,
68            server_name: server_name.into(),
69        }
70    }
71}
72
73#[async_trait::async_trait]
74impl ResourceDispatcher for TimeoutResourceDispatcher {
75    #[tracing::instrument(skip(self), fields(server, uri))]
76    async fn read_resource(
77        &self,
78        server: &str,
79        uri: &str,
80    ) -> Result<serde_json::Value, DispatchError> {
81        match tokio::time::timeout(self.timeout, self.inner.read_resource(server, uri)).await {
82            Ok(result) => result,
83            Err(_elapsed) => Err(DispatchError::Timeout {
84                server: self.server_name.clone(),
85                timeout_ms: self.timeout.as_millis() as u64,
86            }),
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    struct InstantDispatcher;
96
97    #[async_trait::async_trait]
98    impl ToolDispatcher for InstantDispatcher {
99        async fn call_tool(
100            &self,
101            _server: &str,
102            tool: &str,
103            _args: Value,
104        ) -> Result<Value, DispatchError> {
105            Ok(serde_json::json!({"tool": tool, "status": "ok"}))
106        }
107    }
108
109    struct SlowDispatcher {
110        delay: Duration,
111    }
112
113    #[async_trait::async_trait]
114    impl ToolDispatcher for SlowDispatcher {
115        async fn call_tool(
116            &self,
117            _server: &str,
118            _tool: &str,
119            _args: Value,
120        ) -> Result<Value, DispatchError> {
121            tokio::time::sleep(self.delay).await;
122            Ok(serde_json::json!({"status": "ok"}))
123        }
124    }
125
126    struct FailingDispatcher;
127
128    #[async_trait::async_trait]
129    impl ToolDispatcher for FailingDispatcher {
130        async fn call_tool(
131            &self,
132            _server: &str,
133            _tool: &str,
134            _args: Value,
135        ) -> Result<Value, DispatchError> {
136            Err(DispatchError::Internal(anyhow::anyhow!("inner error")))
137        }
138    }
139
140    #[tokio::test]
141    async fn fast_call_passes_through() {
142        let inner = Arc::new(InstantDispatcher);
143        let td = TimeoutDispatcher::new(inner, Duration::from_secs(5), "test-server");
144        let result = td
145            .call_tool("test-server", "echo", serde_json::json!({}))
146            .await;
147        assert!(result.is_ok());
148        assert_eq!(result.unwrap()["tool"], "echo");
149    }
150
151    #[tokio::test]
152    async fn slow_call_times_out() {
153        let inner = Arc::new(SlowDispatcher {
154            delay: Duration::from_secs(10),
155        });
156        let td = TimeoutDispatcher::new(inner, Duration::from_millis(50), "slow-server");
157        let result = td
158            .call_tool("slow-server", "scan", serde_json::json!({}))
159            .await;
160        assert!(result.is_err());
161        assert!(matches!(result.unwrap_err(), DispatchError::Timeout { .. }));
162    }
163
164    #[tokio::test]
165    async fn timeout_error_message_contains_context() {
166        let inner = Arc::new(SlowDispatcher {
167            delay: Duration::from_secs(10),
168        });
169        let td = TimeoutDispatcher::new(inner, Duration::from_millis(50), "narsil");
170        let err = td
171            .call_tool("narsil", "symbols.find", serde_json::json!({}))
172            .await
173            .unwrap_err();
174        assert!(matches!(err, DispatchError::Timeout { ref server, .. } if server == "narsil"));
175        let msg = err.to_string();
176        assert!(msg.contains("timeout"), "should mention timeout: {msg}");
177        assert!(msg.contains("narsil"), "should mention server name: {msg}");
178    }
179
180    // --- v0.2 Resource Timeout Test (RS-C07) ---
181
182    struct InstantResourceDispatcher;
183
184    #[async_trait::async_trait]
185    impl ResourceDispatcher for InstantResourceDispatcher {
186        async fn read_resource(
187            &self,
188            _server: &str,
189            uri: &str,
190        ) -> Result<serde_json::Value, DispatchError> {
191            Ok(serde_json::json!({"uri": uri}))
192        }
193    }
194
195    struct SlowResourceDispatcher;
196
197    #[async_trait::async_trait]
198    impl ResourceDispatcher for SlowResourceDispatcher {
199        async fn read_resource(
200            &self,
201            _server: &str,
202            _uri: &str,
203        ) -> Result<serde_json::Value, DispatchError> {
204            tokio::time::sleep(Duration::from_secs(10)).await;
205            Ok(serde_json::json!({}))
206        }
207    }
208
209    #[tokio::test]
210    async fn rs_c07_timeout_wraps_resource_reads() {
211        // Fast read succeeds
212        let fast = TimeoutResourceDispatcher::new(
213            Arc::new(InstantResourceDispatcher),
214            Duration::from_secs(5),
215            "fast-server",
216        );
217        let result = fast.read_resource("fast-server", "file:///log").await;
218        assert!(result.is_ok());
219
220        // Slow read times out
221        let slow = TimeoutResourceDispatcher::new(
222            Arc::new(SlowResourceDispatcher),
223            Duration::from_millis(50),
224            "slow-server",
225        );
226        let result = slow.read_resource("slow-server", "file:///log").await;
227        assert!(matches!(result, Err(DispatchError::Timeout { .. })));
228    }
229
230    #[tokio::test]
231    async fn inner_error_preserved() {
232        let inner = Arc::new(FailingDispatcher);
233        let td = TimeoutDispatcher::new(inner, Duration::from_secs(5), "test");
234        let err = td
235            .call_tool("test", "tool", serde_json::json!({}))
236            .await
237            .unwrap_err();
238        assert!(matches!(err, DispatchError::Internal(_)));
239    }
240}