1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use mcpkit_core::protocol_version::ProtocolVersion;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22pub mod state {
27 #[derive(Debug, Clone, Copy)]
29 pub struct Disconnected;
30
31 #[derive(Debug, Clone, Copy)]
33 pub struct Connected;
34
35 #[derive(Debug, Clone, Copy)]
37 pub struct Initializing;
38
39 #[derive(Debug, Clone, Copy)]
41 pub struct Ready;
42
43 #[derive(Debug, Clone, Copy)]
45 pub struct Closing;
46}
47
48#[derive(Debug)]
50pub struct ConnectionData {
51 pub client_capabilities: Option<ClientCapabilities>,
53 pub server_capabilities: ServerCapabilities,
55 pub server_info: ServerInfo,
57 pub protocol_version: Option<ProtocolVersion>,
61 pub session_id: Option<String>,
63}
64
65impl ConnectionData {
66 #[must_use]
68 pub const fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
69 Self {
70 client_capabilities: None,
71 server_capabilities,
72 server_info,
73 protocol_version: None,
74 session_id: None,
75 }
76 }
77}
78
79pub struct Connection<S> {
102 inner: Arc<ConnectionData>,
104 _state: PhantomData<S>,
106}
107
108impl<S> Clone for Connection<S> {
109 fn clone(&self) -> Self {
110 Self {
111 inner: Arc::clone(&self.inner),
112 _state: PhantomData,
113 }
114 }
115}
116
117impl<S> std::fmt::Debug for Connection<S> {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("Connection")
120 .field("inner", &self.inner)
121 .field("state", &std::any::type_name::<S>())
122 .finish()
123 }
124}
125
126impl Connection<state::Disconnected> {
127 #[must_use]
129 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
130 Self {
131 inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
132 _state: PhantomData,
133 }
134 }
135
136 pub async fn connect(self) -> Result<Connection<state::Connected>, McpError> {
140 Ok(Connection {
142 inner: self.inner,
143 _state: PhantomData,
144 })
145 }
146}
147
148impl Connection<state::Connected> {
149 pub async fn initialize(
153 self,
154 _protocol_version: ProtocolVersion,
155 ) -> Result<Connection<state::Initializing>, McpError> {
156 Ok(Connection {
158 inner: self.inner,
159 _state: PhantomData,
160 })
161 }
162
163 pub async fn close(self) -> Result<(), McpError> {
165 Ok(())
167 }
168}
169
170impl Connection<state::Initializing> {
171 pub async fn complete(
175 self,
176 client_capabilities: ClientCapabilities,
177 protocol_version: ProtocolVersion,
178 ) -> Result<Connection<state::Ready>, McpError> {
179 let mut data = ConnectionData::new(
182 self.inner.server_info.clone(),
183 self.inner.server_capabilities.clone(),
184 );
185 data.client_capabilities = Some(client_capabilities);
186 data.protocol_version = Some(protocol_version);
187
188 Ok(Connection {
189 inner: Arc::new(data),
190 _state: PhantomData,
191 })
192 }
193
194 pub async fn abort(self) -> Result<Connection<state::Disconnected>, McpError> {
196 Ok(Connection {
197 inner: self.inner,
198 _state: PhantomData,
199 })
200 }
201}
202
203impl Connection<state::Ready> {
204 #[must_use]
212 pub fn client_capabilities(&self) -> &ClientCapabilities {
213 self.inner
214 .client_capabilities
215 .as_ref()
216 .expect("Ready connection must have client capabilities")
217 }
218
219 #[must_use]
223 pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
224 self.inner.client_capabilities.as_ref()
225 }
226
227 #[must_use]
229 pub fn server_capabilities(&self) -> &ServerCapabilities {
230 &self.inner.server_capabilities
231 }
232
233 #[must_use]
235 pub fn server_info(&self) -> &ServerInfo {
236 &self.inner.server_info
237 }
238
239 #[must_use]
247 pub fn protocol_version(&self) -> ProtocolVersion {
248 self.inner
249 .protocol_version
250 .expect("Ready connection must have protocol version")
251 }
252
253 #[must_use]
257 pub fn try_protocol_version(&self) -> Option<ProtocolVersion> {
258 self.inner.protocol_version
259 }
260
261 pub async fn shutdown(self) -> Result<Connection<state::Closing>, McpError> {
265 Ok(Connection {
266 inner: self.inner,
267 _state: PhantomData,
268 })
269 }
270}
271
272impl Connection<state::Closing> {
273 pub async fn disconnect(self) -> Result<(), McpError> {
275 Ok(())
277 }
278}
279
280#[derive(Debug)]
285pub enum ConnectionState {
286 Disconnected(Connection<state::Disconnected>),
288 Connected(Connection<state::Connected>),
290 Initializing(Connection<state::Initializing>),
292 Ready(Connection<state::Ready>),
294 Closing(Connection<state::Closing>),
296}
297
298impl ConnectionState {
299 #[must_use]
301 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
302 Self::Disconnected(Connection::new(server_info, server_capabilities))
303 }
304
305 #[must_use]
307 pub const fn is_ready(&self) -> bool {
308 matches!(self, Self::Ready(_))
309 }
310
311 #[must_use]
313 pub const fn is_disconnected(&self) -> bool {
314 matches!(self, Self::Disconnected(_))
315 }
316
317 #[must_use]
319 pub const fn state_name(&self) -> &'static str {
320 match self {
321 Self::Disconnected(_) => "Disconnected",
322 Self::Connected(_) => "Connected",
323 Self::Initializing(_) => "Initializing",
324 Self::Ready(_) => "Ready",
325 Self::Closing(_) => "Closing",
326 }
327 }
328}
329
330#[derive(Debug, Clone)]
332pub enum ConnectionEvent {
333 Connected,
335 InitializeStarted,
337 InitializeCompleted {
339 protocol_version: ProtocolVersion,
341 },
342 InitializeFailed {
344 error: String,
346 },
347 ShutdownRequested,
349 Disconnected,
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
357 use mcpkit_core::protocol_version::ProtocolVersion;
358
359 #[test]
360 fn test_connection_creation() {
361 let info = ServerInfo::new("test", "1.0.0");
362 let caps = ServerCapabilities::default();
363 let conn: Connection<state::Disconnected> = Connection::new(info, caps);
364
365 assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
366 }
367
368 #[tokio::test]
369 async fn test_connection_lifecycle() {
370 let info = ServerInfo::new("test", "1.0.0");
371 let caps = ServerCapabilities::default();
372
373 let conn = Connection::new(info, caps);
375
376 let conn = conn.connect().await.unwrap();
378
379 let conn = conn.initialize(ProtocolVersion::V2025_11_25).await.unwrap();
381
382 let conn = conn
384 .complete(ClientCapabilities::default(), ProtocolVersion::V2025_11_25)
385 .await
386 .unwrap();
387
388 assert_eq!(conn.protocol_version(), ProtocolVersion::V2025_11_25);
390
391 let conn = conn.shutdown().await.unwrap();
393
394 conn.disconnect().await.unwrap();
396 }
397
398 #[test]
399 fn test_connection_state_enum() {
400 let info = ServerInfo::new("test", "1.0.0");
401 let caps = ServerCapabilities::default();
402
403 let state = ConnectionState::new(info, caps);
404 assert!(state.is_disconnected());
405 assert!(!state.is_ready());
406 assert_eq!(state.state_name(), "Disconnected");
407 }
408}