1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use axum::extract::{Extension, State};
6use axum::http::header::{HeaderName, HeaderValue};
7use axum::http::{HeaderMap, Method, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::{Json, body::Body};
10use forge_core::config::McpConfig;
11use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
12use forge_core::mcp::McpToolContext;
13use forge_core::rate_limit::RateLimitKey;
14use serde_json::Value;
15use tokio::sync::RwLock;
16
17use crate::mcp::McpToolRegistry;
18use crate::rate_limit::RateLimiter;
19
20const SUPPORTED_VERSIONS: &[&str] = &["2025-11-25", "2025-03-26", "2024-11-05"];
21#[cfg(test)]
22const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
23const MCP_SESSION_HEADER: &str = "mcp-session-id";
24const MCP_PROTOCOL_HEADER: &str = "mcp-protocol-version";
25const DEFAULT_PAGE_SIZE: usize = 50;
26type ResponseError = Box<Response>;
27
28#[derive(Debug, Clone)]
29struct McpSession {
30 initialized: bool,
31 protocol_version: String,
32 expires_at: Instant,
33}
34
35#[derive(Clone)]
36pub struct McpState {
37 config: McpConfig,
38 registry: McpToolRegistry,
39 pool: sqlx::PgPool,
40 sessions: Arc<RwLock<HashMap<String, McpSession>>>,
41 job_dispatcher: Option<Arc<dyn JobDispatch>>,
42 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
43 rate_limiter: Arc<RateLimiter>,
44}
45
46impl McpState {
47 pub fn new(
48 config: McpConfig,
49 registry: McpToolRegistry,
50 pool: sqlx::PgPool,
51 job_dispatcher: Option<Arc<dyn JobDispatch>>,
52 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
53 ) -> Self {
54 Self {
55 config,
56 registry,
57 pool: pool.clone(),
58 sessions: Arc::new(RwLock::new(HashMap::new())),
59 job_dispatcher,
60 workflow_dispatcher,
61 rate_limiter: Arc::new(RateLimiter::new(pool)),
62 }
63 }
64
65 async fn cleanup_expired_sessions(&self) {
66 let mut sessions = self.sessions.write().await;
67 let now = Instant::now();
68 sessions.retain(|_, session| session.expires_at > now);
69 }
70
71 async fn touch_session(&self, session_id: &str) {
72 let mut sessions = self.sessions.write().await;
73 if let Some(session) = sessions.get_mut(session_id) {
74 session.expires_at = Instant::now() + Duration::from_secs(self.config.session_ttl_secs);
75 }
76 }
77
78 fn session_ttl(&self) -> Duration {
79 Duration::from_secs(self.config.session_ttl_secs)
80 }
81}
82
83pub async fn mcp_get_handler(State(state): State<Arc<McpState>>, headers: HeaderMap) -> Response {
84 if let Err(resp) = validate_origin(&headers, &state.config) {
85 return *resp;
86 }
87
88 (
89 StatusCode::METHOD_NOT_ALLOWED,
90 Json(json_rpc_error(
91 None,
92 -32601,
93 "GET stream is not supported in Forge MCP v1",
94 None,
95 )),
96 )
97 .into_response()
98}
99
100pub async fn mcp_post_handler(
101 State(state): State<Arc<McpState>>,
102 Extension(auth): Extension<AuthContext>,
103 Extension(tracing): Extension<super::tracing::TracingState>,
104 method: Method,
105 headers: HeaderMap,
106 Json(payload): Json<Value>,
107) -> Response {
108 if method != Method::POST {
109 return (
110 StatusCode::METHOD_NOT_ALLOWED,
111 Json(json_rpc_error(None, -32601, "Only POST is supported", None)),
112 )
113 .into_response();
114 }
115
116 if let Err(resp) = validate_origin(&headers, &state.config) {
117 return *resp;
118 }
119
120 state.cleanup_expired_sessions().await;
121
122 let Some(method_name) = payload.get("method").and_then(Value::as_str) else {
123 if payload.get("id").is_some()
125 && (payload.get("result").is_some() || payload.get("error").is_some())
126 {
127 return StatusCode::ACCEPTED.into_response();
128 }
129 return (
130 StatusCode::BAD_REQUEST,
131 Json(json_rpc_error(
132 None,
133 -32600,
134 "Invalid JSON-RPC payload",
135 None,
136 )),
137 )
138 .into_response();
139 };
140
141 let id = payload.get("id").cloned();
142 let params = payload
143 .get("params")
144 .cloned()
145 .unwrap_or(Value::Object(Default::default()));
146
147 if id.is_none() {
149 return handle_notification(&state, method_name, params, &headers).await;
150 }
151
152 if method_name != "initialize"
154 && let Err(resp) = enforce_protocol_header(&state.config, &headers)
155 {
156 return *resp;
157 }
158
159 match method_name {
160 "initialize" => handle_initialize(&state, id, ¶ms).await,
161 "tools/list" => {
162 let session_id = match required_session_id(&state, &headers, true).await {
163 Ok(v) => v,
164 Err(resp) => return resp,
165 };
166 state.touch_session(&session_id).await;
167 handle_tools_list(&state, id, ¶ms)
168 }
169 "tools/call" => {
170 let session_id = match required_session_id(&state, &headers, true).await {
171 Ok(v) => v,
172 Err(resp) => return resp,
173 };
174 state.touch_session(&session_id).await;
175
176 let metadata = build_request_metadata(&tracing, &headers);
177 handle_tools_call(&state, id, ¶ms, &auth, metadata).await
178 }
179 _ => (
180 StatusCode::OK,
181 Json(json_rpc_error(id, -32601, "Method not found", None)),
182 )
183 .into_response(),
184 }
185}
186
187async fn handle_notification(
188 state: &Arc<McpState>,
189 method_name: &str,
190 _params: Value,
191 headers: &HeaderMap,
192) -> Response {
193 if let Err(resp) = enforce_protocol_header(&state.config, headers) {
194 return *resp;
195 }
196
197 match method_name {
198 "notifications/initialized" => {
199 let session_id = match required_session_id(state, headers, false).await {
200 Ok(v) => v,
201 Err(resp) => return resp,
202 };
203
204 let mut sessions = state.sessions.write().await;
205 if let Some(session) = sessions.get_mut(&session_id) {
206 session.initialized = true;
207 session.expires_at = Instant::now() + state.session_ttl();
208 return StatusCode::ACCEPTED.into_response();
209 }
210
211 (
212 StatusCode::BAD_REQUEST,
213 Json(json_rpc_error(
214 None,
215 -32600,
216 "Unknown MCP session. Re-initialize the connection.",
217 None,
218 )),
219 )
220 .into_response()
221 }
222 _ => StatusCode::ACCEPTED.into_response(),
223 }
224}
225
226async fn handle_initialize(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
227 let Some(requested_version) = params.get("protocolVersion").and_then(Value::as_str) else {
228 return (
229 StatusCode::OK,
230 Json(json_rpc_error(
231 id,
232 -32602,
233 "Missing protocolVersion in initialize params",
234 None,
235 )),
236 )
237 .into_response();
238 };
239
240 if !SUPPORTED_VERSIONS.contains(&requested_version) {
241 return (
242 StatusCode::OK,
243 Json(json_rpc_error(
244 id,
245 -32602,
246 "Unsupported protocolVersion",
247 Some(serde_json::json!({
248 "supported": SUPPORTED_VERSIONS
249 })),
250 )),
251 )
252 .into_response();
253 }
254
255 let session_id = uuid::Uuid::new_v4().to_string();
256 {
257 let mut sessions = state.sessions.write().await;
258 sessions.insert(
259 session_id.clone(),
260 McpSession {
261 initialized: false,
262 protocol_version: requested_version.to_string(),
263 expires_at: Instant::now() + state.session_ttl(),
264 },
265 );
266 }
267
268 let mut response = (
269 StatusCode::OK,
270 Json(json_rpc_success(
271 id,
272 serde_json::json!({
273 "protocolVersion": requested_version,
274 "capabilities": {
275 "tools": {
276 "listChanged": false
277 }
278 },
279 "serverInfo": {
280 "name": "forge",
281 "version": env!("CARGO_PKG_VERSION")
282 }
283 }),
284 )),
285 )
286 .into_response();
287
288 set_header(&mut response, MCP_SESSION_HEADER, &session_id);
289 set_header(&mut response, MCP_PROTOCOL_HEADER, requested_version);
290 response
291}
292
293fn handle_tools_list(state: &Arc<McpState>, id: Option<Value>, params: &Value) -> Response {
294 let cursor = params.get("cursor").and_then(Value::as_str);
295 let start = match cursor {
296 Some(c) => match c.parse::<usize>() {
297 Ok(v) => v,
298 Err(_) => {
299 return (
300 StatusCode::OK,
301 Json(json_rpc_error(
302 id,
303 -32602,
304 "Invalid cursor in tools/list request",
305 None,
306 )),
307 )
308 .into_response();
309 }
310 },
311 None => 0,
312 };
313
314 let mut tools: Vec<_> = state.registry.list().collect();
315 tools.sort_by(|a, b| a.info.name.cmp(b.info.name));
316
317 let page: Vec<_> = tools
318 .iter()
319 .skip(start)
320 .take(DEFAULT_PAGE_SIZE)
321 .map(|entry| {
322 let mut annotations = serde_json::Map::new();
325 if let Some(title) = &entry.info.annotations.title {
326 annotations.insert("title".into(), serde_json::Value::String(title.to_string()));
327 }
328 if let Some(v) = entry.info.annotations.read_only_hint {
329 annotations.insert("readOnlyHint".into(), serde_json::Value::Bool(v));
330 }
331 if let Some(v) = entry.info.annotations.destructive_hint {
332 annotations.insert("destructiveHint".into(), serde_json::Value::Bool(v));
333 }
334 if let Some(v) = entry.info.annotations.idempotent_hint {
335 annotations.insert("idempotentHint".into(), serde_json::Value::Bool(v));
336 }
337 if let Some(v) = entry.info.annotations.open_world_hint {
338 annotations.insert("openWorldHint".into(), serde_json::Value::Bool(v));
339 }
340
341 let mut value = serde_json::json!({
342 "name": entry.info.name,
343 "description": entry.info.description,
344 "inputSchema": entry.input_schema,
345 });
346 let obj = value.as_object_mut().expect("json! object literal");
348
349 if let Some(title) = &entry.info.title {
350 obj.insert("title".into(), serde_json::Value::String(title.to_string()));
351 }
352 if !annotations.is_empty() {
353 obj.insert("annotations".into(), serde_json::Value::Object(annotations));
354 }
355 if !entry.info.icons.is_empty() {
356 let icons: Vec<_> = entry
357 .info
358 .icons
359 .iter()
360 .map(|icon| {
361 serde_json::json!({
362 "src": icon.src,
363 "mimeType": icon.mime_type,
364 "sizes": icon.sizes,
365 "theme": icon.theme
366 })
367 })
368 .collect();
369 obj.insert("icons".into(), serde_json::Value::Array(icons));
370 }
371 if let Some(output_schema) = &entry.output_schema {
372 let schema = normalize_output_schema(output_schema);
376 obj.insert("outputSchema".into(), schema);
377 }
378 value
379 })
380 .collect();
381
382 let end = start.saturating_add(page.len());
383
384 let mut result = serde_json::json!({ "tools": page });
386 if end < tools.len() && result.is_object() {
387 result
389 .as_object_mut()
390 .expect("json! object literal")
391 .insert(
392 "nextCursor".into(),
393 serde_json::Value::String(end.to_string()),
394 );
395 }
396
397 (StatusCode::OK, Json(json_rpc_success(id, result))).into_response()
398}
399
400fn normalize_output_schema(schema: &Value) -> Value {
403 let type_str = schema.get("type").and_then(Value::as_str).unwrap_or("");
404 if type_str == "object" {
405 return schema.clone();
406 }
407
408 let mut wrapper = serde_json::json!({
410 "type": "object",
411 "properties": {
412 "result": schema
413 }
414 });
415
416 if let (Some(s), Some(obj)) = (schema.get("$schema"), wrapper.as_object_mut()) {
418 obj.insert("$schema".into(), s.clone());
419 }
420 if let (Some(d), Some(obj)) = (schema.get("definitions"), wrapper.as_object_mut()) {
421 obj.insert("definitions".into(), d.clone());
422 if let Some(inner) = wrapper.pointer_mut("/properties/result") {
424 inner.as_object_mut().map(|o| o.remove("definitions"));
425 }
426 }
427
428 wrapper
429}
430
431async fn handle_tools_call(
432 state: &Arc<McpState>,
433 id: Option<Value>,
434 params: &Value,
435 auth: &AuthContext,
436 request_metadata: RequestMetadata,
437) -> Response {
438 let Some(tool_name) = params.get("name").and_then(Value::as_str) else {
439 return (
440 StatusCode::OK,
441 Json(json_rpc_error(id, -32602, "Missing tool name", None)),
442 )
443 .into_response();
444 };
445
446 let Some(entry) = state.registry.get(tool_name) else {
447 return (
448 StatusCode::OK,
449 Json(json_rpc_error(id, -32602, "Unknown tool", None)),
450 )
451 .into_response();
452 };
453
454 if !entry.info.is_public && !auth.is_authenticated() {
455 if state.config.oauth {
456 let mut response = (
458 StatusCode::UNAUTHORIZED,
459 Json(json_rpc_error(id, -32001, "Authentication required", None)),
460 )
461 .into_response();
462 response.headers_mut().insert(
463 "WWW-Authenticate",
464 axum::http::header::HeaderValue::from_static(
465 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\"",
466 ),
467 );
468 return response;
469 }
470 return (
471 StatusCode::OK,
472 Json(json_rpc_error(id, -32001, "Authentication required", None)),
473 )
474 .into_response();
475 }
476 if let Some(role) = entry.info.required_role
477 && !auth.has_role(role)
478 {
479 return (
480 StatusCode::OK,
481 Json(json_rpc_error(
482 id,
483 -32003,
484 format!("Role '{}' required", role),
485 None,
486 )),
487 )
488 .into_response();
489 }
490
491 if let (Some(requests), Some(per_secs)) = (
492 entry.info.rate_limit_requests,
493 entry.info.rate_limit_per_secs,
494 ) {
495 let key_type = entry
496 .info
497 .rate_limit_key
498 .and_then(|k| k.parse::<RateLimitKey>().ok())
499 .unwrap_or_default();
500
501 let config = forge_core::RateLimitConfig::new(requests, Duration::from_secs(per_secs))
502 .with_key(key_type);
503 let bucket_key = state
504 .rate_limiter
505 .build_key(key_type, tool_name, auth, &request_metadata);
506
507 if let Err(e) = state.rate_limiter.enforce(&bucket_key, &config).await {
508 return (
509 StatusCode::OK,
510 Json(json_rpc_error(id, -32029, e.to_string(), None)),
511 )
512 .into_response();
513 }
514 }
515
516 let args = params
517 .get("arguments")
518 .cloned()
519 .unwrap_or(Value::Object(Default::default()));
520
521 let ctx = McpToolContext::with_dispatch(
522 state.pool.clone(),
523 auth.clone(),
524 request_metadata,
525 state.job_dispatcher.clone(),
526 state.workflow_dispatcher.clone(),
527 );
528
529 let result = if let Some(timeout_secs) = entry.info.timeout {
530 match tokio::time::timeout(
531 Duration::from_secs(timeout_secs),
532 (entry.handler)(&ctx, args),
533 )
534 .await
535 {
536 Ok(inner) => inner,
537 Err(_) => {
538 return (
539 StatusCode::OK,
540 Json(json_rpc_error(id, -32000, "Tool timed out", None)),
541 )
542 .into_response();
543 }
544 }
545 } else {
546 (entry.handler)(&ctx, args).await
547 };
548
549 match result {
550 Ok(output) => {
551 let result = tool_success_result(output);
552 (
553 StatusCode::OK,
554 Json(json_rpc_success(id, serde_json::json!(result))),
555 )
556 .into_response()
557 }
558 Err(e) => match e {
559 forge_core::ForgeError::Validation(msg)
560 | forge_core::ForgeError::InvalidArgument(msg) => (
561 StatusCode::OK,
562 Json(json_rpc_success(
563 id,
564 serde_json::json!({
565 "content": [{ "type": "text", "text": msg }],
566 "isError": true
567 }),
568 )),
569 )
570 .into_response(),
571 forge_core::ForgeError::Unauthorized(msg) => {
572 (StatusCode::OK, Json(json_rpc_error(id, -32001, msg, None))).into_response()
573 }
574 forge_core::ForgeError::Forbidden(msg) => {
575 (StatusCode::OK, Json(json_rpc_error(id, -32003, msg, None))).into_response()
576 }
577 _ => (
578 StatusCode::OK,
579 Json(json_rpc_error(id, -32603, "Internal server error", None)),
580 )
581 .into_response(),
582 },
583 }
584}
585
586fn tool_success_result(output: Value) -> Value {
587 match output {
588 Value::Object(_) => serde_json::json!({
589 "content": [{
590 "type": "text",
591 "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string())
592 }],
593 "structuredContent": output
594 }),
595 Value::String(text) => serde_json::json!({
596 "content": [{ "type": "text", "text": text }]
597 }),
598 other => serde_json::json!({
599 "content": [{
600 "type": "text",
601 "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string())
602 }]
603 }),
604 }
605}
606
607async fn required_session_id(
608 state: &Arc<McpState>,
609 headers: &HeaderMap,
610 require_initialized: bool,
611) -> std::result::Result<String, Response> {
612 let Some(session_id) = headers
613 .get(MCP_SESSION_HEADER)
614 .and_then(|v| v.to_str().ok())
615 else {
616 return Err((
617 StatusCode::BAD_REQUEST,
618 Json(json_rpc_error(
619 None,
620 -32600,
621 "Missing MCP-Session-Id header",
622 None,
623 )),
624 )
625 .into_response());
626 };
627
628 let sessions = state.sessions.read().await;
629 match sessions.get(session_id) {
630 Some(session) => {
631 if !SUPPORTED_VERSIONS.contains(&session.protocol_version.as_str()) {
632 return Err((
633 StatusCode::BAD_REQUEST,
634 Json(json_rpc_error(
635 None,
636 -32600,
637 "Session protocol version mismatch",
638 None,
639 )),
640 )
641 .into_response());
642 }
643 if require_initialized && !session.initialized {
644 return Err((
645 StatusCode::BAD_REQUEST,
646 Json(json_rpc_error(
647 None,
648 -32600,
649 "MCP session is not initialized",
650 None,
651 )),
652 )
653 .into_response());
654 }
655 Ok(session_id.to_string())
656 }
657 None => Err((
658 StatusCode::BAD_REQUEST,
659 Json(json_rpc_error(
660 None,
661 -32600,
662 "Unknown MCP session. Re-initialize.",
663 None,
664 )),
665 )
666 .into_response()),
667 }
668}
669
670fn validate_origin(
671 headers: &HeaderMap,
672 config: &McpConfig,
673) -> std::result::Result<(), ResponseError> {
674 let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else {
675 return Ok(());
676 };
677
678 if config.allowed_origins.is_empty() {
679 return Ok(());
680 }
681
682 let allowed = config
683 .allowed_origins
684 .iter()
685 .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin));
686 if allowed {
687 return Ok(());
688 }
689
690 Err(Box::new(
691 (
692 StatusCode::FORBIDDEN,
693 Json(json_rpc_error(None, -32600, "Invalid Origin header", None)),
694 )
695 .into_response(),
696 ))
697}
698
699fn enforce_protocol_header(
700 config: &McpConfig,
701 headers: &HeaderMap,
702) -> std::result::Result<(), ResponseError> {
703 if !config.require_protocol_version_header {
704 return Ok(());
705 }
706
707 let Some(version) = headers
708 .get(MCP_PROTOCOL_HEADER)
709 .and_then(|v| v.to_str().ok())
710 else {
711 return Err(Box::new(
712 (
713 StatusCode::BAD_REQUEST,
714 Json(json_rpc_error(
715 None,
716 -32600,
717 "Missing MCP-Protocol-Version header",
718 None,
719 )),
720 )
721 .into_response(),
722 ));
723 };
724
725 if !SUPPORTED_VERSIONS.contains(&version) {
726 return Err(Box::new(
727 (
728 StatusCode::BAD_REQUEST,
729 Json(json_rpc_error(
730 None,
731 -32600,
732 "Unsupported MCP-Protocol-Version",
733 Some(serde_json::json!({ "supported": SUPPORTED_VERSIONS })),
734 )),
735 )
736 .into_response(),
737 ));
738 }
739
740 Ok(())
741}
742
743fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
744 headers
745 .get("x-forwarded-for")
746 .and_then(|v| v.to_str().ok())
747 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
748 .filter(|s| !s.is_empty())
749 .or_else(|| {
750 headers
751 .get("x-real-ip")
752 .and_then(|v| v.to_str().ok())
753 .map(|s| s.trim().to_string())
754 .filter(|s| !s.is_empty())
755 })
756}
757
758fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
759 headers
760 .get(axum::http::header::USER_AGENT)
761 .and_then(|v| v.to_str().ok())
762 .map(String::from)
763}
764
765fn build_request_metadata(
766 tracing: &super::tracing::TracingState,
767 headers: &HeaderMap,
768) -> RequestMetadata {
769 RequestMetadata {
770 request_id: uuid::Uuid::parse_str(&tracing.request_id)
771 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
772 trace_id: tracing.trace_id.clone(),
773 client_ip: extract_client_ip(headers),
774 user_agent: extract_user_agent(headers),
775 timestamp: chrono::Utc::now(),
776 }
777}
778
779fn json_rpc_success(id: Option<Value>, result: Value) -> Value {
780 serde_json::json!({
781 "jsonrpc": "2.0",
782 "id": id.unwrap_or(Value::Null),
783 "result": result
784 })
785}
786
787fn json_rpc_error(
788 id: Option<Value>,
789 code: i32,
790 message: impl Into<String>,
791 data: Option<Value>,
792) -> Value {
793 let mut error = serde_json::json!({
794 "code": code,
795 "message": message.into()
796 });
797 if let Some(data) = data
798 && let Some(obj) = error.as_object_mut()
799 {
800 obj.insert("data".to_string(), data);
801 }
802
803 serde_json::json!({
804 "jsonrpc": "2.0",
805 "id": id.unwrap_or(Value::Null),
806 "error": error
807 })
808}
809
810fn set_header(response: &mut Response<Body>, name: &str, value: &str) {
811 if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::from_str(value)) {
812 response.headers_mut().insert(name, value);
813 }
814}
815
816#[cfg(test)]
817#[allow(clippy::expect_used, clippy::indexing_slicing, clippy::unwrap_used)]
818mod tests {
819 use super::super::tracing::TracingState;
820 use super::*;
821 use axum::body::to_bytes;
822 use forge_core::function::AuthContext;
823 use forge_core::mcp::{ForgeMcpTool, McpToolAnnotations, McpToolInfo};
824 use forge_core::schemars::{self, JsonSchema};
825 use serde::{Deserialize, Serialize};
826 use std::collections::HashMap;
827 use std::future::Future;
828 use std::pin::Pin;
829
830 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
831 struct EchoArgs {
832 message: String,
833 }
834
835 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
836 struct EchoOutput {
837 echoed: String,
838 }
839
840 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
841 #[serde(rename_all = "snake_case")]
842 enum ExportFormat {
843 Json,
844 Csv,
845 }
846
847 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
848 struct MetadataArgs {
849 #[schemars(description = "Project UUID to export")]
850 project_id: String,
851 format: ExportFormat,
852 }
853
854 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
855 struct MetadataOutput {
856 accepted: bool,
857 }
858
859 struct EchoTool;
860
861 impl ForgeMcpTool for EchoTool {
862 type Args = EchoArgs;
863 type Output = EchoOutput;
864
865 fn info() -> McpToolInfo {
866 McpToolInfo {
867 name: "echo",
868 title: Some("Echo"),
869 description: Some("Echo back the message"),
870 required_role: None,
871 is_public: false,
872 timeout: None,
873 rate_limit_requests: None,
874 rate_limit_per_secs: None,
875 rate_limit_key: None,
876 annotations: McpToolAnnotations::default(),
877 icons: &[],
878 }
879 }
880
881 fn execute(
882 _ctx: &McpToolContext,
883 args: Self::Args,
884 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
885 Box::pin(async move {
886 Ok(EchoOutput {
887 echoed: args.message,
888 })
889 })
890 }
891 }
892
893 struct AdminTool;
894
895 impl ForgeMcpTool for AdminTool {
896 type Args = EchoArgs;
897 type Output = EchoOutput;
898
899 fn info() -> McpToolInfo {
900 McpToolInfo {
901 name: "admin.echo",
902 title: Some("Admin Echo"),
903 description: Some("Admin only echo"),
904 required_role: Some("admin"),
905 is_public: false,
906 timeout: None,
907 rate_limit_requests: None,
908 rate_limit_per_secs: None,
909 rate_limit_key: None,
910 annotations: McpToolAnnotations::default(),
911 icons: &[],
912 }
913 }
914
915 fn execute(
916 _ctx: &McpToolContext,
917 args: Self::Args,
918 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
919 Box::pin(async move {
920 Ok(EchoOutput {
921 echoed: args.message,
922 })
923 })
924 }
925 }
926
927 struct MetadataTool;
928
929 impl ForgeMcpTool for MetadataTool {
930 type Args = MetadataArgs;
931 type Output = MetadataOutput;
932
933 fn info() -> McpToolInfo {
934 McpToolInfo {
935 name: "export.project",
936 title: Some("Export Project"),
937 description: Some("Export project data"),
938 required_role: None,
939 is_public: false,
940 timeout: None,
941 rate_limit_requests: None,
942 rate_limit_per_secs: None,
943 rate_limit_key: None,
944 annotations: McpToolAnnotations::default(),
945 icons: &[],
946 }
947 }
948
949 fn execute(
950 _ctx: &McpToolContext,
951 _args: Self::Args,
952 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Self::Output>> + Send + '_>> {
953 Box::pin(async move { Ok(MetadataOutput { accepted: true }) })
954 }
955 }
956
957 #[test]
958 fn test_json_rpc_helpers() {
959 let success = json_rpc_success(
960 Some(serde_json::json!(1)),
961 serde_json::json!({ "ok": true }),
962 );
963 assert_eq!(success["jsonrpc"], "2.0");
964 assert!(success.get("result").is_some());
965
966 let err = json_rpc_error(Some(serde_json::json!(1)), -32601, "not found", None);
967 assert_eq!(err["error"]["code"], -32601);
968 }
969
970 fn test_state(config: McpConfig) -> Arc<McpState> {
971 test_state_with_registry(config, McpToolRegistry::new())
972 }
973
974 fn test_state_with_registry(config: McpConfig, registry: McpToolRegistry) -> Arc<McpState> {
975 let pool = sqlx::postgres::PgPoolOptions::new()
976 .max_connections(1)
977 .connect_lazy("postgres://localhost/nonexistent")
978 .expect("lazy pool must build");
979 Arc::new(McpState::new(config, registry, pool, None, None))
980 }
981
982 async fn response_json(response: Response) -> Value {
983 let bytes = to_bytes(response.into_body(), usize::MAX)
984 .await
985 .expect("body bytes");
986 if bytes.is_empty() {
987 return serde_json::json!({});
988 }
989 serde_json::from_slice(&bytes).expect("valid json")
990 }
991
992 async fn initialize_session(state: Arc<McpState>) -> String {
993 let payload = serde_json::json!({
994 "jsonrpc": "2.0",
995 "id": 1,
996 "method": "initialize",
997 "params": {
998 "protocolVersion": "2025-11-25",
999 "capabilities": {},
1000 "clientInfo": { "name": "test", "version": "1.0.0" }
1001 }
1002 });
1003 let response = mcp_post_handler(
1004 State(state),
1005 Extension(AuthContext::unauthenticated()),
1006 Extension(TracingState::new()),
1007 Method::POST,
1008 HeaderMap::new(),
1009 Json(payload),
1010 )
1011 .await;
1012
1013 assert_eq!(response.status(), StatusCode::OK);
1014 response
1015 .headers()
1016 .get(MCP_SESSION_HEADER)
1017 .and_then(|v| v.to_str().ok())
1018 .expect("session id must exist")
1019 .to_string()
1020 }
1021
1022 async fn mark_initialized(state: Arc<McpState>, headers: HeaderMap) {
1023 let payload = serde_json::json!({
1024 "jsonrpc": "2.0",
1025 "method": "notifications/initialized",
1026 "params": {}
1027 });
1028 let response = mcp_post_handler(
1029 State(state),
1030 Extension(AuthContext::unauthenticated()),
1031 Extension(TracingState::new()),
1032 Method::POST,
1033 headers,
1034 Json(payload),
1035 )
1036 .await;
1037 assert_eq!(response.status(), StatusCode::ACCEPTED);
1038 }
1039
1040 async fn initialized_headers(state: Arc<McpState>) -> HeaderMap {
1041 let session_id = initialize_session(state.clone()).await;
1042 let mut headers = HeaderMap::new();
1043 headers.insert(
1044 MCP_SESSION_HEADER,
1045 HeaderValue::from_str(&session_id).expect("valid session id header"),
1046 );
1047 headers.insert(
1048 MCP_PROTOCOL_HEADER,
1049 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1050 );
1051 mark_initialized(state, headers.clone()).await;
1052 headers
1053 }
1054
1055 #[tokio::test]
1056 async fn test_initialize_sets_session_header() {
1057 let state = test_state(McpConfig {
1058 enabled: true,
1059 ..Default::default()
1060 });
1061 let session = initialize_session(state).await;
1062 assert!(!session.is_empty());
1063 }
1064
1065 #[tokio::test]
1066 async fn test_tools_list_requires_initialized_session() {
1067 let state = test_state(McpConfig {
1068 enabled: true,
1069 ..Default::default()
1070 });
1071
1072 let session_id = initialize_session(state.clone()).await;
1073
1074 let mut headers = HeaderMap::new();
1075 headers.insert(
1076 MCP_SESSION_HEADER,
1077 HeaderValue::from_str(&session_id).expect("valid"),
1078 );
1079 headers.insert(
1080 MCP_PROTOCOL_HEADER,
1081 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1082 );
1083
1084 let list_payload = serde_json::json!({
1085 "jsonrpc": "2.0",
1086 "id": 2,
1087 "method": "tools/list",
1088 "params": {}
1089 });
1090 let response = mcp_post_handler(
1091 State(state),
1092 Extension(AuthContext::unauthenticated()),
1093 Extension(TracingState::new()),
1094 Method::POST,
1095 headers,
1096 Json(list_payload),
1097 )
1098 .await;
1099
1100 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1101 }
1102
1103 #[tokio::test]
1104 async fn test_tools_list_returns_registered_tools() {
1105 let mut registry = McpToolRegistry::new();
1106 registry.register::<EchoTool>();
1107
1108 let state = test_state_with_registry(
1109 McpConfig {
1110 enabled: true,
1111 ..Default::default()
1112 },
1113 registry,
1114 );
1115 let headers = initialized_headers(state.clone()).await;
1116 let payload = serde_json::json!({
1117 "jsonrpc": "2.0",
1118 "id": 2,
1119 "method": "tools/list",
1120 "params": {}
1121 });
1122
1123 let response = mcp_post_handler(
1124 State(state),
1125 Extension(AuthContext::unauthenticated()),
1126 Extension(TracingState::new()),
1127 Method::POST,
1128 headers,
1129 Json(payload),
1130 )
1131 .await;
1132
1133 assert_eq!(response.status(), StatusCode::OK);
1134 let body = response_json(response).await;
1135 let tools = body["result"]["tools"]
1136 .as_array()
1137 .expect("tools list should be array");
1138 assert_eq!(tools.len(), 1);
1139 assert_eq!(tools[0]["name"], "echo");
1140 assert!(tools[0].get("inputSchema").is_some());
1141 assert!(tools[0].get("outputSchema").is_some());
1142 }
1143
1144 #[tokio::test]
1145 async fn test_tools_list_exposes_parameter_metadata() {
1146 let mut registry = McpToolRegistry::new();
1147 registry.register::<MetadataTool>();
1148
1149 let state = test_state_with_registry(
1150 McpConfig {
1151 enabled: true,
1152 ..Default::default()
1153 },
1154 registry,
1155 );
1156 let headers = initialized_headers(state.clone()).await;
1157 let payload = serde_json::json!({
1158 "jsonrpc": "2.0",
1159 "id": 9,
1160 "method": "tools/list",
1161 "params": {}
1162 });
1163
1164 let response = mcp_post_handler(
1165 State(state),
1166 Extension(AuthContext::unauthenticated()),
1167 Extension(TracingState::new()),
1168 Method::POST,
1169 headers,
1170 Json(payload),
1171 )
1172 .await;
1173
1174 assert_eq!(response.status(), StatusCode::OK);
1175 let body = response_json(response).await;
1176 let tools = body["result"]["tools"]
1177 .as_array()
1178 .expect("tools list should be array");
1179 assert_eq!(tools.len(), 1);
1180
1181 let input_schema = &tools[0]["inputSchema"];
1182 assert_eq!(
1183 input_schema["properties"]["project_id"]["description"],
1184 "Project UUID to export"
1185 );
1186
1187 let schema_text = input_schema.to_string();
1188 assert!(schema_text.contains("\"json\""));
1189 assert!(schema_text.contains("\"csv\""));
1190 }
1191
1192 #[tokio::test]
1193 async fn test_tools_call_success_returns_structured_content() {
1194 let mut registry = McpToolRegistry::new();
1195 registry.register::<EchoTool>();
1196
1197 let state = test_state_with_registry(
1198 McpConfig {
1199 enabled: true,
1200 ..Default::default()
1201 },
1202 registry,
1203 );
1204 let headers = initialized_headers(state.clone()).await;
1205 let auth = AuthContext::authenticated(
1206 uuid::Uuid::new_v4(),
1207 vec!["member".to_string()],
1208 HashMap::new(),
1209 );
1210 let payload = serde_json::json!({
1211 "jsonrpc": "2.0",
1212 "id": 3,
1213 "method": "tools/call",
1214 "params": {
1215 "name": "echo",
1216 "arguments": { "message": "hello" }
1217 }
1218 });
1219
1220 let response = mcp_post_handler(
1221 State(state),
1222 Extension(auth),
1223 Extension(TracingState::new()),
1224 Method::POST,
1225 headers,
1226 Json(payload),
1227 )
1228 .await;
1229
1230 assert_eq!(response.status(), StatusCode::OK);
1231 let body = response_json(response).await;
1232 assert_eq!(body["result"]["structuredContent"]["echoed"], "hello");
1233 assert_eq!(body["result"]["content"][0]["type"], "text");
1234 }
1235
1236 #[tokio::test]
1237 async fn test_tools_call_validation_failure_returns_is_error() {
1238 let mut registry = McpToolRegistry::new();
1239 registry.register::<EchoTool>();
1240
1241 let state = test_state_with_registry(
1242 McpConfig {
1243 enabled: true,
1244 ..Default::default()
1245 },
1246 registry,
1247 );
1248 let headers = initialized_headers(state.clone()).await;
1249 let auth = AuthContext::authenticated(
1250 uuid::Uuid::new_v4(),
1251 vec!["member".to_string()],
1252 HashMap::new(),
1253 );
1254 let payload = serde_json::json!({
1255 "jsonrpc": "2.0",
1256 "id": 4,
1257 "method": "tools/call",
1258 "params": {
1259 "name": "echo",
1260 "arguments": {}
1261 }
1262 });
1263
1264 let response = mcp_post_handler(
1265 State(state),
1266 Extension(auth),
1267 Extension(TracingState::new()),
1268 Method::POST,
1269 headers,
1270 Json(payload),
1271 )
1272 .await;
1273
1274 assert_eq!(response.status(), StatusCode::OK);
1275 let body = response_json(response).await;
1276 assert_eq!(body["result"]["isError"], true);
1277 }
1278
1279 #[tokio::test]
1280 async fn test_tools_call_requires_authentication() {
1281 let mut registry = McpToolRegistry::new();
1282 registry.register::<EchoTool>();
1283
1284 let state = test_state_with_registry(
1285 McpConfig {
1286 enabled: true,
1287 ..Default::default()
1288 },
1289 registry,
1290 );
1291 let headers = initialized_headers(state.clone()).await;
1292 let payload = serde_json::json!({
1293 "jsonrpc": "2.0",
1294 "id": 5,
1295 "method": "tools/call",
1296 "params": {
1297 "name": "echo",
1298 "arguments": { "message": "hello" }
1299 }
1300 });
1301
1302 let response = mcp_post_handler(
1303 State(state),
1304 Extension(AuthContext::unauthenticated()),
1305 Extension(TracingState::new()),
1306 Method::POST,
1307 headers,
1308 Json(payload),
1309 )
1310 .await;
1311
1312 assert_eq!(response.status(), StatusCode::OK);
1313 let body = response_json(response).await;
1314 assert_eq!(body["error"]["code"], -32001);
1315 }
1316
1317 #[tokio::test]
1318 async fn test_tools_call_requires_role() {
1319 let mut registry = McpToolRegistry::new();
1320 registry.register::<AdminTool>();
1321
1322 let state = test_state_with_registry(
1323 McpConfig {
1324 enabled: true,
1325 ..Default::default()
1326 },
1327 registry,
1328 );
1329 let headers = initialized_headers(state.clone()).await;
1330 let auth = AuthContext::authenticated(
1331 uuid::Uuid::new_v4(),
1332 vec!["member".to_string()],
1333 HashMap::new(),
1334 );
1335 let payload = serde_json::json!({
1336 "jsonrpc": "2.0",
1337 "id": 6,
1338 "method": "tools/call",
1339 "params": {
1340 "name": "admin.echo",
1341 "arguments": { "message": "hello" }
1342 }
1343 });
1344
1345 let response = mcp_post_handler(
1346 State(state),
1347 Extension(auth),
1348 Extension(TracingState::new()),
1349 Method::POST,
1350 headers,
1351 Json(payload),
1352 )
1353 .await;
1354
1355 assert_eq!(response.status(), StatusCode::OK);
1356 let body = response_json(response).await;
1357 assert_eq!(body["error"]["code"], -32003);
1358 }
1359
1360 #[tokio::test]
1361 async fn test_invalid_protocol_header_returns_400() {
1362 let state = test_state(McpConfig {
1363 enabled: true,
1364 ..Default::default()
1365 });
1366 let session_id = initialize_session(state.clone()).await;
1367 let mut headers = HeaderMap::new();
1368 headers.insert(
1369 MCP_SESSION_HEADER,
1370 HeaderValue::from_str(&session_id).expect("valid"),
1371 );
1372 headers.insert(
1373 MCP_PROTOCOL_HEADER,
1374 HeaderValue::from_static("invalid-version"),
1375 );
1376
1377 let payload = serde_json::json!({
1378 "jsonrpc": "2.0",
1379 "id": 7,
1380 "method": "tools/list",
1381 "params": {}
1382 });
1383
1384 let response = mcp_post_handler(
1385 State(state),
1386 Extension(AuthContext::unauthenticated()),
1387 Extension(TracingState::new()),
1388 Method::POST,
1389 headers,
1390 Json(payload),
1391 )
1392 .await;
1393 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1394 }
1395
1396 #[tokio::test]
1397 async fn test_missing_protocol_header_returns_400() {
1398 let state = test_state(McpConfig {
1399 enabled: true,
1400 ..Default::default()
1401 });
1402 let session_id = initialize_session(state.clone()).await;
1403 let mut headers = HeaderMap::new();
1404 headers.insert(
1405 MCP_SESSION_HEADER,
1406 HeaderValue::from_str(&session_id).expect("valid"),
1407 );
1408
1409 let payload = serde_json::json!({
1410 "jsonrpc": "2.0",
1411 "id": 8,
1412 "method": "tools/list",
1413 "params": {}
1414 });
1415
1416 let response = mcp_post_handler(
1417 State(state),
1418 Extension(AuthContext::unauthenticated()),
1419 Extension(TracingState::new()),
1420 Method::POST,
1421 headers,
1422 Json(payload),
1423 )
1424 .await;
1425 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
1426 }
1427
1428 #[tokio::test]
1429 async fn test_notifications_return_202() {
1430 let state = test_state(McpConfig {
1431 enabled: true,
1432 ..Default::default()
1433 });
1434 let mut headers = HeaderMap::new();
1435 headers.insert(
1436 MCP_PROTOCOL_HEADER,
1437 HeaderValue::from_static(MCP_PROTOCOL_VERSION),
1438 );
1439 let payload = serde_json::json!({
1440 "jsonrpc": "2.0",
1441 "method": "notifications/tools/list_changed",
1442 "params": {}
1443 });
1444 let response = mcp_post_handler(
1445 State(state),
1446 Extension(AuthContext::unauthenticated()),
1447 Extension(TracingState::new()),
1448 Method::POST,
1449 headers,
1450 Json(payload),
1451 )
1452 .await;
1453 assert_eq!(response.status(), StatusCode::ACCEPTED);
1454 }
1455
1456 #[tokio::test]
1457 async fn test_invalid_origin_rejected() {
1458 let state = test_state(McpConfig {
1459 enabled: true,
1460 allowed_origins: vec!["https://allowed.example".to_string()],
1461 ..Default::default()
1462 });
1463 let payload = serde_json::json!({
1464 "jsonrpc": "2.0",
1465 "id": 1,
1466 "method": "initialize",
1467 "params": {
1468 "protocolVersion": "2025-11-25",
1469 "capabilities": {},
1470 "clientInfo": { "name": "test", "version": "1.0.0" }
1471 }
1472 });
1473
1474 let mut headers = HeaderMap::new();
1475 headers.insert("origin", HeaderValue::from_static("https://evil.example"));
1476
1477 let response = mcp_post_handler(
1478 State(state),
1479 Extension(AuthContext::unauthenticated()),
1480 Extension(TracingState::new()),
1481 Method::POST,
1482 headers,
1483 Json(payload),
1484 )
1485 .await;
1486
1487 assert_eq!(response.status(), StatusCode::FORBIDDEN);
1488 }
1489}