1use serde::{Deserialize, Serialize};
4
5use super::{default_mcp_sse_read_timeout, default_mcp_timeout, default_true};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct McpStdioServer {
10 pub command: String,
12 #[serde(default)]
14 pub args: Vec<String>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpSseServer {
20 pub url: String,
22 #[serde(default)]
24 pub headers: Option<std::collections::HashMap<String, String>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpStreamableHttpServer {
30 pub url: String,
32 #[serde(default)]
34 pub headers: Option<std::collections::HashMap<String, String>>,
35 #[serde(default = "default_mcp_timeout")]
37 pub timeout: f64,
38 #[serde(default = "default_mcp_sse_read_timeout")]
40 pub sse_read_timeout: f64,
41 #[serde(default = "default_true")]
43 pub terminate_on_close: bool,
44}
45
46#[non_exhaustive]
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(tag = "type")]
56pub enum McpServer {
57 #[serde(rename = "stdio")]
58 Stdio(McpStdioServer),
59 #[serde(rename = "sse")]
60 Sse(McpSseServer),
61 #[serde(rename = "http")]
62 Http(McpStreamableHttpServer),
63}
64
65impl McpServer {
66 #[must_use]
68 pub fn stdio(command: impl Into<String>) -> McpStdioServer {
69 McpStdioServer::new(command)
70 }
71
72 #[must_use]
74 pub fn sse(url: impl Into<String>) -> McpSseServer {
75 McpSseServer::new(url)
76 }
77
78 #[must_use]
80 pub fn http(url: impl Into<String>) -> McpStreamableHttpServer {
81 McpStreamableHttpServer::new(url)
82 }
83}
84
85impl From<McpStdioServer> for McpServer {
88 fn from(val: McpStdioServer) -> Self {
89 Self::Stdio(val)
90 }
91}
92
93impl McpStdioServer {
94 #[must_use]
96 pub fn new(command: impl Into<String>) -> Self {
97 Self {
98 command: command.into(),
99 args: Vec::new(),
100 }
101 }
102
103 #[must_use]
105 pub fn arg(mut self, arg: impl Into<String>) -> Self {
106 self.args.push(arg.into());
107 self
108 }
109
110 #[must_use]
112 pub fn args<I, S>(mut self, args: I) -> Self
113 where
114 I: IntoIterator<Item = S>,
115 S: Into<String>,
116 {
117 self.args.extend(args.into_iter().map(Into::into));
118 self
119 }
120
121 #[must_use]
123 pub fn build(self) -> McpServer {
124 McpServer::Stdio(self)
125 }
126}
127
128impl From<McpSseServer> for McpServer {
129 fn from(val: McpSseServer) -> Self {
130 Self::Sse(val)
131 }
132}
133
134impl McpSseServer {
135 #[must_use]
137 pub fn new(url: impl Into<String>) -> Self {
138 Self {
139 url: url.into(),
140 headers: None,
141 }
142 }
143
144 #[must_use]
146 pub fn header(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
147 self.headers
148 .get_or_insert_with(std::collections::HashMap::new)
149 .insert(k.into(), v.into());
150 self
151 }
152
153 #[must_use]
155 pub fn build(self) -> McpServer {
156 McpServer::Sse(self)
157 }
158}
159
160impl From<McpStreamableHttpServer> for McpServer {
161 fn from(val: McpStreamableHttpServer) -> Self {
162 Self::Http(val)
163 }
164}
165
166impl McpStreamableHttpServer {
167 #[must_use]
169 pub fn new(url: impl Into<String>) -> Self {
170 Self {
171 url: url.into(),
172 headers: None,
173 timeout: default_mcp_timeout(),
174 sse_read_timeout: default_mcp_sse_read_timeout(),
175 terminate_on_close: true,
176 }
177 }
178
179 #[must_use]
181 pub fn header(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
182 self.headers
183 .get_or_insert_with(std::collections::HashMap::new)
184 .insert(k.into(), v.into());
185 self
186 }
187
188 #[must_use]
190 pub const fn timeout(mut self, timeout: f64) -> Self {
191 self.timeout = timeout;
192 self
193 }
194
195 #[must_use]
197 pub const fn sse_read_timeout(mut self, timeout: f64) -> Self {
198 self.sse_read_timeout = timeout;
199 self
200 }
201
202 #[must_use]
204 pub fn build(self) -> McpServer {
205 McpServer::Http(self)
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use pyo3::types::PyAnyMethods;
212
213 use super::{
214 super::{DEFAULT_MCP_SSE_READ_TIMEOUT_SECS, DEFAULT_MCP_TIMEOUT_SECS},
215 *,
216 };
217
218 fn py_pydantic_field_default(module: &str, class: &str, field: &str) -> f64 {
219 pyo3::prepare_freethreaded_python();
220 pyo3::Python::with_gil(|py| {
221 crate::runtime::venv::configure_python_sys_path(py)
222 .unwrap_or_else(|e| panic!("Failed to configure python sys.path: {e}"));
223 let m = py
224 .import_bound(module)
225 .unwrap_or_else(|e| panic!("Failed to import {module}: {e}"));
226 let cls = m
227 .getattr(class)
228 .unwrap_or_else(|e| panic!("Failed to get {module}.{class}: {e}"));
229 let fields = cls
230 .getattr("model_fields")
231 .unwrap_or_else(|e| panic!("Failed to get {module}.{class}.model_fields: {e}"));
232 let field_info = fields.get_item(field).unwrap_or_else(|e| {
233 panic!("Failed to get field '{field}' from {module}.{class}.model_fields: {e}")
234 });
235 field_info
236 .getattr("default")
237 .unwrap_or_else(|e| {
238 panic!("Failed to get default for {module}.{class}.{field}: {e}")
239 })
240 .extract::<f64>()
241 .unwrap_or_else(|e| {
242 panic!("Failed to extract {module}.{class}.{field} default as f64: {e}")
243 })
244 })
245 }
246
247 #[test]
248 fn mcp_server_config_stdio_roundtrip() {
249 let config = McpServer::Stdio(McpStdioServer {
250 command: "npx".to_string(),
251 args: vec![
252 "-y".to_string(),
253 "@modelcontextprotocol/server-filesystem".to_string(),
254 ],
255 });
256 let json = serde_json::to_string(&config).unwrap();
257 let parsed: McpServer = serde_json::from_str(&json).unwrap();
258 match parsed {
259 McpServer::Stdio(s) => {
260 assert_eq!(s.command, "npx");
261 assert_eq!(
262 s.args,
263 vec!["-y", "@modelcontextprotocol/server-filesystem"]
264 );
265 }
266 other => panic!("Expected Stdio, got {other:?}"),
267 }
268 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
270 assert_eq!(value["type"], "stdio");
271 }
272
273 #[test]
274 fn mcp_server_config_sse_roundtrip() {
275 let config = McpServer::Sse(McpSseServer {
276 url: "http://localhost:8080/sse".to_string(),
277 headers: Some(std::collections::HashMap::from([(
278 "Authorization".to_string(),
279 "Bearer token123".to_string(),
280 )])),
281 });
282 let json = serde_json::to_string(&config).unwrap();
283 let parsed: McpServer = serde_json::from_str(&json).unwrap();
284 match parsed {
285 McpServer::Sse(s) => {
286 assert_eq!(s.url, "http://localhost:8080/sse");
287 assert_eq!(
288 s.headers.as_ref().unwrap()["Authorization"],
289 "Bearer token123"
290 );
291 }
292 other => panic!("Expected Sse, got {other:?}"),
293 }
294 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
295 assert_eq!(value["type"], "sse");
296 }
297
298 #[test]
299 fn mcp_server_config_http_roundtrip() {
300 let config = McpServer::Http(McpStreamableHttpServer {
301 url: "http://localhost:9090/mcp".to_string(),
302 headers: None,
303 timeout: 60.0,
304 sse_read_timeout: 120.0,
305 terminate_on_close: false,
306 });
307 let json = serde_json::to_string(&config).unwrap();
308 let parsed: McpServer = serde_json::from_str(&json).unwrap();
309 match parsed {
310 McpServer::Http(s) => {
311 assert_eq!(s.url, "http://localhost:9090/mcp");
312 assert!(s.headers.is_none());
313 assert!((s.timeout - 60.0).abs() < f64::EPSILON);
314 assert!((s.sse_read_timeout - 120.0).abs() < f64::EPSILON);
315 assert!(!s.terminate_on_close);
316 }
317 other => panic!("Expected Http, got {other:?}"),
318 }
319 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
320 assert_eq!(value["type"], "http");
321 }
322
323 #[test]
324 fn mcp_server_config_http_defaults_roundtrip() {
325 let json = r#"{"type":"http","url":"http://example.com/mcp"}"#;
327 let parsed: McpServer = serde_json::from_str(json).unwrap();
328 match parsed {
329 McpServer::Http(s) => {
330 assert_eq!(s.url, "http://example.com/mcp");
331 assert!(s.headers.is_none());
332 assert!((s.timeout - 30.0).abs() < f64::EPSILON);
333 assert!((s.sse_read_timeout - 300.0).abs() < f64::EPSILON);
334 assert!(s.terminate_on_close);
335 }
336 other => panic!("Expected Http, got {other:?}"),
337 }
338 }
339
340 #[test]
341 fn mcp_timeout_matches_python_sdk() {
342 let py_val = py_pydantic_field_default(
343 "google.antigravity.types",
344 "McpStreamableHttpServer",
345 "timeout",
346 );
347 assert!(
348 (DEFAULT_MCP_TIMEOUT_SECS - py_val).abs() < f64::EPSILON,
349 "Rust DEFAULT_MCP_TIMEOUT_SECS ({DEFAULT_MCP_TIMEOUT_SECS}) != Python SDK ({py_val})"
350 );
351 }
352
353 #[test]
354 fn mcp_sse_read_timeout_matches_python_sdk() {
355 let py_val = py_pydantic_field_default(
356 "google.antigravity.types",
357 "McpStreamableHttpServer",
358 "sse_read_timeout",
359 );
360 assert!(
361 (DEFAULT_MCP_SSE_READ_TIMEOUT_SECS - py_val).abs() < f64::EPSILON,
362 "Rust DEFAULT_MCP_SSE_READ_TIMEOUT_SECS ({DEFAULT_MCP_SSE_READ_TIMEOUT_SECS}) != Python SDK ({py_val})"
363 );
364 }
365
366 #[test]
367 fn test_mcp_server_builders() {
368 let stdio = McpServer::stdio("npx")
369 .args(["-y", "@modelcontextprotocol/server-postgres"])
370 .build();
371 match stdio {
372 McpServer::Stdio(s) => {
373 assert_eq!(s.command, "npx");
374 assert_eq!(s.args, vec!["-y", "@modelcontextprotocol/server-postgres"]);
375 }
376 _ => panic!("Expected Stdio"),
377 }
378
379 let sse = McpServer::sse("http://example.com/sse")
380 .header("Auth", "token")
381 .build();
382 match sse {
383 McpServer::Sse(s) => {
384 assert_eq!(s.url, "http://example.com/sse");
385 assert_eq!(s.headers.as_ref().unwrap()["Auth"], "token");
386 }
387 _ => panic!("Expected Sse"),
388 }
389
390 let http = McpServer::http("http://example.com/http")
391 .header("Auth", "token")
392 .timeout(10.0)
393 .build();
394 match http {
395 McpServer::Http(s) => {
396 assert_eq!(s.url, "http://example.com/http");
397 assert_eq!(s.headers.as_ref().unwrap()["Auth"], "token");
398 assert!((s.timeout - 10.0).abs() < f64::EPSILON);
399 }
400 _ => panic!("Expected Http"),
401 }
402 }
403}