Skip to main content

agy_bridge/config/
mcp.rs

1//! MCP (Model Context Protocol) server configuration types.
2
3use serde::{Deserialize, Serialize};
4
5use super::{default_mcp_sse_read_timeout, default_mcp_timeout, default_true};
6
7/// Configuration for an MCP server connected via stdio.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct McpStdioServer {
10    /// The command to run to start the server.
11    pub command: String,
12    /// Arguments to pass to the command.
13    #[serde(default)]
14    pub args: Vec<String>,
15}
16
17/// Configuration for an MCP server connected via SSE.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpSseServer {
20    /// The URL of the SSE endpoint.
21    pub url: String,
22    /// Optional headers to send with the connection request.
23    #[serde(default)]
24    pub headers: Option<std::collections::HashMap<String, String>>,
25}
26
27/// Configuration for an MCP server connected via Streamable HTTP.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpStreamableHttpServer {
30    /// The URL of the HTTP endpoint.
31    pub url: String,
32    /// Optional headers to send with the connection request.
33    #[serde(default)]
34    pub headers: Option<std::collections::HashMap<String, String>>,
35    /// Connection timeout in seconds.
36    #[serde(default = "default_mcp_timeout")]
37    pub timeout: f64,
38    /// SSE read timeout in seconds.
39    #[serde(default = "default_mcp_sse_read_timeout")]
40    pub sse_read_timeout: f64,
41    /// Whether to terminate the connection on close.
42    #[serde(default = "default_true")]
43    pub terminate_on_close: bool,
44}
45
46/// An MCP server, identified by its transport.
47///
48/// All MCP transports speak JSON-RPC 2.0; the variants describe *how* the
49/// client connects to the server process.
50///
51/// Use the convenience constructors [`McpServer::stdio`], [`McpServer::sse`],
52/// and [`McpServer::http`] to avoid importing the inner transport types.
53#[non_exhaustive]
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(tag = "type")]
56pub enum McpServer {
57    #[serde(rename = "stdio")]
58    Stdio(McpStdioServer),
59    #[serde(rename = "sse")]
60    Sse(McpSseServer),
61    #[serde(rename = "http")]
62    Http(McpStreamableHttpServer),
63}
64
65impl McpServer {
66    /// Create a stdio-transport MCP server that spawns `command` as a child process.
67    #[must_use]
68    pub fn stdio(command: impl Into<String>) -> McpStdioServer {
69        McpStdioServer::new(command)
70    }
71
72    /// Create an SSE-transport MCP server at the given `url`.
73    #[must_use]
74    pub fn sse(url: impl Into<String>) -> McpSseServer {
75        McpSseServer::new(url)
76    }
77
78    /// Create a Streamable-HTTP-transport MCP server at the given `url`.
79    #[must_use]
80    pub fn http(url: impl Into<String>) -> McpStreamableHttpServer {
81        McpStreamableHttpServer::new(url)
82    }
83}
84
85// ─── MCP Server Builders ───────────────────────────────────────────────────
86
87impl From<McpStdioServer> for McpServer {
88    fn from(val: McpStdioServer) -> Self {
89        Self::Stdio(val)
90    }
91}
92
93impl McpStdioServer {
94    /// Create a new Stdio MCP Server configuration.
95    #[must_use]
96    pub fn new(command: impl Into<String>) -> Self {
97        Self {
98            command: command.into(),
99            args: Vec::new(),
100        }
101    }
102
103    /// Add an argument to the command.
104    #[must_use]
105    pub fn arg(mut self, arg: impl Into<String>) -> Self {
106        self.args.push(arg.into());
107        self
108    }
109
110    /// Add multiple arguments to the command at once.
111    #[must_use]
112    pub fn args<I, S>(mut self, args: I) -> Self
113    where
114        I: IntoIterator<Item = S>,
115        S: Into<String>,
116    {
117        self.args.extend(args.into_iter().map(Into::into));
118        self
119    }
120
121    /// Build this stdio configuration into an [`McpServer`].
122    #[must_use]
123    pub fn build(self) -> McpServer {
124        McpServer::Stdio(self)
125    }
126}
127
128impl From<McpSseServer> for McpServer {
129    fn from(val: McpSseServer) -> Self {
130        Self::Sse(val)
131    }
132}
133
134impl McpSseServer {
135    /// Create a new SSE MCP Server configuration.
136    #[must_use]
137    pub fn new(url: impl Into<String>) -> Self {
138        Self {
139            url: url.into(),
140            headers: None,
141        }
142    }
143
144    /// Add a header to the SSE connection.
145    #[must_use]
146    pub fn header(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
147        self.headers
148            .get_or_insert_with(std::collections::HashMap::new)
149            .insert(k.into(), v.into());
150        self
151    }
152
153    /// Build this SSE configuration into an [`McpServer`].
154    #[must_use]
155    pub fn build(self) -> McpServer {
156        McpServer::Sse(self)
157    }
158}
159
160impl From<McpStreamableHttpServer> for McpServer {
161    fn from(val: McpStreamableHttpServer) -> Self {
162        Self::Http(val)
163    }
164}
165
166impl McpStreamableHttpServer {
167    /// Create a new Streamable HTTP MCP Server configuration.
168    #[must_use]
169    pub fn new(url: impl Into<String>) -> Self {
170        Self {
171            url: url.into(),
172            headers: None,
173            timeout: default_mcp_timeout(),
174            sse_read_timeout: default_mcp_sse_read_timeout(),
175            terminate_on_close: true,
176        }
177    }
178
179    /// Add a header to the HTTP connection.
180    #[must_use]
181    pub fn header(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
182        self.headers
183            .get_or_insert_with(std::collections::HashMap::new)
184            .insert(k.into(), v.into());
185        self
186    }
187
188    /// Set the HTTP connection/request timeout in seconds.
189    #[must_use]
190    pub const fn timeout(mut self, timeout: f64) -> Self {
191        self.timeout = timeout;
192        self
193    }
194
195    /// Set the streaming read timeout in seconds.
196    #[must_use]
197    pub const fn sse_read_timeout(mut self, timeout: f64) -> Self {
198        self.sse_read_timeout = timeout;
199        self
200    }
201
202    /// Build this HTTP configuration into an [`McpServer`].
203    #[must_use]
204    pub fn build(self) -> McpServer {
205        McpServer::Http(self)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use pyo3::types::PyAnyMethods;
212
213    use super::{
214        super::{DEFAULT_MCP_SSE_READ_TIMEOUT_SECS, DEFAULT_MCP_TIMEOUT_SECS},
215        *,
216    };
217
218    fn py_pydantic_field_default(module: &str, class: &str, field: &str) -> f64 {
219        pyo3::prepare_freethreaded_python();
220        pyo3::Python::with_gil(|py| {
221            crate::runtime::venv::configure_python_sys_path(py)
222                .unwrap_or_else(|e| panic!("Failed to configure python sys.path: {e}"));
223            let m = py
224                .import_bound(module)
225                .unwrap_or_else(|e| panic!("Failed to import {module}: {e}"));
226            let cls = m
227                .getattr(class)
228                .unwrap_or_else(|e| panic!("Failed to get {module}.{class}: {e}"));
229            let fields = cls
230                .getattr("model_fields")
231                .unwrap_or_else(|e| panic!("Failed to get {module}.{class}.model_fields: {e}"));
232            let field_info = fields.get_item(field).unwrap_or_else(|e| {
233                panic!("Failed to get field '{field}' from {module}.{class}.model_fields: {e}")
234            });
235            field_info
236                .getattr("default")
237                .unwrap_or_else(|e| {
238                    panic!("Failed to get default for {module}.{class}.{field}: {e}")
239                })
240                .extract::<f64>()
241                .unwrap_or_else(|e| {
242                    panic!("Failed to extract {module}.{class}.{field} default as f64: {e}")
243                })
244        })
245    }
246
247    #[test]
248    fn mcp_server_config_stdio_roundtrip() {
249        let config = McpServer::Stdio(McpStdioServer {
250            command: "npx".to_string(),
251            args: vec![
252                "-y".to_string(),
253                "@modelcontextprotocol/server-filesystem".to_string(),
254            ],
255        });
256        let json = serde_json::to_string(&config).unwrap();
257        let parsed: McpServer = serde_json::from_str(&json).unwrap();
258        match parsed {
259            McpServer::Stdio(s) => {
260                assert_eq!(s.command, "npx");
261                assert_eq!(
262                    s.args,
263                    vec!["-y", "@modelcontextprotocol/server-filesystem"]
264                );
265            }
266            other => panic!("Expected Stdio, got {other:?}"),
267        }
268        // Verify the JSON contains the "type" tag from serde.
269        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
270        assert_eq!(value["type"], "stdio");
271    }
272
273    #[test]
274    fn mcp_server_config_sse_roundtrip() {
275        let config = McpServer::Sse(McpSseServer {
276            url: "http://localhost:8080/sse".to_string(),
277            headers: Some(std::collections::HashMap::from([(
278                "Authorization".to_string(),
279                "Bearer token123".to_string(),
280            )])),
281        });
282        let json = serde_json::to_string(&config).unwrap();
283        let parsed: McpServer = serde_json::from_str(&json).unwrap();
284        match parsed {
285            McpServer::Sse(s) => {
286                assert_eq!(s.url, "http://localhost:8080/sse");
287                assert_eq!(
288                    s.headers.as_ref().unwrap()["Authorization"],
289                    "Bearer token123"
290                );
291            }
292            other => panic!("Expected Sse, got {other:?}"),
293        }
294        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
295        assert_eq!(value["type"], "sse");
296    }
297
298    #[test]
299    fn mcp_server_config_http_roundtrip() {
300        let config = McpServer::Http(McpStreamableHttpServer {
301            url: "http://localhost:9090/mcp".to_string(),
302            headers: None,
303            timeout: 60.0,
304            sse_read_timeout: 120.0,
305            terminate_on_close: false,
306        });
307        let json = serde_json::to_string(&config).unwrap();
308        let parsed: McpServer = serde_json::from_str(&json).unwrap();
309        match parsed {
310            McpServer::Http(s) => {
311                assert_eq!(s.url, "http://localhost:9090/mcp");
312                assert!(s.headers.is_none());
313                assert!((s.timeout - 60.0).abs() < f64::EPSILON);
314                assert!((s.sse_read_timeout - 120.0).abs() < f64::EPSILON);
315                assert!(!s.terminate_on_close);
316            }
317            other => panic!("Expected Http, got {other:?}"),
318        }
319        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
320        assert_eq!(value["type"], "http");
321    }
322
323    #[test]
324    fn mcp_server_config_http_defaults_roundtrip() {
325        // Deserialize with only required fields to verify defaults.
326        let json = r#"{"type":"http","url":"http://example.com/mcp"}"#;
327        let parsed: McpServer = serde_json::from_str(json).unwrap();
328        match parsed {
329            McpServer::Http(s) => {
330                assert_eq!(s.url, "http://example.com/mcp");
331                assert!(s.headers.is_none());
332                assert!((s.timeout - 30.0).abs() < f64::EPSILON);
333                assert!((s.sse_read_timeout - 300.0).abs() < f64::EPSILON);
334                assert!(s.terminate_on_close);
335            }
336            other => panic!("Expected Http, got {other:?}"),
337        }
338    }
339
340    #[test]
341    fn mcp_timeout_matches_python_sdk() {
342        let py_val = py_pydantic_field_default(
343            "google.antigravity.types",
344            "McpStreamableHttpServer",
345            "timeout",
346        );
347        assert!(
348            (DEFAULT_MCP_TIMEOUT_SECS - py_val).abs() < f64::EPSILON,
349            "Rust DEFAULT_MCP_TIMEOUT_SECS ({DEFAULT_MCP_TIMEOUT_SECS}) != Python SDK ({py_val})"
350        );
351    }
352
353    #[test]
354    fn mcp_sse_read_timeout_matches_python_sdk() {
355        let py_val = py_pydantic_field_default(
356            "google.antigravity.types",
357            "McpStreamableHttpServer",
358            "sse_read_timeout",
359        );
360        assert!(
361            (DEFAULT_MCP_SSE_READ_TIMEOUT_SECS - py_val).abs() < f64::EPSILON,
362            "Rust DEFAULT_MCP_SSE_READ_TIMEOUT_SECS ({DEFAULT_MCP_SSE_READ_TIMEOUT_SECS}) != Python SDK ({py_val})"
363        );
364    }
365
366    #[test]
367    fn test_mcp_server_builders() {
368        let stdio = McpServer::stdio("npx")
369            .args(["-y", "@modelcontextprotocol/server-postgres"])
370            .build();
371        match stdio {
372            McpServer::Stdio(s) => {
373                assert_eq!(s.command, "npx");
374                assert_eq!(s.args, vec!["-y", "@modelcontextprotocol/server-postgres"]);
375            }
376            _ => panic!("Expected Stdio"),
377        }
378
379        let sse = McpServer::sse("http://example.com/sse")
380            .header("Auth", "token")
381            .build();
382        match sse {
383            McpServer::Sse(s) => {
384                assert_eq!(s.url, "http://example.com/sse");
385                assert_eq!(s.headers.as_ref().unwrap()["Auth"], "token");
386            }
387            _ => panic!("Expected Sse"),
388        }
389
390        let http = McpServer::http("http://example.com/http")
391            .header("Auth", "token")
392            .timeout(10.0)
393            .build();
394        match http {
395            McpServer::Http(s) => {
396                assert_eq!(s.url, "http://example.com/http");
397                assert_eq!(s.headers.as_ref().unwrap()["Auth"], "token");
398                assert!((s.timeout - 10.0).abs() < f64::EPSILON);
399            }
400            _ => panic!("Expected Http"),
401        }
402    }
403}