1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use std::marker::PhantomData;
19use std::sync::Arc;
20
21pub mod state {
26 #[derive(Debug, Clone, Copy)]
28 pub struct Disconnected;
29
30 #[derive(Debug, Clone, Copy)]
32 pub struct Connected;
33
34 #[derive(Debug, Clone, Copy)]
36 pub struct Initializing;
37
38 #[derive(Debug, Clone, Copy)]
40 pub struct Ready;
41
42 #[derive(Debug, Clone, Copy)]
44 pub struct Closing;
45}
46
47#[derive(Debug)]
49pub struct ConnectionData {
50 pub client_capabilities: Option<ClientCapabilities>,
52 pub server_capabilities: ServerCapabilities,
54 pub server_info: ServerInfo,
56 pub protocol_version: Option<String>,
58 pub session_id: Option<String>,
60}
61
62impl ConnectionData {
63 #[must_use]
65 pub const fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
66 Self {
67 client_capabilities: None,
68 server_capabilities,
69 server_info,
70 protocol_version: None,
71 session_id: None,
72 }
73 }
74}
75
76pub struct Connection<S> {
99 inner: Arc<ConnectionData>,
101 _state: PhantomData<S>,
103}
104
105impl<S> Clone for Connection<S> {
106 fn clone(&self) -> Self {
107 Self {
108 inner: Arc::clone(&self.inner),
109 _state: PhantomData,
110 }
111 }
112}
113
114impl<S> std::fmt::Debug for Connection<S> {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.debug_struct("Connection")
117 .field("inner", &self.inner)
118 .field("state", &std::any::type_name::<S>())
119 .finish()
120 }
121}
122
123impl Connection<state::Disconnected> {
124 #[must_use]
126 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
127 Self {
128 inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
129 _state: PhantomData,
130 }
131 }
132
133 pub async fn connect(self) -> Result<Connection<state::Connected>, McpError> {
137 Ok(Connection {
139 inner: self.inner,
140 _state: PhantomData,
141 })
142 }
143}
144
145impl Connection<state::Connected> {
146 pub async fn initialize(
150 self,
151 _protocol_version: &str,
152 ) -> Result<Connection<state::Initializing>, McpError> {
153 Ok(Connection {
155 inner: self.inner,
156 _state: PhantomData,
157 })
158 }
159
160 pub async fn close(self) -> Result<(), McpError> {
162 Ok(())
164 }
165}
166
167impl Connection<state::Initializing> {
168 pub async fn complete(
172 self,
173 client_capabilities: ClientCapabilities,
174 protocol_version: String,
175 ) -> Result<Connection<state::Ready>, McpError> {
176 let mut data = ConnectionData::new(
179 self.inner.server_info.clone(),
180 self.inner.server_capabilities.clone(),
181 );
182 data.client_capabilities = Some(client_capabilities);
183 data.protocol_version = Some(protocol_version);
184
185 Ok(Connection {
186 inner: Arc::new(data),
187 _state: PhantomData,
188 })
189 }
190
191 pub async fn abort(self) -> Result<Connection<state::Disconnected>, McpError> {
193 Ok(Connection {
194 inner: self.inner,
195 _state: PhantomData,
196 })
197 }
198}
199
200impl Connection<state::Ready> {
201 #[must_use]
209 pub fn client_capabilities(&self) -> &ClientCapabilities {
210 self.inner
211 .client_capabilities
212 .as_ref()
213 .expect("Ready connection must have client capabilities")
214 }
215
216 #[must_use]
220 pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
221 self.inner.client_capabilities.as_ref()
222 }
223
224 #[must_use]
226 pub fn server_capabilities(&self) -> &ServerCapabilities {
227 &self.inner.server_capabilities
228 }
229
230 #[must_use]
232 pub fn server_info(&self) -> &ServerInfo {
233 &self.inner.server_info
234 }
235
236 #[must_use]
244 pub fn protocol_version(&self) -> &str {
245 self.inner
246 .protocol_version
247 .as_ref()
248 .expect("Ready connection must have protocol version")
249 }
250
251 #[must_use]
255 pub fn try_protocol_version(&self) -> Option<&str> {
256 self.inner.protocol_version.as_deref()
257 }
258
259 pub async fn shutdown(self) -> Result<Connection<state::Closing>, McpError> {
263 Ok(Connection {
264 inner: self.inner,
265 _state: PhantomData,
266 })
267 }
268}
269
270impl Connection<state::Closing> {
271 pub async fn disconnect(self) -> Result<(), McpError> {
273 Ok(())
275 }
276}
277
278#[derive(Debug)]
283pub enum ConnectionState {
284 Disconnected(Connection<state::Disconnected>),
286 Connected(Connection<state::Connected>),
288 Initializing(Connection<state::Initializing>),
290 Ready(Connection<state::Ready>),
292 Closing(Connection<state::Closing>),
294}
295
296impl ConnectionState {
297 #[must_use]
299 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
300 Self::Disconnected(Connection::new(server_info, server_capabilities))
301 }
302
303 #[must_use]
305 pub const fn is_ready(&self) -> bool {
306 matches!(self, Self::Ready(_))
307 }
308
309 #[must_use]
311 pub const fn is_disconnected(&self) -> bool {
312 matches!(self, Self::Disconnected(_))
313 }
314
315 #[must_use]
317 pub const fn state_name(&self) -> &'static str {
318 match self {
319 Self::Disconnected(_) => "Disconnected",
320 Self::Connected(_) => "Connected",
321 Self::Initializing(_) => "Initializing",
322 Self::Ready(_) => "Ready",
323 Self::Closing(_) => "Closing",
324 }
325 }
326}
327
328#[derive(Debug, Clone)]
330pub enum ConnectionEvent {
331 Connected,
333 InitializeStarted,
335 InitializeCompleted {
337 protocol_version: String,
339 },
340 InitializeFailed {
342 error: String,
344 },
345 ShutdownRequested,
347 Disconnected,
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
355
356 #[test]
357 fn test_connection_creation() {
358 let info = ServerInfo::new("test", "1.0.0");
359 let caps = ServerCapabilities::default();
360 let conn: Connection<state::Disconnected> = Connection::new(info, caps);
361
362 assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
363 }
364
365 #[tokio::test]
366 async fn test_connection_lifecycle() {
367 let info = ServerInfo::new("test", "1.0.0");
368 let caps = ServerCapabilities::default();
369
370 let conn = Connection::new(info, caps);
372
373 let conn = conn.connect().await.unwrap();
375
376 let conn = conn.initialize("2025-11-25").await.unwrap();
378
379 let conn = conn
381 .complete(ClientCapabilities::default(), "2025-11-25".to_string())
382 .await
383 .unwrap();
384
385 assert_eq!(conn.protocol_version(), "2025-11-25");
387
388 let conn = conn.shutdown().await.unwrap();
390
391 conn.disconnect().await.unwrap();
393 }
394
395 #[test]
396 fn test_connection_state_enum() {
397 let info = ServerInfo::new("test", "1.0.0");
398 let caps = ServerCapabilities::default();
399
400 let state = ConnectionState::new(info, caps);
401 assert!(state.is_disconnected());
402 assert!(!state.is_ready());
403 assert_eq!(state.state_name(), "Disconnected");
404 }
405}