Skip to main content

forge_runtime/gateway/
mcp.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use axum::extract::{Extension, State};
6use axum::http::header::{HeaderName, HeaderValue};
7use axum::http::{HeaderMap, Method, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::{Json, body::Body};
10use forge_core::config::McpConfig;
11use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
12use forge_core::mcp::McpToolContext;
13use forge_core::rate_limit::RateLimitKey;
14use serde_json::Value;
15use tokio::sync::RwLock;
16
17use crate::mcp::McpToolRegistry;
18use crate::rate_limit::RateLimiter;
19
20const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
21const MCP_SESSION_HEADER: &str = "mcp-session-id";
22const MCP_PROTOCOL_HEADER: &str = "mcp-protocol-version";
23const DEFAULT_PAGE_SIZE: usize = 50;
24type ResponseError = Box<Response>;
25
26#[derive(Debug, Clone)]
27struct McpSession {
28    initialized: bool,
29    protocol_version: String,
30    expires_at: Instant,
31}
32
33#[derive(Clone)]
34pub struct McpState {
35    config: McpConfig,
36    registry: McpToolRegistry,
37    pool: sqlx::PgPool,
38    sessions: Arc<RwLock<HashMap<String, McpSession>>>,
39    job_dispatcher: Option<Arc<dyn JobDispatch>>,
40    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41    rate_limiter: Arc<RateLimiter>,
42}
43
44impl McpState {
45    pub fn new(
46        config: McpConfig,
47        registry: McpToolRegistry,
48        pool: sqlx::PgPool,
49        job_dispatcher: Option<Arc<dyn JobDispatch>>,
50        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
51    ) -> Self {
52        Self {
53            config,
54            registry,
55            pool: pool.clone(),
56            sessions: Arc::new(RwLock::new(HashMap::new())),
57            job_dispatcher,
58            workflow_dispatcher,
59            rate_limiter: Arc::new(RateLimiter::new(pool)),
60        }
61    }
62
63    async fn cleanup_expired_sessions(&self) {
64        let mut sessions = self.sessions.write().await;
65        let now = Instant::now();
66        sessions.retain(|_, session| session.expires_at > now);
67    }
68
69    async fn touch_session(&self, session_id: &str) {
70        let mut sessions = self.sessions.write().await;
71        if let Some(session) = sessions.get_mut(session_id) {
72            session.expires_at = Instant::now() + Duration::from_secs(self.config.session_ttl_secs);
73        }
74    }
75
76    fn session_ttl(&self) -> Duration {
77        Duration::from_secs(self.config.session_ttl_secs)
78    }
79}
80
81pub async fn mcp_get_handler(State(state): State<Arc<McpState>>, headers: HeaderMap) -> Response {
82    if let Err(resp) = validate_origin(&headers, &state.config) {
83        return *resp;
84    }
85
86    (
87        StatusCode::METHOD_NOT_ALLOWED,
88        Json(json_rpc_error(
89            None,
90            -32601,
91            "GET stream is not supported in Forge MCP v1",
92            None,
93        )),
94    )
95        .into_response()
96}
97
98pub async fn mcp_post_handler(
99    State(state): State<Arc<McpState>>,
100    Extension(auth): Extension<AuthContext>,
101    Extension(tracing): Extension<super::tracing::TracingState>,
102    method: Method,
103    headers: HeaderMap,
104    Json(payload): Json<Value>,
105) -> Response {
106    if method != Method::POST {
107        return (
108            StatusCode::METHOD_NOT_ALLOWED,
109            Json(json_rpc_error(None, -32601, "Only POST is supported", None)),
110        )
111            .into_response();
112    }
113
114    if let Err(resp) = validate_origin(&headers, &state.config) {
115        return *resp;
116    }
117
118    state.cleanup_expired_sessions().await;
119
120    let Some(method_name) = payload.get("method").and_then(Value::as_str) else {
121        // Notifications / responses sent by client should get 202 when accepted.
122        if payload.get("id").is_some()
123            && (payload.get("result").is_some() || payload.get("error").is_some())
124        {
125            return StatusCode::ACCEPTED.into_response();
126        }
127        return (
128            StatusCode::BAD_REQUEST,
129            Json(json_rpc_error(
130                None,
131                -32600,
132                "Invalid JSON-RPC payload",
133                None,
134            )),
135        )
136            .into_response();
137    };
138
139    let id = payload.get("id").cloned();
140    let params = payload
141        .get("params")
142        .cloned()
143        .unwrap_or(Value::Object(Default::default()));
144
145    // Notification flow
146    if id.is_none() {
147        return handle_notification(&state, method_name, params, &headers).await;
148    }
149
150    // Request flow
151    if method_name != "initialize"
152        && let Err(resp) = enforce_protocol_header(&state.config, &headers)
153    {
154        return *resp;
155    }
156
157    match method_name {
158        "initialize" => handle_initialize(&state, id, &params).await,
159        "tools/list" => {
160            let session_id = match required_session_id(&state, &headers, true).await {
161                Ok(v) => v,
162                Err(resp) => return resp,
163            };
164            state.touch_session(&session_id).await;
165            handle_tools_list(&state, id, &params)
166        }
167        "tools/call" => {
168            let session_id = match required_session_id(&state, &headers, true).await {
169                Ok(v) => v,
170                Err(resp) => return resp,
171            };
172            state.touch_session(&session_id).await;
173
174            let metadata = build_request_metadata(&tracing, &headers);
175            handle_tools_call(&state, id, &params, &auth, metadata).await
176        }
177        _ => (
178            StatusCode::OK,
179            Json(json_rpc_error(id, -32601, "Method not found", None)),
180        )
181            .into_response(),
182    }
183}
184
185async fn handle_notification(
186    state: &Arc<McpState>,
187    method_name: &str,
188    _params: Value,
189    headers: &HeaderMap,
190) -> Response {
191    if let Err(resp) = enforce_protocol_header(&state.config, headers) {
192        return *resp;
193    }
194
195    match method_name {
196        "notifications/initialized" => {
197            let session_id = match required_session_id(state, headers, false).await {
198                Ok(v) => v,
199                Err(resp) => return resp,
200            };
201
202            let mut sessions = state.sessions.write().await;
203            if let Some(session) = sessions.get_mut(&session_id) {
204                session.initialized = true;
205                session.expires_at = Instant::now() + state.session_ttl();
206                return StatusCode::ACCEPTED.into_response();
207            }
208
209            (
210                StatusCode::BAD_REQUEST,
211                Json(json_rpc_error(
212                    None,
213                    -32600,
214                    "Unknown MCP session. Re-initialize the connection.",
215                    None,
216                )),
217            )
218                .into_response()
219        }
220        _ => StatusCode::ACCEPTED.into_response(),
221    }
222}
223
224async fn handle_initialize(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
225    let Some(requested_version) = params.get("protocolVersion").and_then(Value::as_str) else {
226        return (
227            StatusCode::OK,
228            Json(json_rpc_error(
229                id,
230                -32602,
231                "Missing protocolVersion in initialize params",
232                None,
233            )),
234        )
235            .into_response();
236    };
237
238    if requested_version != MCP_PROTOCOL_VERSION {
239        return (
240            StatusCode::OK,
241            Json(json_rpc_error(
242                id,
243                -32602,
244                "Unsupported protocolVersion",
245                Some(serde_json::json!({
246                    "supported": MCP_PROTOCOL_VERSION
247                })),
248            )),
249        )
250            .into_response();
251    }
252
253    let session_id = uuid::Uuid::new_v4().to_string();
254    {
255        let mut sessions = state.sessions.write().await;
256        sessions.insert(
257            session_id.clone(),
258            McpSession {
259                initialized: false,
260                protocol_version: MCP_PROTOCOL_VERSION.to_string(),
261                expires_at: Instant::now() + state.session_ttl(),
262            },
263        );
264    }
265
266    let mut response = (
267        StatusCode::OK,
268        Json(json_rpc_success(
269            id,
270            serde_json::json!({
271                "protocolVersion": MCP_PROTOCOL_VERSION,
272                "capabilities": {
273                    "tools": {}
274                },
275                "serverInfo": {
276                    "name": "forge",
277                    "version": env!("CARGO_PKG_VERSION")
278                }
279            }),
280        )),
281    )
282        .into_response();
283
284    set_header(&mut response, MCP_SESSION_HEADER, &session_id);
285    set_header(&mut response, MCP_PROTOCOL_HEADER, MCP_PROTOCOL_VERSION);
286    response
287}
288
289fn handle_tools_list(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
290    let cursor = params.get("cursor").and_then(Value::as_str);
291    let start = match cursor {
292        Some(c) => match c.parse::<usize>() {
293            Ok(v) => v,
294            Err(_) => {
295                return (
296                    StatusCode::OK,
297                    Json(json_rpc_error(
298                        id,
299                        -32602,
300                        "Invalid cursor in tools/list request",
301                        None,
302                    )),
303                )
304                    .into_response();
305            }
306        },
307        None => 0,
308    };
309
310    let mut tools: Vec<_> = state.registry.list().collect();
311    tools.sort_by(|a, b| a.info.name.cmp(b.info.name));
312
313    let page: Vec<_> = tools
314        .iter()
315        .skip(start)
316        .take(DEFAULT_PAGE_SIZE)
317        .map(|entry| {
318            let annotations = serde_json::json!({
319                "title": entry.info.annotations.title,
320                "readOnlyHint": entry.info.annotations.read_only_hint,
321                "destructiveHint": entry.info.annotations.destructive_hint,
322                "idempotentHint": entry.info.annotations.idempotent_hint,
323                "openWorldHint": entry.info.annotations.open_world_hint
324            });
325
326            let icons: Vec<_> = entry
327                .info
328                .icons
329                .iter()
330                .map(|icon| {
331                    serde_json::json!({
332                        "src": icon.src,
333                        "mimeType": icon.mime_type,
334                        "sizes": icon.sizes,
335                        "theme": icon.theme
336                    })
337                })
338                .collect();
339
340            let mut value = serde_json::json!({
341                "name": entry.info.name,
342                "title": entry.info.title,
343                "description": entry.info.description,
344                "inputSchema": entry.input_schema,
345                "annotations": annotations,
346                "icons": icons
347            });
348            if let Some(output_schema) = &entry.output_schema
349                && let Some(obj) = value.as_object_mut()
350            {
351                obj.insert("outputSchema".to_string(), output_schema.clone());
352            }
353            value
354        })
355        .collect();
356
357    let end = start.saturating_add(page.len());
358    let next_cursor = if end < tools.len() {
359        Some(end.to_string())
360    } else {
361        None
362    };
363
364    (
365        StatusCode::OK,
366        Json(json_rpc_success(
367            id,
368            serde_json::json!({
369                "tools": page,
370                "nextCursor": next_cursor
371            }),
372        )),
373    )
374        .into_response()
375}
376
377async fn handle_tools_call(
378    state: &Arc<McpState>,
379    id: Option<Value>,
380    params: &Value,
381    auth: &AuthContext,
382    request_metadata: RequestMetadata,
383) -> Response {
384    let Some(tool_name) = params.get("name").and_then(Value::as_str) else {
385        return (
386            StatusCode::OK,
387            Json(json_rpc_error(id, -32602, "Missing tool name", None)),
388        )
389            .into_response();
390    };
391
392    let Some(entry) = state.registry.get(tool_name) else {
393        return (
394            StatusCode::OK,
395            Json(json_rpc_error(id, -32602, "Unknown tool", None)),
396        )
397            .into_response();
398    };
399
400    if !entry.info.is_public && !auth.is_authenticated() {
401        return (
402            StatusCode::OK,
403            Json(json_rpc_error(id, -32001, "Authentication required", None)),
404        )
405            .into_response();
406    }
407    if let Some(role) = entry.info.required_role
408        && !auth.has_role(role)
409    {
410        return (
411            StatusCode::OK,
412            Json(json_rpc_error(
413                id,
414                -32003,
415                format!("Role '{}' required", role),
416                None,
417            )),
418        )
419            .into_response();
420    }
421
422    if let (Some(requests), Some(per_secs)) = (
423        entry.info.rate_limit_requests,
424        entry.info.rate_limit_per_secs,
425    ) {
426        let key_type = entry
427            .info
428            .rate_limit_key
429            .and_then(|k| k.parse::<RateLimitKey>().ok())
430            .unwrap_or_default();
431
432        let config = forge_core::RateLimitConfig::new(requests, Duration::from_secs(per_secs))
433            .with_key(key_type);
434        let bucket_key = state
435            .rate_limiter
436            .build_key(key_type, tool_name, auth, &request_metadata);
437
438        if let Err(e) = state.rate_limiter.enforce(&bucket_key, &config).await {
439            return (
440                StatusCode::OK,
441                Json(json_rpc_error(id, -32029, e.to_string(), None)),
442            )
443                .into_response();
444        }
445    }
446
447    let args = params
448        .get("arguments")
449        .cloned()
450        .unwrap_or(Value::Object(Default::default()));
451
452    let ctx = McpToolContext::with_dispatch(
453        state.pool.clone(),
454        auth.clone(),
455        request_metadata,
456        state.job_dispatcher.clone(),
457        state.workflow_dispatcher.clone(),
458    );
459
460    let result = if let Some(timeout_secs) = entry.info.timeout {
461        match tokio::time::timeout(
462            Duration::from_secs(timeout_secs),
463            (entry.handler)(&ctx, args),
464        )
465        .await
466        {
467            Ok(inner) => inner,
468            Err(_) => {
469                return (
470                    StatusCode::OK,
471                    Json(json_rpc_error(id, -32000, "Tool timed out", None)),
472                )
473                    .into_response();
474            }
475        }
476    } else {
477        (entry.handler)(&ctx, args).await
478    };
479
480    match result {
481        Ok(output) => {
482            let result = tool_success_result(output);
483            (
484                StatusCode::OK,
485                Json(json_rpc_success(id, serde_json::json!(result))),
486            )
487                .into_response()
488        }
489        Err(e) => match e {
490            forge_core::ForgeError::Validation(msg)
491            | forge_core::ForgeError::InvalidArgument(msg) => (
492                StatusCode::OK,
493                Json(json_rpc_success(
494                    id,
495                    serde_json::json!({
496                        "content": [{ "type": "text", "text": msg }],
497                        "isError": true
498                    }),
499                )),
500            )
501                .into_response(),
502            forge_core::ForgeError::Unauthorized(msg) => {
503                (StatusCode::OK, Json(json_rpc_error(id, -32001, msg, None))).into_response()
504            }
505            forge_core::ForgeError::Forbidden(msg) => {
506                (StatusCode::OK, Json(json_rpc_error(id, -32003, msg, None))).into_response()
507            }
508            _ => (
509                StatusCode::OK,
510                Json(json_rpc_error(id, -32603, "Internal server error", None)),
511            )
512                .into_response(),
513        },
514    }
515}
516
517fn tool_success_result(output: Value) -> Value {
518    match output {
519        Value::Object(_) => serde_json::json!({
520            "content": [{
521                "type": "text",
522                "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string())
523            }],
524            "structuredContent": output
525        }),
526        Value::String(text) => serde_json::json!({
527            "content": [{ "type": "text", "text": text }]
528        }),
529        other => serde_json::json!({
530            "content": [{
531                "type": "text",
532                "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string())
533            }]
534        }),
535    }
536}
537
538async fn required_session_id(
539    state: &Arc<McpState>,
540    headers: &HeaderMap,
541    require_initialized: bool,
542) -> std::result::Result<String, Response> {
543    let Some(session_id) = headers
544        .get(MCP_SESSION_HEADER)
545        .and_then(|v| v.to_str().ok())
546    else {
547        return Err((
548            StatusCode::BAD_REQUEST,
549            Json(json_rpc_error(
550                None,
551                -32600,
552                "Missing MCP-Session-Id header",
553                None,
554            )),
555        )
556            .into_response());
557    };
558
559    let sessions = state.sessions.read().await;
560    match sessions.get(session_id) {
561        Some(session) => {
562            if session.protocol_version != MCP_PROTOCOL_VERSION {
563                return Err((
564                    StatusCode::BAD_REQUEST,
565                    Json(json_rpc_error(
566                        None,
567                        -32600,
568                        "Session protocol version mismatch",
569                        None,
570                    )),
571                )
572                    .into_response());
573            }
574            if require_initialized && !session.initialized {
575                return Err((
576                    StatusCode::BAD_REQUEST,
577                    Json(json_rpc_error(
578                        None,
579                        -32600,
580                        "MCP session is not initialized",
581                        None,
582                    )),
583                )
584                    .into_response());
585            }
586            Ok(session_id.to_string())
587        }
588        None => Err((
589            StatusCode::BAD_REQUEST,
590            Json(json_rpc_error(
591                None,
592                -32600,
593                "Unknown MCP session. Re-initialize.",
594                None,
595            )),
596        )
597            .into_response()),
598    }
599}
600
601fn validate_origin(
602    headers: &HeaderMap,
603    config: &McpConfig,
604) -> std::result::Result<(), ResponseError> {
605    let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else {
606        return Ok(());
607    };
608
609    if config.allowed_origins.is_empty() {
610        return Ok(());
611    }
612
613    let allowed = config
614        .allowed_origins
615        .iter()
616        .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin));
617    if allowed {
618        return Ok(());
619    }
620
621    Err(Box::new(
622        (
623            StatusCode::FORBIDDEN,
624            Json(json_rpc_error(None, -32600, "Invalid Origin header", None)),
625        )
626            .into_response(),
627    ))
628}
629
630fn enforce_protocol_header(
631    config: &McpConfig,
632    headers: &HeaderMap,
633) -> std::result::Result<(), ResponseError> {
634    if !config.require_protocol_version_header {
635        return Ok(());
636    }
637
638    let Some(version) = headers
639        .get(MCP_PROTOCOL_HEADER)
640        .and_then(|v| v.to_str().ok())
641    else {
642        return Err(Box::new(
643            (
644                StatusCode::BAD_REQUEST,
645                Json(json_rpc_error(
646                    None,
647                    -32600,
648                    "Missing MCP-Protocol-Version header",
649                    None,
650                )),
651            )
652                .into_response(),
653        ));
654    };
655
656    if version != MCP_PROTOCOL_VERSION {
657        return Err(Box::new(
658            (
659                StatusCode::BAD_REQUEST,
660                Json(json_rpc_error(
661                    None,
662                    -32600,
663                    "Unsupported MCP-Protocol-Version",
664                    Some(serde_json::json!({ "supported": MCP_PROTOCOL_VERSION })),
665                )),
666            )
667                .into_response(),
668        ));
669    }
670
671    Ok(())
672}
673
674fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
675    headers
676        .get("x-forwarded-for")
677        .and_then(|v| v.to_str().ok())
678        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
679        .filter(|s| !s.is_empty())
680        .or_else(|| {
681            headers
682                .get("x-real-ip")
683                .and_then(|v| v.to_str().ok())
684                .map(|s| s.trim().to_string())
685                .filter(|s| !s.is_empty())
686        })
687}
688
689fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
690    headers
691        .get(axum::http::header::USER_AGENT)
692        .and_then(|v| v.to_str().ok())
693        .map(String::from)
694}
695
696fn build_request_metadata(
697    tracing: &super::tracing::TracingState,
698    headers: &HeaderMap,
699) -> RequestMetadata {
700    RequestMetadata {
701        request_id: uuid::Uuid::parse_str(&tracing.request_id)
702            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
703        trace_id: tracing.trace_id.clone(),
704        client_ip: extract_client_ip(headers),
705        user_agent: extract_user_agent(headers),
706        timestamp: chrono::Utc::now(),
707    }
708}
709
710fn json_rpc_success(id: Option<Value>, result: Value) -> Value {
711    serde_json::json!({
712        "jsonrpc": "2.0",
713        "id": id.unwrap_or(Value::Null),
714        "result": result
715    })
716}
717
718fn json_rpc_error(
719    id: Option<Value>,
720    code: i32,
721    message: impl Into<String>,
722    data: Option<Value>,
723) -> Value {
724    let mut error = serde_json::json!({
725        "code": code,
726        "message": message.into()
727    });
728    if let Some(data) = data
729        && let Some(obj) = error.as_object_mut()
730    {
731        obj.insert("data".to_string(), data);
732    }
733
734    serde_json::json!({
735        "jsonrpc": "2.0",
736        "id": id.unwrap_or(Value::Null),
737        "error": error
738    })
739}
740
741fn set_header(response: &mut Response<Body>, name: &str, value: &str) {
742    if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::from_str(value)) {
743        response.headers_mut().insert(name, value);
744    }
745}
746
747#[cfg(test)]
748#[allow(clippy::expect_used, clippy::indexing_slicing, clippy::unwrap_used)]
749mod tests {
750    use super::super::tracing::TracingState;
751    use super::*;
752    use axum::body::to_bytes;
753    use forge_core::function::AuthContext;
754    use forge_core::mcp::{ForgeMcpTool, McpToolAnnotations, McpToolInfo};
755    use forge_core::schemars::{self, JsonSchema};
756    use serde::{Deserialize, Serialize};
757    use std::collections::HashMap;
758    use std::future::Future;
759    use std::pin::Pin;
760
761    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
762    struct EchoArgs {
763        message: String,
764    }
765
766    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
767    struct EchoOutput {
768        echoed: String,
769    }
770
771    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
772    #[serde(rename_all = "snake_case")]
773    enum ExportFormat {
774        Json,
775        Csv,
776    }
777
778    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
779    struct MetadataArgs {
780        #[schemars(description = "Project UUID to export")]
781        project_id: String,
782        format: ExportFormat,
783    }
784
785    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
786    struct MetadataOutput {
787        accepted: bool,
788    }
789
790    struct EchoTool;
791
792    impl ForgeMcpTool for EchoTool {
793        type Args = EchoArgs;
794        type Output = EchoOutput;
795
796        fn info() -> McpToolInfo {
797            McpToolInfo {
798                name: "echo",
799                title: Some("Echo"),
800                description: Some("Echo back the message"),
801                required_role: None,
802                is_public: false,
803                timeout: None,
804                rate_limit_requests: None,
805                rate_limit_per_secs: None,
806                rate_limit_key: None,
807                annotations: McpToolAnnotations::default(),
808                icons: &[],
809            }
810        }
811
812        fn execute(
813            _ctx: &McpToolContext,
814            args: Self::Args,
815        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
816            Box::pin(async move {
817                Ok(EchoOutput {
818                    echoed: args.message,
819                })
820            })
821        }
822    }
823
824    struct AdminTool;
825
826    impl ForgeMcpTool for AdminTool {
827        type Args = EchoArgs;
828        type Output = EchoOutput;
829
830        fn info() -> McpToolInfo {
831            McpToolInfo {
832                name: "admin.echo",
833                title: Some("Admin Echo"),
834                description: Some("Admin only echo"),
835                required_role: Some("admin"),
836                is_public: false,
837                timeout: None,
838                rate_limit_requests: None,
839                rate_limit_per_secs: None,
840                rate_limit_key: None,
841                annotations: McpToolAnnotations::default(),
842                icons: &[],
843            }
844        }
845
846        fn execute(
847            _ctx: &McpToolContext,
848            args: Self::Args,
849        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
850            Box::pin(async move {
851                Ok(EchoOutput {
852                    echoed: args.message,
853                })
854            })
855        }
856    }
857
858    struct MetadataTool;
859
860    impl ForgeMcpTool for MetadataTool {
861        type Args = MetadataArgs;
862        type Output = MetadataOutput;
863
864        fn info() -> McpToolInfo {
865            McpToolInfo {
866                name: "export.project",
867                title: Some("Export Project"),
868                description: Some("Export project data"),
869                required_role: None,
870                is_public: false,
871                timeout: None,
872                rate_limit_requests: None,
873                rate_limit_per_secs: None,
874                rate_limit_key: None,
875                annotations: McpToolAnnotations::default(),
876                icons: &[],
877            }
878        }
879
880        fn execute(
881            _ctx: &McpToolContext,
882            _args: Self::Args,
883        ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
884            Box::pin(async move { Ok(MetadataOutput { accepted: true }) })
885        }
886    }
887
888    #[test]
889    fn test_json_rpc_helpers() {
890        let success = json_rpc_success(
891            Some(serde_json::json!(1)),
892            serde_json::json!({ "ok": true }),
893        );
894        assert_eq!(success["jsonrpc"], "2.0");
895        assert!(success.get("result").is_some());
896
897        let err = json_rpc_error(Some(serde_json::json!(1)), -32601, "not found", None);
898        assert_eq!(err["error"]["code"], -32601);
899    }
900
901    fn test_state(config: McpConfig) -> Arc<McpState> {
902        test_state_with_registry(config, McpToolRegistry::new())
903    }
904
905    fn test_state_with_registry(config: McpConfig, registry: McpToolRegistry) -> Arc<McpState> {
906        let pool = sqlx::postgres::PgPoolOptions::new()
907            .max_connections(1)
908            .connect_lazy("postgres://localhost/nonexistent")
909            .expect("lazy pool must build");
910        Arc::new(McpState::new(config, registry, pool, None, None))
911    }
912
913    async fn response_json(response: Response) -> Value {
914        let bytes = to_bytes(response.into_body(), usize::MAX)
915            .await
916            .expect("body bytes");
917        if bytes.is_empty() {
918            return serde_json::json!({});
919        }
920        serde_json::from_slice(&bytes).expect("valid json")
921    }
922
923    async fn initialize_session(state: Arc<McpState>) -> String {
924        let payload = serde_json::json!({
925            "jsonrpc": "2.0",
926            "id": 1,
927            "method": "initialize",
928            "params": {
929                "protocolVersion": "2025-11-25",
930                "capabilities": {},
931                "clientInfo": { "name": "test", "version": "1.0.0" }
932            }
933        });
934        let response = mcp_post_handler(
935            State(state),
936            Extension(AuthContext::unauthenticated()),
937            Extension(TracingState::new()),
938            Method::POST,
939            HeaderMap::new(),
940            Json(payload),
941        )
942        .await;
943
944        assert_eq!(response.status(), StatusCode::OK);
945        response
946            .headers()
947            .get(MCP_SESSION_HEADER)
948            .and_then(|v| v.to_str().ok())
949            .expect("session id must exist")
950            .to_string()
951    }
952
953    async fn mark_initialized(state: Arc<McpState>, headers: HeaderMap) {
954        let payload = serde_json::json!({
955            "jsonrpc": "2.0",
956            "method": "notifications/initialized",
957            "params": {}
958        });
959        let response = mcp_post_handler(
960            State(state),
961            Extension(AuthContext::unauthenticated()),
962            Extension(TracingState::new()),
963            Method::POST,
964            headers,
965            Json(payload),
966        )
967        .await;
968        assert_eq!(response.status(), StatusCode::ACCEPTED);
969    }
970
971    async fn initialized_headers(state: Arc<McpState>) -> HeaderMap {
972        let session_id = initialize_session(state.clone()).await;
973        let mut headers = HeaderMap::new();
974        headers.insert(
975            MCP_SESSION_HEADER,
976            HeaderValue::from_str(&session_id).expect("valid session id header"),
977        );
978        headers.insert(
979            MCP_PROTOCOL_HEADER,
980            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
981        );
982        mark_initialized(state, headers.clone()).await;
983        headers
984    }
985
986    #[tokio::test]
987    async fn test_initialize_sets_session_header() {
988        let state = test_state(McpConfig {
989            enabled: true,
990            ..Default::default()
991        });
992        let session = initialize_session(state).await;
993        assert!(!session.is_empty());
994    }
995
996    #[tokio::test]
997    async fn test_tools_list_requires_initialized_session() {
998        let state = test_state(McpConfig {
999            enabled: true,
1000            ..Default::default()
1001        });
1002
1003        let session_id = initialize_session(state.clone()).await;
1004
1005        let mut headers = HeaderMap::new();
1006        headers.insert(
1007            MCP_SESSION_HEADER,
1008            HeaderValue::from_str(&session_id).expect("valid"),
1009        );
1010        headers.insert(
1011            MCP_PROTOCOL_HEADER,
1012            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1013        );
1014
1015        let list_payload = serde_json::json!({
1016            "jsonrpc": "2.0",
1017            "id": 2,
1018            "method": "tools/list",
1019            "params": {}
1020        });
1021        let response = mcp_post_handler(
1022            State(state),
1023            Extension(AuthContext::unauthenticated()),
1024            Extension(TracingState::new()),
1025            Method::POST,
1026            headers,
1027            Json(list_payload),
1028        )
1029        .await;
1030
1031        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_tools_list_returns_registered_tools() {
1036        let mut registry = McpToolRegistry::new();
1037        registry.register::<EchoTool>();
1038
1039        let state = test_state_with_registry(
1040            McpConfig {
1041                enabled: true,
1042                ..Default::default()
1043            },
1044            registry,
1045        );
1046        let headers = initialized_headers(state.clone()).await;
1047        let payload = serde_json::json!({
1048            "jsonrpc": "2.0",
1049            "id": 2,
1050            "method": "tools/list",
1051            "params": {}
1052        });
1053
1054        let response = mcp_post_handler(
1055            State(state),
1056            Extension(AuthContext::unauthenticated()),
1057            Extension(TracingState::new()),
1058            Method::POST,
1059            headers,
1060            Json(payload),
1061        )
1062        .await;
1063
1064        assert_eq!(response.status(), StatusCode::OK);
1065        let body = response_json(response).await;
1066        let tools = body["result"]["tools"]
1067            .as_array()
1068            .expect("tools list should be array");
1069        assert_eq!(tools.len(), 1);
1070        assert_eq!(tools[0]["name"], "echo");
1071        assert!(tools[0].get("inputSchema").is_some());
1072        assert!(tools[0].get("outputSchema").is_some());
1073    }
1074
1075    #[tokio::test]
1076    async fn test_tools_list_exposes_parameter_metadata() {
1077        let mut registry = McpToolRegistry::new();
1078        registry.register::<MetadataTool>();
1079
1080        let state = test_state_with_registry(
1081            McpConfig {
1082                enabled: true,
1083                ..Default::default()
1084            },
1085            registry,
1086        );
1087        let headers = initialized_headers(state.clone()).await;
1088        let payload = serde_json::json!({
1089            "jsonrpc": "2.0",
1090            "id": 9,
1091            "method": "tools/list",
1092            "params": {}
1093        });
1094
1095        let response = mcp_post_handler(
1096            State(state),
1097            Extension(AuthContext::unauthenticated()),
1098            Extension(TracingState::new()),
1099            Method::POST,
1100            headers,
1101            Json(payload),
1102        )
1103        .await;
1104
1105        assert_eq!(response.status(), StatusCode::OK);
1106        let body = response_json(response).await;
1107        let tools = body["result"]["tools"]
1108            .as_array()
1109            .expect("tools list should be array");
1110        assert_eq!(tools.len(), 1);
1111
1112        let input_schema = &tools[0]["inputSchema"];
1113        assert_eq!(
1114            input_schema["properties"]["project_id"]["description"],
1115            "Project UUID to export"
1116        );
1117
1118        let schema_text = input_schema.to_string();
1119        assert!(schema_text.contains("\"json\""));
1120        assert!(schema_text.contains("\"csv\""));
1121    }
1122
1123    #[tokio::test]
1124    async fn test_tools_call_success_returns_structured_content() {
1125        let mut registry = McpToolRegistry::new();
1126        registry.register::<EchoTool>();
1127
1128        let state = test_state_with_registry(
1129            McpConfig {
1130                enabled: true,
1131                ..Default::default()
1132            },
1133            registry,
1134        );
1135        let headers = initialized_headers(state.clone()).await;
1136        let auth = AuthContext::authenticated(
1137            uuid::Uuid::new_v4(),
1138            vec!["member".to_string()],
1139            HashMap::new(),
1140        );
1141        let payload = serde_json::json!({
1142            "jsonrpc": "2.0",
1143            "id": 3,
1144            "method": "tools/call",
1145            "params": {
1146                "name": "echo",
1147                "arguments": { "message": "hello" }
1148            }
1149        });
1150
1151        let response = mcp_post_handler(
1152            State(state),
1153            Extension(auth),
1154            Extension(TracingState::new()),
1155            Method::POST,
1156            headers,
1157            Json(payload),
1158        )
1159        .await;
1160
1161        assert_eq!(response.status(), StatusCode::OK);
1162        let body = response_json(response).await;
1163        assert_eq!(body["result"]["structuredContent"]["echoed"], "hello");
1164        assert_eq!(body["result"]["content"][0]["type"], "text");
1165    }
1166
1167    #[tokio::test]
1168    async fn test_tools_call_validation_failure_returns_is_error() {
1169        let mut registry = McpToolRegistry::new();
1170        registry.register::<EchoTool>();
1171
1172        let state = test_state_with_registry(
1173            McpConfig {
1174                enabled: true,
1175                ..Default::default()
1176            },
1177            registry,
1178        );
1179        let headers = initialized_headers(state.clone()).await;
1180        let auth = AuthContext::authenticated(
1181            uuid::Uuid::new_v4(),
1182            vec!["member".to_string()],
1183            HashMap::new(),
1184        );
1185        let payload = serde_json::json!({
1186            "jsonrpc": "2.0",
1187            "id": 4,
1188            "method": "tools/call",
1189            "params": {
1190                "name": "echo",
1191                "arguments": {}
1192            }
1193        });
1194
1195        let response = mcp_post_handler(
1196            State(state),
1197            Extension(auth),
1198            Extension(TracingState::new()),
1199            Method::POST,
1200            headers,
1201            Json(payload),
1202        )
1203        .await;
1204
1205        assert_eq!(response.status(), StatusCode::OK);
1206        let body = response_json(response).await;
1207        assert_eq!(body["result"]["isError"], true);
1208    }
1209
1210    #[tokio::test]
1211    async fn test_tools_call_requires_authentication() {
1212        let mut registry = McpToolRegistry::new();
1213        registry.register::<EchoTool>();
1214
1215        let state = test_state_with_registry(
1216            McpConfig {
1217                enabled: true,
1218                ..Default::default()
1219            },
1220            registry,
1221        );
1222        let headers = initialized_headers(state.clone()).await;
1223        let payload = serde_json::json!({
1224            "jsonrpc": "2.0",
1225            "id": 5,
1226            "method": "tools/call",
1227            "params": {
1228                "name": "echo",
1229                "arguments": { "message": "hello" }
1230            }
1231        });
1232
1233        let response = mcp_post_handler(
1234            State(state),
1235            Extension(AuthContext::unauthenticated()),
1236            Extension(TracingState::new()),
1237            Method::POST,
1238            headers,
1239            Json(payload),
1240        )
1241        .await;
1242
1243        assert_eq!(response.status(), StatusCode::OK);
1244        let body = response_json(response).await;
1245        assert_eq!(body["error"]["code"], -32001);
1246    }
1247
1248    #[tokio::test]
1249    async fn test_tools_call_requires_role() {
1250        let mut registry = McpToolRegistry::new();
1251        registry.register::<AdminTool>();
1252
1253        let state = test_state_with_registry(
1254            McpConfig {
1255                enabled: true,
1256                ..Default::default()
1257            },
1258            registry,
1259        );
1260        let headers = initialized_headers(state.clone()).await;
1261        let auth = AuthContext::authenticated(
1262            uuid::Uuid::new_v4(),
1263            vec!["member".to_string()],
1264            HashMap::new(),
1265        );
1266        let payload = serde_json::json!({
1267            "jsonrpc": "2.0",
1268            "id": 6,
1269            "method": "tools/call",
1270            "params": {
1271                "name": "admin.echo",
1272                "arguments": { "message": "hello" }
1273            }
1274        });
1275
1276        let response = mcp_post_handler(
1277            State(state),
1278            Extension(auth),
1279            Extension(TracingState::new()),
1280            Method::POST,
1281            headers,
1282            Json(payload),
1283        )
1284        .await;
1285
1286        assert_eq!(response.status(), StatusCode::OK);
1287        let body = response_json(response).await;
1288        assert_eq!(body["error"]["code"], -32003);
1289    }
1290
1291    #[tokio::test]
1292    async fn test_invalid_protocol_header_returns_400() {
1293        let state = test_state(McpConfig {
1294            enabled: true,
1295            ..Default::default()
1296        });
1297        let session_id = initialize_session(state.clone()).await;
1298        let mut headers = HeaderMap::new();
1299        headers.insert(
1300            MCP_SESSION_HEADER,
1301            HeaderValue::from_str(&session_id).expect("valid"),
1302        );
1303        headers.insert(
1304            MCP_PROTOCOL_HEADER,
1305            HeaderValue::from_static("invalid-version"),
1306        );
1307
1308        let payload = serde_json::json!({
1309            "jsonrpc": "2.0",
1310            "id": 7,
1311            "method": "tools/list",
1312            "params": {}
1313        });
1314
1315        let response = mcp_post_handler(
1316            State(state),
1317            Extension(AuthContext::unauthenticated()),
1318            Extension(TracingState::new()),
1319            Method::POST,
1320            headers,
1321            Json(payload),
1322        )
1323        .await;
1324        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1325    }
1326
1327    #[tokio::test]
1328    async fn test_missing_protocol_header_returns_400() {
1329        let state = test_state(McpConfig {
1330            enabled: true,
1331            ..Default::default()
1332        });
1333        let session_id = initialize_session(state.clone()).await;
1334        let mut headers = HeaderMap::new();
1335        headers.insert(
1336            MCP_SESSION_HEADER,
1337            HeaderValue::from_str(&session_id).expect("valid"),
1338        );
1339
1340        let payload = serde_json::json!({
1341            "jsonrpc": "2.0",
1342            "id": 8,
1343            "method": "tools/list",
1344            "params": {}
1345        });
1346
1347        let response = mcp_post_handler(
1348            State(state),
1349            Extension(AuthContext::unauthenticated()),
1350            Extension(TracingState::new()),
1351            Method::POST,
1352            headers,
1353            Json(payload),
1354        )
1355        .await;
1356        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1357    }
1358
1359    #[tokio::test]
1360    async fn test_notifications_return_202() {
1361        let state = test_state(McpConfig {
1362            enabled: true,
1363            ..Default::default()
1364        });
1365        let mut headers = HeaderMap::new();
1366        headers.insert(
1367            MCP_PROTOCOL_HEADER,
1368            HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1369        );
1370        let payload = serde_json::json!({
1371            "jsonrpc": "2.0",
1372            "method": "notifications/tools/list_changed",
1373            "params": {}
1374        });
1375        let response = mcp_post_handler(
1376            State(state),
1377            Extension(AuthContext::unauthenticated()),
1378            Extension(TracingState::new()),
1379            Method::POST,
1380            headers,
1381            Json(payload),
1382        )
1383        .await;
1384        assert_eq!(response.status(), StatusCode::ACCEPTED);
1385    }
1386
1387    #[tokio::test]
1388    async fn test_invalid_origin_rejected() {
1389        let state = test_state(McpConfig {
1390            enabled: true,
1391            allowed_origins: vec!["https://allowed.example".to_string()],
1392            ..Default::default()
1393        });
1394        let payload = serde_json::json!({
1395            "jsonrpc": "2.0",
1396            "id": 1,
1397            "method": "initialize",
1398            "params": {
1399                "protocolVersion": "2025-11-25",
1400                "capabilities": {},
1401                "clientInfo": { "name": "test", "version": "1.0.0" }
1402            }
1403        });
1404
1405        let mut headers = HeaderMap::new();
1406        headers.insert("origin", HeaderValue::from_static("https://evil.example"));
1407
1408        let response = mcp_post_handler(
1409            State(state),
1410            Extension(AuthContext::unauthenticated()),
1411            Extension(TracingState::new()),
1412            Method::POST,
1413            headers,
1414            Json(payload),
1415        )
1416        .await;
1417
1418        assert_eq!(response.status(), StatusCode::FORBIDDEN);
1419    }
1420}