Skip to main content

a8e_test_support/
mcp.rs

1use crate::session::{ExpectedSessionId, SESSION_ID_HEADER};
2use rmcp::model::{
3    CallToolResult, ClientNotification, ClientRequest, Content, ErrorCode, Implementation, Meta,
4    ProtocolVersion, ServerCapabilities, ServerInfo,
5};
6use rmcp::service::{DynService, NotificationContext, RequestContext, ServiceExt, ServiceRole};
7use rmcp::transport::streamable_http_server::{
8    session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
9};
10use rmcp::{
11    handler::server::router::tool::ToolRouter, tool, tool_handler, tool_router,
12    ErrorData as McpError, RoleServer, ServerHandler, Service,
13};
14use tokio::task::JoinHandle;
15
16pub const FAKE_CODE: &str = "test-uuid-12345-67890";
17
18pub const TEST_IMAGE_B64: &str = include_str!("test_assets/test_image.b64").trim_ascii_end();
19
20pub trait HasMeta {
21    fn meta(&self) -> &Meta;
22}
23
24impl<R: ServiceRole> HasMeta for RequestContext<R> {
25    fn meta(&self) -> &Meta {
26        &self.meta
27    }
28}
29
30impl<R: ServiceRole> HasMeta for NotificationContext<R> {
31    fn meta(&self) -> &Meta {
32        &self.meta
33    }
34}
35
36struct ValidatingService<S> {
37    inner: S,
38    expected_session_id: ExpectedSessionId,
39}
40
41impl<S> ValidatingService<S> {
42    fn new(inner: S, expected_session_id: ExpectedSessionId) -> Self {
43        Self {
44            inner,
45            expected_session_id,
46        }
47    }
48
49    fn validate<C: HasMeta>(&self, context: &C) -> Result<(), McpError> {
50        let actual = context
51            .meta()
52            .0
53            .get(SESSION_ID_HEADER)
54            .and_then(|v| v.as_str());
55        self.expected_session_id
56            .validate(actual)
57            .map_err(|e| McpError::new(ErrorCode::INVALID_REQUEST, e, None))
58    }
59}
60
61impl<S: Service<RoleServer>> Service<RoleServer> for ValidatingService<S> {
62    async fn handle_request(
63        &self,
64        request: ClientRequest,
65        context: RequestContext<RoleServer>,
66    ) -> Result<rmcp::model::ServerResult, McpError> {
67        if !matches!(request, ClientRequest::InitializeRequest(_)) {
68            self.validate(&context)?;
69        }
70        self.inner.handle_request(request, context).await
71    }
72
73    async fn handle_notification(
74        &self,
75        notification: ClientNotification,
76        context: NotificationContext<RoleServer>,
77    ) -> Result<(), McpError> {
78        if !matches!(notification, ClientNotification::InitializedNotification(_)) {
79            self.validate(&context).ok();
80        }
81        self.inner.handle_notification(notification, context).await
82    }
83
84    fn get_info(&self) -> ServerInfo {
85        self.inner.get_info()
86    }
87}
88
89#[derive(Clone)]
90pub struct McpFixtureServer {
91    tool_router: ToolRouter<McpFixtureServer>,
92}
93
94impl Default for McpFixtureServer {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100#[tool_router]
101impl McpFixtureServer {
102    pub fn new() -> Self {
103        Self {
104            tool_router: Self::tool_router(),
105        }
106    }
107
108    #[tool(description = "Get the code")]
109    fn get_code(&self) -> Result<CallToolResult, McpError> {
110        Ok(CallToolResult::success(vec![Content::text(FAKE_CODE)]))
111    }
112
113    #[tool(description = "Get an image")]
114    fn get_image(&self) -> Result<CallToolResult, McpError> {
115        Ok(CallToolResult::success(vec![Content::image(
116            TEST_IMAGE_B64,
117            "image/png",
118        )]))
119    }
120}
121
122#[tool_handler]
123impl ServerHandler for McpFixtureServer {
124    fn get_info(&self) -> ServerInfo {
125        ServerInfo {
126            protocol_version: ProtocolVersion::V_2025_03_26,
127            capabilities: ServerCapabilities::builder().enable_tools().build(),
128            server_info: Implementation {
129                name: "mcp-fixture".into(),
130                version: "1.0.0".into(),
131                ..Default::default()
132            },
133            instructions: Some("Test server with get_code and get_image tools.".into()),
134        }
135    }
136}
137
138pub struct McpFixture {
139    pub url: String,
140    handle: JoinHandle<()>,
141}
142
143impl Drop for McpFixture {
144    fn drop(&mut self) {
145        self.handle.abort();
146    }
147}
148
149type McpServiceFactory =
150    Box<dyn Fn() -> Result<Box<dyn DynService<RoleServer>>, std::io::Error> + Send + Sync>;
151
152impl McpFixture {
153    pub async fn new(expected_session_id: Option<ExpectedSessionId>) -> Self {
154        let service_factory: McpServiceFactory = match expected_session_id {
155            Some(expected_session_id) => Box::new(move || {
156                Ok(
157                    ValidatingService::new(McpFixtureServer::new(), expected_session_id.clone())
158                        .into_dyn(),
159                )
160            }),
161            None => Box::new(|| Ok(McpFixtureServer::new().into_dyn())),
162        };
163
164        let service = StreamableHttpService::new(
165            service_factory,
166            LocalSessionManager::default().into(),
167            StreamableHttpServerConfig::default(),
168        );
169        let router = axum::Router::new().nest_service("/mcp", service);
170        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
171        let addr = listener.local_addr().unwrap();
172        let url = format!("http://{addr}/mcp");
173
174        let handle = tokio::spawn(async move {
175            axum::serve(listener, router).await.unwrap();
176        });
177
178        Self { url, handle }
179    }
180}