Skip to main content

forge_runtime/gateway/
mcp.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use axum::extract::{Extension, State};
8use axum::http::header::{HeaderName, HeaderValue};
9use axum::http::{HeaderMap, Method, StatusCode};
10use axum::response::IntoResponse;
11use axum::response::Response;
12use axum::response::sse::{Event, KeepAlive, Sse};
13use axum::{Json, body::Body};
14use forge_core::config::McpConfig;
15use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
16use forge_core::mcp::McpToolContext;
17use forge_core::rate_limit::RateLimitKey;
18use futures_util::Stream;
19use serde_json::Value;
20use tokio::sync::RwLock;
21
22use crate::mcp::McpToolRegistry;
23use crate::rate_limit::RateLimiter;
24
25const SUPPORTED_VERSIONS: &[&str] = &["2025-11-25", "2025-03-26", "2024-11-05"];
26#[cfg(test)]
27const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
28const MCP_SESSION_HEADER: &str = "mcp-session-id";
29const MCP_PROTOCOL_HEADER: &str = "mcp-protocol-version";
30const DEFAULT_PAGE_SIZE: usize = 50;
31const MAX_MCP_SESSIONS: usize = 10_000;
32type ResponseError = Box<Response>;
33
34#[derive(Debug, Clone)]
35struct McpSession {
36    initialized: bool,
37    protocol_version: String,
38    expires_at: Instant,
39}
40
41#[derive(Clone)]
42pub struct McpState {
43    config: McpConfig,
44    registry: McpToolRegistry,
45    pool: sqlx::PgPool,
46    sessions: Arc<RwLock<HashMap<String, McpSession>>>,
47    job_dispatcher: Option<Arc<dyn JobDispatch>>,
48    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49    rate_limiter: Arc<RateLimiter>,
50}
51
52impl McpState {
53    pub fn new(
54        config: McpConfig,
55        registry: McpToolRegistry,
56        pool: sqlx::PgPool,
57        job_dispatcher: Option<Arc<dyn JobDispatch>>,
58        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
59    ) -> Self {
60        Self {
61            config,
62            registry,
63            pool: pool.clone(),
64            sessions: Arc::new(RwLock::new(HashMap::new())),
65            job_dispatcher,
66            workflow_dispatcher,
67            rate_limiter: Arc::new(RateLimiter::new(pool)),
68        }
69    }
70
71    async fn cleanup_expired_sessions(&self) {
72        let mut sessions = self.sessions.write().await;
73        let now = Instant::now();
74        sessions.retain(|_, session| session.expires_at > now);
75    }
76
77    async fn touch_session(&self, session_id: &str) {
78        let mut sessions = self.sessions.write().await;
79        if let Some(session) = sessions.get_mut(session_id) {
80            session.expires_at = Instant::now() + Duration::from_secs(self.config.session_ttl_secs);
81        }
82    }
83
84    fn session_ttl(&self) -> Duration {
85        Duration::from_secs(self.config.session_ttl_secs)
86    }
87}
88
89/// Wraps an mpsc::Receiver as a Stream for MCP SSE.
90struct McpReceiverStream {
91    rx: tokio::sync::mpsc::Receiver<Result<Event, std::convert::Infallible>>,
92}
93
94impl Stream for McpReceiverStream {
95    type Item = Result<Event, std::convert::Infallible>;
96    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97        self.rx.poll_recv(cx)
98    }
99}
100
101/// MCP Streamable HTTP GET handler.
102///
103/// Opens a Server-Sent Events stream for server-initiated messages.
104/// Clients use this to receive notifications and asynchronous responses
105/// from the MCP server. The stream starts with an `endpoint` event
106/// containing the session ID, then sends keepalive pings every 30 seconds.
107pub async fn mcp_get_handler(State(state): State<Arc<McpState>>, headers: HeaderMap) -> Response {
108    if let Err(resp) = validate_origin(&headers, &state.config) {
109        return *resp;
110    }
111
112    if let Err(resp) = enforce_protocol_header(&state.config, &headers) {
113        return *resp;
114    }
115
116    // Require a valid session (created via POST initialize)
117    let session_id = match required_session_id(&state, &headers, true).await {
118        Ok(v) => v,
119        Err(resp) => return resp,
120    };
121
122    state.touch_session(&session_id).await;
123
124    // Create a channel for server-to-client messages
125    let (tx, rx) = tokio::sync::mpsc::channel::<Result<Event, std::convert::Infallible>>(32);
126
127    // Send the initial endpoint event with the session binding
128    let session_id_clone = session_id.clone();
129    tokio::spawn(async move {
130        let endpoint_data = serde_json::json!({
131            "sessionId": session_id_clone,
132        });
133        let _ = tx
134            .send(Ok(Event::default()
135                .event("endpoint")
136                .data(endpoint_data.to_string())))
137            .await;
138
139        // Keep channel open until client disconnects.
140        // The SSE keepalive mechanism handles pings.
141        // When the client disconnects, the tx will be dropped.
142        loop {
143            tokio::time::sleep(Duration::from_secs(30)).await;
144            if tx.is_closed() {
145                break;
146            }
147        }
148    });
149
150    let stream = McpReceiverStream { rx };
151
152    let mut response = Sse::new(stream)
153        .keep_alive(KeepAlive::new().interval(Duration::from_secs(30)))
154        .into_response();
155
156    // Set MCP session header on the response
157    if let Ok(val) = HeaderValue::from_str(&session_id) {
158        response
159            .headers_mut()
160            .insert(HeaderName::from_static(MCP_SESSION_HEADER), val);
161    }
162
163    response
164}
165
166pub async fn mcp_post_handler(
167    State(state): State<Arc<McpState>>,
168    Extension(auth): Extension<AuthContext>,
169    Extension(tracing): Extension<super::tracing::TracingState>,
170    method: Method,
171    headers: HeaderMap,
172    Json(payload): Json<Value>,
173) -> Response {
174    if method != Method::POST {
175        return (
176            StatusCode::METHOD_NOT_ALLOWED,
177            Json(json_rpc_error(None, -32601, "Only POST is supported", None)),
178        )
179            .into_response();
180    }
181
182    if let Err(resp) = validate_origin(&headers, &state.config) {
183        return *resp;
184    }
185
186    state.cleanup_expired_sessions().await;
187
188    let Some(method_name) = payload.get("method").and_then(Value::as_str) else {
189        // Notifications / responses sent by client should get 202 when accepted.
190        if payload.get("id").is_some()
191            && (payload.get("result").is_some() || payload.get("error").is_some())
192        {
193            return StatusCode::ACCEPTED.into_response();
194        }
195        return (
196            StatusCode::BAD_REQUEST,
197            Json(json_rpc_error(
198                None,
199                -32600,
200                "Invalid JSON-RPC payload",
201                None,
202            )),
203        )
204            .into_response();
205    };
206
207    let id = payload.get("id").cloned();
208    let params = payload
209        .get("params")
210        .cloned()
211        .unwrap_or(Value::Object(Default::default()));
212
213    // Notification flow
214    if id.is_none() {
215        return handle_notification(&state, method_name, params, &headers).await;
216    }
217
218    // Request flow
219    if method_name != "initialize"
220        && let Err(resp) = enforce_protocol_header(&state.config, &headers)
221    {
222        return *resp;
223    }
224
225    match method_name {
226        "initialize" => handle_initialize(&state, id, &params).await,
227        "tools/list" => {
228            let session_id = match required_session_id(&state, &headers, true).await {
229                Ok(v) => v,
230                Err(resp) => return resp,
231            };
232            state.touch_session(&session_id).await;
233            handle_tools_list(&state, id, &params)
234        }
235        "tools/call" => {
236            let session_id = match required_session_id(&state, &headers, true).await {
237                Ok(v) => v,
238                Err(resp) => return resp,
239            };
240            state.touch_session(&session_id).await;
241
242            let metadata = build_request_metadata(&tracing, &headers);
243            handle_tools_call(&state, id, &params, &auth, metadata).await
244        }
245        _ => (
246            StatusCode::OK,
247            Json(json_rpc_error(id, -32601, "Method not found", None)),
248        )
249            .into_response(),
250    }
251}
252
253async fn handle_notification(
254    state: &Arc<McpState>,
255    method_name: &str,
256    _params: Value,
257    headers: &HeaderMap,
258) -> Response {
259    if let Err(resp) = enforce_protocol_header(&state.config, headers) {
260        return *resp;
261    }
262
263    match method_name {
264        "notifications/initialized" => {
265            let session_id = match required_session_id(state, headers, false).await {
266                Ok(v) => v,
267                Err(resp) => return resp,
268            };
269
270            let mut sessions = state.sessions.write().await;
271            if let Some(session) = sessions.get_mut(&session_id) {
272                session.initialized = true;
273                session.expires_at = Instant::now() + state.session_ttl();
274                return StatusCode::ACCEPTED.into_response();
275            }
276
277            (
278                StatusCode::BAD_REQUEST,
279                Json(json_rpc_error(
280                    None,
281                    -32600,
282                    "Unknown MCP session. Re-initialize the connection.",
283                    None,
284                )),
285            )
286                .into_response()
287        }
288        _ => StatusCode::ACCEPTED.into_response(),
289    }
290}
291
292async fn handle_initialize(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
293    let Some(requested_version) = params.get("protocolVersion").and_then(Value::as_str) else {
294        return (
295            StatusCode::OK,
296            Json(json_rpc_error(
297                id,
298                -32602,
299                "Missing protocolVersion in initialize params",
300                None,
301            )),
302        )
303            .into_response();
304    };
305
306    if !SUPPORTED_VERSIONS.contains(&requested_version) {
307        return (
308            StatusCode::OK,
309            Json(json_rpc_error(
310                id,
311                -32602,
312                "Unsupported protocolVersion",
313                Some(serde_json::json!({
314                    "supported": SUPPORTED_VERSIONS
315                })),
316            )),
317        )
318            .into_response();
319    }
320
321    let session_id = uuid::Uuid::new_v4().to_string();
322    {
323        let mut sessions = state.sessions.write().await;
324        // Enforce session limit to prevent memory exhaustion DoS
325        if sessions.len() >= MAX_MCP_SESSIONS {
326            return (
327                StatusCode::SERVICE_UNAVAILABLE,
328                Json(json_rpc_error(
329                    id,
330                    -32000,
331                    "Server at MCP session capacity",
332                    None,
333                )),
334            )
335                .into_response();
336        }
337        sessions.insert(
338            session_id.clone(),
339            McpSession {
340                initialized: false,
341                protocol_version: requested_version.to_string(),
342                expires_at: Instant::now() + state.session_ttl(),
343            },
344        );
345    }
346
347    let mut response = (
348        StatusCode::OK,
349        Json(json_rpc_success(
350            id,
351            serde_json::json!({
352                "protocolVersion": requested_version,
353                "capabilities": {
354                    "tools": {
355                        "listChanged": false
356                    }
357                },
358                "serverInfo": {
359                    "name": "forge",
360                    "version": env!("CARGO_PKG_VERSION")
361                }
362            }),
363        )),
364    )
365        .into_response();
366
367    set_header(&mut response, MCP_SESSION_HEADER, &session_id);
368    set_header(&mut response, MCP_PROTOCOL_HEADER, requested_version);
369    response
370}
371
372fn handle_tools_list(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
373    let cursor = params.get("cursor").and_then(Value::as_str);
374    let start = match cursor {
375        Some(c) => match c.parse::<usize>() {
376            Ok(v) => v,
377            Err(_) => {
378                return (
379                    StatusCode::OK,
380                    Json(json_rpc_error(
381                        id,
382                        -32602,
383                        "Invalid cursor in tools/list request",
384                        None,
385                    )),
386                )
387                    .into_response();
388            }
389        },
390        None => 0,
391    };
392
393    let mut tools: Vec<_> = state.registry.list().collect();
394    tools.sort_by(|a, b| a.info.name.cmp(b.info.name));
395
396    let page: Vec<_> = tools
397        .iter()
398        .skip(start)
399        .take(DEFAULT_PAGE_SIZE)
400        .map(|entry| {
401            // Only include annotation fields that are set (non-null).
402            // Claude Code's MCP client rejects null booleans.
403            let mut annotations = serde_json::Map::new();
404            if let Some(title) = &entry.info.annotations.title {
405                annotations.insert("title".into(), serde_json::Value::String(title.to_string()));
406            }
407            if let Some(v) = entry.info.annotations.read_only_hint {
408                annotations.insert("readOnlyHint".into(), serde_json::Value::Bool(v));
409            }
410            if let Some(v) = entry.info.annotations.destructive_hint {
411                annotations.insert("destructiveHint".into(), serde_json::Value::Bool(v));
412            }
413            if let Some(v) = entry.info.annotations.idempotent_hint {
414                annotations.insert("idempotentHint".into(), serde_json::Value::Bool(v));
415            }
416            if let Some(v) = entry.info.annotations.open_world_hint {
417                annotations.insert("openWorldHint".into(), serde_json::Value::Bool(v));
418            }
419
420            let mut value = serde_json::json!({
421                "name": entry.info.name,
422                "description": entry.info.description,
423                "inputSchema": entry.input_schema,
424            });
425            // json!({}) always produces Value::Object
426            let obj = value.as_object_mut().expect("json! object literal");
427
428            if let Some(title) = &entry.info.title {
429                obj.insert("title".into(), serde_json::Value::String(title.to_string()));
430            }
431            if !annotations.is_empty() {
432                obj.insert("annotations".into(), serde_json::Value::Object(annotations));
433            }
434            if !entry.info.icons.is_empty() {
435                let icons: Vec<_> = entry
436                    .info
437                    .icons
438                    .iter()
439                    .map(|icon| {
440                        serde_json::json!({
441                            "src": icon.src,
442                            "mimeType": icon.mime_type,
443                            "sizes": icon.sizes,
444                            "theme": icon.theme
445                        })
446                    })
447                    .collect();
448                obj.insert("icons".into(), serde_json::Value::Array(icons));
449            }
450            if let Some(output_schema) = &entry.output_schema {
451                // MCP spec requires outputSchema to have type: "object".
452                // schemars may generate type: "array" (Vec<T>) or anyOf (Option<T>).
453                // Wrap non-object schemas so they conform.
454                let schema = normalize_output_schema(output_schema);
455                obj.insert("outputSchema".into(), schema);
456            }
457            value
458        })
459        .collect();
460
461    let end = start.saturating_add(page.len());
462
463    // Build result: omit nextCursor when null (Claude Code expects string or absent)
464    let mut result = serde_json::json!({ "tools": page });
465    if end < tools.len() && result.is_object() {
466        // json!({}) always produces Value::Object
467        result
468            .as_object_mut()
469            .expect("json! object literal")
470            .insert(
471                "nextCursor".into(),
472                serde_json::Value::String(end.to_string()),
473            );
474    }
475
476    (StatusCode::OK, Json(json_rpc_success(id, result))).into_response()
477}
478
479/// MCP spec requires outputSchema to be `type: "object"`. Wrap schemas that
480/// schemars generates as arrays or union types (anyOf/oneOf for Option<T>).
481fn normalize_output_schema(schema: &Value) -> Value {
482    let type_str = schema.get("type").and_then(Value::as_str).unwrap_or("");
483    if type_str == "object" {
484        return schema.clone();
485    }
486
487    // Wrap non-object schemas: put the original under a "result" property
488    let mut wrapper = serde_json::json!({
489        "type": "object",
490        "properties": {
491            "result": schema
492        }
493    });
494
495    // Hoist $schema and definitions to the wrapper level
496    if let (Some(s), Some(obj)) = (schema.get("$schema"), wrapper.as_object_mut()) {
497        obj.insert("$schema".into(), s.clone());
498    }
499    if let (Some(d), Some(obj)) = (schema.get("definitions"), wrapper.as_object_mut()) {
500        obj.insert("definitions".into(), d.clone());
501        // Remove from the nested copy to avoid duplication
502        if let Some(inner) = wrapper.pointer_mut("/properties/result") {
503            inner.as_object_mut().map(|o| o.remove("definitions"));
504        }
505    }
506
507    wrapper
508}
509
510async fn handle_tools_call(
511    state: &Arc<McpState>,
512    id: Option<Value>,
513    params: &Value,
514    auth: &AuthContext,
515    request_metadata: RequestMetadata,
516) -> Response {
517    let Some(tool_name) = params.get("name").and_then(Value::as_str) else {
518        return (
519            StatusCode::OK,
520            Json(json_rpc_error(id, -32602, "Missing tool name", None)),
521        )
522            .into_response();
523    };
524
525    let Some(entry) = state.registry.get(tool_name) else {
526        return (
527            StatusCode::OK,
528            Json(json_rpc_error(id, -32602, "Unknown tool", None)),
529        )
530            .into_response();
531    };
532
533    if !entry.info.is_public && !auth.is_authenticated() {
534        if state.config.oauth {
535            // Return HTTP 401 with discovery header so MCP clients trigger OAuth flow
536            let mut response = (
537                StatusCode::UNAUTHORIZED,
538                Json(json_rpc_error(id, -32001, "Authentication required", None)),
539            )
540                .into_response();
541            response.headers_mut().insert(
542                "WWW-Authenticate",
543                axum::http::header::HeaderValue::from_static(
544                    "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\"",
545                ),
546            );
547            return response;
548        }
549        return (
550            StatusCode::OK,
551            Json(json_rpc_error(id, -32001, "Authentication required", None)),
552        )
553            .into_response();
554    }
555    if let Some(role) = entry.info.required_role
556        && !auth.has_role(role)
557    {
558        return (
559            StatusCode::OK,
560            Json(json_rpc_error(
561                id,
562                -32003,
563                format!("Role '{}' required", role),
564                None,
565            )),
566        )
567            .into_response();
568    }
569
570    if let (Some(requests), Some(per_secs)) = (
571        entry.info.rate_limit_requests,
572        entry.info.rate_limit_per_secs,
573    ) {
574        let key_type = entry
575            .info
576            .rate_limit_key
577            .and_then(|k| k.parse::<RateLimitKey>().ok())
578            .unwrap_or_default();
579
580        let config = forge_core::RateLimitConfig::new(requests, Duration::from_secs(per_secs))
581            .with_key(key_type);
582        let bucket_key = state
583            .rate_limiter
584            .build_key(key_type, tool_name, auth, &request_metadata);
585
586        if let Err(e) = state.rate_limiter.enforce(&bucket_key, &config).await {
587            return (
588                StatusCode::OK,
589                Json(json_rpc_error(id, -32029, e.to_string(), None)),
590            )
591                .into_response();
592        }
593    }
594
595    let args = params
596        .get("arguments")
597        .cloned()
598        .unwrap_or(Value::Object(Default::default()));
599
600    let ctx = McpToolContext::with_dispatch(
601        state.pool.clone(),
602        auth.clone(),
603        request_metadata,
604        state.job_dispatcher.clone(),
605        state.workflow_dispatcher.clone(),
606    );
607
608    let result = if let Some(timeout_secs) = entry.info.timeout {
609        match tokio::time::timeout(
610            Duration::from_secs(timeout_secs),
611            (entry.handler)(&ctx, args),
612        )
613        .await
614        {
615            Ok(inner) => inner,
616            Err(_) => {
617                return (
618                    StatusCode::OK,
619                    Json(json_rpc_error(id, -32000, "Tool timed out", None)),
620                )
621                    .into_response();
622            }
623        }
624    } else {
625        (entry.handler)(&ctx, args).await
626    };
627
628    match result {
629        Ok(output) => {
630            let result = tool_success_result(output);
631            (
632                StatusCode::OK,
633                Json(json_rpc_success(id, serde_json::json!(result))),
634            )
635                .into_response()
636        }
637        Err(e) => match e {
638            forge_core::ForgeError::Validation(msg)
639            | forge_core::ForgeError::InvalidArgument(msg) => (
640                StatusCode::OK,
641                Json(json_rpc_success(
642                    id,
643                    serde_json::json!({
644                        "content": [{ "type": "text", "text": msg }],
645                        "isError": true
646                    }),
647                )),
648            )
649                .into_response(),
650            forge_core::ForgeError::Unauthorized(msg) => {
651                (StatusCode::OK, Json(json_rpc_error(id, -32001, msg, None))).into_response()
652            }
653            forge_core::ForgeError::Forbidden(msg) => {
654                (StatusCode::OK, Json(json_rpc_error(id, -32003, msg, None))).into_response()
655            }
656            _ => (
657                StatusCode::OK,
658                Json(json_rpc_error(id, -32603, "Internal server error", None)),
659            )
660                .into_response(),
661        },
662    }
663}
664
665fn tool_success_result(output: Value) -> Value {
666    match output {
667        Value::Object(_) => serde_json::json!({
668            "content": [{
669                "type": "text",
670                "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string())
671            }],
672            "structuredContent": output
673        }),
674        Value::String(text) => serde_json::json!({
675            "content": [{ "type": "text", "text": text }]
676        }),
677        other => serde_json::json!({
678            "content": [{
679                "type": "text",
680                "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string())
681            }]
682        }),
683    }
684}
685
686async fn required_session_id(
687    state: &Arc<McpState>,
688    headers: &HeaderMap,
689    require_initialized: bool,
690) -> std::result::Result<String, Response> {
691    let Some(session_id) = headers
692        .get(MCP_SESSION_HEADER)
693        .and_then(|v| v.to_str().ok())
694    else {
695        return Err((
696            StatusCode::BAD_REQUEST,
697            Json(json_rpc_error(
698                None,
699                -32600,
700                "Missing MCP-Session-Id header",
701                None,
702            )),
703        )
704            .into_response());
705    };
706
707    let sessions = state.sessions.read().await;
708    match sessions.get(session_id) {
709        Some(session) => {
710            if !SUPPORTED_VERSIONS.contains(&session.protocol_version.as_str()) {
711                return Err((
712                    StatusCode::BAD_REQUEST,
713                    Json(json_rpc_error(
714                        None,
715                        -32600,
716                        "Session protocol version mismatch",
717                        None,
718                    )),
719                )
720                    .into_response());
721            }
722            if require_initialized && !session.initialized {
723                return Err((
724                    StatusCode::BAD_REQUEST,
725                    Json(json_rpc_error(
726                        None,
727                        -32600,
728                        "MCP session is not initialized",
729                        None,
730                    )),
731                )
732                    .into_response());
733            }
734            Ok(session_id.to_string())
735        }
736        None => Err((
737            StatusCode::BAD_REQUEST,
738            Json(json_rpc_error(
739                None,
740                -32600,
741                "Unknown MCP session. Re-initialize.",
742                None,
743            )),
744        )
745            .into_response()),
746    }
747}
748
749fn validate_origin(
750    headers: &HeaderMap,
751    config: &McpConfig,
752) -> std::result::Result<(), ResponseError> {
753    let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else {
754        return Ok(());
755    };
756
757    // When no allowed_origins are configured, reject cross-origin requests
758    // rather than allowing all origins (secure by default)
759    if config.allowed_origins.is_empty() {
760        return Err(Box::new(
761            (
762                StatusCode::FORBIDDEN,
763                Json(json_rpc_error(
764                    None,
765                    -32600,
766                    "Cross-origin requests require allowed_origins to be configured",
767                    None,
768                )),
769            )
770                .into_response(),
771        ));
772    }
773
774    let allowed = config
775        .allowed_origins
776        .iter()
777        .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin));
778    if allowed {
779        return Ok(());
780    }
781
782    Err(Box::new(
783        (
784            StatusCode::FORBIDDEN,
785            Json(json_rpc_error(None, -32600, "Invalid Origin header", None)),
786        )
787            .into_response(),
788    ))
789}
790
791fn enforce_protocol_header(
792    config: &McpConfig,
793    headers: &HeaderMap,
794) -> std::result::Result<(), ResponseError> {
795    if !config.require_protocol_version_header {
796        return Ok(());
797    }
798
799    let Some(version) = headers
800        .get(MCP_PROTOCOL_HEADER)
801        .and_then(|v| v.to_str().ok())
802    else {
803        return Err(Box::new(
804            (
805                StatusCode::BAD_REQUEST,
806                Json(json_rpc_error(
807                    None,
808                    -32600,
809                    "Missing MCP-Protocol-Version header",
810                    None,
811                )),
812            )
813                .into_response(),
814        ));
815    };
816
817    if !SUPPORTED_VERSIONS.contains(&version) {
818        return Err(Box::new(
819            (
820                StatusCode::BAD_REQUEST,
821                Json(json_rpc_error(
822                    None,
823                    -32600,
824                    "Unsupported MCP-Protocol-Version",
825                    Some(serde_json::json!({ "supported": SUPPORTED_VERSIONS })),
826                )),
827            )
828                .into_response(),
829        ));
830    }
831
832    Ok(())
833}
834
835use super::extract_client_ip;
836
837fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
838    headers
839        .get(axum::http::header::USER_AGENT)
840        .and_then(|v| v.to_str().ok())
841        .map(String::from)
842}
843
844fn build_request_metadata(
845    tracing: &super::tracing::TracingState,
846    headers: &HeaderMap,
847) -> RequestMetadata {
848    RequestMetadata {
849        request_id: uuid::Uuid::parse_str(&tracing.request_id)
850            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
851        trace_id: tracing.trace_id.clone(),
852        client_ip: extract_client_ip(headers),
853        user_agent: extract_user_agent(headers),
854        correlation_id: None,
855        timestamp: chrono::Utc::now(),
856    }
857}
858
859fn json_rpc_success(id: Option<Value>, result: Value) -> Value {
860    serde_json::json!({
861        "jsonrpc": "2.0",
862        "id": id.unwrap_or(Value::Null),
863        "result": result
864    })
865}
866
867fn json_rpc_error(
868    id: Option<Value>,
869    code: i32,
870    message: impl Into<String>,
871    data: Option<Value>,
872) -> Value {
873    let mut error = serde_json::json!({
874        "code": code,
875        "message": message.into()
876    });
877    if let Some(data) = data
878        && let Some(obj) = error.as_object_mut()
879    {
880        obj.insert("data".to_string(), data);
881    }
882
883    serde_json::json!({
884        "jsonrpc": "2.0",
885        "id": id.unwrap_or(Value::Null),
886        "error": error
887    })
888}
889
890fn set_header(response: &mut Response<Body>, name: &str, value: &str) {
891    if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::from_str(value)) {
892        response.headers_mut().insert(name, value);
893    }
894}
895
896#[cfg(test)]
897#[allow(clippy::expect_used, clippy::indexing_slicing, clippy::unwrap_used)]
898mod tests {
899    use super::super::tracing::TracingState;
900    use super::*;
901    use axum::body::to_bytes;
902    use forge_core::function::AuthContext;
903    use forge_core::mcp::{ForgeMcpTool, McpToolAnnotations, McpToolInfo};
904    use forge_core::schemars::{self, JsonSchema};
905    use serde::{Deserialize, Serialize};
906    use std::collections::HashMap;
907    use std::future::Future;
908    use std::pin::Pin;
909
910    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
911    struct EchoArgs {
912        message: String,
913    }
914
915    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
916    struct EchoOutput {
917        echoed: String,
918    }
919
920    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
921    #[serde(rename_all = "snake_case")]
922    enum ExportFormat {
923        Json,
924        Csv,
925    }
926
927    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
928    struct MetadataArgs {
929        #[schemars(description = "Project UUID to export")]
930        project_id: String,
931        format: ExportFormat,
932    }
933
934    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
935    struct MetadataOutput {
936        accepted: bool,
937    }
938
939    struct EchoTool;
940
941    impl ForgeMcpTool for EchoTool {
942        type Args = EchoArgs;
943        type Output = EchoOutput;
944
945        fn info() -> McpToolInfo {
946            McpToolInfo {
947                name: "echo",
948                title: Some("Echo"),
949                description: Some("Echo back the message"),
950                required_role: None,
951                is_public: false,
952                timeout: None,
953                rate_limit_requests: None,
954                rate_limit_per_secs: None,
955                rate_limit_key: None,
956                annotations: McpToolAnnotations::default(),
957                icons: &[],
958            }
959        }
960
961        fn execute(
962            _ctx: &McpToolContext,
963            args: Self::Args,
964        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
965            Box::pin(async move {
966                Ok(EchoOutput {
967                    echoed: args.message,
968                })
969            })
970        }
971    }
972
973    struct AdminTool;
974
975    impl ForgeMcpTool for AdminTool {
976        type Args = EchoArgs;
977        type Output = EchoOutput;
978
979        fn info() -> McpToolInfo {
980            McpToolInfo {
981                name: "admin.echo",
982                title: Some("Admin Echo"),
983                description: Some("Admin only echo"),
984                required_role: Some("admin"),
985                is_public: false,
986                timeout: None,
987                rate_limit_requests: None,
988                rate_limit_per_secs: None,
989                rate_limit_key: None,
990                annotations: McpToolAnnotations::default(),
991                icons: &[],
992            }
993        }
994
995        fn execute(
996            _ctx: &McpToolContext,
997            args: Self::Args,
998        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
999            Box::pin(async move {
1000                Ok(EchoOutput {
1001                    echoed: args.message,
1002                })
1003            })
1004        }
1005    }
1006
1007    struct MetadataTool;
1008
1009    impl ForgeMcpTool for MetadataTool {
1010        type Args = MetadataArgs;
1011        type Output = MetadataOutput;
1012
1013        fn info() -> McpToolInfo {
1014            McpToolInfo {
1015                name: "export.project",
1016                title: Some("Export Project"),
1017                description: Some("Export project data"),
1018                required_role: None,
1019                is_public: false,
1020                timeout: None,
1021                rate_limit_requests: None,
1022                rate_limit_per_secs: None,
1023                rate_limit_key: None,
1024                annotations: McpToolAnnotations::default(),
1025                icons: &[],
1026            }
1027        }
1028
1029        fn execute(
1030            _ctx: &McpToolContext,
1031            _args: Self::Args,
1032        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
1033            Box::pin(async move { Ok(MetadataOutput { accepted: true }) })
1034        }
1035    }
1036
1037    #[test]
1038    fn test_json_rpc_helpers() {
1039        let success = json_rpc_success(
1040            Some(serde_json::json!(1)),
1041            serde_json::json!({ "ok": true }),
1042        );
1043        assert_eq!(success["jsonrpc"], "2.0");
1044        assert!(success.get("result").is_some());
1045
1046        let err = json_rpc_error(Some(serde_json::json!(1)), -32601, "not found", None);
1047        assert_eq!(err["error"]["code"], -32601);
1048    }
1049
1050    fn test_state(config: McpConfig) -> Arc<McpState> {
1051        test_state_with_registry(config, McpToolRegistry::new())
1052    }
1053
1054    fn test_state_with_registry(config: McpConfig, registry: McpToolRegistry) -> Arc<McpState> {
1055        let pool = sqlx::postgres::PgPoolOptions::new()
1056            .max_connections(1)
1057            .connect_lazy("postgres://localhost/nonexistent")
1058            .expect("lazy pool must build");
1059        Arc::new(McpState::new(config, registry, pool, None, None))
1060    }
1061
1062    async fn response_json(response: Response) -> Value {
1063        let bytes = to_bytes(response.into_body(), usize::MAX)
1064            .await
1065            .expect("body bytes");
1066        if bytes.is_empty() {
1067            return serde_json::json!({});
1068        }
1069        serde_json::from_slice(&bytes).expect("valid json")
1070    }
1071
1072    async fn initialize_session(state: Arc<McpState>) -> String {
1073        let payload = serde_json::json!({
1074            "jsonrpc": "2.0",
1075            "id": 1,
1076            "method": "initialize",
1077            "params": {
1078                "protocolVersion": "2025-11-25",
1079                "capabilities": {},
1080                "clientInfo": { "name": "test", "version": "1.0.0" }
1081            }
1082        });
1083        let response = mcp_post_handler(
1084            State(state),
1085            Extension(AuthContext::unauthenticated()),
1086            Extension(TracingState::new()),
1087            Method::POST,
1088            HeaderMap::new(),
1089            Json(payload),
1090        )
1091        .await;
1092
1093        assert_eq!(response.status(), StatusCode::OK);
1094        response
1095            .headers()
1096            .get(MCP_SESSION_HEADER)
1097            .and_then(|v| v.to_str().ok())
1098            .expect("session id must exist")
1099            .to_string()
1100    }
1101
1102    async fn mark_initialized(state: Arc<McpState>, headers: HeaderMap) {
1103        let payload = serde_json::json!({
1104            "jsonrpc": "2.0",
1105            "method": "notifications/initialized",
1106            "params": {}
1107        });
1108        let response = mcp_post_handler(
1109            State(state),
1110            Extension(AuthContext::unauthenticated()),
1111            Extension(TracingState::new()),
1112            Method::POST,
1113            headers,
1114            Json(payload),
1115        )
1116        .await;
1117        assert_eq!(response.status(), StatusCode::ACCEPTED);
1118    }
1119
1120    async fn initialized_headers(state: Arc<McpState>) -> HeaderMap {
1121        let session_id = initialize_session(state.clone()).await;
1122        let mut headers = HeaderMap::new();
1123        headers.insert(
1124            MCP_SESSION_HEADER,
1125            HeaderValue::from_str(&session_id).expect("valid session id header"),
1126        );
1127        headers.insert(
1128            MCP_PROTOCOL_HEADER,
1129            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1130        );
1131        mark_initialized(state, headers.clone()).await;
1132        headers
1133    }
1134
1135    #[tokio::test]
1136    async fn test_initialize_sets_session_header() {
1137        let state = test_state(McpConfig {
1138            enabled: true,
1139            ..Default::default()
1140        });
1141        let session = initialize_session(state).await;
1142        assert!(!session.is_empty());
1143    }
1144
1145    #[tokio::test]
1146    async fn test_initialize_rejects_unsupported_protocol_version() {
1147        let state = test_state(McpConfig {
1148            enabled: true,
1149            ..Default::default()
1150        });
1151        let payload = serde_json::json!({
1152            "jsonrpc": "2.0",
1153            "id": 1,
1154            "method": "initialize",
1155            "params": {
1156                "protocolVersion": "2024-01-01",
1157                "capabilities": {},
1158                "clientInfo": { "name": "test", "version": "1.0.0" }
1159            }
1160        });
1161
1162        let response = mcp_post_handler(
1163            State(state),
1164            Extension(AuthContext::unauthenticated()),
1165            Extension(TracingState::new()),
1166            Method::POST,
1167            HeaderMap::new(),
1168            Json(payload),
1169        )
1170        .await;
1171
1172        assert_eq!(response.status(), StatusCode::OK);
1173        let body = response_json(response).await;
1174        assert_eq!(body["error"]["code"], -32602);
1175        let supported = body["error"]["data"]["supported"]
1176            .as_array()
1177            .expect("supported versions array");
1178        assert!(
1179            supported
1180                .iter()
1181                .any(|value| value.as_str() == Some(MCP_PROTOCOL_VERSION))
1182        );
1183    }
1184
1185    #[tokio::test]
1186    async fn test_tools_list_requires_initialized_session() {
1187        let state = test_state(McpConfig {
1188            enabled: true,
1189            ..Default::default()
1190        });
1191
1192        let session_id = initialize_session(state.clone()).await;
1193
1194        let mut headers = HeaderMap::new();
1195        headers.insert(
1196            MCP_SESSION_HEADER,
1197            HeaderValue::from_str(&session_id).expect("valid"),
1198        );
1199        headers.insert(
1200            MCP_PROTOCOL_HEADER,
1201            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1202        );
1203
1204        let list_payload = serde_json::json!({
1205            "jsonrpc": "2.0",
1206            "id": 2,
1207            "method": "tools/list",
1208            "params": {}
1209        });
1210        let response = mcp_post_handler(
1211            State(state),
1212            Extension(AuthContext::unauthenticated()),
1213            Extension(TracingState::new()),
1214            Method::POST,
1215            headers,
1216            Json(list_payload),
1217        )
1218        .await;
1219
1220        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1221    }
1222
1223    #[tokio::test]
1224    async fn test_tools_list_returns_registered_tools() {
1225        let mut registry = McpToolRegistry::new();
1226        registry.register::<EchoTool>();
1227
1228        let state = test_state_with_registry(
1229            McpConfig {
1230                enabled: true,
1231                ..Default::default()
1232            },
1233            registry,
1234        );
1235        let headers = initialized_headers(state.clone()).await;
1236        let payload = serde_json::json!({
1237            "jsonrpc": "2.0",
1238            "id": 2,
1239            "method": "tools/list",
1240            "params": {}
1241        });
1242
1243        let response = mcp_post_handler(
1244            State(state),
1245            Extension(AuthContext::unauthenticated()),
1246            Extension(TracingState::new()),
1247            Method::POST,
1248            headers,
1249            Json(payload),
1250        )
1251        .await;
1252
1253        assert_eq!(response.status(), StatusCode::OK);
1254        let body = response_json(response).await;
1255        let tools = body["result"]["tools"]
1256            .as_array()
1257            .expect("tools list should be array");
1258        assert_eq!(tools.len(), 1);
1259        assert_eq!(tools[0]["name"], "echo");
1260        assert!(tools[0].get("inputSchema").is_some());
1261        assert!(tools[0].get("outputSchema").is_some());
1262    }
1263
1264    #[tokio::test]
1265    async fn test_tools_list_exposes_parameter_metadata() {
1266        let mut registry = McpToolRegistry::new();
1267        registry.register::<MetadataTool>();
1268
1269        let state = test_state_with_registry(
1270            McpConfig {
1271                enabled: true,
1272                ..Default::default()
1273            },
1274            registry,
1275        );
1276        let headers = initialized_headers(state.clone()).await;
1277        let payload = serde_json::json!({
1278            "jsonrpc": "2.0",
1279            "id": 9,
1280            "method": "tools/list",
1281            "params": {}
1282        });
1283
1284        let response = mcp_post_handler(
1285            State(state),
1286            Extension(AuthContext::unauthenticated()),
1287            Extension(TracingState::new()),
1288            Method::POST,
1289            headers,
1290            Json(payload),
1291        )
1292        .await;
1293
1294        assert_eq!(response.status(), StatusCode::OK);
1295        let body = response_json(response).await;
1296        let tools = body["result"]["tools"]
1297            .as_array()
1298            .expect("tools list should be array");
1299        assert_eq!(tools.len(), 1);
1300
1301        let input_schema = &tools[0]["inputSchema"];
1302        assert_eq!(
1303            input_schema["properties"]["project_id"]["description"],
1304            "Project UUID to export"
1305        );
1306
1307        let schema_text = input_schema.to_string();
1308        assert!(schema_text.contains("\"json\""));
1309        assert!(schema_text.contains("\"csv\""));
1310    }
1311
1312    #[tokio::test]
1313    async fn test_tools_call_success_returns_structured_content() {
1314        let mut registry = McpToolRegistry::new();
1315        registry.register::<EchoTool>();
1316
1317        let state = test_state_with_registry(
1318            McpConfig {
1319                enabled: true,
1320                ..Default::default()
1321            },
1322            registry,
1323        );
1324        let headers = initialized_headers(state.clone()).await;
1325        let auth = AuthContext::authenticated(
1326            uuid::Uuid::new_v4(),
1327            vec!["member".to_string()],
1328            HashMap::new(),
1329        );
1330        let payload = serde_json::json!({
1331            "jsonrpc": "2.0",
1332            "id": 3,
1333            "method": "tools/call",
1334            "params": {
1335                "name": "echo",
1336                "arguments": { "message": "hello" }
1337            }
1338        });
1339
1340        let response = mcp_post_handler(
1341            State(state),
1342            Extension(auth),
1343            Extension(TracingState::new()),
1344            Method::POST,
1345            headers,
1346            Json(payload),
1347        )
1348        .await;
1349
1350        assert_eq!(response.status(), StatusCode::OK);
1351        let body = response_json(response).await;
1352        assert_eq!(body["result"]["structuredContent"]["echoed"], "hello");
1353        assert_eq!(body["result"]["content"][0]["type"], "text");
1354    }
1355
1356    #[tokio::test]
1357    async fn test_tools_call_validation_failure_returns_is_error() {
1358        let mut registry = McpToolRegistry::new();
1359        registry.register::<EchoTool>();
1360
1361        let state = test_state_with_registry(
1362            McpConfig {
1363                enabled: true,
1364                ..Default::default()
1365            },
1366            registry,
1367        );
1368        let headers = initialized_headers(state.clone()).await;
1369        let auth = AuthContext::authenticated(
1370            uuid::Uuid::new_v4(),
1371            vec!["member".to_string()],
1372            HashMap::new(),
1373        );
1374        let payload = serde_json::json!({
1375            "jsonrpc": "2.0",
1376            "id": 4,
1377            "method": "tools/call",
1378            "params": {
1379                "name": "echo",
1380                "arguments": {}
1381            }
1382        });
1383
1384        let response = mcp_post_handler(
1385            State(state),
1386            Extension(auth),
1387            Extension(TracingState::new()),
1388            Method::POST,
1389            headers,
1390            Json(payload),
1391        )
1392        .await;
1393
1394        assert_eq!(response.status(), StatusCode::OK);
1395        let body = response_json(response).await;
1396        assert_eq!(body["result"]["isError"], true);
1397    }
1398
1399    #[tokio::test]
1400    async fn test_tools_call_requires_authentication() {
1401        let mut registry = McpToolRegistry::new();
1402        registry.register::<EchoTool>();
1403
1404        let state = test_state_with_registry(
1405            McpConfig {
1406                enabled: true,
1407                ..Default::default()
1408            },
1409            registry,
1410        );
1411        let headers = initialized_headers(state.clone()).await;
1412        let payload = serde_json::json!({
1413            "jsonrpc": "2.0",
1414            "id": 5,
1415            "method": "tools/call",
1416            "params": {
1417                "name": "echo",
1418                "arguments": { "message": "hello" }
1419            }
1420        });
1421
1422        let response = mcp_post_handler(
1423            State(state),
1424            Extension(AuthContext::unauthenticated()),
1425            Extension(TracingState::new()),
1426            Method::POST,
1427            headers,
1428            Json(payload),
1429        )
1430        .await;
1431
1432        assert_eq!(response.status(), StatusCode::OK);
1433        let body = response_json(response).await;
1434        assert_eq!(body["error"]["code"], -32001);
1435    }
1436
1437    #[tokio::test]
1438    async fn test_tools_call_requires_role() {
1439        let mut registry = McpToolRegistry::new();
1440        registry.register::<AdminTool>();
1441
1442        let state = test_state_with_registry(
1443            McpConfig {
1444                enabled: true,
1445                ..Default::default()
1446            },
1447            registry,
1448        );
1449        let headers = initialized_headers(state.clone()).await;
1450        let auth = AuthContext::authenticated(
1451            uuid::Uuid::new_v4(),
1452            vec!["member".to_string()],
1453            HashMap::new(),
1454        );
1455        let payload = serde_json::json!({
1456            "jsonrpc": "2.0",
1457            "id": 6,
1458            "method": "tools/call",
1459            "params": {
1460                "name": "admin.echo",
1461                "arguments": { "message": "hello" }
1462            }
1463        });
1464
1465        let response = mcp_post_handler(
1466            State(state),
1467            Extension(auth),
1468            Extension(TracingState::new()),
1469            Method::POST,
1470            headers,
1471            Json(payload),
1472        )
1473        .await;
1474
1475        assert_eq!(response.status(), StatusCode::OK);
1476        let body = response_json(response).await;
1477        assert_eq!(body["error"]["code"], -32003);
1478    }
1479
1480    #[tokio::test]
1481    async fn test_invalid_protocol_header_returns_400() {
1482        let state = test_state(McpConfig {
1483            enabled: true,
1484            ..Default::default()
1485        });
1486        let session_id = initialize_session(state.clone()).await;
1487        let mut headers = HeaderMap::new();
1488        headers.insert(
1489            MCP_SESSION_HEADER,
1490            HeaderValue::from_str(&session_id).expect("valid"),
1491        );
1492        headers.insert(
1493            MCP_PROTOCOL_HEADER,
1494            HeaderValue::from_static("invalid-version"),
1495        );
1496
1497        let payload = serde_json::json!({
1498            "jsonrpc": "2.0",
1499            "id": 7,
1500            "method": "tools/list",
1501            "params": {}
1502        });
1503
1504        let response = mcp_post_handler(
1505            State(state),
1506            Extension(AuthContext::unauthenticated()),
1507            Extension(TracingState::new()),
1508            Method::POST,
1509            headers,
1510            Json(payload),
1511        )
1512        .await;
1513        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1514    }
1515
1516    #[tokio::test]
1517    async fn test_expired_session_is_rejected_after_cleanup() {
1518        let state = test_state(McpConfig {
1519            enabled: true,
1520            ..Default::default()
1521        });
1522        let session_id = "expired-session".to_string();
1523        {
1524            let mut sessions = state.sessions.write().await;
1525            sessions.insert(
1526                session_id.clone(),
1527                McpSession {
1528                    initialized: true,
1529                    protocol_version: MCP_PROTOCOL_VERSION.to_string(),
1530                    expires_at: Instant::now() - Duration::from_secs(1),
1531                },
1532            );
1533        }
1534
1535        let mut headers = HeaderMap::new();
1536        headers.insert(
1537            MCP_SESSION_HEADER,
1538            HeaderValue::from_str(&session_id).expect("valid session id"),
1539        );
1540        headers.insert(
1541            MCP_PROTOCOL_HEADER,
1542            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1543        );
1544
1545        let payload = serde_json::json!({
1546            "jsonrpc": "2.0",
1547            "id": 10,
1548            "method": "tools/list",
1549            "params": {}
1550        });
1551
1552        let response = mcp_post_handler(
1553            State(state),
1554            Extension(AuthContext::unauthenticated()),
1555            Extension(TracingState::new()),
1556            Method::POST,
1557            headers,
1558            Json(payload),
1559        )
1560        .await;
1561
1562        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1563        let body = response_json(response).await;
1564        assert_eq!(body["error"]["code"], -32600);
1565        assert_eq!(
1566            body["error"]["message"],
1567            "Unknown MCP session. Re-initialize."
1568        );
1569    }
1570
1571    #[tokio::test]
1572    async fn test_missing_protocol_header_returns_400() {
1573        let state = test_state(McpConfig {
1574            enabled: true,
1575            ..Default::default()
1576        });
1577        let session_id = initialize_session(state.clone()).await;
1578        let mut headers = HeaderMap::new();
1579        headers.insert(
1580            MCP_SESSION_HEADER,
1581            HeaderValue::from_str(&session_id).expect("valid"),
1582        );
1583
1584        let payload = serde_json::json!({
1585            "jsonrpc": "2.0",
1586            "id": 8,
1587            "method": "tools/list",
1588            "params": {}
1589        });
1590
1591        let response = mcp_post_handler(
1592            State(state),
1593            Extension(AuthContext::unauthenticated()),
1594            Extension(TracingState::new()),
1595            Method::POST,
1596            headers,
1597            Json(payload),
1598        )
1599        .await;
1600        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1601    }
1602
1603    #[tokio::test]
1604    async fn test_notifications_return_202() {
1605        let state = test_state(McpConfig {
1606            enabled: true,
1607            ..Default::default()
1608        });
1609        let mut headers = HeaderMap::new();
1610        headers.insert(
1611            MCP_PROTOCOL_HEADER,
1612            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1613        );
1614        let payload = serde_json::json!({
1615            "jsonrpc": "2.0",
1616            "method": "notifications/tools/list_changed",
1617            "params": {}
1618        });
1619        let response = mcp_post_handler(
1620            State(state),
1621            Extension(AuthContext::unauthenticated()),
1622            Extension(TracingState::new()),
1623            Method::POST,
1624            headers,
1625            Json(payload),
1626        )
1627        .await;
1628        assert_eq!(response.status(), StatusCode::ACCEPTED);
1629    }
1630
1631    #[tokio::test]
1632    async fn test_invalid_origin_rejected() {
1633        let state = test_state(McpConfig {
1634            enabled: true,
1635            allowed_origins: vec!["https://allowed.example".to_string()],
1636            ..Default::default()
1637        });
1638        let payload = serde_json::json!({
1639            "jsonrpc": "2.0",
1640            "id": 1,
1641            "method": "initialize",
1642            "params": {
1643                "protocolVersion": "2025-11-25",
1644                "capabilities": {},
1645                "clientInfo": { "name": "test", "version": "1.0.0" }
1646            }
1647        });
1648
1649        let mut headers = HeaderMap::new();
1650        headers.insert("origin", HeaderValue::from_static("https://evil.example"));
1651
1652        let response = mcp_post_handler(
1653            State(state),
1654            Extension(AuthContext::unauthenticated()),
1655            Extension(TracingState::new()),
1656            Method::POST,
1657            headers,
1658            Json(payload),
1659        )
1660        .await;
1661
1662        assert_eq!(response.status(), StatusCode::FORBIDDEN);
1663    }
1664}