1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use std::marker::PhantomData;
19use std::sync::Arc;
20
21pub mod state {
26 #[derive(Debug, Clone, Copy)]
28 pub struct Disconnected;
29
30 #[derive(Debug, Clone, Copy)]
32 pub struct Connected;
33
34 #[derive(Debug, Clone, Copy)]
36 pub struct Initializing;
37
38 #[derive(Debug, Clone, Copy)]
40 pub struct Ready;
41
42 #[derive(Debug, Clone, Copy)]
44 pub struct Closing;
45}
46
47#[derive(Debug)]
49pub struct ConnectionData {
50 pub client_capabilities: Option<ClientCapabilities>,
52 pub server_capabilities: ServerCapabilities,
54 pub server_info: ServerInfo,
56 pub protocol_version: Option<String>,
58 pub session_id: Option<String>,
60}
61
62impl ConnectionData {
63 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
65 Self {
66 client_capabilities: None,
67 server_capabilities,
68 server_info,
69 protocol_version: None,
70 session_id: None,
71 }
72 }
73}
74
75pub struct Connection<S> {
98 inner: Arc<ConnectionData>,
100 _state: PhantomData<S>,
102}
103
104impl<S> Clone for Connection<S> {
105 fn clone(&self) -> Self {
106 Self {
107 inner: Arc::clone(&self.inner),
108 _state: PhantomData,
109 }
110 }
111}
112
113impl<S> std::fmt::Debug for Connection<S> {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("Connection")
116 .field("inner", &self.inner)
117 .field("state", &std::any::type_name::<S>())
118 .finish()
119 }
120}
121
122impl Connection<state::Disconnected> {
123 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
125 Self {
126 inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
127 _state: PhantomData,
128 }
129 }
130
131 pub async fn connect(self) -> Result<Connection<state::Connected>, McpError> {
135 Ok(Connection {
137 inner: self.inner,
138 _state: PhantomData,
139 })
140 }
141}
142
143impl Connection<state::Connected> {
144 pub async fn initialize(
148 self,
149 _protocol_version: &str,
150 ) -> Result<Connection<state::Initializing>, McpError> {
151 Ok(Connection {
153 inner: self.inner,
154 _state: PhantomData,
155 })
156 }
157
158 pub async fn close(self) -> Result<(), McpError> {
160 Ok(())
162 }
163}
164
165impl Connection<state::Initializing> {
166 pub async fn complete(
170 self,
171 client_capabilities: ClientCapabilities,
172 protocol_version: String,
173 ) -> Result<Connection<state::Ready>, McpError> {
174 let mut data = ConnectionData::new(
177 self.inner.server_info.clone(),
178 self.inner.server_capabilities.clone(),
179 );
180 data.client_capabilities = Some(client_capabilities);
181 data.protocol_version = Some(protocol_version);
182
183 Ok(Connection {
184 inner: Arc::new(data),
185 _state: PhantomData,
186 })
187 }
188
189 pub async fn abort(self) -> Result<Connection<state::Disconnected>, McpError> {
191 Ok(Connection {
192 inner: self.inner,
193 _state: PhantomData,
194 })
195 }
196}
197
198impl Connection<state::Ready> {
199 pub fn client_capabilities(&self) -> &ClientCapabilities {
201 self.inner
202 .client_capabilities
203 .as_ref()
204 .expect("Ready connection must have client capabilities")
205 }
206
207 pub fn server_capabilities(&self) -> &ServerCapabilities {
209 &self.inner.server_capabilities
210 }
211
212 pub fn server_info(&self) -> &ServerInfo {
214 &self.inner.server_info
215 }
216
217 pub fn protocol_version(&self) -> &str {
219 self.inner
220 .protocol_version
221 .as_ref()
222 .expect("Ready connection must have protocol version")
223 }
224
225 pub async fn shutdown(self) -> Result<Connection<state::Closing>, McpError> {
229 Ok(Connection {
230 inner: self.inner,
231 _state: PhantomData,
232 })
233 }
234}
235
236impl Connection<state::Closing> {
237 pub async fn disconnect(self) -> Result<(), McpError> {
239 Ok(())
241 }
242}
243
244#[derive(Debug)]
249pub enum ConnectionState {
250 Disconnected(Connection<state::Disconnected>),
252 Connected(Connection<state::Connected>),
254 Initializing(Connection<state::Initializing>),
256 Ready(Connection<state::Ready>),
258 Closing(Connection<state::Closing>),
260}
261
262impl ConnectionState {
263 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
265 Self::Disconnected(Connection::new(server_info, server_capabilities))
266 }
267
268 pub fn is_ready(&self) -> bool {
270 matches!(self, ConnectionState::Ready(_))
271 }
272
273 pub fn is_disconnected(&self) -> bool {
275 matches!(self, ConnectionState::Disconnected(_))
276 }
277
278 pub fn state_name(&self) -> &'static str {
280 match self {
281 ConnectionState::Disconnected(_) => "Disconnected",
282 ConnectionState::Connected(_) => "Connected",
283 ConnectionState::Initializing(_) => "Initializing",
284 ConnectionState::Ready(_) => "Ready",
285 ConnectionState::Closing(_) => "Closing",
286 }
287 }
288}
289
290#[derive(Debug, Clone)]
292pub enum ConnectionEvent {
293 Connected,
295 InitializeStarted,
297 InitializeCompleted {
299 protocol_version: String,
301 },
302 InitializeFailed {
304 error: String,
306 },
307 ShutdownRequested,
309 Disconnected,
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
317
318 #[test]
319 fn test_connection_creation() {
320 let info = ServerInfo::new("test", "1.0.0");
321 let caps = ServerCapabilities::default();
322 let conn: Connection<state::Disconnected> = Connection::new(info, caps);
323
324 assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
325 }
326
327 #[tokio::test]
328 async fn test_connection_lifecycle() {
329 let info = ServerInfo::new("test", "1.0.0");
330 let caps = ServerCapabilities::default();
331
332 let conn = Connection::new(info, caps);
334
335 let conn = conn.connect().await.unwrap();
337
338 let conn = conn.initialize("2025-11-25").await.unwrap();
340
341 let conn = conn
343 .complete(ClientCapabilities::default(), "2025-11-25".to_string())
344 .await
345 .unwrap();
346
347 assert_eq!(conn.protocol_version(), "2025-11-25");
349
350 let conn = conn.shutdown().await.unwrap();
352
353 conn.disconnect().await.unwrap();
355 }
356
357 #[test]
358 fn test_connection_state_enum() {
359 let info = ServerInfo::new("test", "1.0.0");
360 let caps = ServerCapabilities::default();
361
362 let state = ConnectionState::new(info, caps);
363 assert!(state.is_disconnected());
364 assert!(!state.is_ready());
365 assert_eq!(state.state_name(), "Disconnected");
366 }
367}