mcpkit_server/
state.rs

1//! Typestate connection management for MCP servers.
2//!
3//! This module implements the typestate pattern for managing
4//! connection lifecycle, ensuring compile-time correctness of
5//! state transitions.
6//!
7//! # Connection Lifecycle
8//!
9//! ```text
10//! Disconnected -> Connected -> Initializing -> Ready -> Closing
11//! ```
12//!
13//! Each state transition is enforced at compile time through
14//! different types, preventing invalid state transitions.
15
16use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use std::marker::PhantomData;
19use std::sync::Arc;
20
21/// Connection state markers.
22///
23/// These types represent different states in the connection lifecycle.
24/// They contain no data and are used purely for type-level state tracking.
25pub mod state {
26    /// Connection is disconnected (initial state).
27    #[derive(Debug, Clone, Copy)]
28    pub struct Disconnected;
29
30    /// Connection is established but not initialized.
31    #[derive(Debug, Clone, Copy)]
32    pub struct Connected;
33
34    /// Connection is in the initialization handshake.
35    #[derive(Debug, Clone, Copy)]
36    pub struct Initializing;
37
38    /// Connection is fully initialized and ready for requests.
39    #[derive(Debug, Clone, Copy)]
40    pub struct Ready;
41
42    /// Connection is closing down.
43    #[derive(Debug, Clone, Copy)]
44    pub struct Closing;
45}
46
47/// Internal connection data shared across states.
48#[derive(Debug)]
49pub struct ConnectionData {
50    /// Client capabilities (set after initialization).
51    pub client_capabilities: Option<ClientCapabilities>,
52    /// Server capabilities advertised.
53    pub server_capabilities: ServerCapabilities,
54    /// Server information.
55    pub server_info: ServerInfo,
56    /// Protocol version negotiated.
57    pub protocol_version: Option<String>,
58    /// Session ID if applicable.
59    pub session_id: Option<String>,
60}
61
62impl ConnectionData {
63    /// Create new connection data.
64    #[must_use]
65    pub const fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
66        Self {
67            client_capabilities: None,
68            server_capabilities,
69            server_info,
70            protocol_version: None,
71            session_id: None,
72        }
73    }
74}
75
76/// A typestate connection that tracks lifecycle state at the type level.
77///
78/// The state parameter `S` ensures that only valid operations are
79/// available for each connection state.
80///
81/// # Example
82///
83/// ```rust
84/// use mcpkit_server::state::{Connection, state};
85/// use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
86///
87/// // Start disconnected
88/// let conn: Connection<state::Disconnected> = Connection::new(
89///     ServerInfo::new("my-server", "1.0.0"),
90///     ServerCapabilities::new().with_tools(),
91/// );
92///
93/// // The typestate pattern ensures compile-time safety:
94/// // - A Disconnected connection can only call connect()
95/// // - A Connected connection can only call initialize() or close()
96/// // - A Ready connection can access capabilities
97/// ```
98pub struct Connection<S> {
99    /// Shared connection data.
100    inner: Arc<ConnectionData>,
101    /// Phantom data to track state type.
102    _state: PhantomData<S>,
103}
104
105impl<S> Clone for Connection<S> {
106    fn clone(&self) -> Self {
107        Self {
108            inner: Arc::clone(&self.inner),
109            _state: PhantomData,
110        }
111    }
112}
113
114impl<S> std::fmt::Debug for Connection<S> {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("Connection")
117            .field("inner", &self.inner)
118            .field("state", &std::any::type_name::<S>())
119            .finish()
120    }
121}
122
123impl Connection<state::Disconnected> {
124    /// Create a new disconnected connection.
125    #[must_use]
126    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
127        Self {
128            inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
129            _state: PhantomData,
130        }
131    }
132
133    /// Connect to establish a transport connection.
134    ///
135    /// This transitions from `Disconnected` to `Connected` state.
136    pub async fn connect(self) -> Result<Connection<state::Connected>, McpError> {
137        // In a real implementation, this would establish the transport
138        Ok(Connection {
139            inner: self.inner,
140            _state: PhantomData,
141        })
142    }
143}
144
145impl Connection<state::Connected> {
146    /// Start the initialization handshake.
147    ///
148    /// This transitions from `Connected` to `Initializing` state.
149    pub async fn initialize(
150        self,
151        _protocol_version: &str,
152    ) -> Result<Connection<state::Initializing>, McpError> {
153        // In a real implementation, this would send the initialize request
154        Ok(Connection {
155            inner: self.inner,
156            _state: PhantomData,
157        })
158    }
159
160    /// Close the connection before initialization.
161    pub async fn close(self) -> Result<(), McpError> {
162        // Clean up resources
163        Ok(())
164    }
165}
166
167impl Connection<state::Initializing> {
168    /// Complete the initialization handshake.
169    ///
170    /// This transitions from `Initializing` to `Ready` state.
171    pub async fn complete(
172        self,
173        client_capabilities: ClientCapabilities,
174        protocol_version: String,
175    ) -> Result<Connection<state::Ready>, McpError> {
176        // Update the connection data with negotiated values
177        // In a real implementation, we'd use interior mutability
178        let mut data = ConnectionData::new(
179            self.inner.server_info.clone(),
180            self.inner.server_capabilities.clone(),
181        );
182        data.client_capabilities = Some(client_capabilities);
183        data.protocol_version = Some(protocol_version);
184
185        Ok(Connection {
186            inner: Arc::new(data),
187            _state: PhantomData,
188        })
189    }
190
191    /// Abort initialization.
192    pub async fn abort(self) -> Result<Connection<state::Disconnected>, McpError> {
193        Ok(Connection {
194            inner: self.inner,
195            _state: PhantomData,
196        })
197    }
198}
199
200impl Connection<state::Ready> {
201    /// Get the client capabilities.
202    ///
203    /// # Panics
204    ///
205    /// This should never panic if the connection was properly initialized
206    /// through the typestate transitions. Use `try_client_capabilities()`
207    /// for a fallible version.
208    #[must_use]
209    pub fn client_capabilities(&self) -> &ClientCapabilities {
210        self.inner
211            .client_capabilities
212            .as_ref()
213            .expect("Ready connection must have client capabilities")
214    }
215
216    /// Try to get the client capabilities.
217    ///
218    /// Returns `None` if capabilities were not set (should not happen in normal use).
219    #[must_use]
220    pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
221        self.inner.client_capabilities.as_ref()
222    }
223
224    /// Get the server capabilities.
225    #[must_use]
226    pub fn server_capabilities(&self) -> &ServerCapabilities {
227        &self.inner.server_capabilities
228    }
229
230    /// Get the server info.
231    #[must_use]
232    pub fn server_info(&self) -> &ServerInfo {
233        &self.inner.server_info
234    }
235
236    /// Get the negotiated protocol version.
237    ///
238    /// # Panics
239    ///
240    /// This should never panic if the connection was properly initialized
241    /// through the typestate transitions. Use `try_protocol_version()`
242    /// for a fallible version.
243    #[must_use]
244    pub fn protocol_version(&self) -> &str {
245        self.inner
246            .protocol_version
247            .as_ref()
248            .expect("Ready connection must have protocol version")
249    }
250
251    /// Try to get the negotiated protocol version.
252    ///
253    /// Returns `None` if version was not set (should not happen in normal use).
254    #[must_use]
255    pub fn try_protocol_version(&self) -> Option<&str> {
256        self.inner.protocol_version.as_deref()
257    }
258
259    /// Start graceful shutdown.
260    ///
261    /// This transitions from `Ready` to `Closing` state.
262    pub async fn shutdown(self) -> Result<Connection<state::Closing>, McpError> {
263        Ok(Connection {
264            inner: self.inner,
265            _state: PhantomData,
266        })
267    }
268}
269
270impl Connection<state::Closing> {
271    /// Complete the shutdown and disconnect.
272    pub async fn disconnect(self) -> Result<(), McpError> {
273        // Clean up resources
274        Ok(())
275    }
276}
277
278/// A state machine wrapper for connections that allows runtime state tracking.
279///
280/// This provides an alternative to the pure typestate approach when
281/// runtime state inspection is needed.
282#[derive(Debug)]
283pub enum ConnectionState {
284    /// Not connected.
285    Disconnected(Connection<state::Disconnected>),
286    /// Connected but not initialized.
287    Connected(Connection<state::Connected>),
288    /// In initialization handshake.
289    Initializing(Connection<state::Initializing>),
290    /// Ready for requests.
291    Ready(Connection<state::Ready>),
292    /// Closing down.
293    Closing(Connection<state::Closing>),
294}
295
296impl ConnectionState {
297    /// Create a new disconnected connection state.
298    #[must_use]
299    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
300        Self::Disconnected(Connection::new(server_info, server_capabilities))
301    }
302
303    /// Check if the connection is ready for requests.
304    #[must_use]
305    pub const fn is_ready(&self) -> bool {
306        matches!(self, Self::Ready(_))
307    }
308
309    /// Check if the connection is disconnected.
310    #[must_use]
311    pub const fn is_disconnected(&self) -> bool {
312        matches!(self, Self::Disconnected(_))
313    }
314
315    /// Get the current state name.
316    #[must_use]
317    pub const fn state_name(&self) -> &'static str {
318        match self {
319            Self::Disconnected(_) => "Disconnected",
320            Self::Connected(_) => "Connected",
321            Self::Initializing(_) => "Initializing",
322            Self::Ready(_) => "Ready",
323            Self::Closing(_) => "Closing",
324        }
325    }
326}
327
328/// Transition events for connection state changes.
329#[derive(Debug, Clone)]
330pub enum ConnectionEvent {
331    /// Connection established.
332    Connected,
333    /// Initialization started.
334    InitializeStarted,
335    /// Initialization completed successfully.
336    InitializeCompleted {
337        /// Negotiated protocol version.
338        protocol_version: String,
339    },
340    /// Initialization failed.
341    InitializeFailed {
342        /// Error message.
343        error: String,
344    },
345    /// Shutdown requested.
346    ShutdownRequested,
347    /// Connection closed.
348    Disconnected,
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
355
356    #[test]
357    fn test_connection_creation() {
358        let info = ServerInfo::new("test", "1.0.0");
359        let caps = ServerCapabilities::default();
360        let conn: Connection<state::Disconnected> = Connection::new(info, caps);
361
362        assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
363    }
364
365    #[tokio::test]
366    async fn test_connection_lifecycle() {
367        let info = ServerInfo::new("test", "1.0.0");
368        let caps = ServerCapabilities::default();
369
370        // Start disconnected
371        let conn = Connection::new(info, caps);
372
373        // Connect
374        let conn = conn.connect().await.unwrap();
375
376        // Initialize
377        let conn = conn.initialize("2025-11-25").await.unwrap();
378
379        // Complete
380        let conn = conn
381            .complete(ClientCapabilities::default(), "2025-11-25".to_string())
382            .await
383            .unwrap();
384
385        // Verify ready state
386        assert_eq!(conn.protocol_version(), "2025-11-25");
387
388        // Shutdown
389        let conn = conn.shutdown().await.unwrap();
390
391        // Disconnect
392        conn.disconnect().await.unwrap();
393    }
394
395    #[test]
396    fn test_connection_state_enum() {
397        let info = ServerInfo::new("test", "1.0.0");
398        let caps = ServerCapabilities::default();
399
400        let state = ConnectionState::new(info, caps);
401        assert!(state.is_disconnected());
402        assert!(!state.is_ready());
403        assert_eq!(state.state_name(), "Disconnected");
404    }
405}