mcpkit_server/
state.rs

1//! Typestate connection management for MCP servers.
2//!
3//! This module implements the typestate pattern for managing
4//! connection lifecycle, ensuring compile-time correctness of
5//! state transitions.
6//!
7//! # Connection Lifecycle
8//!
9//! ```text
10//! Disconnected -> Connected -> Initializing -> Ready -> Closing
11//! ```
12//!
13//! Each state transition is enforced at compile time through
14//! different types, preventing invalid state transitions.
15
16use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities, ServerInfo};
17use mcpkit_core::error::McpError;
18use mcpkit_core::protocol_version::ProtocolVersion;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22/// Connection state markers.
23///
24/// These types represent different states in the connection lifecycle.
25/// They contain no data and are used purely for type-level state tracking.
26pub mod state {
27    /// Connection is disconnected (initial state).
28    #[derive(Debug, Clone, Copy)]
29    pub struct Disconnected;
30
31    /// Connection is established but not initialized.
32    #[derive(Debug, Clone, Copy)]
33    pub struct Connected;
34
35    /// Connection is in the initialization handshake.
36    #[derive(Debug, Clone, Copy)]
37    pub struct Initializing;
38
39    /// Connection is fully initialized and ready for requests.
40    #[derive(Debug, Clone, Copy)]
41    pub struct Ready;
42
43    /// Connection is closing down.
44    #[derive(Debug, Clone, Copy)]
45    pub struct Closing;
46}
47
48/// Internal connection data shared across states.
49#[derive(Debug)]
50pub struct ConnectionData {
51    /// Client capabilities (set after initialization).
52    pub client_capabilities: Option<ClientCapabilities>,
53    /// Server capabilities advertised.
54    pub server_capabilities: ServerCapabilities,
55    /// Server information.
56    pub server_info: ServerInfo,
57    /// Protocol version negotiated.
58    ///
59    /// Stored as a type-safe [`ProtocolVersion`] enum for feature detection.
60    pub protocol_version: Option<ProtocolVersion>,
61    /// Session ID if applicable.
62    pub session_id: Option<String>,
63}
64
65impl ConnectionData {
66    /// Create new connection data.
67    #[must_use]
68    pub const fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
69        Self {
70            client_capabilities: None,
71            server_capabilities,
72            server_info,
73            protocol_version: None,
74            session_id: None,
75        }
76    }
77}
78
79/// A typestate connection that tracks lifecycle state at the type level.
80///
81/// The state parameter `S` ensures that only valid operations are
82/// available for each connection state.
83///
84/// # Example
85///
86/// ```rust
87/// use mcpkit_server::state::{Connection, state};
88/// use mcpkit_core::capability::{ServerInfo, ServerCapabilities};
89///
90/// // Start disconnected
91/// let conn: Connection<state::Disconnected> = Connection::new(
92///     ServerInfo::new("my-server", "1.0.0"),
93///     ServerCapabilities::new().with_tools(),
94/// );
95///
96/// // The typestate pattern ensures compile-time safety:
97/// // - A Disconnected connection can only call connect()
98/// // - A Connected connection can only call initialize() or close()
99/// // - A Ready connection can access capabilities
100/// ```
101pub struct Connection<S> {
102    /// Shared connection data.
103    inner: Arc<ConnectionData>,
104    /// Phantom data to track state type.
105    _state: PhantomData<S>,
106}
107
108impl<S> Clone for Connection<S> {
109    fn clone(&self) -> Self {
110        Self {
111            inner: Arc::clone(&self.inner),
112            _state: PhantomData,
113        }
114    }
115}
116
117impl<S> std::fmt::Debug for Connection<S> {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("Connection")
120            .field("inner", &self.inner)
121            .field("state", &std::any::type_name::<S>())
122            .finish()
123    }
124}
125
126impl Connection<state::Disconnected> {
127    /// Create a new disconnected connection.
128    #[must_use]
129    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
130        Self {
131            inner: Arc::new(ConnectionData::new(server_info, server_capabilities)),
132            _state: PhantomData,
133        }
134    }
135
136    /// Connect to establish a transport connection.
137    ///
138    /// This transitions from `Disconnected` to `Connected` state.
139    pub async fn connect(self) -> Result<Connection<state::Connected>, McpError> {
140        // In a real implementation, this would establish the transport
141        Ok(Connection {
142            inner: self.inner,
143            _state: PhantomData,
144        })
145    }
146}
147
148impl Connection<state::Connected> {
149    /// Start the initialization handshake.
150    ///
151    /// This transitions from `Connected` to `Initializing` state.
152    pub async fn initialize(
153        self,
154        _protocol_version: ProtocolVersion,
155    ) -> Result<Connection<state::Initializing>, McpError> {
156        // In a real implementation, this would send the initialize request
157        Ok(Connection {
158            inner: self.inner,
159            _state: PhantomData,
160        })
161    }
162
163    /// Close the connection before initialization.
164    pub async fn close(self) -> Result<(), McpError> {
165        // Clean up resources
166        Ok(())
167    }
168}
169
170impl Connection<state::Initializing> {
171    /// Complete the initialization handshake.
172    ///
173    /// This transitions from `Initializing` to `Ready` state.
174    pub async fn complete(
175        self,
176        client_capabilities: ClientCapabilities,
177        protocol_version: ProtocolVersion,
178    ) -> Result<Connection<state::Ready>, McpError> {
179        // Update the connection data with negotiated values
180        // In a real implementation, we'd use interior mutability
181        let mut data = ConnectionData::new(
182            self.inner.server_info.clone(),
183            self.inner.server_capabilities.clone(),
184        );
185        data.client_capabilities = Some(client_capabilities);
186        data.protocol_version = Some(protocol_version);
187
188        Ok(Connection {
189            inner: Arc::new(data),
190            _state: PhantomData,
191        })
192    }
193
194    /// Abort initialization.
195    pub async fn abort(self) -> Result<Connection<state::Disconnected>, McpError> {
196        Ok(Connection {
197            inner: self.inner,
198            _state: PhantomData,
199        })
200    }
201}
202
203impl Connection<state::Ready> {
204    /// Get the client capabilities.
205    ///
206    /// # Panics
207    ///
208    /// This should never panic if the connection was properly initialized
209    /// through the typestate transitions. Use `try_client_capabilities()`
210    /// for a fallible version.
211    #[must_use]
212    pub fn client_capabilities(&self) -> &ClientCapabilities {
213        self.inner
214            .client_capabilities
215            .as_ref()
216            .expect("Ready connection must have client capabilities")
217    }
218
219    /// Try to get the client capabilities.
220    ///
221    /// Returns `None` if capabilities were not set (should not happen in normal use).
222    #[must_use]
223    pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
224        self.inner.client_capabilities.as_ref()
225    }
226
227    /// Get the server capabilities.
228    #[must_use]
229    pub fn server_capabilities(&self) -> &ServerCapabilities {
230        &self.inner.server_capabilities
231    }
232
233    /// Get the server info.
234    #[must_use]
235    pub fn server_info(&self) -> &ServerInfo {
236        &self.inner.server_info
237    }
238
239    /// Get the negotiated protocol version.
240    ///
241    /// # Panics
242    ///
243    /// This should never panic if the connection was properly initialized
244    /// through the typestate transitions. Use `try_protocol_version()`
245    /// for a fallible version.
246    #[must_use]
247    pub fn protocol_version(&self) -> ProtocolVersion {
248        self.inner
249            .protocol_version
250            .expect("Ready connection must have protocol version")
251    }
252
253    /// Try to get the negotiated protocol version.
254    ///
255    /// Returns `None` if version was not set (should not happen in normal use).
256    #[must_use]
257    pub fn try_protocol_version(&self) -> Option<ProtocolVersion> {
258        self.inner.protocol_version
259    }
260
261    /// Start graceful shutdown.
262    ///
263    /// This transitions from `Ready` to `Closing` state.
264    pub async fn shutdown(self) -> Result<Connection<state::Closing>, McpError> {
265        Ok(Connection {
266            inner: self.inner,
267            _state: PhantomData,
268        })
269    }
270}
271
272impl Connection<state::Closing> {
273    /// Complete the shutdown and disconnect.
274    pub async fn disconnect(self) -> Result<(), McpError> {
275        // Clean up resources
276        Ok(())
277    }
278}
279
280/// A state machine wrapper for connections that allows runtime state tracking.
281///
282/// This provides an alternative to the pure typestate approach when
283/// runtime state inspection is needed.
284#[derive(Debug)]
285pub enum ConnectionState {
286    /// Not connected.
287    Disconnected(Connection<state::Disconnected>),
288    /// Connected but not initialized.
289    Connected(Connection<state::Connected>),
290    /// In initialization handshake.
291    Initializing(Connection<state::Initializing>),
292    /// Ready for requests.
293    Ready(Connection<state::Ready>),
294    /// Closing down.
295    Closing(Connection<state::Closing>),
296}
297
298impl ConnectionState {
299    /// Create a new disconnected connection state.
300    #[must_use]
301    pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
302        Self::Disconnected(Connection::new(server_info, server_capabilities))
303    }
304
305    /// Check if the connection is ready for requests.
306    #[must_use]
307    pub const fn is_ready(&self) -> bool {
308        matches!(self, Self::Ready(_))
309    }
310
311    /// Check if the connection is disconnected.
312    #[must_use]
313    pub const fn is_disconnected(&self) -> bool {
314        matches!(self, Self::Disconnected(_))
315    }
316
317    /// Get the current state name.
318    #[must_use]
319    pub const fn state_name(&self) -> &'static str {
320        match self {
321            Self::Disconnected(_) => "Disconnected",
322            Self::Connected(_) => "Connected",
323            Self::Initializing(_) => "Initializing",
324            Self::Ready(_) => "Ready",
325            Self::Closing(_) => "Closing",
326        }
327    }
328}
329
330/// Transition events for connection state changes.
331#[derive(Debug, Clone)]
332pub enum ConnectionEvent {
333    /// Connection established.
334    Connected,
335    /// Initialization started.
336    InitializeStarted,
337    /// Initialization completed successfully.
338    InitializeCompleted {
339        /// Negotiated protocol version.
340        protocol_version: ProtocolVersion,
341    },
342    /// Initialization failed.
343    InitializeFailed {
344        /// Error message.
345        error: String,
346    },
347    /// Shutdown requested.
348    ShutdownRequested,
349    /// Connection closed.
350    Disconnected,
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
357    use mcpkit_core::protocol_version::ProtocolVersion;
358
359    #[test]
360    fn test_connection_creation() {
361        let info = ServerInfo::new("test", "1.0.0");
362        let caps = ServerCapabilities::default();
363        let conn: Connection<state::Disconnected> = Connection::new(info, caps);
364
365        assert!(std::any::type_name_of_val(&conn._state).contains("Disconnected"));
366    }
367
368    #[tokio::test]
369    async fn test_connection_lifecycle() {
370        let info = ServerInfo::new("test", "1.0.0");
371        let caps = ServerCapabilities::default();
372
373        // Start disconnected
374        let conn = Connection::new(info, caps);
375
376        // Connect
377        let conn = conn.connect().await.unwrap();
378
379        // Initialize
380        let conn = conn.initialize(ProtocolVersion::V2025_11_25).await.unwrap();
381
382        // Complete
383        let conn = conn
384            .complete(ClientCapabilities::default(), ProtocolVersion::V2025_11_25)
385            .await
386            .unwrap();
387
388        // Verify ready state
389        assert_eq!(conn.protocol_version(), ProtocolVersion::V2025_11_25);
390
391        // Shutdown
392        let conn = conn.shutdown().await.unwrap();
393
394        // Disconnect
395        conn.disconnect().await.unwrap();
396    }
397
398    #[test]
399    fn test_connection_state_enum() {
400        let info = ServerInfo::new("test", "1.0.0");
401        let caps = ServerCapabilities::default();
402
403        let state = ConnectionState::new(info, caps);
404        assert!(state.is_disconnected());
405        assert!(!state.is_ready());
406        assert_eq!(state.state_name(), "Disconnected");
407    }
408}