mcpkit_server/
state.rs

1//! Typestate connection management for MCP servers.
2//!
3//! This module implements the typestate pattern for managing
4//! connection lifecycle, ensuring compile-time correctness of
5//! state transitions.
6//!
7//! # Connection Lifecycle
8//!
9//! ```text
10//! Disconnected -> Connected -> Initializing -> Ready -> Closing
11//! ```
12//!
13//! Each state transition is enforced at compile time through
14//! different types, preventing invalid state transitions.
15
16use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use std::marker::PhantomData;
19use std::sync::Arc;
20
21/// Connection state markers.
22///
23/// These types represent different states in the connection lifecycle.
24/// They contain no data and are used purely for type-level state tracking.
25pub mod state {
26    /// Connection is disconnected (initial state).
27    #[derive(Debug, Clone, Copy)]
28    pub struct Disconnected;
29
30    /// Connection is established but not initialized.
31    #[derive(Debug, Clone, Copy)]
32    pub struct Connected;
33
34    /// Connection is in the initialization handshake.
35    #[derive(Debug, Clone, Copy)]
36    pub struct Initializing;
37
38    /// Connection is fully initialized and ready for requests.
39    #[derive(Debug, Clone, Copy)]
40    pub struct Ready;
41
42    /// Connection is closing down.
43    #[derive(Debug, Clone, Copy)]
44    pub struct Closing;
45}
46
47/// Internal connection data shared across states.
48#[derive(Debug)]
49pub struct ConnectionData {
50    /// Client capabilities (set after initialization).
51    pub client_capabilities: Option<ClientCapabilities>,
52    /// Server capabilities advertised.
53    pub server_capabilities: ServerCapabilities,
54    /// Server information.
55    pub server_info: ServerInfo,
56    /// Protocol version negotiated.
57    pub protocol_version: Option<String>,
58    /// Session ID if applicable.
59    pub session_id: Option<String>,
60}
61
62impl ConnectionData {
63    /// Create new connection data.
64    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
65        Self {
66            client_capabilities: None,
67            server_capabilities,
68            server_info,
69            protocol_version: None,
70            session_id: None,
71        }
72    }
73}
74
75/// A typestate connection that tracks lifecycle state at the type level.
76///
77/// The state parameter `S` ensures that only valid operations are
78/// available for each connection state.
79///
80/// # Example
81///
82/// ```rust
83/// use mcpkit_server::state::{Connection, state};
84/// use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
85///
86/// // Start disconnected
87/// let conn: Connection<state::Disconnected> = Connection::new(
88///     ServerInfo::new("my-server", "1.0.0"),
89///     ServerCapabilities::new().with_tools(),
90/// );
91///
92/// // The typestate pattern ensures compile-time safety:
93/// // - A Disconnected connection can only call connect()
94/// // - A Connected connection can only call initialize() or close()
95/// // - A Ready connection can access capabilities
96/// ```
97pub struct Connection<S> {
98    /// Shared connection data.
99    inner: Arc<ConnectionData>,
100    /// Phantom data to track state type.
101    _state: PhantomData<S>,
102}
103
104impl<S> Clone for Connection<S> {
105    fn clone(&self) -> Self {
106        Self {
107            inner: Arc::clone(&self.inner),
108            _state: PhantomData,
109        }
110    }
111}
112
113impl<S> std::fmt::Debug for Connection<S> {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("Connection")
116            .field("inner", &self.inner)
117            .field("state", &std::any::type_name::<S>())
118            .finish()
119    }
120}
121
122impl Connection<state::Disconnected> {
123    /// Create a new disconnected connection.
124    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
125        Self {
126            inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
127            _state: PhantomData,
128        }
129    }
130
131    /// Connect to establish a transport connection.
132    ///
133    /// This transitions from `Disconnected` to `Connected` state.
134    pub async fn connect(self) -> Result<Connection<state::Connected>, McpError> {
135        // In a real implementation, this would establish the transport
136        Ok(Connection {
137            inner: self.inner,
138            _state: PhantomData,
139        })
140    }
141}
142
143impl Connection<state::Connected> {
144    /// Start the initialization handshake.
145    ///
146    /// This transitions from `Connected` to `Initializing` state.
147    pub async fn initialize(
148        self,
149        _protocol_version: &str,
150    ) -> Result<Connection<state::Initializing>, McpError> {
151        // In a real implementation, this would send the initialize request
152        Ok(Connection {
153            inner: self.inner,
154            _state: PhantomData,
155        })
156    }
157
158    /// Close the connection before initialization.
159    pub async fn close(self) -> Result<(), McpError> {
160        // Clean up resources
161        Ok(())
162    }
163}
164
165impl Connection<state::Initializing> {
166    /// Complete the initialization handshake.
167    ///
168    /// This transitions from `Initializing` to `Ready` state.
169    pub async fn complete(
170        self,
171        client_capabilities: ClientCapabilities,
172        protocol_version: String,
173    ) -> Result<Connection<state::Ready>, McpError> {
174        // Update the connection data with negotiated values
175        // In a real implementation, we'd use interior mutability
176        let mut data = ConnectionData::new(
177            self.inner.server_info.clone(),
178            self.inner.server_capabilities.clone(),
179        );
180        data.client_capabilities = Some(client_capabilities);
181        data.protocol_version = Some(protocol_version);
182
183        Ok(Connection {
184            inner: Arc::new(data),
185            _state: PhantomData,
186        })
187    }
188
189    /// Abort initialization.
190    pub async fn abort(self) -> Result<Connection<state::Disconnected>, McpError> {
191        Ok(Connection {
192            inner: self.inner,
193            _state: PhantomData,
194        })
195    }
196}
197
198impl Connection<state::Ready> {
199    /// Get the client capabilities.
200    pub fn client_capabilities(&self) -> &ClientCapabilities {
201        self.inner
202            .client_capabilities
203            .as_ref()
204            .expect("Ready connection must have client capabilities")
205    }
206
207    /// Get the server capabilities.
208    pub fn server_capabilities(&self) -> &ServerCapabilities {
209        &self.inner.server_capabilities
210    }
211
212    /// Get the server info.
213    pub fn server_info(&self) -> &ServerInfo {
214        &self.inner.server_info
215    }
216
217    /// Get the negotiated protocol version.
218    pub fn protocol_version(&self) -> &str {
219        self.inner
220            .protocol_version
221            .as_ref()
222            .expect("Ready connection must have protocol version")
223    }
224
225    /// Start graceful shutdown.
226    ///
227    /// This transitions from `Ready` to `Closing` state.
228    pub async fn shutdown(self) -> Result<Connection<state::Closing>, McpError> {
229        Ok(Connection {
230            inner: self.inner,
231            _state: PhantomData,
232        })
233    }
234}
235
236impl Connection<state::Closing> {
237    /// Complete the shutdown and disconnect.
238    pub async fn disconnect(self) -> Result<(), McpError> {
239        // Clean up resources
240        Ok(())
241    }
242}
243
244/// A state machine wrapper for connections that allows runtime state tracking.
245///
246/// This provides an alternative to the pure typestate approach when
247/// runtime state inspection is needed.
248#[derive(Debug)]
249pub enum ConnectionState {
250    /// Not connected.
251    Disconnected(Connection<state::Disconnected>),
252    /// Connected but not initialized.
253    Connected(Connection<state::Connected>),
254    /// In initialization handshake.
255    Initializing(Connection<state::Initializing>),
256    /// Ready for requests.
257    Ready(Connection<state::Ready>),
258    /// Closing down.
259    Closing(Connection<state::Closing>),
260}
261
262impl ConnectionState {
263    /// Create a new disconnected connection state.
264    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
265        Self::Disconnected(Connection::new(server_info, server_capabilities))
266    }
267
268    /// Check if the connection is ready for requests.
269    pub fn is_ready(&self) -> bool {
270        matches!(self, ConnectionState::Ready(_))
271    }
272
273    /// Check if the connection is disconnected.
274    pub fn is_disconnected(&self) -> bool {
275        matches!(self, ConnectionState::Disconnected(_))
276    }
277
278    /// Get the current state name.
279    pub fn state_name(&self) -> &'static str {
280        match self {
281            ConnectionState::Disconnected(_) => "Disconnected",
282            ConnectionState::Connected(_) => "Connected",
283            ConnectionState::Initializing(_) => "Initializing",
284            ConnectionState::Ready(_) => "Ready",
285            ConnectionState::Closing(_) => "Closing",
286        }
287    }
288}
289
290/// Transition events for connection state changes.
291#[derive(Debug, Clone)]
292pub enum ConnectionEvent {
293    /// Connection established.
294    Connected,
295    /// Initialization started.
296    InitializeStarted,
297    /// Initialization completed successfully.
298    InitializeCompleted {
299        /// Negotiated protocol version.
300        protocol_version: String,
301    },
302    /// Initialization failed.
303    InitializeFailed {
304        /// Error message.
305        error: String,
306    },
307    /// Shutdown requested.
308    ShutdownRequested,
309    /// Connection closed.
310    Disconnected,
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
317
318    #[test]
319    fn test_connection_creation() {
320        let info = ServerInfo::new("test", "1.0.0");
321        let caps = ServerCapabilities::default();
322        let conn: Connection<state::Disconnected> = Connection::new(info, caps);
323
324        assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
325    }
326
327    #[tokio::test]
328    async fn test_connection_lifecycle() {
329        let info = ServerInfo::new("test", "1.0.0");
330        let caps = ServerCapabilities::default();
331
332        // Start disconnected
333        let conn = Connection::new(info, caps);
334
335        // Connect
336        let conn = conn.connect().await.unwrap();
337
338        // Initialize
339        let conn = conn.initialize("2025-11-25").await.unwrap();
340
341        // Complete
342        let conn = conn
343            .complete(ClientCapabilities::default(), "2025-11-25".to_string())
344            .await
345            .unwrap();
346
347        // Verify ready state
348        assert_eq!(conn.protocol_version(), "2025-11-25");
349
350        // Shutdown
351        let conn = conn.shutdown().await.unwrap();
352
353        // Disconnect
354        conn.disconnect().await.unwrap();
355    }
356
357    #[test]
358    fn test_connection_state_enum() {
359        let info = ServerInfo::new("test", "1.0.0");
360        let caps = ServerCapabilities::default();
361
362        let state = ConnectionState::new(info, caps);
363        assert!(state.is_disconnected());
364        assert!(!state.is_ready());
365        assert_eq!(state.state_name(), "Disconnected");
366    }
367}