Skip to main content

forge_runtime/gateway/
mcp.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use axum::extract::{Extension, State};
6use axum::http::header::{HeaderName, HeaderValue};
7use axum::http::{HeaderMap, Method, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::{Json, body::Body};
10use forge_core::config::McpConfig;
11use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
12use forge_core::mcp::McpToolContext;
13use forge_core::rate_limit::RateLimitKey;
14use serde_json::Value;
15use tokio::sync::RwLock;
16
17use crate::mcp::McpToolRegistry;
18use crate::rate_limit::RateLimiter;
19
20const SUPPORTED_VERSIONS: &[&str] = &["2025-11-25", "2025-03-26", "2024-11-05"];
21#[cfg(test)]
22const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
23const MCP_SESSION_HEADER: &str = "mcp-session-id";
24const MCP_PROTOCOL_HEADER: &str = "mcp-protocol-version";
25const DEFAULT_PAGE_SIZE: usize = 50;
26type ResponseError = Box<Response>;
27
28#[derive(Debug, Clone)]
29struct McpSession {
30    initialized: bool,
31    protocol_version: String,
32    expires_at: Instant,
33}
34
35#[derive(Clone)]
36pub struct McpState {
37    config: McpConfig,
38    registry: McpToolRegistry,
39    pool: sqlx::PgPool,
40    sessions: Arc<RwLock<HashMap<String, McpSession>>>,
41    job_dispatcher: Option<Arc<dyn JobDispatch>>,
42    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
43    rate_limiter: Arc<RateLimiter>,
44}
45
46impl McpState {
47    pub fn new(
48        config: McpConfig,
49        registry: McpToolRegistry,
50        pool: sqlx::PgPool,
51        job_dispatcher: Option<Arc<dyn JobDispatch>>,
52        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
53    ) -> Self {
54        Self {
55            config,
56            registry,
57            pool: pool.clone(),
58            sessions: Arc::new(RwLock::new(HashMap::new())),
59            job_dispatcher,
60            workflow_dispatcher,
61            rate_limiter: Arc::new(RateLimiter::new(pool)),
62        }
63    }
64
65    async fn cleanup_expired_sessions(&self) {
66        let mut sessions = self.sessions.write().await;
67        let now = Instant::now();
68        sessions.retain(|_, session| session.expires_at > now);
69    }
70
71    async fn touch_session(&self, session_id: &str) {
72        let mut sessions = self.sessions.write().await;
73        if let Some(session) = sessions.get_mut(session_id) {
74            session.expires_at = Instant::now() + Duration::from_secs(self.config.session_ttl_secs);
75        }
76    }
77
78    fn session_ttl(&self) -> Duration {
79        Duration::from_secs(self.config.session_ttl_secs)
80    }
81}
82
83pub async fn mcp_get_handler(State(state): State<Arc<McpState>>, headers: HeaderMap) -> Response {
84    if let Err(resp) = validate_origin(&headers, &state.config) {
85        return *resp;
86    }
87
88    (
89        StatusCode::METHOD_NOT_ALLOWED,
90        Json(json_rpc_error(
91            None,
92            -32601,
93            "GET stream is not supported in Forge MCP v1",
94            None,
95        )),
96    )
97        .into_response()
98}
99
100pub async fn mcp_post_handler(
101    State(state): State<Arc<McpState>>,
102    Extension(auth): Extension<AuthContext>,
103    Extension(tracing): Extension<super::tracing::TracingState>,
104    method: Method,
105    headers: HeaderMap,
106    Json(payload): Json<Value>,
107) -> Response {
108    if method != Method::POST {
109        return (
110            StatusCode::METHOD_NOT_ALLOWED,
111            Json(json_rpc_error(None, -32601, "Only POST is supported", None)),
112        )
113            .into_response();
114    }
115
116    if let Err(resp) = validate_origin(&headers, &state.config) {
117        return *resp;
118    }
119
120    state.cleanup_expired_sessions().await;
121
122    let Some(method_name) = payload.get("method").and_then(Value::as_str) else {
123        // Notifications / responses sent by client should get 202 when accepted.
124        if payload.get("id").is_some()
125            && (payload.get("result").is_some() || payload.get("error").is_some())
126        {
127            return StatusCode::ACCEPTED.into_response();
128        }
129        return (
130            StatusCode::BAD_REQUEST,
131            Json(json_rpc_error(
132                None,
133                -32600,
134                "Invalid JSON-RPC payload",
135                None,
136            )),
137        )
138            .into_response();
139    };
140
141    let id = payload.get("id").cloned();
142    let params = payload
143        .get("params")
144        .cloned()
145        .unwrap_or(Value::Object(Default::default()));
146
147    // Notification flow
148    if id.is_none() {
149        return handle_notification(&state, method_name, params, &headers).await;
150    }
151
152    // Request flow
153    if method_name != "initialize"
154        && let Err(resp) = enforce_protocol_header(&state.config, &headers)
155    {
156        return *resp;
157    }
158
159    match method_name {
160        "initialize" => handle_initialize(&state, id, &params).await,
161        "tools/list" => {
162            let session_id = match required_session_id(&state, &headers, true).await {
163                Ok(v) => v,
164                Err(resp) => return resp,
165            };
166            state.touch_session(&session_id).await;
167            handle_tools_list(&state, id, &params)
168        }
169        "tools/call" => {
170            let session_id = match required_session_id(&state, &headers, true).await {
171                Ok(v) => v,
172                Err(resp) => return resp,
173            };
174            state.touch_session(&session_id).await;
175
176            let metadata = build_request_metadata(&tracing, &headers);
177            handle_tools_call(&state, id, &params, &auth, metadata).await
178        }
179        _ => (
180            StatusCode::OK,
181            Json(json_rpc_error(id, -32601, "Method not found", None)),
182        )
183            .into_response(),
184    }
185}
186
187async fn handle_notification(
188    state: &Arc<McpState>,
189    method_name: &str,
190    _params: Value,
191    headers: &HeaderMap,
192) -> Response {
193    if let Err(resp) = enforce_protocol_header(&state.config, headers) {
194        return *resp;
195    }
196
197    match method_name {
198        "notifications/initialized" => {
199            let session_id = match required_session_id(state, headers, false).await {
200                Ok(v) => v,
201                Err(resp) => return resp,
202            };
203
204            let mut sessions = state.sessions.write().await;
205            if let Some(session) = sessions.get_mut(&session_id) {
206                session.initialized = true;
207                session.expires_at = Instant::now() + state.session_ttl();
208                return StatusCode::ACCEPTED.into_response();
209            }
210
211            (
212                StatusCode::BAD_REQUEST,
213                Json(json_rpc_error(
214                    None,
215                    -32600,
216                    "Unknown MCP session. Re-initialize the connection.",
217                    None,
218                )),
219            )
220                .into_response()
221        }
222        _ => StatusCode::ACCEPTED.into_response(),
223    }
224}
225
226async fn handle_initialize(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
227    let Some(requested_version) = params.get("protocolVersion").and_then(Value::as_str) else {
228        return (
229            StatusCode::OK,
230            Json(json_rpc_error(
231                id,
232                -32602,
233                "Missing protocolVersion in initialize params",
234                None,
235            )),
236        )
237            .into_response();
238    };
239
240    if !SUPPORTED_VERSIONS.contains(&requested_version) {
241        return (
242            StatusCode::OK,
243            Json(json_rpc_error(
244                id,
245                -32602,
246                "Unsupported protocolVersion",
247                Some(serde_json::json!({
248                    "supported": SUPPORTED_VERSIONS
249                })),
250            )),
251        )
252            .into_response();
253    }
254
255    let session_id = uuid::Uuid::new_v4().to_string();
256    {
257        let mut sessions = state.sessions.write().await;
258        sessions.insert(
259            session_id.clone(),
260            McpSession {
261                initialized: false,
262                protocol_version: requested_version.to_string(),
263                expires_at: Instant::now() + state.session_ttl(),
264            },
265        );
266    }
267
268    let mut response = (
269        StatusCode::OK,
270        Json(json_rpc_success(
271            id,
272            serde_json::json!({
273                "protocolVersion": requested_version,
274                "capabilities": {
275                    "tools": {
276                        "listChanged": false
277                    }
278                },
279                "serverInfo": {
280                    "name": "forge",
281                    "version": env!("CARGO_PKG_VERSION")
282                }
283            }),
284        )),
285    )
286        .into_response();
287
288    set_header(&mut response, MCP_SESSION_HEADER, &session_id);
289    set_header(&mut response, MCP_PROTOCOL_HEADER, requested_version);
290    response
291}
292
293fn handle_tools_list(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
294    let cursor = params.get("cursor").and_then(Value::as_str);
295    let start = match cursor {
296        Some(c) => match c.parse::<usize>() {
297            Ok(v) => v,
298            Err(_) => {
299                return (
300                    StatusCode::OK,
301                    Json(json_rpc_error(
302                        id,
303                        -32602,
304                        "Invalid cursor in tools/list request",
305                        None,
306                    )),
307                )
308                    .into_response();
309            }
310        },
311        None => 0,
312    };
313
314    let mut tools: Vec<_> = state.registry.list().collect();
315    tools.sort_by(|a, b| a.info.name.cmp(b.info.name));
316
317    let page: Vec<_> = tools
318        .iter()
319        .skip(start)
320        .take(DEFAULT_PAGE_SIZE)
321        .map(|entry| {
322            // Only include annotation fields that are set (non-null).
323            // Claude Code's MCP client rejects null booleans.
324            let mut annotations = serde_json::Map::new();
325            if let Some(title) = &entry.info.annotations.title {
326                annotations.insert("title".into(), serde_json::Value::String(title.to_string()));
327            }
328            if let Some(v) = entry.info.annotations.read_only_hint {
329                annotations.insert("readOnlyHint".into(), serde_json::Value::Bool(v));
330            }
331            if let Some(v) = entry.info.annotations.destructive_hint {
332                annotations.insert("destructiveHint".into(), serde_json::Value::Bool(v));
333            }
334            if let Some(v) = entry.info.annotations.idempotent_hint {
335                annotations.insert("idempotentHint".into(), serde_json::Value::Bool(v));
336            }
337            if let Some(v) = entry.info.annotations.open_world_hint {
338                annotations.insert("openWorldHint".into(), serde_json::Value::Bool(v));
339            }
340
341            let mut value = serde_json::json!({
342                "name": entry.info.name,
343                "description": entry.info.description,
344                "inputSchema": entry.input_schema,
345            });
346            // json!({}) always produces Value::Object
347            let obj = value.as_object_mut().expect("json! object literal");
348
349            if let Some(title) = &entry.info.title {
350                obj.insert("title".into(), serde_json::Value::String(title.to_string()));
351            }
352            if !annotations.is_empty() {
353                obj.insert("annotations".into(), serde_json::Value::Object(annotations));
354            }
355            if !entry.info.icons.is_empty() {
356                let icons: Vec<_> = entry
357                    .info
358                    .icons
359                    .iter()
360                    .map(|icon| {
361                        serde_json::json!({
362                            "src": icon.src,
363                            "mimeType": icon.mime_type,
364                            "sizes": icon.sizes,
365                            "theme": icon.theme
366                        })
367                    })
368                    .collect();
369                obj.insert("icons".into(), serde_json::Value::Array(icons));
370            }
371            if let Some(output_schema) = &entry.output_schema {
372                // MCP spec requires outputSchema to have type: "object".
373                // schemars may generate type: "array" (Vec<T>) or anyOf (Option<T>).
374                // Wrap non-object schemas so they conform.
375                let schema = normalize_output_schema(output_schema);
376                obj.insert("outputSchema".into(), schema);
377            }
378            value
379        })
380        .collect();
381
382    let end = start.saturating_add(page.len());
383
384    // Build result: omit nextCursor when null (Claude Code expects string or absent)
385    let mut result = serde_json::json!({ "tools": page });
386    if end < tools.len() && result.is_object() {
387        // json!({}) always produces Value::Object
388        result
389            .as_object_mut()
390            .expect("json! object literal")
391            .insert(
392                "nextCursor".into(),
393                serde_json::Value::String(end.to_string()),
394            );
395    }
396
397    (StatusCode::OK, Json(json_rpc_success(id, result))).into_response()
398}
399
400/// MCP spec requires outputSchema to be `type: "object"`. Wrap schemas that
401/// schemars generates as arrays or union types (anyOf/oneOf for Option<T>).
402fn normalize_output_schema(schema: &Value) -> Value {
403    let type_str = schema.get("type").and_then(Value::as_str).unwrap_or("");
404    if type_str == "object" {
405        return schema.clone();
406    }
407
408    // Wrap non-object schemas: put the original under a "result" property
409    let mut wrapper = serde_json::json!({
410        "type": "object",
411        "properties": {
412            "result": schema
413        }
414    });
415
416    // Hoist $schema and definitions to the wrapper level
417    if let (Some(s), Some(obj)) = (schema.get("$schema"), wrapper.as_object_mut()) {
418        obj.insert("$schema".into(), s.clone());
419    }
420    if let (Some(d), Some(obj)) = (schema.get("definitions"), wrapper.as_object_mut()) {
421        obj.insert("definitions".into(), d.clone());
422        // Remove from the nested copy to avoid duplication
423        if let Some(inner) = wrapper.pointer_mut("/properties/result") {
424            inner.as_object_mut().map(|o| o.remove("definitions"));
425        }
426    }
427
428    wrapper
429}
430
431async fn handle_tools_call(
432    state: &Arc<McpState>,
433    id: Option<Value>,
434    params: &Value,
435    auth: &AuthContext,
436    request_metadata: RequestMetadata,
437) -> Response {
438    let Some(tool_name) = params.get("name").and_then(Value::as_str) else {
439        return (
440            StatusCode::OK,
441            Json(json_rpc_error(id, -32602, "Missing tool name", None)),
442        )
443            .into_response();
444    };
445
446    let Some(entry) = state.registry.get(tool_name) else {
447        return (
448            StatusCode::OK,
449            Json(json_rpc_error(id, -32602, "Unknown tool", None)),
450        )
451            .into_response();
452    };
453
454    if !entry.info.is_public && !auth.is_authenticated() {
455        if state.config.oauth {
456            // Return HTTP 401 with discovery header so MCP clients trigger OAuth flow
457            let mut response = (
458                StatusCode::UNAUTHORIZED,
459                Json(json_rpc_error(id, -32001, "Authentication required", None)),
460            )
461                .into_response();
462            response.headers_mut().insert(
463                "WWW-Authenticate",
464                axum::http::header::HeaderValue::from_static(
465                    "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\"",
466                ),
467            );
468            return response;
469        }
470        return (
471            StatusCode::OK,
472            Json(json_rpc_error(id, -32001, "Authentication required", None)),
473        )
474            .into_response();
475    }
476    if let Some(role) = entry.info.required_role
477        && !auth.has_role(role)
478    {
479        return (
480            StatusCode::OK,
481            Json(json_rpc_error(
482                id,
483                -32003,
484                format!("Role '{}' required", role),
485                None,
486            )),
487        )
488            .into_response();
489    }
490
491    if let (Some(requests), Some(per_secs)) = (
492        entry.info.rate_limit_requests,
493        entry.info.rate_limit_per_secs,
494    ) {
495        let key_type = entry
496            .info
497            .rate_limit_key
498            .and_then(|k| k.parse::<RateLimitKey>().ok())
499            .unwrap_or_default();
500
501        let config = forge_core::RateLimitConfig::new(requests, Duration::from_secs(per_secs))
502            .with_key(key_type);
503        let bucket_key = state
504            .rate_limiter
505            .build_key(key_type, tool_name, auth, &request_metadata);
506
507        if let Err(e) = state.rate_limiter.enforce(&bucket_key, &config).await {
508            return (
509                StatusCode::OK,
510                Json(json_rpc_error(id, -32029, e.to_string(), None)),
511            )
512                .into_response();
513        }
514    }
515
516    let args = params
517        .get("arguments")
518        .cloned()
519        .unwrap_or(Value::Object(Default::default()));
520
521    let ctx = McpToolContext::with_dispatch(
522        state.pool.clone(),
523        auth.clone(),
524        request_metadata,
525        state.job_dispatcher.clone(),
526        state.workflow_dispatcher.clone(),
527    );
528
529    let result = if let Some(timeout_secs) = entry.info.timeout {
530        match tokio::time::timeout(
531            Duration::from_secs(timeout_secs),
532            (entry.handler)(&ctx, args),
533        )
534        .await
535        {
536            Ok(inner) => inner,
537            Err(_) => {
538                return (
539                    StatusCode::OK,
540                    Json(json_rpc_error(id, -32000, "Tool timed out", None)),
541                )
542                    .into_response();
543            }
544        }
545    } else {
546        (entry.handler)(&ctx, args).await
547    };
548
549    match result {
550        Ok(output) => {
551            let result = tool_success_result(output);
552            (
553                StatusCode::OK,
554                Json(json_rpc_success(id, serde_json::json!(result))),
555            )
556                .into_response()
557        }
558        Err(e) => match e {
559            forge_core::ForgeError::Validation(msg)
560            | forge_core::ForgeError::InvalidArgument(msg) => (
561                StatusCode::OK,
562                Json(json_rpc_success(
563                    id,
564                    serde_json::json!({
565                        "content": [{ "type": "text", "text": msg }],
566                        "isError": true
567                    }),
568                )),
569            )
570                .into_response(),
571            forge_core::ForgeError::Unauthorized(msg) => {
572                (StatusCode::OK, Json(json_rpc_error(id, -32001, msg, None))).into_response()
573            }
574            forge_core::ForgeError::Forbidden(msg) => {
575                (StatusCode::OK, Json(json_rpc_error(id, -32003, msg, None))).into_response()
576            }
577            _ => (
578                StatusCode::OK,
579                Json(json_rpc_error(id, -32603, "Internal server error", None)),
580            )
581                .into_response(),
582        },
583    }
584}
585
586fn tool_success_result(output: Value) -> Value {
587    match output {
588        Value::Object(_) => serde_json::json!({
589            "content": [{
590                "type": "text",
591                "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string())
592            }],
593            "structuredContent": output
594        }),
595        Value::String(text) => serde_json::json!({
596            "content": [{ "type": "text", "text": text }]
597        }),
598        other => serde_json::json!({
599            "content": [{
600                "type": "text",
601                "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string())
602            }]
603        }),
604    }
605}
606
607async fn required_session_id(
608    state: &Arc<McpState>,
609    headers: &HeaderMap,
610    require_initialized: bool,
611) -> std::result::Result<String, Response> {
612    let Some(session_id) = headers
613        .get(MCP_SESSION_HEADER)
614        .and_then(|v| v.to_str().ok())
615    else {
616        return Err((
617            StatusCode::BAD_REQUEST,
618            Json(json_rpc_error(
619                None,
620                -32600,
621                "Missing MCP-Session-Id header",
622                None,
623            )),
624        )
625            .into_response());
626    };
627
628    let sessions = state.sessions.read().await;
629    match sessions.get(session_id) {
630        Some(session) => {
631            if !SUPPORTED_VERSIONS.contains(&session.protocol_version.as_str()) {
632                return Err((
633                    StatusCode::BAD_REQUEST,
634                    Json(json_rpc_error(
635                        None,
636                        -32600,
637                        "Session protocol version mismatch",
638                        None,
639                    )),
640                )
641                    .into_response());
642            }
643            if require_initialized && !session.initialized {
644                return Err((
645                    StatusCode::BAD_REQUEST,
646                    Json(json_rpc_error(
647                        None,
648                        -32600,
649                        "MCP session is not initialized",
650                        None,
651                    )),
652                )
653                    .into_response());
654            }
655            Ok(session_id.to_string())
656        }
657        None => Err((
658            StatusCode::BAD_REQUEST,
659            Json(json_rpc_error(
660                None,
661                -32600,
662                "Unknown MCP session. Re-initialize.",
663                None,
664            )),
665        )
666            .into_response()),
667    }
668}
669
670fn validate_origin(
671    headers: &HeaderMap,
672    config: &McpConfig,
673) -> std::result::Result<(), ResponseError> {
674    let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else {
675        return Ok(());
676    };
677
678    if config.allowed_origins.is_empty() {
679        return Ok(());
680    }
681
682    let allowed = config
683        .allowed_origins
684        .iter()
685        .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin));
686    if allowed {
687        return Ok(());
688    }
689
690    Err(Box::new(
691        (
692            StatusCode::FORBIDDEN,
693            Json(json_rpc_error(None, -32600, "Invalid Origin header", None)),
694        )
695            .into_response(),
696    ))
697}
698
699fn enforce_protocol_header(
700    config: &McpConfig,
701    headers: &HeaderMap,
702) -> std::result::Result<(), ResponseError> {
703    if !config.require_protocol_version_header {
704        return Ok(());
705    }
706
707    let Some(version) = headers
708        .get(MCP_PROTOCOL_HEADER)
709        .and_then(|v| v.to_str().ok())
710    else {
711        return Err(Box::new(
712            (
713                StatusCode::BAD_REQUEST,
714                Json(json_rpc_error(
715                    None,
716                    -32600,
717                    "Missing MCP-Protocol-Version header",
718                    None,
719                )),
720            )
721                .into_response(),
722        ));
723    };
724
725    if !SUPPORTED_VERSIONS.contains(&version) {
726        return Err(Box::new(
727            (
728                StatusCode::BAD_REQUEST,
729                Json(json_rpc_error(
730                    None,
731                    -32600,
732                    "Unsupported MCP-Protocol-Version",
733                    Some(serde_json::json!({ "supported": SUPPORTED_VERSIONS })),
734                )),
735            )
736                .into_response(),
737        ));
738    }
739
740    Ok(())
741}
742
743fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
744    headers
745        .get("x-forwarded-for")
746        .and_then(|v| v.to_str().ok())
747        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
748        .filter(|s| !s.is_empty())
749        .or_else(|| {
750            headers
751                .get("x-real-ip")
752                .and_then(|v| v.to_str().ok())
753                .map(|s| s.trim().to_string())
754                .filter(|s| !s.is_empty())
755        })
756}
757
758fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
759    headers
760        .get(axum::http::header::USER_AGENT)
761        .and_then(|v| v.to_str().ok())
762        .map(String::from)
763}
764
765fn build_request_metadata(
766    tracing: &super::tracing::TracingState,
767    headers: &HeaderMap,
768) -> RequestMetadata {
769    RequestMetadata {
770        request_id: uuid::Uuid::parse_str(&tracing.request_id)
771            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
772        trace_id: tracing.trace_id.clone(),
773        client_ip: extract_client_ip(headers),
774        user_agent: extract_user_agent(headers),
775        timestamp: chrono::Utc::now(),
776    }
777}
778
779fn json_rpc_success(id: Option<Value>, result: Value) -> Value {
780    serde_json::json!({
781        "jsonrpc": "2.0",
782        "id": id.unwrap_or(Value::Null),
783        "result": result
784    })
785}
786
787fn json_rpc_error(
788    id: Option<Value>,
789    code: i32,
790    message: impl Into<String>,
791    data: Option<Value>,
792) -> Value {
793    let mut error = serde_json::json!({
794        "code": code,
795        "message": message.into()
796    });
797    if let Some(data) = data
798        && let Some(obj) = error.as_object_mut()
799    {
800        obj.insert("data".to_string(), data);
801    }
802
803    serde_json::json!({
804        "jsonrpc": "2.0",
805        "id": id.unwrap_or(Value::Null),
806        "error": error
807    })
808}
809
810fn set_header(response: &mut Response<Body>, name: &str, value: &str) {
811    if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::from_str(value)) {
812        response.headers_mut().insert(name, value);
813    }
814}
815
816#[cfg(test)]
817#[allow(clippy::expect_used, clippy::indexing_slicing, clippy::unwrap_used)]
818mod tests {
819    use super::super::tracing::TracingState;
820    use super::*;
821    use axum::body::to_bytes;
822    use forge_core::function::AuthContext;
823    use forge_core::mcp::{ForgeMcpTool, McpToolAnnotations, McpToolInfo};
824    use forge_core::schemars::{self, JsonSchema};
825    use serde::{Deserialize, Serialize};
826    use std::collections::HashMap;
827    use std::future::Future;
828    use std::pin::Pin;
829
830    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
831    struct EchoArgs {
832        message: String,
833    }
834
835    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
836    struct EchoOutput {
837        echoed: String,
838    }
839
840    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
841    #[serde(rename_all = "snake_case")]
842    enum ExportFormat {
843        Json,
844        Csv,
845    }
846
847    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
848    struct MetadataArgs {
849        #[schemars(description = "Project UUID to export")]
850        project_id: String,
851        format: ExportFormat,
852    }
853
854    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
855    struct MetadataOutput {
856        accepted: bool,
857    }
858
859    struct EchoTool;
860
861    impl ForgeMcpTool for EchoTool {
862        type Args = EchoArgs;
863        type Output = EchoOutput;
864
865        fn info() -> McpToolInfo {
866            McpToolInfo {
867                name: "echo",
868                title: Some("Echo"),
869                description: Some("Echo back the message"),
870                required_role: None,
871                is_public: false,
872                timeout: None,
873                rate_limit_requests: None,
874                rate_limit_per_secs: None,
875                rate_limit_key: None,
876                annotations: McpToolAnnotations::default(),
877                icons: &[],
878            }
879        }
880
881        fn execute(
882            _ctx: &McpToolContext,
883            args: Self::Args,
884        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
885            Box::pin(async move {
886                Ok(EchoOutput {
887                    echoed: args.message,
888                })
889            })
890        }
891    }
892
893    struct AdminTool;
894
895    impl ForgeMcpTool for AdminTool {
896        type Args = EchoArgs;
897        type Output = EchoOutput;
898
899        fn info() -> McpToolInfo {
900            McpToolInfo {
901                name: "admin.echo",
902                title: Some("Admin Echo"),
903                description: Some("Admin only echo"),
904                required_role: Some("admin"),
905                is_public: false,
906                timeout: None,
907                rate_limit_requests: None,
908                rate_limit_per_secs: None,
909                rate_limit_key: None,
910                annotations: McpToolAnnotations::default(),
911                icons: &[],
912            }
913        }
914
915        fn execute(
916            _ctx: &McpToolContext,
917            args: Self::Args,
918        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
919            Box::pin(async move {
920                Ok(EchoOutput {
921                    echoed: args.message,
922                })
923            })
924        }
925    }
926
927    struct MetadataTool;
928
929    impl ForgeMcpTool for MetadataTool {
930        type Args = MetadataArgs;
931        type Output = MetadataOutput;
932
933        fn info() -> McpToolInfo {
934            McpToolInfo {
935                name: "export.project",
936                title: Some("Export Project"),
937                description: Some("Export project data"),
938                required_role: None,
939                is_public: false,
940                timeout: None,
941                rate_limit_requests: None,
942                rate_limit_per_secs: None,
943                rate_limit_key: None,
944                annotations: McpToolAnnotations::default(),
945                icons: &[],
946            }
947        }
948
949        fn execute(
950            _ctx: &McpToolContext,
951            _args: Self::Args,
952        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
953            Box::pin(async move { Ok(MetadataOutput { accepted: true }) })
954        }
955    }
956
957    #[test]
958    fn test_json_rpc_helpers() {
959        let success = json_rpc_success(
960            Some(serde_json::json!(1)),
961            serde_json::json!({ "ok": true }),
962        );
963        assert_eq!(success["jsonrpc"], "2.0");
964        assert!(success.get("result").is_some());
965
966        let err = json_rpc_error(Some(serde_json::json!(1)), -32601, "not found", None);
967        assert_eq!(err["error"]["code"], -32601);
968    }
969
970    fn test_state(config: McpConfig) -> Arc<McpState> {
971        test_state_with_registry(config, McpToolRegistry::new())
972    }
973
974    fn test_state_with_registry(config: McpConfig, registry: McpToolRegistry) -> Arc<McpState> {
975        let pool = sqlx::postgres::PgPoolOptions::new()
976            .max_connections(1)
977            .connect_lazy("postgres://localhost/nonexistent")
978            .expect("lazy pool must build");
979        Arc::new(McpState::new(config, registry, pool, None, None))
980    }
981
982    async fn response_json(response: Response) -> Value {
983        let bytes = to_bytes(response.into_body(), usize::MAX)
984            .await
985            .expect("body bytes");
986        if bytes.is_empty() {
987            return serde_json::json!({});
988        }
989        serde_json::from_slice(&bytes).expect("valid json")
990    }
991
992    async fn initialize_session(state: Arc<McpState>) -> String {
993        let payload = serde_json::json!({
994            "jsonrpc": "2.0",
995            "id": 1,
996            "method": "initialize",
997            "params": {
998                "protocolVersion": "2025-11-25",
999                "capabilities": {},
1000                "clientInfo": { "name": "test", "version": "1.0.0" }
1001            }
1002        });
1003        let response = mcp_post_handler(
1004            State(state),
1005            Extension(AuthContext::unauthenticated()),
1006            Extension(TracingState::new()),
1007            Method::POST,
1008            HeaderMap::new(),
1009            Json(payload),
1010        )
1011        .await;
1012
1013        assert_eq!(response.status(), StatusCode::OK);
1014        response
1015            .headers()
1016            .get(MCP_SESSION_HEADER)
1017            .and_then(|v| v.to_str().ok())
1018            .expect("session id must exist")
1019            .to_string()
1020    }
1021
1022    async fn mark_initialized(state: Arc<McpState>, headers: HeaderMap) {
1023        let payload = serde_json::json!({
1024            "jsonrpc": "2.0",
1025            "method": "notifications/initialized",
1026            "params": {}
1027        });
1028        let response = mcp_post_handler(
1029            State(state),
1030            Extension(AuthContext::unauthenticated()),
1031            Extension(TracingState::new()),
1032            Method::POST,
1033            headers,
1034            Json(payload),
1035        )
1036        .await;
1037        assert_eq!(response.status(), StatusCode::ACCEPTED);
1038    }
1039
1040    async fn initialized_headers(state: Arc<McpState>) -> HeaderMap {
1041        let session_id = initialize_session(state.clone()).await;
1042        let mut headers = HeaderMap::new();
1043        headers.insert(
1044            MCP_SESSION_HEADER,
1045            HeaderValue::from_str(&session_id).expect("valid session id header"),
1046        );
1047        headers.insert(
1048            MCP_PROTOCOL_HEADER,
1049            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1050        );
1051        mark_initialized(state, headers.clone()).await;
1052        headers
1053    }
1054
1055    #[tokio::test]
1056    async fn test_initialize_sets_session_header() {
1057        let state = test_state(McpConfig {
1058            enabled: true,
1059            ..Default::default()
1060        });
1061        let session = initialize_session(state).await;
1062        assert!(!session.is_empty());
1063    }
1064
1065    #[tokio::test]
1066    async fn test_tools_list_requires_initialized_session() {
1067        let state = test_state(McpConfig {
1068            enabled: true,
1069            ..Default::default()
1070        });
1071
1072        let session_id = initialize_session(state.clone()).await;
1073
1074        let mut headers = HeaderMap::new();
1075        headers.insert(
1076            MCP_SESSION_HEADER,
1077            HeaderValue::from_str(&session_id).expect("valid"),
1078        );
1079        headers.insert(
1080            MCP_PROTOCOL_HEADER,
1081            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1082        );
1083
1084        let list_payload = serde_json::json!({
1085            "jsonrpc": "2.0",
1086            "id": 2,
1087            "method": "tools/list",
1088            "params": {}
1089        });
1090        let response = mcp_post_handler(
1091            State(state),
1092            Extension(AuthContext::unauthenticated()),
1093            Extension(TracingState::new()),
1094            Method::POST,
1095            headers,
1096            Json(list_payload),
1097        )
1098        .await;
1099
1100        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1101    }
1102
1103    #[tokio::test]
1104    async fn test_tools_list_returns_registered_tools() {
1105        let mut registry = McpToolRegistry::new();
1106        registry.register::<EchoTool>();
1107
1108        let state = test_state_with_registry(
1109            McpConfig {
1110                enabled: true,
1111                ..Default::default()
1112            },
1113            registry,
1114        );
1115        let headers = initialized_headers(state.clone()).await;
1116        let payload = serde_json::json!({
1117            "jsonrpc": "2.0",
1118            "id": 2,
1119            "method": "tools/list",
1120            "params": {}
1121        });
1122
1123        let response = mcp_post_handler(
1124            State(state),
1125            Extension(AuthContext::unauthenticated()),
1126            Extension(TracingState::new()),
1127            Method::POST,
1128            headers,
1129            Json(payload),
1130        )
1131        .await;
1132
1133        assert_eq!(response.status(), StatusCode::OK);
1134        let body = response_json(response).await;
1135        let tools = body["result"]["tools"]
1136            .as_array()
1137            .expect("tools list should be array");
1138        assert_eq!(tools.len(), 1);
1139        assert_eq!(tools[0]["name"], "echo");
1140        assert!(tools[0].get("inputSchema").is_some());
1141        assert!(tools[0].get("outputSchema").is_some());
1142    }
1143
1144    #[tokio::test]
1145    async fn test_tools_list_exposes_parameter_metadata() {
1146        let mut registry = McpToolRegistry::new();
1147        registry.register::<MetadataTool>();
1148
1149        let state = test_state_with_registry(
1150            McpConfig {
1151                enabled: true,
1152                ..Default::default()
1153            },
1154            registry,
1155        );
1156        let headers = initialized_headers(state.clone()).await;
1157        let payload = serde_json::json!({
1158            "jsonrpc": "2.0",
1159            "id": 9,
1160            "method": "tools/list",
1161            "params": {}
1162        });
1163
1164        let response = mcp_post_handler(
1165            State(state),
1166            Extension(AuthContext::unauthenticated()),
1167            Extension(TracingState::new()),
1168            Method::POST,
1169            headers,
1170            Json(payload),
1171        )
1172        .await;
1173
1174        assert_eq!(response.status(), StatusCode::OK);
1175        let body = response_json(response).await;
1176        let tools = body["result"]["tools"]
1177            .as_array()
1178            .expect("tools list should be array");
1179        assert_eq!(tools.len(), 1);
1180
1181        let input_schema = &tools[0]["inputSchema"];
1182        assert_eq!(
1183            input_schema["properties"]["project_id"]["description"],
1184            "Project UUID to export"
1185        );
1186
1187        let schema_text = input_schema.to_string();
1188        assert!(schema_text.contains("\"json\""));
1189        assert!(schema_text.contains("\"csv\""));
1190    }
1191
1192    #[tokio::test]
1193    async fn test_tools_call_success_returns_structured_content() {
1194        let mut registry = McpToolRegistry::new();
1195        registry.register::<EchoTool>();
1196
1197        let state = test_state_with_registry(
1198            McpConfig {
1199                enabled: true,
1200                ..Default::default()
1201            },
1202            registry,
1203        );
1204        let headers = initialized_headers(state.clone()).await;
1205        let auth = AuthContext::authenticated(
1206            uuid::Uuid::new_v4(),
1207            vec!["member".to_string()],
1208            HashMap::new(),
1209        );
1210        let payload = serde_json::json!({
1211            "jsonrpc": "2.0",
1212            "id": 3,
1213            "method": "tools/call",
1214            "params": {
1215                "name": "echo",
1216                "arguments": { "message": "hello" }
1217            }
1218        });
1219
1220        let response = mcp_post_handler(
1221            State(state),
1222            Extension(auth),
1223            Extension(TracingState::new()),
1224            Method::POST,
1225            headers,
1226            Json(payload),
1227        )
1228        .await;
1229
1230        assert_eq!(response.status(), StatusCode::OK);
1231        let body = response_json(response).await;
1232        assert_eq!(body["result"]["structuredContent"]["echoed"], "hello");
1233        assert_eq!(body["result"]["content"][0]["type"], "text");
1234    }
1235
1236    #[tokio::test]
1237    async fn test_tools_call_validation_failure_returns_is_error() {
1238        let mut registry = McpToolRegistry::new();
1239        registry.register::<EchoTool>();
1240
1241        let state = test_state_with_registry(
1242            McpConfig {
1243                enabled: true,
1244                ..Default::default()
1245            },
1246            registry,
1247        );
1248        let headers = initialized_headers(state.clone()).await;
1249        let auth = AuthContext::authenticated(
1250            uuid::Uuid::new_v4(),
1251            vec!["member".to_string()],
1252            HashMap::new(),
1253        );
1254        let payload = serde_json::json!({
1255            "jsonrpc": "2.0",
1256            "id": 4,
1257            "method": "tools/call",
1258            "params": {
1259                "name": "echo",
1260                "arguments": {}
1261            }
1262        });
1263
1264        let response = mcp_post_handler(
1265            State(state),
1266            Extension(auth),
1267            Extension(TracingState::new()),
1268            Method::POST,
1269            headers,
1270            Json(payload),
1271        )
1272        .await;
1273
1274        assert_eq!(response.status(), StatusCode::OK);
1275        let body = response_json(response).await;
1276        assert_eq!(body["result"]["isError"], true);
1277    }
1278
1279    #[tokio::test]
1280    async fn test_tools_call_requires_authentication() {
1281        let mut registry = McpToolRegistry::new();
1282        registry.register::<EchoTool>();
1283
1284        let state = test_state_with_registry(
1285            McpConfig {
1286                enabled: true,
1287                ..Default::default()
1288            },
1289            registry,
1290        );
1291        let headers = initialized_headers(state.clone()).await;
1292        let payload = serde_json::json!({
1293            "jsonrpc": "2.0",
1294            "id": 5,
1295            "method": "tools/call",
1296            "params": {
1297                "name": "echo",
1298                "arguments": { "message": "hello" }
1299            }
1300        });
1301
1302        let response = mcp_post_handler(
1303            State(state),
1304            Extension(AuthContext::unauthenticated()),
1305            Extension(TracingState::new()),
1306            Method::POST,
1307            headers,
1308            Json(payload),
1309        )
1310        .await;
1311
1312        assert_eq!(response.status(), StatusCode::OK);
1313        let body = response_json(response).await;
1314        assert_eq!(body["error"]["code"], -32001);
1315    }
1316
1317    #[tokio::test]
1318    async fn test_tools_call_requires_role() {
1319        let mut registry = McpToolRegistry::new();
1320        registry.register::<AdminTool>();
1321
1322        let state = test_state_with_registry(
1323            McpConfig {
1324                enabled: true,
1325                ..Default::default()
1326            },
1327            registry,
1328        );
1329        let headers = initialized_headers(state.clone()).await;
1330        let auth = AuthContext::authenticated(
1331            uuid::Uuid::new_v4(),
1332            vec!["member".to_string()],
1333            HashMap::new(),
1334        );
1335        let payload = serde_json::json!({
1336            "jsonrpc": "2.0",
1337            "id": 6,
1338            "method": "tools/call",
1339            "params": {
1340                "name": "admin.echo",
1341                "arguments": { "message": "hello" }
1342            }
1343        });
1344
1345        let response = mcp_post_handler(
1346            State(state),
1347            Extension(auth),
1348            Extension(TracingState::new()),
1349            Method::POST,
1350            headers,
1351            Json(payload),
1352        )
1353        .await;
1354
1355        assert_eq!(response.status(), StatusCode::OK);
1356        let body = response_json(response).await;
1357        assert_eq!(body["error"]["code"], -32003);
1358    }
1359
1360    #[tokio::test]
1361    async fn test_invalid_protocol_header_returns_400() {
1362        let state = test_state(McpConfig {
1363            enabled: true,
1364            ..Default::default()
1365        });
1366        let session_id = initialize_session(state.clone()).await;
1367        let mut headers = HeaderMap::new();
1368        headers.insert(
1369            MCP_SESSION_HEADER,
1370            HeaderValue::from_str(&session_id).expect("valid"),
1371        );
1372        headers.insert(
1373            MCP_PROTOCOL_HEADER,
1374            HeaderValue::from_static("invalid-version"),
1375        );
1376
1377        let payload = serde_json::json!({
1378            "jsonrpc": "2.0",
1379            "id": 7,
1380            "method": "tools/list",
1381            "params": {}
1382        });
1383
1384        let response = mcp_post_handler(
1385            State(state),
1386            Extension(AuthContext::unauthenticated()),
1387            Extension(TracingState::new()),
1388            Method::POST,
1389            headers,
1390            Json(payload),
1391        )
1392        .await;
1393        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1394    }
1395
1396    #[tokio::test]
1397    async fn test_missing_protocol_header_returns_400() {
1398        let state = test_state(McpConfig {
1399            enabled: true,
1400            ..Default::default()
1401        });
1402        let session_id = initialize_session(state.clone()).await;
1403        let mut headers = HeaderMap::new();
1404        headers.insert(
1405            MCP_SESSION_HEADER,
1406            HeaderValue::from_str(&session_id).expect("valid"),
1407        );
1408
1409        let payload = serde_json::json!({
1410            "jsonrpc": "2.0",
1411            "id": 8,
1412            "method": "tools/list",
1413            "params": {}
1414        });
1415
1416        let response = mcp_post_handler(
1417            State(state),
1418            Extension(AuthContext::unauthenticated()),
1419            Extension(TracingState::new()),
1420            Method::POST,
1421            headers,
1422            Json(payload),
1423        )
1424        .await;
1425        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1426    }
1427
1428    #[tokio::test]
1429    async fn test_notifications_return_202() {
1430        let state = test_state(McpConfig {
1431            enabled: true,
1432            ..Default::default()
1433        });
1434        let mut headers = HeaderMap::new();
1435        headers.insert(
1436            MCP_PROTOCOL_HEADER,
1437            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1438        );
1439        let payload = serde_json::json!({
1440            "jsonrpc": "2.0",
1441            "method": "notifications/tools/list_changed",
1442            "params": {}
1443        });
1444        let response = mcp_post_handler(
1445            State(state),
1446            Extension(AuthContext::unauthenticated()),
1447            Extension(TracingState::new()),
1448            Method::POST,
1449            headers,
1450            Json(payload),
1451        )
1452        .await;
1453        assert_eq!(response.status(), StatusCode::ACCEPTED);
1454    }
1455
1456    #[tokio::test]
1457    async fn test_invalid_origin_rejected() {
1458        let state = test_state(McpConfig {
1459            enabled: true,
1460            allowed_origins: vec!["https://allowed.example".to_string()],
1461            ..Default::default()
1462        });
1463        let payload = serde_json::json!({
1464            "jsonrpc": "2.0",
1465            "id": 1,
1466            "method": "initialize",
1467            "params": {
1468                "protocolVersion": "2025-11-25",
1469                "capabilities": {},
1470                "clientInfo": { "name": "test", "version": "1.0.0" }
1471            }
1472        });
1473
1474        let mut headers = HeaderMap::new();
1475        headers.insert("origin", HeaderValue::from_static("https://evil.example"));
1476
1477        let response = mcp_post_handler(
1478            State(state),
1479            Extension(AuthContext::unauthenticated()),
1480            Extension(TracingState::new()),
1481            Method::POST,
1482            headers,
1483            Json(payload),
1484        )
1485        .await;
1486
1487        assert_eq!(response.status(), StatusCode::FORBIDDEN);
1488    }
1489}