1use crate::session::SessionStore;
4use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
5use mcpkit_server::ServerHandler;
6use std::sync::Arc;
7
8pub trait HasServerInfo {
10 fn server_info(&self) -> ServerInfo;
12}
13
14impl<H: ServerHandler> HasServerInfo for H {
15 fn server_info(&self) -> ServerInfo {
16 ServerHandler::server_info(self)
17 }
18}
19
20pub struct McpState<H> {
22 pub handler: Arc<H>,
24 pub server_info: ServerInfo,
26 pub sessions: SessionStore,
28 pub sse_sessions: SessionStore,
30}
31
32impl<H> McpState<H>
33where
34 H: HasServerInfo,
35{
36 pub fn new(handler: H) -> Self {
38 let server_info = handler.server_info();
39 Self {
40 handler: Arc::new(handler),
41 server_info,
42 sessions: SessionStore::new(),
43 sse_sessions: SessionStore::new(),
44 }
45 }
46
47 #[must_use]
49 pub fn capabilities(&self) -> ServerCapabilities
50 where
51 H: ServerHandler,
52 {
53 self.handler.capabilities()
54 }
55}
56
57impl<H> Clone for McpState<H> {
58 fn clone(&self) -> Self {
59 Self {
60 handler: Arc::clone(&self.handler),
61 server_info: self.server_info.clone(),
62 sessions: self.sessions.clone(),
63 sse_sessions: self.sse_sessions.clone(),
64 }
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use mcpkit_core::error::McpError;
72 use mcpkit_core::types::{
73 GetPromptResult, Prompt, Resource, ResourceContents, Tool, ToolOutput,
74 };
75 use mcpkit_server::context::Context;
76 use mcpkit_server::handler::{PromptHandler, ResourceHandler, ToolHandler};
77
78 struct TestHandler;
79
80 impl ServerHandler for TestHandler {
81 fn server_info(&self) -> ServerInfo {
82 ServerInfo::new("test-server", "1.0.0")
83 }
84
85 fn capabilities(&self) -> ServerCapabilities {
86 ServerCapabilities::new().with_tools().with_resources()
87 }
88 }
89
90 impl ToolHandler for TestHandler {
91 async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
92 Ok(vec![Tool::new("test").description("A test tool")])
93 }
94
95 async fn call_tool(
96 &self,
97 _name: &str,
98 _args: serde_json::Value,
99 _ctx: &Context<'_>,
100 ) -> Result<ToolOutput, McpError> {
101 Ok(ToolOutput::text("test result"))
102 }
103 }
104
105 impl ResourceHandler for TestHandler {
106 async fn list_resources(&self, _ctx: &Context<'_>) -> Result<Vec<Resource>, McpError> {
107 Ok(vec![])
108 }
109
110 async fn read_resource(
111 &self,
112 uri: &str,
113 _ctx: &Context<'_>,
114 ) -> Result<Vec<ResourceContents>, McpError> {
115 Ok(vec![ResourceContents::text(uri, "content")])
116 }
117 }
118
119 impl PromptHandler for TestHandler {
120 async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
121 Ok(vec![])
122 }
123
124 async fn get_prompt(
125 &self,
126 _name: &str,
127 _args: Option<serde_json::Map<String, serde_json::Value>>,
128 _ctx: &Context<'_>,
129 ) -> Result<GetPromptResult, McpError> {
130 Ok(GetPromptResult {
131 description: Some("Test".to_string()),
132 messages: vec![],
133 })
134 }
135 }
136
137 #[test]
138 fn test_mcp_state_creation() {
139 let state = McpState::new(TestHandler);
140
141 assert_eq!(state.server_info.name, "test-server");
142 assert_eq!(state.server_info.version, "1.0.0");
143 }
144
145 #[test]
146 fn test_mcp_state_capabilities() {
147 let state = McpState::new(TestHandler);
148 let caps = state.capabilities();
149
150 assert!(caps.tools.is_some());
151 assert!(caps.resources.is_some());
152 }
153
154 #[test]
155 fn test_mcp_state_clone() {
156 let state = McpState::new(TestHandler);
157 let cloned = state.clone();
158
159 assert_eq!(cloned.server_info.name, state.server_info.name);
160 assert_eq!(cloned.server_info.version, state.server_info.version);
161 }
162
163 #[test]
164 fn test_mcp_state_sessions() {
165 let state = McpState::new(TestHandler);
166
167 let id1 = state.sessions.create();
169 let id2 = state.sse_sessions.create();
170
171 assert!(state.sessions.exists(&id1));
172 assert!(state.sse_sessions.exists(&id2));
173 }
174
175 #[test]
176 fn test_has_server_info_trait() {
177 let handler = TestHandler;
178 let info = HasServerInfo::server_info(&handler);
179
180 assert_eq!(info.name, "test-server");
181 assert_eq!(info.version, "1.0.0");
182 }
183}