1use super::AppState;
15use super::mcp_discovery::{McpDiscovery, read_construct_mcp};
16use crate::config::schema::{McpServerConfig, McpTransport};
17use crate::tools::mcp_client::McpServer;
18use axum::{
19 body::Body,
20 extract::{Path, State},
21 http::{HeaderMap, StatusCode, header},
22 response::{IntoResponse, Json, Response},
23};
24use futures_util::StreamExt;
25use serde::Deserialize;
26use serde_json::{Value, json};
27use std::collections::HashMap;
28use std::time::{Duration, Instant};
29use tokio::time::timeout;
30
31fn health_url_from_discovery(url: &str) -> String {
35 let trimmed = url.trim_end_matches('/');
36 match trimmed.strip_suffix("/mcp") {
37 Some(base) => format!("{base}/health"),
38 None => format!("{trimmed}/health"),
39 }
40}
41
42#[async_trait::async_trait]
44pub trait HealthProbe: Send + Sync {
45 async fn get_health(&self, url: &str) -> Result<Value, String>;
46}
47
48pub struct ReqwestHealthProbe;
50
51#[async_trait::async_trait]
52impl HealthProbe for ReqwestHealthProbe {
53 async fn get_health(&self, url: &str) -> Result<Value, String> {
54 let client = reqwest::Client::builder()
55 .timeout(Duration::from_millis(500))
56 .build()
57 .map_err(|e| e.to_string())?;
58 let resp = client.get(url).send().await.map_err(|e| e.to_string())?;
59 if !resp.status().is_success() {
60 return Err(format!("health status {}", resp.status()));
61 }
62 resp.json::<Value>().await.map_err(|e| e.to_string())
63 }
64}
65
66pub async fn build_discovery_payload(
68 discovery: Option<McpDiscovery>,
69 probe: &dyn HealthProbe,
70) -> Value {
71 let Some(d) = discovery else {
72 return json!({
73 "available": false,
74 "reason": "discovery file missing",
75 });
76 };
77 let health_url = health_url_from_discovery(&d.url);
78 match probe.get_health(&health_url).await {
79 Ok(health) => json!({
80 "available": true,
81 "url": d.url,
82 "health": health,
83 }),
84 Err(_) => json!({
85 "available": false,
86 "reason": "health check failed",
87 }),
88 }
89}
90
91pub async fn handle_api_mcp_discovery(
93 State(state): State<AppState>,
94 headers: HeaderMap,
95) -> impl IntoResponse {
96 if let Err(e) = super::api::require_auth(&state, &headers) {
97 return e.into_response();
98 }
99
100 let discovery = read_construct_mcp().ok();
101 let payload = build_discovery_payload(discovery, &ReqwestHealthProbe).await;
102 (StatusCode::OK, Json(payload)).into_response()
103}
104
105fn join_mcp_url(base: &str, path: &str) -> String {
120 format!("{}{path}", base.trim_end_matches('/'))
121}
122
123fn mcp_upstream_url(state: &AppState, path: &str) -> Option<String> {
126 let base = state.mcp_local_url.as_ref()?;
127 Some(join_mcp_url(base, path))
128}
129
130fn mcp_unavailable() -> Response {
133 (
134 StatusCode::SERVICE_UNAVAILABLE,
135 Json(json!({
136 "available": false,
137 "reason": "mcp server not bound",
138 })),
139 )
140 .into_response()
141}
142
143pub async fn handle_api_mcp_health(State(state): State<AppState>, headers: HeaderMap) -> Response {
146 if let Err(e) = super::api::require_auth(&state, &headers) {
147 return e.into_response();
148 }
149 let Some(url) = mcp_upstream_url(&state, "/health") else {
150 return mcp_unavailable();
151 };
152 let client = match reqwest::Client::builder()
153 .timeout(Duration::from_secs(5))
154 .build()
155 {
156 Ok(c) => c,
157 Err(e) => {
158 tracing::warn!("api_mcp: build client failed: {e}");
159 return (StatusCode::INTERNAL_SERVER_ERROR, "client build failed").into_response();
160 }
161 };
162 match client.get(&url).send().await {
163 Ok(resp) => {
164 let status =
165 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
166 let ct = resp
167 .headers()
168 .get(header::CONTENT_TYPE)
169 .and_then(|v| v.to_str().ok())
170 .unwrap_or("application/json")
171 .to_string();
172 let body = resp.bytes().await.unwrap_or_default();
173 (status, [(header::CONTENT_TYPE, ct)], body).into_response()
174 }
175 Err(e) => {
176 tracing::warn!("api_mcp: health upstream error: {e}");
177 (StatusCode::BAD_GATEWAY, "mcp upstream error").into_response()
178 }
179 }
180}
181
182pub async fn handle_api_mcp_session_create(
188 State(state): State<AppState>,
189 headers: HeaderMap,
190 body: axum::body::Bytes,
191) -> Response {
192 if let Err(e) = super::api::require_auth(&state, &headers) {
193 return e.into_response();
194 }
195 let Some(url) = mcp_upstream_url(&state, "/session") else {
196 return mcp_unavailable();
197 };
198 let client = match reqwest::Client::builder()
199 .timeout(Duration::from_secs(10))
200 .build()
201 {
202 Ok(c) => c,
203 Err(e) => {
204 tracing::warn!("api_mcp: build client failed: {e}");
205 return (StatusCode::INTERNAL_SERVER_ERROR, "client build failed").into_response();
206 }
207 };
208 let mut req = client.post(&url).body(body.to_vec());
209 if let Some(ct) = headers
210 .get(header::CONTENT_TYPE)
211 .and_then(|v| v.to_str().ok())
212 {
213 req = req.header(header::CONTENT_TYPE, ct);
214 } else {
215 req = req.header(header::CONTENT_TYPE, "application/json");
216 }
217 match req.send().await {
218 Ok(resp) => {
219 let status =
220 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
221 let ct = resp
222 .headers()
223 .get(header::CONTENT_TYPE)
224 .and_then(|v| v.to_str().ok())
225 .unwrap_or("application/json")
226 .to_string();
227 let bytes = resp.bytes().await.unwrap_or_default();
228 (status, [(header::CONTENT_TYPE, ct)], bytes).into_response()
229 }
230 Err(e) => {
231 tracing::warn!("api_mcp: session upstream error: {e}");
232 (StatusCode::BAD_GATEWAY, "mcp upstream error").into_response()
233 }
234 }
235}
236
237pub async fn handle_api_mcp_call(
239 State(state): State<AppState>,
240 headers: HeaderMap,
241 body: axum::body::Bytes,
242) -> Response {
243 if let Err(e) = super::api::require_auth(&state, &headers) {
244 return e.into_response();
245 }
246 let Some(url) = mcp_upstream_url(&state, "/mcp") else {
247 return mcp_unavailable();
248 };
249 let client = match reqwest::Client::builder()
252 .timeout(Duration::from_secs(120))
253 .build()
254 {
255 Ok(c) => c,
256 Err(e) => {
257 tracing::warn!("api_mcp: build client failed: {e}");
258 return (StatusCode::INTERNAL_SERVER_ERROR, "client build failed").into_response();
259 }
260 };
261 let mut req = client.post(&url).body(body.to_vec());
262 if let Some(ct) = headers
263 .get(header::CONTENT_TYPE)
264 .and_then(|v| v.to_str().ok())
265 {
266 req = req.header(header::CONTENT_TYPE, ct);
267 } else {
268 req = req.header(header::CONTENT_TYPE, "application/json");
269 }
270 if let Some(auth) = headers
271 .get(header::AUTHORIZATION)
272 .and_then(|v| v.to_str().ok())
273 {
274 req = req.header(header::AUTHORIZATION, auth);
279 }
280 match req.send().await {
281 Ok(resp) => {
282 let status =
283 StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
284 let ct = resp
285 .headers()
286 .get(header::CONTENT_TYPE)
287 .and_then(|v| v.to_str().ok())
288 .unwrap_or("application/json")
289 .to_string();
290 let bytes = resp.bytes().await.unwrap_or_default();
291 (status, [(header::CONTENT_TYPE, ct)], bytes).into_response()
292 }
293 Err(e) => {
294 tracing::warn!("api_mcp: call upstream error: {e}");
295 (StatusCode::BAD_GATEWAY, "mcp upstream error").into_response()
296 }
297 }
298}
299
300pub async fn handle_api_mcp_session_events(
306 State(state): State<AppState>,
307 headers: HeaderMap,
308 Path(session_id): Path<String>,
309) -> Response {
310 if let Err(e) = super::api::require_auth(&state, &headers) {
311 return e.into_response();
312 }
313 let Some(url) = mcp_upstream_url(&state, &format!("/session/{session_id}/events")) else {
314 return mcp_unavailable();
315 };
316 let client = match reqwest::Client::builder()
317 .connect_timeout(Duration::from_secs(5))
318 .build()
320 {
321 Ok(c) => c,
322 Err(e) => {
323 tracing::warn!("api_mcp: build sse client failed: {e}");
324 return (StatusCode::INTERNAL_SERVER_ERROR, "client build failed").into_response();
325 }
326 };
327 let mut req = client.get(&url).header(header::ACCEPT, "text/event-stream");
328 if let Some(auth) = headers
329 .get(header::AUTHORIZATION)
330 .and_then(|v| v.to_str().ok())
331 {
332 req = req.header(header::AUTHORIZATION, auth);
333 }
334 let upstream = match req.send().await {
335 Ok(r) => r,
336 Err(e) => {
337 tracing::warn!("api_mcp: sse upstream connect failed: {e}");
338 return (StatusCode::BAD_GATEWAY, "mcp upstream error").into_response();
339 }
340 };
341 if !upstream.status().is_success() {
342 let status =
343 StatusCode::from_u16(upstream.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
344 let body = upstream.bytes().await.unwrap_or_default();
345 return (status, body).into_response();
346 }
347 let byte_stream = upstream
348 .bytes_stream()
349 .map(|r| r.map_err(std::io::Error::other));
350 Response::builder()
351 .status(StatusCode::OK)
352 .header(header::CONTENT_TYPE, "text/event-stream")
353 .header(header::CACHE_CONTROL, "no-cache")
354 .header("x-accel-buffering", "no")
355 .body(Body::from_stream(byte_stream))
356 .unwrap_or_else(|_| {
357 (StatusCode::INTERNAL_SERVER_ERROR, "response build failed").into_response()
358 })
359}
360
361const TEST_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
370
371#[derive(Debug, Deserialize)]
375pub struct TestServerRequest {
376 pub name: String,
377 pub transport: String,
378 #[serde(default)]
379 pub command: Option<String>,
380 #[serde(default)]
381 pub args: Option<Vec<String>>,
382 #[serde(default)]
383 pub env: Option<HashMap<String, String>>,
384 #[serde(default)]
385 pub url: Option<String>,
386 #[serde(default)]
387 pub headers: Option<HashMap<String, String>>,
388 #[serde(default)]
389 pub timeout_ms: Option<u64>,
390}
391
392pub fn request_to_config(req: &TestServerRequest) -> Result<McpServerConfig, String> {
398 if req.name.trim().is_empty() {
399 return Err("name is required".to_string());
400 }
401 let transport = match req.transport.as_str() {
402 "stdio" => McpTransport::Stdio,
403 "http" => McpTransport::Http,
404 "sse" => McpTransport::Sse,
405 other => return Err(format!("unknown transport `{other}`")),
406 };
407 match transport {
408 McpTransport::Stdio => {
409 if req
410 .command
411 .as_deref()
412 .map(str::trim)
413 .unwrap_or("")
414 .is_empty()
415 {
416 return Err("command is required for stdio transport".to_string());
417 }
418 }
419 McpTransport::Http | McpTransport::Sse => {
420 if req.url.as_deref().map(str::trim).unwrap_or("").is_empty() {
421 return Err("url is required for http/sse transport".to_string());
422 }
423 }
424 }
425 let tool_timeout_secs = req.timeout_ms.map(|ms| (ms / 1000).max(1));
426 Ok(McpServerConfig {
427 name: req.name.clone(),
428 transport,
429 url: req.url.clone(),
430 command: req.command.clone().unwrap_or_default(),
431 args: req.args.clone().unwrap_or_default(),
432 env: req.env.clone().unwrap_or_default(),
433 headers: req.headers.clone().unwrap_or_default(),
434 tool_timeout_secs,
435 })
436}
437
438pub async fn handle_api_mcp_servers_test(
440 State(state): State<AppState>,
441 headers: HeaderMap,
442 Json(req): Json<TestServerRequest>,
443) -> impl IntoResponse {
444 if let Err(e) = super::api::require_auth(&state, &headers) {
445 return e.into_response();
446 }
447
448 let config = match request_to_config(&req) {
449 Ok(c) => c,
450 Err(msg) => {
451 return (
452 StatusCode::OK,
453 Json(json!({
454 "ok": false,
455 "error": msg,
456 "latency_ms": 0,
457 })),
458 )
459 .into_response();
460 }
461 };
462
463 let started = Instant::now();
464 let result = timeout(
465 Duration::from_secs(TEST_HANDSHAKE_TIMEOUT_SECS),
466 McpServer::connect(config),
467 )
468 .await;
469 let latency_ms = started.elapsed().as_millis() as u64;
470
471 let payload = match result {
472 Ok(Ok(server)) => {
473 let tools = server.tools().await;
474 let tool_names: Vec<String> = tools.iter().map(|t| t.name.clone()).collect();
475 json!({
476 "ok": true,
477 "tool_count": tools.len(),
478 "tools": tool_names,
479 "latency_ms": latency_ms,
480 })
481 }
482 Ok(Err(e)) => json!({
483 "ok": false,
484 "error": format!("{e:#}"),
485 "latency_ms": latency_ms,
486 }),
487 Err(_) => json!({
488 "ok": false,
489 "error": format!("timed out after {TEST_HANDSHAKE_TIMEOUT_SECS}s"),
490 "latency_ms": latency_ms,
491 }),
492 };
493
494 (StatusCode::OK, Json(payload)).into_response()
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use std::sync::atomic::{AtomicUsize, Ordering};
501
502 struct FakeProbeOk;
503 #[async_trait::async_trait]
504 impl HealthProbe for FakeProbeOk {
505 async fn get_health(&self, _url: &str) -> Result<Value, String> {
506 Ok(json!({
507 "status": "ok",
508 "pid": 123,
509 "uptime_seconds": 5,
510 "started_at": "2026-04-17T00:00:00Z",
511 "protocol_version": "2024-11-05",
512 }))
513 }
514 }
515
516 struct FakeProbeErr;
517 #[async_trait::async_trait]
518 impl HealthProbe for FakeProbeErr {
519 async fn get_health(&self, _url: &str) -> Result<Value, String> {
520 Err("connection refused".into())
521 }
522 }
523
524 struct CountingProbe(AtomicUsize);
525 #[async_trait::async_trait]
526 impl HealthProbe for CountingProbe {
527 async fn get_health(&self, url: &str) -> Result<Value, String> {
528 self.0.fetch_add(1, Ordering::SeqCst);
529 Ok(json!({"hit": url}))
530 }
531 }
532
533 #[test]
534 fn join_mcp_url_composes_base_and_path() {
535 assert_eq!(
536 join_mcp_url("http://127.0.0.1:60004", "/session"),
537 "http://127.0.0.1:60004/session"
538 );
539 assert_eq!(
540 join_mcp_url("http://127.0.0.1:60004/", "/session"),
541 "http://127.0.0.1:60004/session"
542 );
543 assert_eq!(
544 join_mcp_url("http://127.0.0.1:60004", "/session/abc/events"),
545 "http://127.0.0.1:60004/session/abc/events"
546 );
547 assert_eq!(
548 join_mcp_url("http://127.0.0.1:60004", "/mcp"),
549 "http://127.0.0.1:60004/mcp"
550 );
551 }
552
553 #[test]
554 fn health_url_strips_mcp_suffix() {
555 assert_eq!(
556 health_url_from_discovery("http://127.0.0.1:54500/mcp"),
557 "http://127.0.0.1:54500/health"
558 );
559 assert_eq!(
560 health_url_from_discovery("http://127.0.0.1:54500/mcp/"),
561 "http://127.0.0.1:54500/health"
562 );
563 assert_eq!(
564 health_url_from_discovery("http://127.0.0.1:54500"),
565 "http://127.0.0.1:54500/health"
566 );
567 }
568
569 #[tokio::test]
570 async fn discovery_missing_file() {
571 let v = build_discovery_payload(None, &FakeProbeOk).await;
572 assert_eq!(v["available"], false);
573 assert_eq!(v["reason"], "discovery file missing");
574 }
575
576 #[tokio::test]
577 async fn discovery_present_daemon_reachable() {
578 let d = McpDiscovery {
579 url: "http://127.0.0.1:50000/mcp".into(),
580 pid: Some(42),
581 started_at: None,
582 };
583 let v = build_discovery_payload(Some(d), &FakeProbeOk).await;
584 assert_eq!(v["available"], true);
585 assert_eq!(v["url"], "http://127.0.0.1:50000/mcp");
586 assert_eq!(v["health"]["status"], "ok");
587 assert_eq!(v["health"]["pid"], 123);
588 }
589
590 #[tokio::test]
591 async fn discovery_present_daemon_unreachable() {
592 let d = McpDiscovery {
593 url: "http://127.0.0.1:50000/mcp".into(),
594 pid: Some(42),
595 started_at: None,
596 };
597 let v = build_discovery_payload(Some(d), &FakeProbeErr).await;
598 assert_eq!(v["available"], false);
599 assert_eq!(v["reason"], "health check failed");
600 }
601
602 #[test]
603 fn request_to_config_rejects_empty_name() {
604 let req = TestServerRequest {
605 name: " ".into(),
606 transport: "stdio".into(),
607 command: Some("x".into()),
608 args: None,
609 env: None,
610 url: None,
611 headers: None,
612 timeout_ms: None,
613 };
614 assert!(request_to_config(&req).unwrap_err().contains("name"));
615 }
616
617 #[test]
618 fn request_to_config_rejects_unknown_transport() {
619 let req = TestServerRequest {
620 name: "m".into(),
621 transport: "carrier-pigeon".into(),
622 command: None,
623 args: None,
624 env: None,
625 url: None,
626 headers: None,
627 timeout_ms: None,
628 };
629 assert!(
630 request_to_config(&req)
631 .unwrap_err()
632 .contains("unknown transport")
633 );
634 }
635
636 #[test]
637 fn request_to_config_stdio_requires_command() {
638 let req = TestServerRequest {
639 name: "m".into(),
640 transport: "stdio".into(),
641 command: Some(" ".into()),
642 args: None,
643 env: None,
644 url: None,
645 headers: None,
646 timeout_ms: None,
647 };
648 assert!(request_to_config(&req).unwrap_err().contains("command"));
649 }
650
651 #[test]
652 fn request_to_config_http_requires_url() {
653 let req = TestServerRequest {
654 name: "m".into(),
655 transport: "http".into(),
656 command: None,
657 args: None,
658 env: None,
659 url: Some("".into()),
660 headers: None,
661 timeout_ms: None,
662 };
663 assert!(request_to_config(&req).unwrap_err().contains("url"));
664 }
665
666 #[test]
667 fn request_to_config_maps_stdio_fields() {
668 let mut env = HashMap::new();
669 env.insert("API_KEY".into(), "secret".into());
670 let req = TestServerRequest {
671 name: "memory".into(),
672 transport: "stdio".into(),
673 command: Some("/usr/local/bin/mcp".into()),
674 args: Some(vec!["--flag".into(), "v".into()]),
675 env: Some(env.clone()),
676 url: None,
677 headers: None,
678 timeout_ms: Some(30_000),
679 };
680 let cfg = request_to_config(&req).unwrap();
681 assert_eq!(cfg.name, "memory");
682 assert_eq!(cfg.transport, McpTransport::Stdio);
683 assert_eq!(cfg.command, "/usr/local/bin/mcp");
684 assert_eq!(cfg.args, vec!["--flag", "v"]);
685 assert_eq!(cfg.env, env);
686 assert_eq!(cfg.tool_timeout_secs, Some(30));
687 }
688
689 #[test]
690 fn request_to_config_maps_http_fields() {
691 let mut hdr = HashMap::new();
692 hdr.insert("X-Auth".into(), "abc".into());
693 let req = TestServerRequest {
694 name: "remote".into(),
695 transport: "sse".into(),
696 command: None,
697 args: None,
698 env: None,
699 url: Some("https://example.com/mcp".into()),
700 headers: Some(hdr.clone()),
701 timeout_ms: Some(500),
702 };
703 let cfg = request_to_config(&req).unwrap();
704 assert_eq!(cfg.transport, McpTransport::Sse);
705 assert_eq!(cfg.url.as_deref(), Some("https://example.com/mcp"));
706 assert_eq!(cfg.headers, hdr);
707 assert_eq!(cfg.tool_timeout_secs, Some(1));
709 }
710
711 #[tokio::test]
712 async fn discovery_hits_health_url_only_once() {
713 let probe = CountingProbe(AtomicUsize::new(0));
714 let d = McpDiscovery {
715 url: "http://127.0.0.1:50000/mcp".into(),
716 pid: None,
717 started_at: None,
718 };
719 let _ = build_discovery_payload(Some(d), &probe).await;
720 assert_eq!(probe.0.load(Ordering::SeqCst), 1);
721 }
722}