mcpkit_server/
state.rs

1//! Typestate connection management for MCP servers.
2//!
3//! This module implements the typestate pattern for managing
4//! connection lifecycle, ensuring compile-time correctness of
5//! state transitions.
6//!
7//! # Connection Lifecycle
8//!
9//! ```text
10//! Disconnected -> Connected -> Initializing -> Ready -> Closing
11//! ```
12//!
13//! Each state transition is enforced at compile time through
14//! different types, preventing invalid state transitions.
15
16use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use mcpkit_core::protocol_version::ProtocolVersion;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22/// Connection state markers.
23///
24/// These types represent different states in the connection lifecycle.
25/// They contain no data and are used purely for type-level state tracking.
26pub mod markers {
27    // Note: The types in this module are also re-exported as `state::state::*`
28    // for backwards compatibility with v0.2.5.
29    /// Connection is disconnected (initial state).
30    #[derive(Debug, Clone, Copy)]
31    pub struct Disconnected;
32
33    /// Connection is established but not initialized.
34    #[derive(Debug, Clone, Copy)]
35    pub struct Connected;
36
37    /// Connection is in the initialization handshake.
38    #[derive(Debug, Clone, Copy)]
39    pub struct Initializing;
40
41    /// Connection is fully initialized and ready for requests.
42    #[derive(Debug, Clone, Copy)]
43    pub struct Ready;
44
45    /// Connection is closing down.
46    #[derive(Debug, Clone, Copy)]
47    pub struct Closing;
48}
49
50/// Backwards compatibility alias for the `markers` module.
51///
52/// This module was renamed from `state` to `markers` in v0.2.6.
53/// This alias is provided for backwards compatibility with v0.2.5.
54#[doc(hidden)]
55#[deprecated(since = "0.2.6", note = "Use `markers` module instead")]
56pub mod state {
57    pub use super::markers::*;
58}
59
60/// Internal connection data shared across states.
61#[derive(Debug)]
62pub struct ConnectionData {
63    /// Client capabilities (set after initialization).
64    pub client_capabilities: Option<ClientCapabilities>,
65    /// Server capabilities advertised.
66    pub server_capabilities: ServerCapabilities,
67    /// Server information.
68    pub server_info: ServerInfo,
69    /// Protocol version negotiated.
70    ///
71    /// Stored as a type-safe [`ProtocolVersion`] enum for feature detection.
72    pub protocol_version: Option<ProtocolVersion>,
73    /// Session ID if applicable.
74    pub session_id: Option<String>,
75}
76
77impl ConnectionData {
78    /// Create new connection data.
79    #[must_use]
80    pub const fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
81        Self {
82            client_capabilities: None,
83            server_capabilities,
84            server_info,
85            protocol_version: None,
86            session_id: None,
87        }
88    }
89}
90
91/// A typestate connection that tracks lifecycle state at the type level.
92///
93/// The state parameter `S` ensures that only valid operations are
94/// available for each connection state.
95///
96/// # Example
97///
98/// ```rust
99/// use mcpkit_server::state::{Connection, markers};
100/// use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
101///
102/// // Start disconnected
103/// let conn: Connection<markers::Disconnected> = Connection::new(
104///     ServerInfo::new("my-server", "1.0.0"),
105///     ServerCapabilities::new().with_tools(),
106/// );
107///
108/// // The typestate pattern ensures compile-time safety:
109/// // - A Disconnected connection can only call connect()
110/// // - A Connected connection can only call initialize() or close()
111/// // - A Ready connection can access capabilities
112/// ```
113pub struct Connection<S> {
114    /// Shared connection data.
115    inner: Arc<ConnectionData>,
116    /// Phantom data to track state type.
117    _state: PhantomData<S>,
118}
119
120impl<S> Clone for Connection<S> {
121    fn clone(&self) -> Self {
122        Self {
123            inner: Arc::clone(&self.inner),
124            _state: PhantomData,
125        }
126    }
127}
128
129impl<S> std::fmt::Debug for Connection<S> {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("Connection")
132            .field("inner", &self.inner)
133            .field("state", &std::any::type_name::<S>())
134            .finish()
135    }
136}
137
138impl Connection<markers::Disconnected> {
139    /// Create a new disconnected connection.
140    #[must_use]
141    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
142        Self {
143            inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
144            _state: PhantomData,
145        }
146    }
147
148    /// Connect to establish a transport connection.
149    ///
150    /// This transitions from `Disconnected` to `Connected` state.
151    pub async fn connect(self) -> Result<Connection<markers::Connected>, McpError> {
152        // In a real implementation, this would establish the transport
153        Ok(Connection {
154            inner: self.inner,
155            _state: PhantomData,
156        })
157    }
158}
159
160impl Connection<markers::Connected> {
161    /// Start the initialization handshake.
162    ///
163    /// This transitions from `Connected` to `Initializing` state.
164    pub async fn initialize(
165        self,
166        _protocol_version: ProtocolVersion,
167    ) -> Result<Connection<markers::Initializing>, McpError> {
168        // In a real implementation, this would send the initialize request
169        Ok(Connection {
170            inner: self.inner,
171            _state: PhantomData,
172        })
173    }
174
175    /// Close the connection before initialization.
176    pub async fn close(self) -> Result<(), McpError> {
177        // Clean up resources
178        Ok(())
179    }
180}
181
182impl Connection<markers::Initializing> {
183    /// Complete the initialization handshake.
184    ///
185    /// This transitions from `Initializing` to `Ready` state.
186    pub async fn complete(
187        self,
188        client_capabilities: ClientCapabilities,
189        protocol_version: ProtocolVersion,
190    ) -> Result<Connection<markers::Ready>, McpError> {
191        // Update the connection data with negotiated values
192        // In a real implementation, we'd use interior mutability
193        let mut data = ConnectionData::new(
194            self.inner.server_info.clone(),
195            self.inner.server_capabilities.clone(),
196        );
197        data.client_capabilities = Some(client_capabilities);
198        data.protocol_version = Some(protocol_version);
199
200        Ok(Connection {
201            inner: Arc::new(data),
202            _state: PhantomData,
203        })
204    }
205
206    /// Abort initialization.
207    pub async fn abort(self) -> Result<Connection<markers::Disconnected>, McpError> {
208        Ok(Connection {
209            inner: self.inner,
210            _state: PhantomData,
211        })
212    }
213}
214
215impl Connection<markers::Ready> {
216    /// Get the client capabilities.
217    ///
218    /// # Panics
219    ///
220    /// This should never panic if the connection was properly initialized
221    /// through the typestate transitions. Use `try_client_capabilities()`
222    /// for a fallible version.
223    #[must_use]
224    pub fn client_capabilities(&self) -> &ClientCapabilities {
225        self.inner
226            .client_capabilities
227            .as_ref()
228            .expect("Ready connection must have client capabilities")
229    }
230
231    /// Try to get the client capabilities.
232    ///
233    /// Returns `None` if capabilities were not set (should not happen in normal use).
234    #[must_use]
235    pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
236        self.inner.client_capabilities.as_ref()
237    }
238
239    /// Get the server capabilities.
240    #[must_use]
241    pub fn server_capabilities(&self) -> &ServerCapabilities {
242        &self.inner.server_capabilities
243    }
244
245    /// Get the server info.
246    #[must_use]
247    pub fn server_info(&self) -> &ServerInfo {
248        &self.inner.server_info
249    }
250
251    /// Get the negotiated protocol version.
252    ///
253    /// # Panics
254    ///
255    /// This should never panic if the connection was properly initialized
256    /// through the typestate transitions. Use `try_protocol_version()`
257    /// for a fallible version.
258    #[must_use]
259    pub fn protocol_version(&self) -> ProtocolVersion {
260        self.inner
261            .protocol_version
262            .expect("Ready connection must have protocol version")
263    }
264
265    /// Try to get the negotiated protocol version.
266    ///
267    /// Returns `None` if version was not set (should not happen in normal use).
268    #[must_use]
269    pub fn try_protocol_version(&self) -> Option<ProtocolVersion> {
270        self.inner.protocol_version
271    }
272
273    /// Start graceful shutdown.
274    ///
275    /// This transitions from `Ready` to `Closing` state.
276    pub async fn shutdown(self) -> Result<Connection<markers::Closing>, McpError> {
277        Ok(Connection {
278            inner: self.inner,
279            _state: PhantomData,
280        })
281    }
282}
283
284impl Connection<markers::Closing> {
285    /// Complete the shutdown and disconnect.
286    pub async fn disconnect(self) -> Result<(), McpError> {
287        // Clean up resources
288        Ok(())
289    }
290}
291
292/// A state machine wrapper for connections that allows runtime state tracking.
293///
294/// This provides an alternative to the pure typestate approach when
295/// runtime state inspection is needed.
296#[derive(Debug)]
297pub enum ConnectionState {
298    /// Not connected.
299    Disconnected(Connection<markers::Disconnected>),
300    /// Connected but not initialized.
301    Connected(Connection<markers::Connected>),
302    /// In initialization handshake.
303    Initializing(Connection<markers::Initializing>),
304    /// Ready for requests.
305    Ready(Connection<markers::Ready>),
306    /// Closing down.
307    Closing(Connection<markers::Closing>),
308}
309
310impl ConnectionState {
311    /// Create a new disconnected connection state.
312    #[must_use]
313    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
314        Self::Disconnected(Connection::new(server_info, server_capabilities))
315    }
316
317    /// Check if the connection is ready for requests.
318    #[must_use]
319    pub const fn is_ready(&self) -> bool {
320        matches!(self, Self::Ready(_))
321    }
322
323    /// Check if the connection is disconnected.
324    #[must_use]
325    pub const fn is_disconnected(&self) -> bool {
326        matches!(self, Self::Disconnected(_))
327    }
328
329    /// Get the current state name.
330    #[must_use]
331    pub const fn state_name(&self) -> &'static str {
332        match self {
333            Self::Disconnected(_) => "Disconnected",
334            Self::Connected(_) => "Connected",
335            Self::Initializing(_) => "Initializing",
336            Self::Ready(_) => "Ready",
337            Self::Closing(_) => "Closing",
338        }
339    }
340}
341
342/// Transition events for connection state changes.
343#[derive(Debug, Clone)]
344pub enum ConnectionEvent {
345    /// Connection established.
346    Connected,
347    /// Initialization started.
348    InitializeStarted,
349    /// Initialization completed successfully.
350    InitializeCompleted {
351        /// Negotiated protocol version.
352        protocol_version: ProtocolVersion,
353    },
354    /// Initialization failed.
355    InitializeFailed {
356        /// Error message.
357        error: String,
358    },
359    /// Shutdown requested.
360    ShutdownRequested,
361    /// Connection closed.
362    Disconnected,
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
369    use mcpkit_core::protocol_version::ProtocolVersion;
370
371    #[test]
372    fn test_connection_creation() {
373        let info = ServerInfo::new("test", "1.0.0");
374        let caps = ServerCapabilities::default();
375        let conn: Connection<markers::Disconnected> = Connection::new(info, caps);
376
377        assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
378    }
379
380    #[tokio::test]
381    async fn test_connection_lifecycle() -> Result<(), Box<dyn std::error::Error>> {
382        let info = ServerInfo::new("test", "1.0.0");
383        let caps = ServerCapabilities::default();
384
385        // Start disconnected
386        let conn = Connection::new(info, caps);
387
388        // Connect
389        let conn = conn.connect().await?;
390
391        // Initialize
392        let conn = conn.initialize(ProtocolVersion::V2025_11_25).await?;
393
394        // Complete
395        let conn = conn
396            .complete(ClientCapabilities::default(), ProtocolVersion::V2025_11_25)
397            .await?;
398
399        // Verify ready state
400        assert_eq!(conn.protocol_version(), ProtocolVersion::V2025_11_25);
401
402        // Shutdown
403        let conn = conn.shutdown().await?;
404
405        // Disconnect
406        conn.disconnect().await?;
407
408        Ok(())
409    }
410
411    #[test]
412    fn test_connection_state_enum() {
413        let info = ServerInfo::new("test", "1.0.0");
414        let caps = ServerCapabilities::default();
415
416        let state = ConnectionState::new(info, caps);
417        assert!(state.is_disconnected());
418        assert!(!state.is_ready());
419        assert_eq!(state.state_name(), "Disconnected");
420    }
421}