1use crate::session::{ExpectedSessionId, SESSION_ID_HEADER};
2use rmcp::model::{
3 CallToolResult, ClientNotification, ClientRequest, Content, ErrorCode, Implementation, Meta,
4 ProtocolVersion, ServerCapabilities, ServerInfo,
5};
6use rmcp::service::{DynService, NotificationContext, RequestContext, ServiceExt, ServiceRole};
7use rmcp::transport::streamable_http_server::{
8 session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
9};
10use rmcp::{
11 handler::server::router::tool::ToolRouter, tool, tool_handler, tool_router,
12 ErrorData as McpError, RoleServer, ServerHandler, Service,
13};
14use tokio::task::JoinHandle;
15
16pub const FAKE_CODE: &str = "test-uuid-12345-67890";
17
18pub const TEST_IMAGE_B64: &str = include_str!("test_assets/test_image.b64").trim_ascii_end();
19
20pub trait HasMeta {
21 fn meta(&self) -> &Meta;
22}
23
24impl<R: ServiceRole> HasMeta for RequestContext<R> {
25 fn meta(&self) -> &Meta {
26 &self.meta
27 }
28}
29
30impl<R: ServiceRole> HasMeta for NotificationContext<R> {
31 fn meta(&self) -> &Meta {
32 &self.meta
33 }
34}
35
36struct ValidatingService<S> {
37 inner: S,
38 expected_session_id: ExpectedSessionId,
39}
40
41impl<S> ValidatingService<S> {
42 fn new(inner: S, expected_session_id: ExpectedSessionId) -> Self {
43 Self {
44 inner,
45 expected_session_id,
46 }
47 }
48
49 fn validate<C: HasMeta>(&self, context: &C) -> Result<(), McpError> {
50 let actual = context
51 .meta()
52 .0
53 .get(SESSION_ID_HEADER)
54 .and_then(|v| v.as_str());
55 self.expected_session_id
56 .validate(actual)
57 .map_err(|e| McpError::new(ErrorCode::INVALID_REQUEST, e, None))
58 }
59}
60
61impl<S: Service<RoleServer>> Service<RoleServer> for ValidatingService<S> {
62 async fn handle_request(
63 &self,
64 request: ClientRequest,
65 context: RequestContext<RoleServer>,
66 ) -> Result<rmcp::model::ServerResult, McpError> {
67 if !matches!(request, ClientRequest::InitializeRequest(_)) {
68 self.validate(&context)?;
69 }
70 self.inner.handle_request(request, context).await
71 }
72
73 async fn handle_notification(
74 &self,
75 notification: ClientNotification,
76 context: NotificationContext<RoleServer>,
77 ) -> Result<(), McpError> {
78 if !matches!(notification, ClientNotification::InitializedNotification(_)) {
79 self.validate(&context).ok();
80 }
81 self.inner.handle_notification(notification, context).await
82 }
83
84 fn get_info(&self) -> ServerInfo {
85 self.inner.get_info()
86 }
87}
88
89#[derive(Clone)]
90pub struct McpFixtureServer {
91 tool_router: ToolRouter<McpFixtureServer>,
92}
93
94impl Default for McpFixtureServer {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[tool_router]
101impl McpFixtureServer {
102 pub fn new() -> Self {
103 Self {
104 tool_router: Self::tool_router(),
105 }
106 }
107
108 #[tool(description = "Get the code")]
109 fn get_code(&self) -> Result<CallToolResult, McpError> {
110 Ok(CallToolResult::success(vec![Content::text(FAKE_CODE)]))
111 }
112
113 #[tool(description = "Get an image")]
114 fn get_image(&self) -> Result<CallToolResult, McpError> {
115 Ok(CallToolResult::success(vec![Content::image(
116 TEST_IMAGE_B64,
117 "image/png",
118 )]))
119 }
120}
121
122#[tool_handler]
123impl ServerHandler for McpFixtureServer {
124 fn get_info(&self) -> ServerInfo {
125 ServerInfo {
126 protocol_version: ProtocolVersion::V_2025_03_26,
127 capabilities: ServerCapabilities::builder().enable_tools().build(),
128 server_info: Implementation {
129 name: "mcp-fixture".into(),
130 version: "1.0.0".into(),
131 ..Default::default()
132 },
133 instructions: Some("Test server with get_code and get_image tools.".into()),
134 }
135 }
136}
137
138pub struct McpFixture {
139 pub url: String,
140 handle: JoinHandle<()>,
141}
142
143impl Drop for McpFixture {
144 fn drop(&mut self) {
145 self.handle.abort();
146 }
147}
148
149type McpServiceFactory =
150 Box<dyn Fn() -> Result<Box<dyn DynService<RoleServer>>, std::io::Error> + Send + Sync>;
151
152impl McpFixture {
153 pub async fn new(expected_session_id: Option<ExpectedSessionId>) -> Self {
154 let service_factory: McpServiceFactory = match expected_session_id {
155 Some(expected_session_id) => Box::new(move || {
156 Ok(
157 ValidatingService::new(McpFixtureServer::new(), expected_session_id.clone())
158 .into_dyn(),
159 )
160 }),
161 None => Box::new(|| Ok(McpFixtureServer::new().into_dyn())),
162 };
163
164 let service = StreamableHttpService::new(
165 service_factory,
166 LocalSessionManager::default().into(),
167 StreamableHttpServerConfig::default(),
168 );
169 let router = axum::Router::new().nest_service("/mcp", service);
170 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
171 let addr = listener.local_addr().unwrap();
172 let url = format!("http://{addr}/mcp");
173
174 let handle = tokio::spawn(async move {
175 axum::serve(listener, router).await.unwrap();
176 });
177
178 Self { url, handle }
179 }
180}