1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use mcpkit_core::protocol_version::ProtocolVersion;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22pub mod markers {
27 #[derive(Debug, Clone, Copy)]
31 pub struct Disconnected;
32
33 #[derive(Debug, Clone, Copy)]
35 pub struct Connected;
36
37 #[derive(Debug, Clone, Copy)]
39 pub struct Initializing;
40
41 #[derive(Debug, Clone, Copy)]
43 pub struct Ready;
44
45 #[derive(Debug, Clone, Copy)]
47 pub struct Closing;
48}
49
50#[doc(hidden)]
55#[deprecated(since = "0.2.6", note = "Use `markers` module instead")]
56pub mod state {
57 pub use super::markers::*;
58}
59
60#[derive(Debug)]
62pub struct ConnectionData {
63 pub client_capabilities: Option<ClientCapabilities>,
65 pub server_capabilities: ServerCapabilities,
67 pub server_info: ServerInfo,
69 pub protocol_version: Option<ProtocolVersion>,
73 pub session_id: Option<String>,
75}
76
77impl ConnectionData {
78 #[must_use]
80 pub const fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
81 Self {
82 client_capabilities: None,
83 server_capabilities,
84 server_info,
85 protocol_version: None,
86 session_id: None,
87 }
88 }
89}
90
91pub struct Connection<S> {
114 inner: Arc<ConnectionData>,
116 _state: PhantomData<S>,
118}
119
120impl<S> Clone for Connection<S> {
121 fn clone(&self) -> Self {
122 Self {
123 inner: Arc::clone(&self.inner),
124 _state: PhantomData,
125 }
126 }
127}
128
129impl<S> std::fmt::Debug for Connection<S> {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("Connection")
132 .field("inner", &self.inner)
133 .field("state", &std::any::type_name::<S>())
134 .finish()
135 }
136}
137
138impl Connection<markers::Disconnected> {
139 #[must_use]
141 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
142 Self {
143 inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
144 _state: PhantomData,
145 }
146 }
147
148 pub async fn connect(self) -> Result<Connection<markers::Connected>, McpError> {
152 Ok(Connection {
154 inner: self.inner,
155 _state: PhantomData,
156 })
157 }
158}
159
160impl Connection<markers::Connected> {
161 pub async fn initialize(
165 self,
166 _protocol_version: ProtocolVersion,
167 ) -> Result<Connection<markers::Initializing>, McpError> {
168 Ok(Connection {
170 inner: self.inner,
171 _state: PhantomData,
172 })
173 }
174
175 pub async fn close(self) -> Result<(), McpError> {
177 Ok(())
179 }
180}
181
182impl Connection<markers::Initializing> {
183 pub async fn complete(
187 self,
188 client_capabilities: ClientCapabilities,
189 protocol_version: ProtocolVersion,
190 ) -> Result<Connection<markers::Ready>, McpError> {
191 let mut data = ConnectionData::new(
194 self.inner.server_info.clone(),
195 self.inner.server_capabilities.clone(),
196 );
197 data.client_capabilities = Some(client_capabilities);
198 data.protocol_version = Some(protocol_version);
199
200 Ok(Connection {
201 inner: Arc::new(data),
202 _state: PhantomData,
203 })
204 }
205
206 pub async fn abort(self) -> Result<Connection<markers::Disconnected>, McpError> {
208 Ok(Connection {
209 inner: self.inner,
210 _state: PhantomData,
211 })
212 }
213}
214
215impl Connection<markers::Ready> {
216 #[must_use]
224 pub fn client_capabilities(&self) -> &ClientCapabilities {
225 self.inner
226 .client_capabilities
227 .as_ref()
228 .expect("Ready connection must have client capabilities")
229 }
230
231 #[must_use]
235 pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
236 self.inner.client_capabilities.as_ref()
237 }
238
239 #[must_use]
241 pub fn server_capabilities(&self) -> &ServerCapabilities {
242 &self.inner.server_capabilities
243 }
244
245 #[must_use]
247 pub fn server_info(&self) -> &ServerInfo {
248 &self.inner.server_info
249 }
250
251 #[must_use]
259 pub fn protocol_version(&self) -> ProtocolVersion {
260 self.inner
261 .protocol_version
262 .expect("Ready connection must have protocol version")
263 }
264
265 #[must_use]
269 pub fn try_protocol_version(&self) -> Option<ProtocolVersion> {
270 self.inner.protocol_version
271 }
272
273 pub async fn shutdown(self) -> Result<Connection<markers::Closing>, McpError> {
277 Ok(Connection {
278 inner: self.inner,
279 _state: PhantomData,
280 })
281 }
282}
283
284impl Connection<markers::Closing> {
285 pub async fn disconnect(self) -> Result<(), McpError> {
287 Ok(())
289 }
290}
291
292#[derive(Debug)]
297pub enum ConnectionState {
298 Disconnected(Connection<markers::Disconnected>),
300 Connected(Connection<markers::Connected>),
302 Initializing(Connection<markers::Initializing>),
304 Ready(Connection<markers::Ready>),
306 Closing(Connection<markers::Closing>),
308}
309
310impl ConnectionState {
311 #[must_use]
313 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
314 Self::Disconnected(Connection::new(server_info, server_capabilities))
315 }
316
317 #[must_use]
319 pub const fn is_ready(&self) -> bool {
320 matches!(self, Self::Ready(_))
321 }
322
323 #[must_use]
325 pub const fn is_disconnected(&self) -> bool {
326 matches!(self, Self::Disconnected(_))
327 }
328
329 #[must_use]
331 pub const fn state_name(&self) -> &'static str {
332 match self {
333 Self::Disconnected(_) => "Disconnected",
334 Self::Connected(_) => "Connected",
335 Self::Initializing(_) => "Initializing",
336 Self::Ready(_) => "Ready",
337 Self::Closing(_) => "Closing",
338 }
339 }
340}
341
342#[derive(Debug, Clone)]
344pub enum ConnectionEvent {
345 Connected,
347 InitializeStarted,
349 InitializeCompleted {
351 protocol_version: ProtocolVersion,
353 },
354 InitializeFailed {
356 error: String,
358 },
359 ShutdownRequested,
361 Disconnected,
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
369 use mcpkit_core::protocol_version::ProtocolVersion;
370
371 #[test]
372 fn test_connection_creation() {
373 let info = ServerInfo::new("test", "1.0.0");
374 let caps = ServerCapabilities::default();
375 let conn: Connection<markers::Disconnected> = Connection::new(info, caps);
376
377 assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
378 }
379
380 #[tokio::test]
381 async fn test_connection_lifecycle() -> Result<(), Box<dyn std::error::Error>> {
382 let info = ServerInfo::new("test", "1.0.0");
383 let caps = ServerCapabilities::default();
384
385 let conn = Connection::new(info, caps);
387
388 let conn = conn.connect().await?;
390
391 let conn = conn.initialize(ProtocolVersion::V2025_11_25).await?;
393
394 let conn = conn
396 .complete(ClientCapabilities::default(), ProtocolVersion::V2025_11_25)
397 .await?;
398
399 assert_eq!(conn.protocol_version(), ProtocolVersion::V2025_11_25);
401
402 let conn = conn.shutdown().await?;
404
405 conn.disconnect().await?;
407
408 Ok(())
409 }
410
411 #[test]
412 fn test_connection_state_enum() {
413 let info = ServerInfo::new("test", "1.0.0");
414 let caps = ServerCapabilities::default();
415
416 let state = ConnectionState::new(info, caps);
417 assert!(state.is_disconnected());
418 assert!(!state.is_ready());
419 assert_eq!(state.state_name(), "Disconnected");
420 }
421}