1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use axum::extract::{Extension, State};
6use axum::http::header::{HeaderName, HeaderValue};
7use axum::http::{HeaderMap, Method, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::{Json, body::Body};
10use forge_core::config::McpConfig;
11use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
12use forge_core::mcp::McpToolContext;
13use forge_core::rate_limit::RateLimitKey;
14use serde_json::Value;
15use tokio::sync::RwLock;
16
17use crate::mcp::McpToolRegistry;
18use crate::rate_limit::RateLimiter;
19
20const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
21const MCP_SESSION_HEADER: &str = "mcp-session-id";
22const MCP_PROTOCOL_HEADER: &str = "mcp-protocol-version";
23const DEFAULT_PAGE_SIZE: usize = 50;
24type ResponseError = Box<Response>;
25
26#[derive(Debug, Clone)]
27struct McpSession {
28 initialized: bool,
29 protocol_version: String,
30 expires_at: Instant,
31}
32
33#[derive(Clone)]
34pub struct McpState {
35 config: McpConfig,
36 registry: McpToolRegistry,
37 pool: sqlx::PgPool,
38 sessions: Arc<RwLock<HashMap<String, McpSession>>>,
39 job_dispatcher: Option<Arc<dyn JobDispatch>>,
40 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41 rate_limiter: Arc<RateLimiter>,
42}
43
44impl McpState {
45 pub fn new(
46 config: McpConfig,
47 registry: McpToolRegistry,
48 pool: sqlx::PgPool,
49 job_dispatcher: Option<Arc<dyn JobDispatch>>,
50 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
51 ) -> Self {
52 Self {
53 config,
54 registry,
55 pool: pool.clone(),
56 sessions: Arc::new(RwLock::new(HashMap::new())),
57 job_dispatcher,
58 workflow_dispatcher,
59 rate_limiter: Arc::new(RateLimiter::new(pool)),
60 }
61 }
62
63 async fn cleanup_expired_sessions(&self) {
64 let mut sessions = self.sessions.write().await;
65 let now = Instant::now();
66 sessions.retain(|_, session| session.expires_at > now);
67 }
68
69 async fn touch_session(&self, session_id: &str) {
70 let mut sessions = self.sessions.write().await;
71 if let Some(session) = sessions.get_mut(session_id) {
72 session.expires_at = Instant::now() + Duration::from_secs(self.config.session_ttl_secs);
73 }
74 }
75
76 fn session_ttl(&self) -> Duration {
77 Duration::from_secs(self.config.session_ttl_secs)
78 }
79}
80
81pub async fn mcp_get_handler(State(state): State<Arc<McpState>>, headers: HeaderMap) -> Response {
82 if let Err(resp) = validate_origin(&headers, &state.config) {
83 return *resp;
84 }
85
86 (
87 StatusCode::METHOD_NOT_ALLOWED,
88 Json(json_rpc_error(
89 None,
90 -32601,
91 "GET stream is not supported in Forge MCP v1",
92 None,
93 )),
94 )
95 .into_response()
96}
97
98pub async fn mcp_post_handler(
99 State(state): State<Arc<McpState>>,
100 Extension(auth): Extension<AuthContext>,
101 Extension(tracing): Extension<super::tracing::TracingState>,
102 method: Method,
103 headers: HeaderMap,
104 Json(payload): Json<Value>,
105) -> Response {
106 if method != Method::POST {
107 return (
108 StatusCode::METHOD_NOT_ALLOWED,
109 Json(json_rpc_error(None, -32601, "Only POST is supported", None)),
110 )
111 .into_response();
112 }
113
114 if let Err(resp) = validate_origin(&headers, &state.config) {
115 return *resp;
116 }
117
118 state.cleanup_expired_sessions().await;
119
120 let Some(method_name) = payload.get("method").and_then(Value::as_str) else {
121 if payload.get("id").is_some()
123 && (payload.get("result").is_some() || payload.get("error").is_some())
124 {
125 return StatusCode::ACCEPTED.into_response();
126 }
127 return (
128 StatusCode::BAD_REQUEST,
129 Json(json_rpc_error(
130 None,
131 -32600,
132 "Invalid JSON-RPC payload",
133 None,
134 )),
135 )
136 .into_response();
137 };
138
139 let id = payload.get("id").cloned();
140 let params = payload
141 .get("params")
142 .cloned()
143 .unwrap_or(Value::Object(Default::default()));
144
145 if id.is_none() {
147 return handle_notification(&state, method_name, params, &headers).await;
148 }
149
150 if method_name != "initialize"
152 && let Err(resp) = enforce_protocol_header(&state.config, &headers)
153 {
154 return *resp;
155 }
156
157 match method_name {
158 "initialize" => handle_initialize(&state, id, ¶ms).await,
159 "tools/list" => {
160 let session_id = match required_session_id(&state, &headers, true).await {
161 Ok(v) => v,
162 Err(resp) => return resp,
163 };
164 state.touch_session(&session_id).await;
165 handle_tools_list(&state, id, ¶ms)
166 }
167 "tools/call" => {
168 let session_id = match required_session_id(&state, &headers, true).await {
169 Ok(v) => v,
170 Err(resp) => return resp,
171 };
172 state.touch_session(&session_id).await;
173
174 let metadata = build_request_metadata(&tracing, &headers);
175 handle_tools_call(&state, id, ¶ms, &auth, metadata).await
176 }
177 _ => (
178 StatusCode::OK,
179 Json(json_rpc_error(id, -32601, "Method not found", None)),
180 )
181 .into_response(),
182 }
183}
184
185async fn handle_notification(
186 state: &Arc<McpState>,
187 method_name: &str,
188 _params: Value,
189 headers: &HeaderMap,
190) -> Response {
191 if let Err(resp) = enforce_protocol_header(&state.config, headers) {
192 return *resp;
193 }
194
195 match method_name {
196 "notifications/initialized" => {
197 let session_id = match required_session_id(state, headers, false).await {
198 Ok(v) => v,
199 Err(resp) => return resp,
200 };
201
202 let mut sessions = state.sessions.write().await;
203 if let Some(session) = sessions.get_mut(&session_id) {
204 session.initialized = true;
205 session.expires_at = Instant::now() + state.session_ttl();
206 return StatusCode::ACCEPTED.into_response();
207 }
208
209 (
210 StatusCode::BAD_REQUEST,
211 Json(json_rpc_error(
212 None,
213 -32600,
214 "Unknown MCP session. Re-initialize the connection.",
215 None,
216 )),
217 )
218 .into_response()
219 }
220 _ => StatusCode::ACCEPTED.into_response(),
221 }
222}
223
224async fn handle_initialize(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
225 let Some(requested_version) = params.get("protocolVersion").and_then(Value::as_str) else {
226 return (
227 StatusCode::OK,
228 Json(json_rpc_error(
229 id,
230 -32602,
231 "Missing protocolVersion in initialize params",
232 None,
233 )),
234 )
235 .into_response();
236 };
237
238 if requested_version != MCP_PROTOCOL_VERSION {
239 return (
240 StatusCode::OK,
241 Json(json_rpc_error(
242 id,
243 -32602,
244 "Unsupported protocolVersion",
245 Some(serde_json::json!({
246 "supported": MCP_PROTOCOL_VERSION
247 })),
248 )),
249 )
250 .into_response();
251 }
252
253 let session_id = uuid::Uuid::new_v4().to_string();
254 {
255 let mut sessions = state.sessions.write().await;
256 sessions.insert(
257 session_id.clone(),
258 McpSession {
259 initialized: false,
260 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
261 expires_at: Instant::now() + state.session_ttl(),
262 },
263 );
264 }
265
266 let mut response = (
267 StatusCode::OK,
268 Json(json_rpc_success(
269 id,
270 serde_json::json!({
271 "protocolVersion": MCP_PROTOCOL_VERSION,
272 "capabilities": {
273 "tools": {}
274 },
275 "serverInfo": {
276 "name": "forge",
277 "version": env!("CARGO_PKG_VERSION")
278 }
279 }),
280 )),
281 )
282 .into_response();
283
284 set_header(&mut response, MCP_SESSION_HEADER, &session_id);
285 set_header(&mut response, MCP_PROTOCOL_HEADER, MCP_PROTOCOL_VERSION);
286 response
287}
288
289fn handle_tools_list(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
290 let cursor = params.get("cursor").and_then(Value::as_str);
291 let start = match cursor {
292 Some(c) => match c.parse::<usize>() {
293 Ok(v) => v,
294 Err(_) => {
295 return (
296 StatusCode::OK,
297 Json(json_rpc_error(
298 id,
299 -32602,
300 "Invalid cursor in tools/list request",
301 None,
302 )),
303 )
304 .into_response();
305 }
306 },
307 None => 0,
308 };
309
310 let mut tools: Vec<_> = state.registry.list().collect();
311 tools.sort_by(|a, b| a.info.name.cmp(b.info.name));
312
313 let page: Vec<_> = tools
314 .iter()
315 .skip(start)
316 .take(DEFAULT_PAGE_SIZE)
317 .map(|entry| {
318 let annotations = serde_json::json!({
319 "title": entry.info.annotations.title,
320 "readOnlyHint": entry.info.annotations.read_only_hint,
321 "destructiveHint": entry.info.annotations.destructive_hint,
322 "idempotentHint": entry.info.annotations.idempotent_hint,
323 "openWorldHint": entry.info.annotations.open_world_hint
324 });
325
326 let icons: Vec<_> = entry
327 .info
328 .icons
329 .iter()
330 .map(|icon| {
331 serde_json::json!({
332 "src": icon.src,
333 "mimeType": icon.mime_type,
334 "sizes": icon.sizes,
335 "theme": icon.theme
336 })
337 })
338 .collect();
339
340 let mut value = serde_json::json!({
341 "name": entry.info.name,
342 "title": entry.info.title,
343 "description": entry.info.description,
344 "inputSchema": entry.input_schema,
345 "annotations": annotations,
346 "icons": icons
347 });
348 if let Some(output_schema) = &entry.output_schema
349 && let Some(obj) = value.as_object_mut()
350 {
351 obj.insert("outputSchema".to_string(), output_schema.clone());
352 }
353 value
354 })
355 .collect();
356
357 let end = start.saturating_add(page.len());
358 let next_cursor = if end < tools.len() {
359 Some(end.to_string())
360 } else {
361 None
362 };
363
364 (
365 StatusCode::OK,
366 Json(json_rpc_success(
367 id,
368 serde_json::json!({
369 "tools": page,
370 "nextCursor": next_cursor
371 }),
372 )),
373 )
374 .into_response()
375}
376
377async fn handle_tools_call(
378 state: &Arc<McpState>,
379 id: Option<Value>,
380 params: &Value,
381 auth: &AuthContext,
382 request_metadata: RequestMetadata,
383) -> Response {
384 let Some(tool_name) = params.get("name").and_then(Value::as_str) else {
385 return (
386 StatusCode::OK,
387 Json(json_rpc_error(id, -32602, "Missing tool name", None)),
388 )
389 .into_response();
390 };
391
392 let Some(entry) = state.registry.get(tool_name) else {
393 return (
394 StatusCode::OK,
395 Json(json_rpc_error(id, -32602, "Unknown tool", None)),
396 )
397 .into_response();
398 };
399
400 if !entry.info.is_public && !auth.is_authenticated() {
401 return (
402 StatusCode::OK,
403 Json(json_rpc_error(id, -32001, "Authentication required", None)),
404 )
405 .into_response();
406 }
407 if let Some(role) = entry.info.required_role
408 && !auth.has_role(role)
409 {
410 return (
411 StatusCode::OK,
412 Json(json_rpc_error(
413 id,
414 -32003,
415 format!("Role '{}' required", role),
416 None,
417 )),
418 )
419 .into_response();
420 }
421
422 if let (Some(requests), Some(per_secs)) = (
423 entry.info.rate_limit_requests,
424 entry.info.rate_limit_per_secs,
425 ) {
426 let key_type = entry
427 .info
428 .rate_limit_key
429 .and_then(|k| k.parse::<RateLimitKey>().ok())
430 .unwrap_or_default();
431
432 let config = forge_core::RateLimitConfig::new(requests, Duration::from_secs(per_secs))
433 .with_key(key_type);
434 let bucket_key = state
435 .rate_limiter
436 .build_key(key_type, tool_name, auth, &request_metadata);
437
438 if let Err(e) = state.rate_limiter.enforce(&bucket_key, &config).await {
439 return (
440 StatusCode::OK,
441 Json(json_rpc_error(id, -32029, e.to_string(), None)),
442 )
443 .into_response();
444 }
445 }
446
447 let args = params
448 .get("arguments")
449 .cloned()
450 .unwrap_or(Value::Object(Default::default()));
451
452 let ctx = McpToolContext::with_dispatch(
453 state.pool.clone(),
454 auth.clone(),
455 request_metadata,
456 state.job_dispatcher.clone(),
457 state.workflow_dispatcher.clone(),
458 );
459
460 let result = if let Some(timeout_secs) = entry.info.timeout {
461 match tokio::time::timeout(
462 Duration::from_secs(timeout_secs),
463 (entry.handler)(&ctx, args),
464 )
465 .await
466 {
467 Ok(inner) => inner,
468 Err(_) => {
469 return (
470 StatusCode::OK,
471 Json(json_rpc_error(id, -32000, "Tool timed out", None)),
472 )
473 .into_response();
474 }
475 }
476 } else {
477 (entry.handler)(&ctx, args).await
478 };
479
480 match result {
481 Ok(output) => {
482 let result = tool_success_result(output);
483 (
484 StatusCode::OK,
485 Json(json_rpc_success(id, serde_json::json!(result))),
486 )
487 .into_response()
488 }
489 Err(e) => match e {
490 forge_core::ForgeError::Validation(msg)
491 | forge_core::ForgeError::InvalidArgument(msg) => (
492 StatusCode::OK,
493 Json(json_rpc_success(
494 id,
495 serde_json::json!({
496 "content": [{ "type": "text", "text": msg }],
497 "isError": true
498 }),
499 )),
500 )
501 .into_response(),
502 forge_core::ForgeError::Unauthorized(msg) => {
503 (StatusCode::OK, Json(json_rpc_error(id, -32001, msg, None))).into_response()
504 }
505 forge_core::ForgeError::Forbidden(msg) => {
506 (StatusCode::OK, Json(json_rpc_error(id, -32003, msg, None))).into_response()
507 }
508 _ => (
509 StatusCode::OK,
510 Json(json_rpc_error(id, -32603, "Internal server error", None)),
511 )
512 .into_response(),
513 },
514 }
515}
516
517fn tool_success_result(output: Value) -> Value {
518 match output {
519 Value::Object(_) => serde_json::json!({
520 "content": [{
521 "type": "text",
522 "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string())
523 }],
524 "structuredContent": output
525 }),
526 Value::String(text) => serde_json::json!({
527 "content": [{ "type": "text", "text": text }]
528 }),
529 other => serde_json::json!({
530 "content": [{
531 "type": "text",
532 "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string())
533 }]
534 }),
535 }
536}
537
538async fn required_session_id(
539 state: &Arc<McpState>,
540 headers: &HeaderMap,
541 require_initialized: bool,
542) -> std::result::Result<String, Response> {
543 let Some(session_id) = headers
544 .get(MCP_SESSION_HEADER)
545 .and_then(|v| v.to_str().ok())
546 else {
547 return Err((
548 StatusCode::BAD_REQUEST,
549 Json(json_rpc_error(
550 None,
551 -32600,
552 "Missing MCP-Session-Id header",
553 None,
554 )),
555 )
556 .into_response());
557 };
558
559 let sessions = state.sessions.read().await;
560 match sessions.get(session_id) {
561 Some(session) => {
562 if session.protocol_version != MCP_PROTOCOL_VERSION {
563 return Err((
564 StatusCode::BAD_REQUEST,
565 Json(json_rpc_error(
566 None,
567 -32600,
568 "Session protocol version mismatch",
569 None,
570 )),
571 )
572 .into_response());
573 }
574 if require_initialized && !session.initialized {
575 return Err((
576 StatusCode::BAD_REQUEST,
577 Json(json_rpc_error(
578 None,
579 -32600,
580 "MCP session is not initialized",
581 None,
582 )),
583 )
584 .into_response());
585 }
586 Ok(session_id.to_string())
587 }
588 None => Err((
589 StatusCode::BAD_REQUEST,
590 Json(json_rpc_error(
591 None,
592 -32600,
593 "Unknown MCP session. Re-initialize.",
594 None,
595 )),
596 )
597 .into_response()),
598 }
599}
600
601fn validate_origin(
602 headers: &HeaderMap,
603 config: &McpConfig,
604) -> std::result::Result<(), ResponseError> {
605 let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else {
606 return Ok(());
607 };
608
609 if config.allowed_origins.is_empty() {
610 return Ok(());
611 }
612
613 let allowed = config
614 .allowed_origins
615 .iter()
616 .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin));
617 if allowed {
618 return Ok(());
619 }
620
621 Err(Box::new(
622 (
623 StatusCode::FORBIDDEN,
624 Json(json_rpc_error(None, -32600, "Invalid Origin header", None)),
625 )
626 .into_response(),
627 ))
628}
629
630fn enforce_protocol_header(
631 config: &McpConfig,
632 headers: &HeaderMap,
633) -> std::result::Result<(), ResponseError> {
634 if !config.require_protocol_version_header {
635 return Ok(());
636 }
637
638 let Some(version) = headers
639 .get(MCP_PROTOCOL_HEADER)
640 .and_then(|v| v.to_str().ok())
641 else {
642 return Err(Box::new(
643 (
644 StatusCode::BAD_REQUEST,
645 Json(json_rpc_error(
646 None,
647 -32600,
648 "Missing MCP-Protocol-Version header",
649 None,
650 )),
651 )
652 .into_response(),
653 ));
654 };
655
656 if version != MCP_PROTOCOL_VERSION {
657 return Err(Box::new(
658 (
659 StatusCode::BAD_REQUEST,
660 Json(json_rpc_error(
661 None,
662 -32600,
663 "Unsupported MCP-Protocol-Version",
664 Some(serde_json::json!({ "supported": MCP_PROTOCOL_VERSION })),
665 )),
666 )
667 .into_response(),
668 ));
669 }
670
671 Ok(())
672}
673
674fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
675 headers
676 .get("x-forwarded-for")
677 .and_then(|v| v.to_str().ok())
678 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
679 .filter(|s| !s.is_empty())
680 .or_else(|| {
681 headers
682 .get("x-real-ip")
683 .and_then(|v| v.to_str().ok())
684 .map(|s| s.trim().to_string())
685 .filter(|s| !s.is_empty())
686 })
687}
688
689fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
690 headers
691 .get(axum::http::header::USER_AGENT)
692 .and_then(|v| v.to_str().ok())
693 .map(String::from)
694}
695
696fn build_request_metadata(
697 tracing: &super::tracing::TracingState,
698 headers: &HeaderMap,
699) -> RequestMetadata {
700 RequestMetadata {
701 request_id: uuid::Uuid::parse_str(&tracing.request_id)
702 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
703 trace_id: tracing.trace_id.clone(),
704 client_ip: extract_client_ip(headers),
705 user_agent: extract_user_agent(headers),
706 timestamp: chrono::Utc::now(),
707 }
708}
709
710fn json_rpc_success(id: Option<Value>, result: Value) -> Value {
711 serde_json::json!({
712 "jsonrpc": "2.0",
713 "id": id.unwrap_or(Value::Null),
714 "result": result
715 })
716}
717
718fn json_rpc_error(
719 id: Option<Value>,
720 code: i32,
721 message: impl Into<String>,
722 data: Option<Value>,
723) -> Value {
724 let mut error = serde_json::json!({
725 "code": code,
726 "message": message.into()
727 });
728 if let Some(data) = data
729 && let Some(obj) = error.as_object_mut()
730 {
731 obj.insert("data".to_string(), data);
732 }
733
734 serde_json::json!({
735 "jsonrpc": "2.0",
736 "id": id.unwrap_or(Value::Null),
737 "error": error
738 })
739}
740
741fn set_header(response: &mut Response<Body>, name: &str, value: &str) {
742 if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::from_str(value)) {
743 response.headers_mut().insert(name, value);
744 }
745}
746
747#[cfg(test)]
748#[allow(clippy::expect_used, clippy::indexing_slicing, clippy::unwrap_used)]
749mod tests {
750 use super::super::tracing::TracingState;
751 use super::*;
752 use axum::body::to_bytes;
753 use forge_core::function::AuthContext;
754 use forge_core::mcp::{ForgeMcpTool, McpToolAnnotations, McpToolInfo};
755 use forge_core::schemars::{self, JsonSchema};
756 use serde::{Deserialize, Serialize};
757 use std::collections::HashMap;
758 use std::future::Future;
759 use std::pin::Pin;
760
761 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
762 struct EchoArgs {
763 message: String,
764 }
765
766 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
767 struct EchoOutput {
768 echoed: String,
769 }
770
771 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
772 #[serde(rename_all = "snake_case")]
773 enum ExportFormat {
774 Json,
775 Csv,
776 }
777
778 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
779 struct MetadataArgs {
780 #[schemars(description = "Project UUID to export")]
781 project_id: String,
782 format: ExportFormat,
783 }
784
785 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
786 struct MetadataOutput {
787 accepted: bool,
788 }
789
790 struct EchoTool;
791
792 impl ForgeMcpTool for EchoTool {
793 type Args = EchoArgs;
794 type Output = EchoOutput;
795
796 fn info() -> McpToolInfo {
797 McpToolInfo {
798 name: "echo",
799 title: Some("Echo"),
800 description: Some("Echo back the message"),
801 required_role: None,
802 is_public: false,
803 timeout: None,
804 rate_limit_requests: None,
805 rate_limit_per_secs: None,
806 rate_limit_key: None,
807 annotations: McpToolAnnotations::default(),
808 icons: &[],
809 }
810 }
811
812 fn execute(
813 _ctx: &McpToolContext,
814 args: Self::Args,
815 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
816 Box::pin(async move {
817 Ok(EchoOutput {
818 echoed: args.message,
819 })
820 })
821 }
822 }
823
824 struct AdminTool;
825
826 impl ForgeMcpTool for AdminTool {
827 type Args = EchoArgs;
828 type Output = EchoOutput;
829
830 fn info() -> McpToolInfo {
831 McpToolInfo {
832 name: "admin.echo",
833 title: Some("Admin Echo"),
834 description: Some("Admin only echo"),
835 required_role: Some("admin"),
836 is_public: false,
837 timeout: None,
838 rate_limit_requests: None,
839 rate_limit_per_secs: None,
840 rate_limit_key: None,
841 annotations: McpToolAnnotations::default(),
842 icons: &[],
843 }
844 }
845
846 fn execute(
847 _ctx: &McpToolContext,
848 args: Self::Args,
849 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
850 Box::pin(async move {
851 Ok(EchoOutput {
852 echoed: args.message,
853 })
854 })
855 }
856 }
857
858 struct MetadataTool;
859
860 impl ForgeMcpTool for MetadataTool {
861 type Args = MetadataArgs;
862 type Output = MetadataOutput;
863
864 fn info() -> McpToolInfo {
865 McpToolInfo {
866 name: "export.project",
867 title: Some("Export Project"),
868 description: Some("Export project data"),
869 required_role: None,
870 is_public: false,
871 timeout: None,
872 rate_limit_requests: None,
873 rate_limit_per_secs: None,
874 rate_limit_key: None,
875 annotations: McpToolAnnotations::default(),
876 icons: &[],
877 }
878 }
879
880 fn execute(
881 _ctx: &McpToolContext,
882 _args: Self::Args,
883 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
884 Box::pin(async move { Ok(MetadataOutput { accepted: true }) })
885 }
886 }
887
888 #[test]
889 fn test_json_rpc_helpers() {
890 let success = json_rpc_success(
891 Some(serde_json::json!(1)),
892 serde_json::json!({ "ok": true }),
893 );
894 assert_eq!(success["jsonrpc"], "2.0");
895 assert!(success.get("result").is_some());
896
897 let err = json_rpc_error(Some(serde_json::json!(1)), -32601, "not found", None);
898 assert_eq!(err["error"]["code"], -32601);
899 }
900
901 fn test_state(config: McpConfig) -> Arc<McpState> {
902 test_state_with_registry(config, McpToolRegistry::new())
903 }
904
905 fn test_state_with_registry(config: McpConfig, registry: McpToolRegistry) -> Arc<McpState> {
906 let pool = sqlx::postgres::PgPoolOptions::new()
907 .max_connections(1)
908 .connect_lazy("postgres://localhost/nonexistent")
909 .expect("lazy pool must build");
910 Arc::new(McpState::new(config, registry, pool, None, None))
911 }
912
913 async fn response_json(response: Response) -> Value {
914 let bytes = to_bytes(response.into_body(), usize::MAX)
915 .await
916 .expect("body bytes");
917 if bytes.is_empty() {
918 return serde_json::json!({});
919 }
920 serde_json::from_slice(&bytes).expect("valid json")
921 }
922
923 async fn initialize_session(state: Arc<McpState>) -> String {
924 let payload = serde_json::json!({
925 "jsonrpc": "2.0",
926 "id": 1,
927 "method": "initialize",
928 "params": {
929 "protocolVersion": "2025-11-25",
930 "capabilities": {},
931 "clientInfo": { "name": "test", "version": "1.0.0" }
932 }
933 });
934 let response = mcp_post_handler(
935 State(state),
936 Extension(AuthContext::unauthenticated()),
937 Extension(TracingState::new()),
938 Method::POST,
939 HeaderMap::new(),
940 Json(payload),
941 )
942 .await;
943
944 assert_eq!(response.status(), StatusCode::OK);
945 response
946 .headers()
947 .get(MCP_SESSION_HEADER)
948 .and_then(|v| v.to_str().ok())
949 .expect("session id must exist")
950 .to_string()
951 }
952
953 async fn mark_initialized(state: Arc<McpState>, headers: HeaderMap) {
954 let payload = serde_json::json!({
955 "jsonrpc": "2.0",
956 "method": "notifications/initialized",
957 "params": {}
958 });
959 let response = mcp_post_handler(
960 State(state),
961 Extension(AuthContext::unauthenticated()),
962 Extension(TracingState::new()),
963 Method::POST,
964 headers,
965 Json(payload),
966 )
967 .await;
968 assert_eq!(response.status(), StatusCode::ACCEPTED);
969 }
970
971 async fn initialized_headers(state: Arc<McpState>) -> HeaderMap {
972 let session_id = initialize_session(state.clone()).await;
973 let mut headers = HeaderMap::new();
974 headers.insert(
975 MCP_SESSION_HEADER,
976 HeaderValue::from_str(&session_id).expect("valid session id header"),
977 );
978 headers.insert(
979 MCP_PROTOCOL_HEADER,
980 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
981 );
982 mark_initialized(state, headers.clone()).await;
983 headers
984 }
985
986 #[tokio::test]
987 async fn test_initialize_sets_session_header() {
988 let state = test_state(McpConfig {
989 enabled: true,
990 ..Default::default()
991 });
992 let session = initialize_session(state).await;
993 assert!(!session.is_empty());
994 }
995
996 #[tokio::test]
997 async fn test_tools_list_requires_initialized_session() {
998 let state = test_state(McpConfig {
999 enabled: true,
1000 ..Default::default()
1001 });
1002
1003 let session_id = initialize_session(state.clone()).await;
1004
1005 let mut headers = HeaderMap::new();
1006 headers.insert(
1007 MCP_SESSION_HEADER,
1008 HeaderValue::from_str(&session_id).expect("valid"),
1009 );
1010 headers.insert(
1011 MCP_PROTOCOL_HEADER,
1012 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1013 );
1014
1015 let list_payload = serde_json::json!({
1016 "jsonrpc": "2.0",
1017 "id": 2,
1018 "method": "tools/list",
1019 "params": {}
1020 });
1021 let response = mcp_post_handler(
1022 State(state),
1023 Extension(AuthContext::unauthenticated()),
1024 Extension(TracingState::new()),
1025 Method::POST,
1026 headers,
1027 Json(list_payload),
1028 )
1029 .await;
1030
1031 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1032 }
1033
1034 #[tokio::test]
1035 async fn test_tools_list_returns_registered_tools() {
1036 let mut registry = McpToolRegistry::new();
1037 registry.register::<EchoTool>();
1038
1039 let state = test_state_with_registry(
1040 McpConfig {
1041 enabled: true,
1042 ..Default::default()
1043 },
1044 registry,
1045 );
1046 let headers = initialized_headers(state.clone()).await;
1047 let payload = serde_json::json!({
1048 "jsonrpc": "2.0",
1049 "id": 2,
1050 "method": "tools/list",
1051 "params": {}
1052 });
1053
1054 let response = mcp_post_handler(
1055 State(state),
1056 Extension(AuthContext::unauthenticated()),
1057 Extension(TracingState::new()),
1058 Method::POST,
1059 headers,
1060 Json(payload),
1061 )
1062 .await;
1063
1064 assert_eq!(response.status(), StatusCode::OK);
1065 let body = response_json(response).await;
1066 let tools = body["result"]["tools"]
1067 .as_array()
1068 .expect("tools list should be array");
1069 assert_eq!(tools.len(), 1);
1070 assert_eq!(tools[0]["name"], "echo");
1071 assert!(tools[0].get("inputSchema").is_some());
1072 assert!(tools[0].get("outputSchema").is_some());
1073 }
1074
1075 #[tokio::test]
1076 async fn test_tools_list_exposes_parameter_metadata() {
1077 let mut registry = McpToolRegistry::new();
1078 registry.register::<MetadataTool>();
1079
1080 let state = test_state_with_registry(
1081 McpConfig {
1082 enabled: true,
1083 ..Default::default()
1084 },
1085 registry,
1086 );
1087 let headers = initialized_headers(state.clone()).await;
1088 let payload = serde_json::json!({
1089 "jsonrpc": "2.0",
1090 "id": 9,
1091 "method": "tools/list",
1092 "params": {}
1093 });
1094
1095 let response = mcp_post_handler(
1096 State(state),
1097 Extension(AuthContext::unauthenticated()),
1098 Extension(TracingState::new()),
1099 Method::POST,
1100 headers,
1101 Json(payload),
1102 )
1103 .await;
1104
1105 assert_eq!(response.status(), StatusCode::OK);
1106 let body = response_json(response).await;
1107 let tools = body["result"]["tools"]
1108 .as_array()
1109 .expect("tools list should be array");
1110 assert_eq!(tools.len(), 1);
1111
1112 let input_schema = &tools[0]["inputSchema"];
1113 assert_eq!(
1114 input_schema["properties"]["project_id"]["description"],
1115 "Project UUID to export"
1116 );
1117
1118 let schema_text = input_schema.to_string();
1119 assert!(schema_text.contains("\"json\""));
1120 assert!(schema_text.contains("\"csv\""));
1121 }
1122
1123 #[tokio::test]
1124 async fn test_tools_call_success_returns_structured_content() {
1125 let mut registry = McpToolRegistry::new();
1126 registry.register::<EchoTool>();
1127
1128 let state = test_state_with_registry(
1129 McpConfig {
1130 enabled: true,
1131 ..Default::default()
1132 },
1133 registry,
1134 );
1135 let headers = initialized_headers(state.clone()).await;
1136 let auth = AuthContext::authenticated(
1137 uuid::Uuid::new_v4(),
1138 vec!["member".to_string()],
1139 HashMap::new(),
1140 );
1141 let payload = serde_json::json!({
1142 "jsonrpc": "2.0",
1143 "id": 3,
1144 "method": "tools/call",
1145 "params": {
1146 "name": "echo",
1147 "arguments": { "message": "hello" }
1148 }
1149 });
1150
1151 let response = mcp_post_handler(
1152 State(state),
1153 Extension(auth),
1154 Extension(TracingState::new()),
1155 Method::POST,
1156 headers,
1157 Json(payload),
1158 )
1159 .await;
1160
1161 assert_eq!(response.status(), StatusCode::OK);
1162 let body = response_json(response).await;
1163 assert_eq!(body["result"]["structuredContent"]["echoed"], "hello");
1164 assert_eq!(body["result"]["content"][0]["type"], "text");
1165 }
1166
1167 #[tokio::test]
1168 async fn test_tools_call_validation_failure_returns_is_error() {
1169 let mut registry = McpToolRegistry::new();
1170 registry.register::<EchoTool>();
1171
1172 let state = test_state_with_registry(
1173 McpConfig {
1174 enabled: true,
1175 ..Default::default()
1176 },
1177 registry,
1178 );
1179 let headers = initialized_headers(state.clone()).await;
1180 let auth = AuthContext::authenticated(
1181 uuid::Uuid::new_v4(),
1182 vec!["member".to_string()],
1183 HashMap::new(),
1184 );
1185 let payload = serde_json::json!({
1186 "jsonrpc": "2.0",
1187 "id": 4,
1188 "method": "tools/call",
1189 "params": {
1190 "name": "echo",
1191 "arguments": {}
1192 }
1193 });
1194
1195 let response = mcp_post_handler(
1196 State(state),
1197 Extension(auth),
1198 Extension(TracingState::new()),
1199 Method::POST,
1200 headers,
1201 Json(payload),
1202 )
1203 .await;
1204
1205 assert_eq!(response.status(), StatusCode::OK);
1206 let body = response_json(response).await;
1207 assert_eq!(body["result"]["isError"], true);
1208 }
1209
1210 #[tokio::test]
1211 async fn test_tools_call_requires_authentication() {
1212 let mut registry = McpToolRegistry::new();
1213 registry.register::<EchoTool>();
1214
1215 let state = test_state_with_registry(
1216 McpConfig {
1217 enabled: true,
1218 ..Default::default()
1219 },
1220 registry,
1221 );
1222 let headers = initialized_headers(state.clone()).await;
1223 let payload = serde_json::json!({
1224 "jsonrpc": "2.0",
1225 "id": 5,
1226 "method": "tools/call",
1227 "params": {
1228 "name": "echo",
1229 "arguments": { "message": "hello" }
1230 }
1231 });
1232
1233 let response = mcp_post_handler(
1234 State(state),
1235 Extension(AuthContext::unauthenticated()),
1236 Extension(TracingState::new()),
1237 Method::POST,
1238 headers,
1239 Json(payload),
1240 )
1241 .await;
1242
1243 assert_eq!(response.status(), StatusCode::OK);
1244 let body = response_json(response).await;
1245 assert_eq!(body["error"]["code"], -32001);
1246 }
1247
1248 #[tokio::test]
1249 async fn test_tools_call_requires_role() {
1250 let mut registry = McpToolRegistry::new();
1251 registry.register::<AdminTool>();
1252
1253 let state = test_state_with_registry(
1254 McpConfig {
1255 enabled: true,
1256 ..Default::default()
1257 },
1258 registry,
1259 );
1260 let headers = initialized_headers(state.clone()).await;
1261 let auth = AuthContext::authenticated(
1262 uuid::Uuid::new_v4(),
1263 vec!["member".to_string()],
1264 HashMap::new(),
1265 );
1266 let payload = serde_json::json!({
1267 "jsonrpc": "2.0",
1268 "id": 6,
1269 "method": "tools/call",
1270 "params": {
1271 "name": "admin.echo",
1272 "arguments": { "message": "hello" }
1273 }
1274 });
1275
1276 let response = mcp_post_handler(
1277 State(state),
1278 Extension(auth),
1279 Extension(TracingState::new()),
1280 Method::POST,
1281 headers,
1282 Json(payload),
1283 )
1284 .await;
1285
1286 assert_eq!(response.status(), StatusCode::OK);
1287 let body = response_json(response).await;
1288 assert_eq!(body["error"]["code"], -32003);
1289 }
1290
1291 #[tokio::test]
1292 async fn test_invalid_protocol_header_returns_400() {
1293 let state = test_state(McpConfig {
1294 enabled: true,
1295 ..Default::default()
1296 });
1297 let session_id = initialize_session(state.clone()).await;
1298 let mut headers = HeaderMap::new();
1299 headers.insert(
1300 MCP_SESSION_HEADER,
1301 HeaderValue::from_str(&session_id).expect("valid"),
1302 );
1303 headers.insert(
1304 MCP_PROTOCOL_HEADER,
1305 HeaderValue::from_static("invalid-version"),
1306 );
1307
1308 let payload = serde_json::json!({
1309 "jsonrpc": "2.0",
1310 "id": 7,
1311 "method": "tools/list",
1312 "params": {}
1313 });
1314
1315 let response = mcp_post_handler(
1316 State(state),
1317 Extension(AuthContext::unauthenticated()),
1318 Extension(TracingState::new()),
1319 Method::POST,
1320 headers,
1321 Json(payload),
1322 )
1323 .await;
1324 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1325 }
1326
1327 #[tokio::test]
1328 async fn test_missing_protocol_header_returns_400() {
1329 let state = test_state(McpConfig {
1330 enabled: true,
1331 ..Default::default()
1332 });
1333 let session_id = initialize_session(state.clone()).await;
1334 let mut headers = HeaderMap::new();
1335 headers.insert(
1336 MCP_SESSION_HEADER,
1337 HeaderValue::from_str(&session_id).expect("valid"),
1338 );
1339
1340 let payload = serde_json::json!({
1341 "jsonrpc": "2.0",
1342 "id": 8,
1343 "method": "tools/list",
1344 "params": {}
1345 });
1346
1347 let response = mcp_post_handler(
1348 State(state),
1349 Extension(AuthContext::unauthenticated()),
1350 Extension(TracingState::new()),
1351 Method::POST,
1352 headers,
1353 Json(payload),
1354 )
1355 .await;
1356 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1357 }
1358
1359 #[tokio::test]
1360 async fn test_notifications_return_202() {
1361 let state = test_state(McpConfig {
1362 enabled: true,
1363 ..Default::default()
1364 });
1365 let mut headers = HeaderMap::new();
1366 headers.insert(
1367 MCP_PROTOCOL_HEADER,
1368 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1369 );
1370 let payload = serde_json::json!({
1371 "jsonrpc": "2.0",
1372 "method": "notifications/tools/list_changed",
1373 "params": {}
1374 });
1375 let response = mcp_post_handler(
1376 State(state),
1377 Extension(AuthContext::unauthenticated()),
1378 Extension(TracingState::new()),
1379 Method::POST,
1380 headers,
1381 Json(payload),
1382 )
1383 .await;
1384 assert_eq!(response.status(), StatusCode::ACCEPTED);
1385 }
1386
1387 #[tokio::test]
1388 async fn test_invalid_origin_rejected() {
1389 let state = test_state(McpConfig {
1390 enabled: true,
1391 allowed_origins: vec!["https://allowed.example".to_string()],
1392 ..Default::default()
1393 });
1394 let payload = serde_json::json!({
1395 "jsonrpc": "2.0",
1396 "id": 1,
1397 "method": "initialize",
1398 "params": {
1399 "protocolVersion": "2025-11-25",
1400 "capabilities": {},
1401 "clientInfo": { "name": "test", "version": "1.0.0" }
1402 }
1403 });
1404
1405 let mut headers = HeaderMap::new();
1406 headers.insert("origin", HeaderValue::from_static("https://evil.example"));
1407
1408 let response = mcp_post_handler(
1409 State(state),
1410 Extension(AuthContext::unauthenticated()),
1411 Extension(TracingState::new()),
1412 Method::POST,
1413 headers,
1414 Json(payload),
1415 )
1416 .await;
1417
1418 assert_eq!(response.status(), StatusCode::FORBIDDEN);
1419 }
1420}