mcpkit_rocket/
state.rs

1//! State management for MCP Rocket integration.
2
3use crate::session::SessionStore;
4use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
5use mcpkit_server::ServerHandler;
6use std::sync::Arc;
7
8/// Trait for handlers that provide server info.
9pub trait HasServerInfo {
10    /// Get the server info.
11    fn server_info(&self) -> ServerInfo;
12}
13
14impl<H: ServerHandler> HasServerInfo for H {
15    fn server_info(&self) -> ServerInfo {
16        ServerHandler::server_info(self)
17    }
18}
19
20/// Shared state for MCP request handling.
21pub struct McpState<H> {
22    /// The MCP handler implementation.
23    pub handler: Arc<H>,
24    /// Server info for initialization responses.
25    pub server_info: ServerInfo,
26    /// Session manager for tracking client sessions.
27    pub sessions: SessionStore,
28    /// SSE session manager for Server-Sent Events.
29    pub sse_sessions: SessionStore,
30}
31
32impl<H> McpState<H>
33where
34    H: HasServerInfo,
35{
36    /// Create new MCP state.
37    pub fn new(handler: H) -> Self {
38        let server_info = handler.server_info();
39        Self {
40            handler: Arc::new(handler),
41            server_info,
42            sessions: SessionStore::new(),
43            sse_sessions: SessionStore::new(),
44        }
45    }
46
47    /// Get the handler's capabilities.
48    #[must_use]
49    pub fn capabilities(&self) -> ServerCapabilities
50    where
51        H: ServerHandler,
52    {
53        self.handler.capabilities()
54    }
55}
56
57impl<H> Clone for McpState<H> {
58    fn clone(&self) -> Self {
59        Self {
60            handler: Arc::clone(&self.handler),
61            server_info: self.server_info.clone(),
62            sessions: self.sessions.clone(),
63            sse_sessions: self.sse_sessions.clone(),
64        }
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use mcpkit_core::error::McpError;
72    use mcpkit_core::types::{
73        GetPromptResult, Prompt, Resource, ResourceContents, Tool, ToolOutput,
74    };
75    use mcpkit_server::context::Context;
76    use mcpkit_server::handler::{PromptHandler, ResourceHandler, ToolHandler};
77
78    struct TestHandler;
79
80    impl ServerHandler for TestHandler {
81        fn server_info(&self) -> ServerInfo {
82            ServerInfo::new("test-server", "1.0.0")
83        }
84
85        fn capabilities(&self) -> ServerCapabilities {
86            ServerCapabilities::new().with_tools().with_resources()
87        }
88    }
89
90    impl ToolHandler for TestHandler {
91        async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
92            Ok(vec![Tool::new("test").description("A test tool")])
93        }
94
95        async fn call_tool(
96            &self,
97            _name: &str,
98            _args: serde_json::Value,
99            _ctx: &Context<'_>,
100        ) -> Result<ToolOutput, McpError> {
101            Ok(ToolOutput::text("test result"))
102        }
103    }
104
105    impl ResourceHandler for TestHandler {
106        async fn list_resources(&self, _ctx: &Context<'_>) -> Result<Vec<Resource>, McpError> {
107            Ok(vec![])
108        }
109
110        async fn read_resource(
111            &self,
112            uri: &str,
113            _ctx: &Context<'_>,
114        ) -> Result<Vec<ResourceContents>, McpError> {
115            Ok(vec![ResourceContents::text(uri, "content")])
116        }
117    }
118
119    impl PromptHandler for TestHandler {
120        async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
121            Ok(vec![])
122        }
123
124        async fn get_prompt(
125            &self,
126            _name: &str,
127            _args: Option<serde_json::Map<String, serde_json::Value>>,
128            _ctx: &Context<'_>,
129        ) -> Result<GetPromptResult, McpError> {
130            Ok(GetPromptResult {
131                description: Some("Test".to_string()),
132                messages: vec![],
133            })
134        }
135    }
136
137    #[test]
138    fn test_mcp_state_creation() {
139        let state = McpState::new(TestHandler);
140
141        assert_eq!(state.server_info.name, "test-server");
142        assert_eq!(state.server_info.version, "1.0.0");
143    }
144
145    #[test]
146    fn test_mcp_state_capabilities() {
147        let state = McpState::new(TestHandler);
148        let caps = state.capabilities();
149
150        assert!(caps.tools.is_some());
151        assert!(caps.resources.is_some());
152    }
153
154    #[test]
155    fn test_mcp_state_clone() {
156        let state = McpState::new(TestHandler);
157        let cloned = state.clone();
158
159        assert_eq!(cloned.server_info.name, state.server_info.name);
160        assert_eq!(cloned.server_info.version, state.server_info.version);
161    }
162
163    #[test]
164    fn test_mcp_state_sessions() {
165        let state = McpState::new(TestHandler);
166
167        // Create sessions in both stores
168        let id1 = state.sessions.create();
169        let id2 = state.sse_sessions.create();
170
171        assert!(state.sessions.exists(&id1));
172        assert!(state.sse_sessions.exists(&id2));
173    }
174
175    #[test]
176    fn test_has_server_info_trait() {
177        let handler = TestHandler;
178        let info = HasServerInfo::server_info(&handler);
179
180        assert_eq!(info.name, "test-server");
181        assert_eq!(info.version, "1.0.0");
182    }
183}