1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use axum::extract::{Extension, State};
8use axum::http::header::{HeaderName, HeaderValue};
9use axum::http::{HeaderMap, Method, StatusCode};
10use axum::response::IntoResponse;
11use axum::response::Response;
12use axum::response::sse::{Event, KeepAlive, Sse};
13use axum::{Json, body::Body};
14use forge_core::config::McpConfig;
15use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
16use forge_core::mcp::McpToolContext;
17use forge_core::rate_limit::RateLimitKey;
18use futures_util::Stream;
19use serde_json::Value;
20use tokio::sync::RwLock;
21
22use crate::mcp::McpToolRegistry;
23use crate::rate_limit::RateLimiter;
24
25const SUPPORTED_VERSIONS: &[&str] = &["2025-11-25", "2025-03-26", "2024-11-05"];
26#[cfg(test)]
27const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
28const MCP_SESSION_HEADER: &str = "mcp-session-id";
29const MCP_PROTOCOL_HEADER: &str = "mcp-protocol-version";
30const DEFAULT_PAGE_SIZE: usize = 50;
31const MAX_MCP_SESSIONS: usize = 10_000;
32type ResponseError = Box<Response>;
33
34#[derive(Debug, Clone)]
35struct McpSession {
36 initialized: bool,
37 protocol_version: String,
38 expires_at: Instant,
39}
40
41#[derive(Clone)]
42pub struct McpState {
43 config: McpConfig,
44 registry: McpToolRegistry,
45 pool: sqlx::PgPool,
46 sessions: Arc<RwLock<HashMap<String, McpSession>>>,
47 job_dispatcher: Option<Arc<dyn JobDispatch>>,
48 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49 rate_limiter: Arc<RateLimiter>,
50}
51
52impl McpState {
53 pub fn new(
54 config: McpConfig,
55 registry: McpToolRegistry,
56 pool: sqlx::PgPool,
57 job_dispatcher: Option<Arc<dyn JobDispatch>>,
58 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
59 ) -> Self {
60 Self {
61 config,
62 registry,
63 pool: pool.clone(),
64 sessions: Arc::new(RwLock::new(HashMap::new())),
65 job_dispatcher,
66 workflow_dispatcher,
67 rate_limiter: Arc::new(RateLimiter::new(pool)),
68 }
69 }
70
71 async fn cleanup_expired_sessions(&self) {
72 let mut sessions = self.sessions.write().await;
73 let now = Instant::now();
74 sessions.retain(|_, session| session.expires_at > now);
75 }
76
77 async fn touch_session(&self, session_id: &str) {
78 let mut sessions = self.sessions.write().await;
79 if let Some(session) = sessions.get_mut(session_id) {
80 session.expires_at = Instant::now() + Duration::from_secs(self.config.session_ttl_secs);
81 }
82 }
83
84 fn session_ttl(&self) -> Duration {
85 Duration::from_secs(self.config.session_ttl_secs)
86 }
87}
88
89struct McpReceiverStream {
91 rx: tokio::sync::mpsc::Receiver<Result<Event, std::convert::Infallible>>,
92}
93
94impl Stream for McpReceiverStream {
95 type Item = Result<Event, std::convert::Infallible>;
96 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97 self.rx.poll_recv(cx)
98 }
99}
100
101pub async fn mcp_get_handler(State(state): State<Arc<McpState>>, headers: HeaderMap) -> Response {
108 if let Err(resp) = validate_origin(&headers, &state.config) {
109 return *resp;
110 }
111
112 if let Err(resp) = enforce_protocol_header(&state.config, &headers) {
113 return *resp;
114 }
115
116 let session_id = match required_session_id(&state, &headers, true).await {
118 Ok(v) => v,
119 Err(resp) => return resp,
120 };
121
122 state.touch_session(&session_id).await;
123
124 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Event, std::convert::Infallible>>(32);
126
127 let session_id_clone = session_id.clone();
129 tokio::spawn(async move {
130 let endpoint_data = serde_json::json!({
131 "sessionId": session_id_clone,
132 });
133 let _ = tx
134 .send(Ok(Event::default()
135 .event("endpoint")
136 .data(endpoint_data.to_string())))
137 .await;
138
139 loop {
143 tokio::time::sleep(Duration::from_secs(30)).await;
144 if tx.is_closed() {
145 break;
146 }
147 }
148 });
149
150 let stream = McpReceiverStream { rx };
151
152 let mut response = Sse::new(stream)
153 .keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
154 .into_response();
155
156 if let Ok(val) = HeaderValue::from_str(&session_id) {
158 response
159 .headers_mut()
160 .insert(HeaderName::from_static(MCP_SESSION_HEADER), val);
161 }
162
163 response
164}
165
166pub async fn mcp_post_handler(
167 State(state): State<Arc<McpState>>,
168 Extension(auth): Extension<AuthContext>,
169 Extension(tracing): Extension<super::tracing::TracingState>,
170 method: Method,
171 headers: HeaderMap,
172 Json(payload): Json<Value>,
173) -> Response {
174 if method != Method::POST {
175 return (
176 StatusCode::METHOD_NOT_ALLOWED,
177 Json(json_rpc_error(None, -32601, "Only POST is supported", None)),
178 )
179 .into_response();
180 }
181
182 if let Err(resp) = validate_origin(&headers, &state.config) {
183 return *resp;
184 }
185
186 state.cleanup_expired_sessions().await;
187
188 let Some(method_name) = payload.get("method").and_then(Value::as_str) else {
189 if payload.get("id").is_some()
191 && (payload.get("result").is_some() || payload.get("error").is_some())
192 {
193 return StatusCode::ACCEPTED.into_response();
194 }
195 return (
196 StatusCode::BAD_REQUEST,
197 Json(json_rpc_error(
198 None,
199 -32600,
200 "Invalid JSON-RPC payload",
201 None,
202 )),
203 )
204 .into_response();
205 };
206
207 let id = payload.get("id").cloned();
208 let params = payload
209 .get("params")
210 .cloned()
211 .unwrap_or(Value::Object(Default::default()));
212
213 if id.is_none() {
215 return handle_notification(&state, method_name, params, &headers).await;
216 }
217
218 if method_name != "initialize"
220 && let Err(resp) = enforce_protocol_header(&state.config, &headers)
221 {
222 return *resp;
223 }
224
225 match method_name {
226 "initialize" => handle_initialize(&state, id, ¶ms).await,
227 "tools/list" => {
228 let session_id = match required_session_id(&state, &headers, true).await {
229 Ok(v) => v,
230 Err(resp) => return resp,
231 };
232 state.touch_session(&session_id).await;
233 handle_tools_list(&state, id, ¶ms)
234 }
235 "tools/call" => {
236 let session_id = match required_session_id(&state, &headers, true).await {
237 Ok(v) => v,
238 Err(resp) => return resp,
239 };
240 state.touch_session(&session_id).await;
241
242 let metadata = build_request_metadata(&tracing, &headers);
243 handle_tools_call(&state, id, ¶ms, &auth, metadata).await
244 }
245 _ => (
246 StatusCode::OK,
247 Json(json_rpc_error(id, -32601, "Method not found", None)),
248 )
249 .into_response(),
250 }
251}
252
253async fn handle_notification(
254 state: &Arc<McpState>,
255 method_name: &str,
256 _params: Value,
257 headers: &HeaderMap,
258) -> Response {
259 if let Err(resp) = enforce_protocol_header(&state.config, headers) {
260 return *resp;
261 }
262
263 match method_name {
264 "notifications/initialized" => {
265 let session_id = match required_session_id(state, headers, false).await {
266 Ok(v) => v,
267 Err(resp) => return resp,
268 };
269
270 let mut sessions = state.sessions.write().await;
271 if let Some(session) = sessions.get_mut(&session_id) {
272 session.initialized = true;
273 session.expires_at = Instant::now() + state.session_ttl();
274 return StatusCode::ACCEPTED.into_response();
275 }
276
277 (
278 StatusCode::BAD_REQUEST,
279 Json(json_rpc_error(
280 None,
281 -32600,
282 "Unknown MCP session. Re-initialize the connection.",
283 None,
284 )),
285 )
286 .into_response()
287 }
288 _ => StatusCode::ACCEPTED.into_response(),
289 }
290}
291
292async fn handle_initialize(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
293 let Some(requested_version) = params.get("protocolVersion").and_then(Value::as_str) else {
294 return (
295 StatusCode::OK,
296 Json(json_rpc_error(
297 id,
298 -32602,
299 "Missing protocolVersion in initialize params",
300 None,
301 )),
302 )
303 .into_response();
304 };
305
306 if !SUPPORTED_VERSIONS.contains(&requested_version) {
307 return (
308 StatusCode::OK,
309 Json(json_rpc_error(
310 id,
311 -32602,
312 "Unsupported protocolVersion",
313 Some(serde_json::json!({
314 "supported": SUPPORTED_VERSIONS
315 })),
316 )),
317 )
318 .into_response();
319 }
320
321 let session_id = uuid::Uuid::new_v4().to_string();
322 {
323 let mut sessions = state.sessions.write().await;
324 if sessions.len() >= MAX_MCP_SESSIONS {
326 return (
327 StatusCode::SERVICE_UNAVAILABLE,
328 Json(json_rpc_error(
329 id,
330 -32000,
331 "Server at MCP session capacity",
332 None,
333 )),
334 )
335 .into_response();
336 }
337 sessions.insert(
338 session_id.clone(),
339 McpSession {
340 initialized: false,
341 protocol_version: requested_version.to_string(),
342 expires_at: Instant::now() + state.session_ttl(),
343 },
344 );
345 }
346
347 let mut response = (
348 StatusCode::OK,
349 Json(json_rpc_success(
350 id,
351 serde_json::json!({
352 "protocolVersion": requested_version,
353 "capabilities": {
354 "tools": {
355 "listChanged": false
356 }
357 },
358 "serverInfo": {
359 "name": "forge",
360 "version": env!("CARGO_PKG_VERSION")
361 }
362 }),
363 )),
364 )
365 .into_response();
366
367 set_header(&mut response, MCP_SESSION_HEADER, &session_id);
368 set_header(&mut response, MCP_PROTOCOL_HEADER, requested_version);
369 response
370}
371
372fn handle_tools_list(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
373 let cursor = params.get("cursor").and_then(Value::as_str);
374 let start = match cursor {
375 Some(c) => match c.parse::<usize>() {
376 Ok(v) => v,
377 Err(_) => {
378 return (
379 StatusCode::OK,
380 Json(json_rpc_error(
381 id,
382 -32602,
383 "Invalid cursor in tools/list request",
384 None,
385 )),
386 )
387 .into_response();
388 }
389 },
390 None => 0,
391 };
392
393 let mut tools: Vec<_> = state.registry.list().collect();
394 tools.sort_by(|a, b| a.info.name.cmp(b.info.name));
395
396 let page: Vec<_> = tools
397 .iter()
398 .skip(start)
399 .take(DEFAULT_PAGE_SIZE)
400 .map(|entry| {
401 let mut annotations = serde_json::Map::new();
404 if let Some(title) = &entry.info.annotations.title {
405 annotations.insert("title".into(), serde_json::Value::String(title.to_string()));
406 }
407 if let Some(v) = entry.info.annotations.read_only_hint {
408 annotations.insert("readOnlyHint".into(), serde_json::Value::Bool(v));
409 }
410 if let Some(v) = entry.info.annotations.destructive_hint {
411 annotations.insert("destructiveHint".into(), serde_json::Value::Bool(v));
412 }
413 if let Some(v) = entry.info.annotations.idempotent_hint {
414 annotations.insert("idempotentHint".into(), serde_json::Value::Bool(v));
415 }
416 if let Some(v) = entry.info.annotations.open_world_hint {
417 annotations.insert("openWorldHint".into(), serde_json::Value::Bool(v));
418 }
419
420 let mut value = serde_json::json!({
421 "name": entry.info.name,
422 "description": entry.info.description,
423 "inputSchema": entry.input_schema,
424 });
425 let obj = value.as_object_mut().expect("json! object literal");
427
428 if let Some(title) = &entry.info.title {
429 obj.insert("title".into(), serde_json::Value::String(title.to_string()));
430 }
431 if !annotations.is_empty() {
432 obj.insert("annotations".into(), serde_json::Value::Object(annotations));
433 }
434 if !entry.info.icons.is_empty() {
435 let icons: Vec<_> = entry
436 .info
437 .icons
438 .iter()
439 .map(|icon| {
440 serde_json::json!({
441 "src": icon.src,
442 "mimeType": icon.mime_type,
443 "sizes": icon.sizes,
444 "theme": icon.theme
445 })
446 })
447 .collect();
448 obj.insert("icons".into(), serde_json::Value::Array(icons));
449 }
450 if let Some(output_schema) = &entry.output_schema {
451 let schema = normalize_output_schema(output_schema);
455 obj.insert("outputSchema".into(), schema);
456 }
457 value
458 })
459 .collect();
460
461 let end = start.saturating_add(page.len());
462
463 let mut result = serde_json::json!({ "tools": page });
465 if end < tools.len() && result.is_object() {
466 result
468 .as_object_mut()
469 .expect("json! object literal")
470 .insert(
471 "nextCursor".into(),
472 serde_json::Value::String(end.to_string()),
473 );
474 }
475
476 (StatusCode::OK, Json(json_rpc_success(id, result))).into_response()
477}
478
479fn normalize_output_schema(schema: &Value) -> Value {
482 let type_str = schema.get("type").and_then(Value::as_str).unwrap_or("");
483 if type_str == "object" {
484 return schema.clone();
485 }
486
487 let mut wrapper = serde_json::json!({
489 "type": "object",
490 "properties": {
491 "result": schema
492 }
493 });
494
495 if let (Some(s), Some(obj)) = (schema.get("$schema"), wrapper.as_object_mut()) {
497 obj.insert("$schema".into(), s.clone());
498 }
499 if let (Some(d), Some(obj)) = (schema.get("definitions"), wrapper.as_object_mut()) {
500 obj.insert("definitions".into(), d.clone());
501 if let Some(inner) = wrapper.pointer_mut("/properties/result") {
503 inner.as_object_mut().map(|o| o.remove("definitions"));
504 }
505 }
506
507 wrapper
508}
509
510async fn handle_tools_call(
511 state: &Arc<McpState>,
512 id: Option<Value>,
513 params: &Value,
514 auth: &AuthContext,
515 request_metadata: RequestMetadata,
516) -> Response {
517 let Some(tool_name) = params.get("name").and_then(Value::as_str) else {
518 return (
519 StatusCode::OK,
520 Json(json_rpc_error(id, -32602, "Missing tool name", None)),
521 )
522 .into_response();
523 };
524
525 let Some(entry) = state.registry.get(tool_name) else {
526 return (
527 StatusCode::OK,
528 Json(json_rpc_error(id, -32602, "Unknown tool", None)),
529 )
530 .into_response();
531 };
532
533 if !entry.info.is_public && !auth.is_authenticated() {
534 if state.config.oauth {
535 let mut response = (
537 StatusCode::UNAUTHORIZED,
538 Json(json_rpc_error(id, -32001, "Authentication required", None)),
539 )
540 .into_response();
541 response.headers_mut().insert(
542 "WWW-Authenticate",
543 axum::http::header::HeaderValue::from_static(
544 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\"",
545 ),
546 );
547 return response;
548 }
549 return (
550 StatusCode::OK,
551 Json(json_rpc_error(id, -32001, "Authentication required", None)),
552 )
553 .into_response();
554 }
555 if let Some(role) = entry.info.required_role
556 && !auth.has_role(role)
557 {
558 return (
559 StatusCode::OK,
560 Json(json_rpc_error(
561 id,
562 -32003,
563 format!("Role '{}' required", role),
564 None,
565 )),
566 )
567 .into_response();
568 }
569
570 if let (Some(requests), Some(per_secs)) = (
571 entry.info.rate_limit_requests,
572 entry.info.rate_limit_per_secs,
573 ) {
574 let key_type = entry
575 .info
576 .rate_limit_key
577 .and_then(|k| k.parse::<RateLimitKey>().ok())
578 .unwrap_or_default();
579
580 let config = forge_core::RateLimitConfig::new(requests, Duration::from_secs(per_secs))
581 .with_key(key_type);
582 let bucket_key = state
583 .rate_limiter
584 .build_key(key_type, tool_name, auth, &request_metadata);
585
586 if let Err(e) = state.rate_limiter.enforce(&bucket_key, &config).await {
587 return (
588 StatusCode::OK,
589 Json(json_rpc_error(id, -32029, e.to_string(), None)),
590 )
591 .into_response();
592 }
593 }
594
595 let args = params
596 .get("arguments")
597 .cloned()
598 .unwrap_or(Value::Object(Default::default()));
599
600 let ctx = McpToolContext::with_dispatch(
601 state.pool.clone(),
602 auth.clone(),
603 request_metadata,
604 state.job_dispatcher.clone(),
605 state.workflow_dispatcher.clone(),
606 );
607
608 let result = if let Some(timeout_secs) = entry.info.timeout {
609 match tokio::time::timeout(
610 Duration::from_secs(timeout_secs),
611 (entry.handler)(&ctx, args),
612 )
613 .await
614 {
615 Ok(inner) => inner,
616 Err(_) => {
617 return (
618 StatusCode::OK,
619 Json(json_rpc_error(id, -32000, "Tool timed out", None)),
620 )
621 .into_response();
622 }
623 }
624 } else {
625 (entry.handler)(&ctx, args).await
626 };
627
628 match result {
629 Ok(output) => {
630 let result = tool_success_result(output);
631 (
632 StatusCode::OK,
633 Json(json_rpc_success(id, serde_json::json!(result))),
634 )
635 .into_response()
636 }
637 Err(e) => match e {
638 forge_core::ForgeError::Validation(msg)
639 | forge_core::ForgeError::InvalidArgument(msg) => (
640 StatusCode::OK,
641 Json(json_rpc_success(
642 id,
643 serde_json::json!({
644 "content": [{ "type": "text", "text": msg }],
645 "isError": true
646 }),
647 )),
648 )
649 .into_response(),
650 forge_core::ForgeError::Unauthorized(msg) => {
651 (StatusCode::OK, Json(json_rpc_error(id, -32001, msg, None))).into_response()
652 }
653 forge_core::ForgeError::Forbidden(msg) => {
654 (StatusCode::OK, Json(json_rpc_error(id, -32003, msg, None))).into_response()
655 }
656 _ => (
657 StatusCode::OK,
658 Json(json_rpc_error(id, -32603, "Internal server error", None)),
659 )
660 .into_response(),
661 },
662 }
663}
664
665fn tool_success_result(output: Value) -> Value {
666 match output {
667 Value::Object(_) => serde_json::json!({
668 "content": [{
669 "type": "text",
670 "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string())
671 }],
672 "structuredContent": output
673 }),
674 Value::String(text) => serde_json::json!({
675 "content": [{ "type": "text", "text": text }]
676 }),
677 other => serde_json::json!({
678 "content": [{
679 "type": "text",
680 "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string())
681 }]
682 }),
683 }
684}
685
686async fn required_session_id(
687 state: &Arc<McpState>,
688 headers: &HeaderMap,
689 require_initialized: bool,
690) -> std::result::Result<String, Response> {
691 let Some(session_id) = headers
692 .get(MCP_SESSION_HEADER)
693 .and_then(|v| v.to_str().ok())
694 else {
695 return Err((
696 StatusCode::BAD_REQUEST,
697 Json(json_rpc_error(
698 None,
699 -32600,
700 "Missing MCP-Session-Id header",
701 None,
702 )),
703 )
704 .into_response());
705 };
706
707 let sessions = state.sessions.read().await;
708 match sessions.get(session_id) {
709 Some(session) => {
710 if !SUPPORTED_VERSIONS.contains(&session.protocol_version.as_str()) {
711 return Err((
712 StatusCode::BAD_REQUEST,
713 Json(json_rpc_error(
714 None,
715 -32600,
716 "Session protocol version mismatch",
717 None,
718 )),
719 )
720 .into_response());
721 }
722 if require_initialized && !session.initialized {
723 return Err((
724 StatusCode::BAD_REQUEST,
725 Json(json_rpc_error(
726 None,
727 -32600,
728 "MCP session is not initialized",
729 None,
730 )),
731 )
732 .into_response());
733 }
734 Ok(session_id.to_string())
735 }
736 None => Err((
737 StatusCode::BAD_REQUEST,
738 Json(json_rpc_error(
739 None,
740 -32600,
741 "Unknown MCP session. Re-initialize.",
742 None,
743 )),
744 )
745 .into_response()),
746 }
747}
748
749fn validate_origin(
750 headers: &HeaderMap,
751 config: &McpConfig,
752) -> std::result::Result<(), ResponseError> {
753 let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else {
754 return Ok(());
755 };
756
757 if config.allowed_origins.is_empty() {
760 return Err(Box::new(
761 (
762 StatusCode::FORBIDDEN,
763 Json(json_rpc_error(
764 None,
765 -32600,
766 "Cross-origin requests require allowed_origins to be configured",
767 None,
768 )),
769 )
770 .into_response(),
771 ));
772 }
773
774 let allowed = config
775 .allowed_origins
776 .iter()
777 .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin));
778 if allowed {
779 return Ok(());
780 }
781
782 Err(Box::new(
783 (
784 StatusCode::FORBIDDEN,
785 Json(json_rpc_error(None, -32600, "Invalid Origin header", None)),
786 )
787 .into_response(),
788 ))
789}
790
791fn enforce_protocol_header(
792 config: &McpConfig,
793 headers: &HeaderMap,
794) -> std::result::Result<(), ResponseError> {
795 if !config.require_protocol_version_header {
796 return Ok(());
797 }
798
799 let Some(version) = headers
800 .get(MCP_PROTOCOL_HEADER)
801 .and_then(|v| v.to_str().ok())
802 else {
803 return Err(Box::new(
804 (
805 StatusCode::BAD_REQUEST,
806 Json(json_rpc_error(
807 None,
808 -32600,
809 "Missing MCP-Protocol-Version header",
810 None,
811 )),
812 )
813 .into_response(),
814 ));
815 };
816
817 if !SUPPORTED_VERSIONS.contains(&version) {
818 return Err(Box::new(
819 (
820 StatusCode::BAD_REQUEST,
821 Json(json_rpc_error(
822 None,
823 -32600,
824 "Unsupported MCP-Protocol-Version",
825 Some(serde_json::json!({ "supported": SUPPORTED_VERSIONS })),
826 )),
827 )
828 .into_response(),
829 ));
830 }
831
832 Ok(())
833}
834
835use super::extract_client_ip;
836
837fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
838 headers
839 .get(axum::http::header::USER_AGENT)
840 .and_then(|v| v.to_str().ok())
841 .map(String::from)
842}
843
844fn build_request_metadata(
845 tracing: &super::tracing::TracingState,
846 headers: &HeaderMap,
847) -> RequestMetadata {
848 RequestMetadata {
849 request_id: uuid::Uuid::parse_str(&tracing.request_id)
850 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
851 trace_id: tracing.trace_id.clone(),
852 client_ip: extract_client_ip(headers),
853 user_agent: extract_user_agent(headers),
854 correlation_id: None,
855 timestamp: chrono::Utc::now(),
856 }
857}
858
859fn json_rpc_success(id: Option<Value>, result: Value) -> Value {
860 serde_json::json!({
861 "jsonrpc": "2.0",
862 "id": id.unwrap_or(Value::Null),
863 "result": result
864 })
865}
866
867fn json_rpc_error(
868 id: Option<Value>,
869 code: i32,
870 message: impl Into<String>,
871 data: Option<Value>,
872) -> Value {
873 let mut error = serde_json::json!({
874 "code": code,
875 "message": message.into()
876 });
877 if let Some(data) = data
878 && let Some(obj) = error.as_object_mut()
879 {
880 obj.insert("data".to_string(), data);
881 }
882
883 serde_json::json!({
884 "jsonrpc": "2.0",
885 "id": id.unwrap_or(Value::Null),
886 "error": error
887 })
888}
889
890fn set_header(response: &mut Response<Body>, name: &str, value: &str) {
891 if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::from_str(value)) {
892 response.headers_mut().insert(name, value);
893 }
894}
895
896#[cfg(test)]
897#[allow(clippy::expect_used, clippy::indexing_slicing, clippy::unwrap_used)]
898mod tests {
899 use super::super::tracing::TracingState;
900 use super::*;
901 use axum::body::to_bytes;
902 use forge_core::function::AuthContext;
903 use forge_core::mcp::{ForgeMcpTool, McpToolAnnotations, McpToolInfo};
904 use forge_core::schemars::{self, JsonSchema};
905 use serde::{Deserialize, Serialize};
906 use std::collections::HashMap;
907 use std::future::Future;
908 use std::pin::Pin;
909
910 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
911 struct EchoArgs {
912 message: String,
913 }
914
915 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
916 struct EchoOutput {
917 echoed: String,
918 }
919
920 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
921 #[serde(rename_all = "snake_case")]
922 enum ExportFormat {
923 Json,
924 Csv,
925 }
926
927 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
928 struct MetadataArgs {
929 #[schemars(description = "Project UUID to export")]
930 project_id: String,
931 format: ExportFormat,
932 }
933
934 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
935 struct MetadataOutput {
936 accepted: bool,
937 }
938
939 struct EchoTool;
940
941 impl ForgeMcpTool for EchoTool {
942 type Args = EchoArgs;
943 type Output = EchoOutput;
944
945 fn info() -> McpToolInfo {
946 McpToolInfo {
947 name: "echo",
948 title: Some("Echo"),
949 description: Some("Echo back the message"),
950 required_role: None,
951 is_public: false,
952 timeout: None,
953 rate_limit_requests: None,
954 rate_limit_per_secs: None,
955 rate_limit_key: None,
956 annotations: McpToolAnnotations::default(),
957 icons: &[],
958 }
959 }
960
961 fn execute(
962 _ctx: &McpToolContext,
963 args: Self::Args,
964 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
965 Box::pin(async move {
966 Ok(EchoOutput {
967 echoed: args.message,
968 })
969 })
970 }
971 }
972
973 struct AdminTool;
974
975 impl ForgeMcpTool for AdminTool {
976 type Args = EchoArgs;
977 type Output = EchoOutput;
978
979 fn info() -> McpToolInfo {
980 McpToolInfo {
981 name: "admin.echo",
982 title: Some("Admin Echo"),
983 description: Some("Admin only echo"),
984 required_role: Some("admin"),
985 is_public: false,
986 timeout: None,
987 rate_limit_requests: None,
988 rate_limit_per_secs: None,
989 rate_limit_key: None,
990 annotations: McpToolAnnotations::default(),
991 icons: &[],
992 }
993 }
994
995 fn execute(
996 _ctx: &McpToolContext,
997 args: Self::Args,
998 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
999 Box::pin(async move {
1000 Ok(EchoOutput {
1001 echoed: args.message,
1002 })
1003 })
1004 }
1005 }
1006
1007 struct MetadataTool;
1008
1009 impl ForgeMcpTool for MetadataTool {
1010 type Args = MetadataArgs;
1011 type Output = MetadataOutput;
1012
1013 fn info() -> McpToolInfo {
1014 McpToolInfo {
1015 name: "export.project",
1016 title: Some("Export Project"),
1017 description: Some("Export project data"),
1018 required_role: None,
1019 is_public: false,
1020 timeout: None,
1021 rate_limit_requests: None,
1022 rate_limit_per_secs: None,
1023 rate_limit_key: None,
1024 annotations: McpToolAnnotations::default(),
1025 icons: &[],
1026 }
1027 }
1028
1029 fn execute(
1030 _ctx: &McpToolContext,
1031 _args: Self::Args,
1032 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
1033 Box::pin(async move { Ok(MetadataOutput { accepted: true }) })
1034 }
1035 }
1036
1037 #[test]
1038 fn test_json_rpc_helpers() {
1039 let success = json_rpc_success(
1040 Some(serde_json::json!(1)),
1041 serde_json::json!({ "ok": true }),
1042 );
1043 assert_eq!(success["jsonrpc"], "2.0");
1044 assert!(success.get("result").is_some());
1045
1046 let err = json_rpc_error(Some(serde_json::json!(1)), -32601, "not found", None);
1047 assert_eq!(err["error"]["code"], -32601);
1048 }
1049
1050 fn test_state(config: McpConfig) -> Arc<McpState> {
1051 test_state_with_registry(config, McpToolRegistry::new())
1052 }
1053
1054 fn test_state_with_registry(config: McpConfig, registry: McpToolRegistry) -> Arc<McpState> {
1055 let pool = sqlx::postgres::PgPoolOptions::new()
1056 .max_connections(1)
1057 .connect_lazy("postgres://localhost/nonexistent")
1058 .expect("lazy pool must build");
1059 Arc::new(McpState::new(config, registry, pool, None, None))
1060 }
1061
1062 async fn response_json(response: Response) -> Value {
1063 let bytes = to_bytes(response.into_body(), usize::MAX)
1064 .await
1065 .expect("body bytes");
1066 if bytes.is_empty() {
1067 return serde_json::json!({});
1068 }
1069 serde_json::from_slice(&bytes).expect("valid json")
1070 }
1071
1072 async fn initialize_session(state: Arc<McpState>) -> String {
1073 let payload = serde_json::json!({
1074 "jsonrpc": "2.0",
1075 "id": 1,
1076 "method": "initialize",
1077 "params": {
1078 "protocolVersion": "2025-11-25",
1079 "capabilities": {},
1080 "clientInfo": { "name": "test", "version": "1.0.0" }
1081 }
1082 });
1083 let response = mcp_post_handler(
1084 State(state),
1085 Extension(AuthContext::unauthenticated()),
1086 Extension(TracingState::new()),
1087 Method::POST,
1088 HeaderMap::new(),
1089 Json(payload),
1090 )
1091 .await;
1092
1093 assert_eq!(response.status(), StatusCode::OK);
1094 response
1095 .headers()
1096 .get(MCP_SESSION_HEADER)
1097 .and_then(|v| v.to_str().ok())
1098 .expect("session id must exist")
1099 .to_string()
1100 }
1101
1102 async fn mark_initialized(state: Arc<McpState>, headers: HeaderMap) {
1103 let payload = serde_json::json!({
1104 "jsonrpc": "2.0",
1105 "method": "notifications/initialized",
1106 "params": {}
1107 });
1108 let response = mcp_post_handler(
1109 State(state),
1110 Extension(AuthContext::unauthenticated()),
1111 Extension(TracingState::new()),
1112 Method::POST,
1113 headers,
1114 Json(payload),
1115 )
1116 .await;
1117 assert_eq!(response.status(), StatusCode::ACCEPTED);
1118 }
1119
1120 async fn initialized_headers(state: Arc<McpState>) -> HeaderMap {
1121 let session_id = initialize_session(state.clone()).await;
1122 let mut headers = HeaderMap::new();
1123 headers.insert(
1124 MCP_SESSION_HEADER,
1125 HeaderValue::from_str(&session_id).expect("valid session id header"),
1126 );
1127 headers.insert(
1128 MCP_PROTOCOL_HEADER,
1129 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1130 );
1131 mark_initialized(state, headers.clone()).await;
1132 headers
1133 }
1134
1135 #[tokio::test]
1136 async fn test_initialize_sets_session_header() {
1137 let state = test_state(McpConfig {
1138 enabled: true,
1139 ..Default::default()
1140 });
1141 let session = initialize_session(state).await;
1142 assert!(!session.is_empty());
1143 }
1144
1145 #[tokio::test]
1146 async fn test_initialize_rejects_unsupported_protocol_version() {
1147 let state = test_state(McpConfig {
1148 enabled: true,
1149 ..Default::default()
1150 });
1151 let payload = serde_json::json!({
1152 "jsonrpc": "2.0",
1153 "id": 1,
1154 "method": "initialize",
1155 "params": {
1156 "protocolVersion": "2024-01-01",
1157 "capabilities": {},
1158 "clientInfo": { "name": "test", "version": "1.0.0" }
1159 }
1160 });
1161
1162 let response = mcp_post_handler(
1163 State(state),
1164 Extension(AuthContext::unauthenticated()),
1165 Extension(TracingState::new()),
1166 Method::POST,
1167 HeaderMap::new(),
1168 Json(payload),
1169 )
1170 .await;
1171
1172 assert_eq!(response.status(), StatusCode::OK);
1173 let body = response_json(response).await;
1174 assert_eq!(body["error"]["code"], -32602);
1175 let supported = body["error"]["data"]["supported"]
1176 .as_array()
1177 .expect("supported versions array");
1178 assert!(
1179 supported
1180 .iter()
1181 .any(|value| value.as_str() == Some(MCP_PROTOCOL_VERSION))
1182 );
1183 }
1184
1185 #[tokio::test]
1186 async fn test_tools_list_requires_initialized_session() {
1187 let state = test_state(McpConfig {
1188 enabled: true,
1189 ..Default::default()
1190 });
1191
1192 let session_id = initialize_session(state.clone()).await;
1193
1194 let mut headers = HeaderMap::new();
1195 headers.insert(
1196 MCP_SESSION_HEADER,
1197 HeaderValue::from_str(&session_id).expect("valid"),
1198 );
1199 headers.insert(
1200 MCP_PROTOCOL_HEADER,
1201 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1202 );
1203
1204 let list_payload = serde_json::json!({
1205 "jsonrpc": "2.0",
1206 "id": 2,
1207 "method": "tools/list",
1208 "params": {}
1209 });
1210 let response = mcp_post_handler(
1211 State(state),
1212 Extension(AuthContext::unauthenticated()),
1213 Extension(TracingState::new()),
1214 Method::POST,
1215 headers,
1216 Json(list_payload),
1217 )
1218 .await;
1219
1220 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1221 }
1222
1223 #[tokio::test]
1224 async fn test_tools_list_returns_registered_tools() {
1225 let mut registry = McpToolRegistry::new();
1226 registry.register::<EchoTool>();
1227
1228 let state = test_state_with_registry(
1229 McpConfig {
1230 enabled: true,
1231 ..Default::default()
1232 },
1233 registry,
1234 );
1235 let headers = initialized_headers(state.clone()).await;
1236 let payload = serde_json::json!({
1237 "jsonrpc": "2.0",
1238 "id": 2,
1239 "method": "tools/list",
1240 "params": {}
1241 });
1242
1243 let response = mcp_post_handler(
1244 State(state),
1245 Extension(AuthContext::unauthenticated()),
1246 Extension(TracingState::new()),
1247 Method::POST,
1248 headers,
1249 Json(payload),
1250 )
1251 .await;
1252
1253 assert_eq!(response.status(), StatusCode::OK);
1254 let body = response_json(response).await;
1255 let tools = body["result"]["tools"]
1256 .as_array()
1257 .expect("tools list should be array");
1258 assert_eq!(tools.len(), 1);
1259 assert_eq!(tools[0]["name"], "echo");
1260 assert!(tools[0].get("inputSchema").is_some());
1261 assert!(tools[0].get("outputSchema").is_some());
1262 }
1263
1264 #[tokio::test]
1265 async fn test_tools_list_exposes_parameter_metadata() {
1266 let mut registry = McpToolRegistry::new();
1267 registry.register::<MetadataTool>();
1268
1269 let state = test_state_with_registry(
1270 McpConfig {
1271 enabled: true,
1272 ..Default::default()
1273 },
1274 registry,
1275 );
1276 let headers = initialized_headers(state.clone()).await;
1277 let payload = serde_json::json!({
1278 "jsonrpc": "2.0",
1279 "id": 9,
1280 "method": "tools/list",
1281 "params": {}
1282 });
1283
1284 let response = mcp_post_handler(
1285 State(state),
1286 Extension(AuthContext::unauthenticated()),
1287 Extension(TracingState::new()),
1288 Method::POST,
1289 headers,
1290 Json(payload),
1291 )
1292 .await;
1293
1294 assert_eq!(response.status(), StatusCode::OK);
1295 let body = response_json(response).await;
1296 let tools = body["result"]["tools"]
1297 .as_array()
1298 .expect("tools list should be array");
1299 assert_eq!(tools.len(), 1);
1300
1301 let input_schema = &tools[0]["inputSchema"];
1302 assert_eq!(
1303 input_schema["properties"]["project_id"]["description"],
1304 "Project UUID to export"
1305 );
1306
1307 let schema_text = input_schema.to_string();
1308 assert!(schema_text.contains("\"json\""));
1309 assert!(schema_text.contains("\"csv\""));
1310 }
1311
1312 #[tokio::test]
1313 async fn test_tools_call_success_returns_structured_content() {
1314 let mut registry = McpToolRegistry::new();
1315 registry.register::<EchoTool>();
1316
1317 let state = test_state_with_registry(
1318 McpConfig {
1319 enabled: true,
1320 ..Default::default()
1321 },
1322 registry,
1323 );
1324 let headers = initialized_headers(state.clone()).await;
1325 let auth = AuthContext::authenticated(
1326 uuid::Uuid::new_v4(),
1327 vec!["member".to_string()],
1328 HashMap::new(),
1329 );
1330 let payload = serde_json::json!({
1331 "jsonrpc": "2.0",
1332 "id": 3,
1333 "method": "tools/call",
1334 "params": {
1335 "name": "echo",
1336 "arguments": { "message": "hello" }
1337 }
1338 });
1339
1340 let response = mcp_post_handler(
1341 State(state),
1342 Extension(auth),
1343 Extension(TracingState::new()),
1344 Method::POST,
1345 headers,
1346 Json(payload),
1347 )
1348 .await;
1349
1350 assert_eq!(response.status(), StatusCode::OK);
1351 let body = response_json(response).await;
1352 assert_eq!(body["result"]["structuredContent"]["echoed"], "hello");
1353 assert_eq!(body["result"]["content"][0]["type"], "text");
1354 }
1355
1356 #[tokio::test]
1357 async fn test_tools_call_validation_failure_returns_is_error() {
1358 let mut registry = McpToolRegistry::new();
1359 registry.register::<EchoTool>();
1360
1361 let state = test_state_with_registry(
1362 McpConfig {
1363 enabled: true,
1364 ..Default::default()
1365 },
1366 registry,
1367 );
1368 let headers = initialized_headers(state.clone()).await;
1369 let auth = AuthContext::authenticated(
1370 uuid::Uuid::new_v4(),
1371 vec!["member".to_string()],
1372 HashMap::new(),
1373 );
1374 let payload = serde_json::json!({
1375 "jsonrpc": "2.0",
1376 "id": 4,
1377 "method": "tools/call",
1378 "params": {
1379 "name": "echo",
1380 "arguments": {}
1381 }
1382 });
1383
1384 let response = mcp_post_handler(
1385 State(state),
1386 Extension(auth),
1387 Extension(TracingState::new()),
1388 Method::POST,
1389 headers,
1390 Json(payload),
1391 )
1392 .await;
1393
1394 assert_eq!(response.status(), StatusCode::OK);
1395 let body = response_json(response).await;
1396 assert_eq!(body["result"]["isError"], true);
1397 }
1398
1399 #[tokio::test]
1400 async fn test_tools_call_requires_authentication() {
1401 let mut registry = McpToolRegistry::new();
1402 registry.register::<EchoTool>();
1403
1404 let state = test_state_with_registry(
1405 McpConfig {
1406 enabled: true,
1407 ..Default::default()
1408 },
1409 registry,
1410 );
1411 let headers = initialized_headers(state.clone()).await;
1412 let payload = serde_json::json!({
1413 "jsonrpc": "2.0",
1414 "id": 5,
1415 "method": "tools/call",
1416 "params": {
1417 "name": "echo",
1418 "arguments": { "message": "hello" }
1419 }
1420 });
1421
1422 let response = mcp_post_handler(
1423 State(state),
1424 Extension(AuthContext::unauthenticated()),
1425 Extension(TracingState::new()),
1426 Method::POST,
1427 headers,
1428 Json(payload),
1429 )
1430 .await;
1431
1432 assert_eq!(response.status(), StatusCode::OK);
1433 let body = response_json(response).await;
1434 assert_eq!(body["error"]["code"], -32001);
1435 }
1436
1437 #[tokio::test]
1438 async fn test_tools_call_requires_role() {
1439 let mut registry = McpToolRegistry::new();
1440 registry.register::<AdminTool>();
1441
1442 let state = test_state_with_registry(
1443 McpConfig {
1444 enabled: true,
1445 ..Default::default()
1446 },
1447 registry,
1448 );
1449 let headers = initialized_headers(state.clone()).await;
1450 let auth = AuthContext::authenticated(
1451 uuid::Uuid::new_v4(),
1452 vec!["member".to_string()],
1453 HashMap::new(),
1454 );
1455 let payload = serde_json::json!({
1456 "jsonrpc": "2.0",
1457 "id": 6,
1458 "method": "tools/call",
1459 "params": {
1460 "name": "admin.echo",
1461 "arguments": { "message": "hello" }
1462 }
1463 });
1464
1465 let response = mcp_post_handler(
1466 State(state),
1467 Extension(auth),
1468 Extension(TracingState::new()),
1469 Method::POST,
1470 headers,
1471 Json(payload),
1472 )
1473 .await;
1474
1475 assert_eq!(response.status(), StatusCode::OK);
1476 let body = response_json(response).await;
1477 assert_eq!(body["error"]["code"], -32003);
1478 }
1479
1480 #[tokio::test]
1481 async fn test_invalid_protocol_header_returns_400() {
1482 let state = test_state(McpConfig {
1483 enabled: true,
1484 ..Default::default()
1485 });
1486 let session_id = initialize_session(state.clone()).await;
1487 let mut headers = HeaderMap::new();
1488 headers.insert(
1489 MCP_SESSION_HEADER,
1490 HeaderValue::from_str(&session_id).expect("valid"),
1491 );
1492 headers.insert(
1493 MCP_PROTOCOL_HEADER,
1494 HeaderValue::from_static("invalid-version"),
1495 );
1496
1497 let payload = serde_json::json!({
1498 "jsonrpc": "2.0",
1499 "id": 7,
1500 "method": "tools/list",
1501 "params": {}
1502 });
1503
1504 let response = mcp_post_handler(
1505 State(state),
1506 Extension(AuthContext::unauthenticated()),
1507 Extension(TracingState::new()),
1508 Method::POST,
1509 headers,
1510 Json(payload),
1511 )
1512 .await;
1513 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1514 }
1515
1516 #[tokio::test]
1517 async fn test_expired_session_is_rejected_after_cleanup() {
1518 let state = test_state(McpConfig {
1519 enabled: true,
1520 ..Default::default()
1521 });
1522 let session_id = "expired-session".to_string();
1523 {
1524 let mut sessions = state.sessions.write().await;
1525 sessions.insert(
1526 session_id.clone(),
1527 McpSession {
1528 initialized: true,
1529 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
1530 expires_at: Instant::now() - Duration::from_secs(1),
1531 },
1532 );
1533 }
1534
1535 let mut headers = HeaderMap::new();
1536 headers.insert(
1537 MCP_SESSION_HEADER,
1538 HeaderValue::from_str(&session_id).expect("valid session id"),
1539 );
1540 headers.insert(
1541 MCP_PROTOCOL_HEADER,
1542 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1543 );
1544
1545 let payload = serde_json::json!({
1546 "jsonrpc": "2.0",
1547 "id": 10,
1548 "method": "tools/list",
1549 "params": {}
1550 });
1551
1552 let response = mcp_post_handler(
1553 State(state),
1554 Extension(AuthContext::unauthenticated()),
1555 Extension(TracingState::new()),
1556 Method::POST,
1557 headers,
1558 Json(payload),
1559 )
1560 .await;
1561
1562 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1563 let body = response_json(response).await;
1564 assert_eq!(body["error"]["code"], -32600);
1565 assert_eq!(
1566 body["error"]["message"],
1567 "Unknown MCP session. Re-initialize."
1568 );
1569 }
1570
1571 #[tokio::test]
1572 async fn test_missing_protocol_header_returns_400() {
1573 let state = test_state(McpConfig {
1574 enabled: true,
1575 ..Default::default()
1576 });
1577 let session_id = initialize_session(state.clone()).await;
1578 let mut headers = HeaderMap::new();
1579 headers.insert(
1580 MCP_SESSION_HEADER,
1581 HeaderValue::from_str(&session_id).expect("valid"),
1582 );
1583
1584 let payload = serde_json::json!({
1585 "jsonrpc": "2.0",
1586 "id": 8,
1587 "method": "tools/list",
1588 "params": {}
1589 });
1590
1591 let response = mcp_post_handler(
1592 State(state),
1593 Extension(AuthContext::unauthenticated()),
1594 Extension(TracingState::new()),
1595 Method::POST,
1596 headers,
1597 Json(payload),
1598 )
1599 .await;
1600 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1601 }
1602
1603 #[tokio::test]
1604 async fn test_notifications_return_202() {
1605 let state = test_state(McpConfig {
1606 enabled: true,
1607 ..Default::default()
1608 });
1609 let mut headers = HeaderMap::new();
1610 headers.insert(
1611 MCP_PROTOCOL_HEADER,
1612 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1613 );
1614 let payload = serde_json::json!({
1615 "jsonrpc": "2.0",
1616 "method": "notifications/tools/list_changed",
1617 "params": {}
1618 });
1619 let response = mcp_post_handler(
1620 State(state),
1621 Extension(AuthContext::unauthenticated()),
1622 Extension(TracingState::new()),
1623 Method::POST,
1624 headers,
1625 Json(payload),
1626 )
1627 .await;
1628 assert_eq!(response.status(), StatusCode::ACCEPTED);
1629 }
1630
1631 #[tokio::test]
1632 async fn test_invalid_origin_rejected() {
1633 let state = test_state(McpConfig {
1634 enabled: true,
1635 allowed_origins: vec!["https://allowed.example".to_string()],
1636 ..Default::default()
1637 });
1638 let payload = serde_json::json!({
1639 "jsonrpc": "2.0",
1640 "id": 1,
1641 "method": "initialize",
1642 "params": {
1643 "protocolVersion": "2025-11-25",
1644 "capabilities": {},
1645 "clientInfo": { "name": "test", "version": "1.0.0" }
1646 }
1647 });
1648
1649 let mut headers = HeaderMap::new();
1650 headers.insert("origin", HeaderValue::from_static("https://evil.example"));
1651
1652 let response = mcp_post_handler(
1653 State(state),
1654 Extension(AuthContext::unauthenticated()),
1655 Extension(TracingState::new()),
1656 Method::POST,
1657 headers,
1658 Json(payload),
1659 )
1660 .await;
1661
1662 assert_eq!(response.status(), StatusCode::FORBIDDEN);
1663 }
1664}