mcpkit_core/
state.rs

1//! Typestate pattern for connection lifecycle management.
2//!
3//! This module implements the typestate pattern to enforce correct
4//! connection state transitions at compile time. This prevents runtime
5//! errors from calling methods on connections in invalid states.
6//!
7//! # Connection States
8//!
9//! ```text
10//! Disconnected -> Connected -> Initializing -> Ready -> Closing -> Disconnected
11//! ```
12//!
13//! # Example
14//!
15//! ```rust
16//! use mcpkit_core::state::{Connection, Disconnected, Connected};
17//!
18//! // Connection starts in Disconnected state
19//! let conn: Connection<Disconnected> = Connection::new();
20//!
21//! // Each state has appropriate methods
22//! let id = conn.id();
23//! assert!(!id.is_empty());
24//!
25//! // Connect to transition to Connected state
26//! let connected: Connection<Connected> = conn.connect();
27//! assert!(connected.connected_at().is_some());
28//! ```
29
30use std::marker::PhantomData;
31use std::time::{Duration, Instant};
32
33use crate::capability::{
34    ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult, ServerCapabilities,
35    ServerInfo,
36};
37use crate::error::McpError;
38use crate::protocol::RequestId;
39
40/// Marker type for disconnected state.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct Disconnected;
43
44/// Marker type for connected state (transport established).
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub struct Connected;
47
48/// Marker type for initializing state (handshake in progress).
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct Initializing;
51
52/// Marker type for ready state (fully operational).
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct Ready;
55
56/// Marker type for closing state (shutdown in progress).
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct Closing;
59
60/// Internal connection data shared across states.
61#[derive(Debug)]
62pub struct ConnectionInner {
63    /// Unique connection identifier.
64    pub id: String,
65    /// When the connection was established.
66    pub connected_at: Option<Instant>,
67    /// Last activity timestamp.
68    pub last_activity: Option<Instant>,
69    /// Request counter for generating IDs.
70    pub request_counter: u64,
71    /// Client info (available after initialization).
72    pub client_info: Option<ClientInfo>,
73    /// Server info (available after initialization).
74    pub server_info: Option<ServerInfo>,
75    /// Client capabilities (available after initialization).
76    pub client_capabilities: Option<ClientCapabilities>,
77    /// Server capabilities (available after initialization).
78    pub server_capabilities: Option<ServerCapabilities>,
79}
80
81impl ConnectionInner {
82    /// Create new connection inner data.
83    fn new() -> Self {
84        Self {
85            id: uuid::Uuid::new_v4().to_string(),
86            connected_at: None,
87            last_activity: None,
88            request_counter: 0,
89            client_info: None,
90            server_info: None,
91            client_capabilities: None,
92            server_capabilities: None,
93        }
94    }
95
96    /// Generate the next request ID.
97    fn next_request_id(&mut self) -> RequestId {
98        self.request_counter += 1;
99        RequestId::Number(self.request_counter)
100    }
101
102    /// Update last activity timestamp.
103    fn touch(&mut self) {
104        self.last_activity = Some(Instant::now());
105    }
106}
107
108impl Default for ConnectionInner {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114/// A connection in a specific state.
115///
116/// The type parameter `S` represents the current state of the connection.
117/// Different methods are available depending on the state.
118#[derive(Debug)]
119pub struct Connection<S> {
120    inner: ConnectionInner,
121    _state: PhantomData<S>,
122}
123
124impl Connection<Disconnected> {
125    /// Create a new disconnected connection.
126    #[must_use]
127    pub fn new() -> Self {
128        Self {
129            inner: ConnectionInner::new(),
130            _state: PhantomData,
131        }
132    }
133
134    /// Get the connection ID.
135    #[must_use]
136    pub fn id(&self) -> &str {
137        &self.inner.id
138    }
139
140    /// Establish the connection (transition to Connected state).
141    ///
142    /// In a real implementation, this would take a transport and
143    /// establish the connection. Here we just transition the state.
144    #[must_use]
145    pub fn connect(mut self) -> Connection<Connected> {
146        self.inner.connected_at = Some(Instant::now());
147        self.inner.touch();
148        Connection {
149            inner: self.inner,
150            _state: PhantomData,
151        }
152    }
153}
154
155impl Default for Connection<Disconnected> {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl Connection<Connected> {
162    /// Get the connection ID.
163    #[must_use]
164    pub fn id(&self) -> &str {
165        &self.inner.id
166    }
167
168    /// Get when the connection was established.
169    #[must_use]
170    pub fn connected_at(&self) -> Option<Instant> {
171        self.inner.connected_at
172    }
173
174    /// Get how long the connection has been active.
175    #[must_use]
176    pub fn uptime(&self) -> Duration {
177        self.inner
178            .connected_at
179            .map(|t| t.elapsed())
180            .unwrap_or_default()
181    }
182
183    /// Begin initialization (transition to Initializing state).
184    ///
185    /// For clients: Send initialize request with client info and capabilities.
186    /// For servers: This is called when receiving an initialize request.
187    pub fn initialize(
188        mut self,
189        client_info: ClientInfo,
190        client_capabilities: ClientCapabilities,
191    ) -> (Connection<Initializing>, InitializeRequest) {
192        self.inner.client_info = Some(client_info.clone());
193        self.inner.client_capabilities = Some(client_capabilities.clone());
194        self.inner.touch();
195
196        let request = InitializeRequest::new(client_info, client_capabilities);
197
198        (
199            Connection {
200                inner: self.inner,
201                _state: PhantomData,
202            },
203            request,
204        )
205    }
206
207    /// Disconnect (transition back to Disconnected state).
208    #[must_use]
209    pub fn disconnect(self) -> Connection<Disconnected> {
210        Connection {
211            inner: ConnectionInner::new(),
212            _state: PhantomData,
213        }
214    }
215}
216
217impl Connection<Initializing> {
218    /// Get the connection ID.
219    #[must_use]
220    pub fn id(&self) -> &str {
221        &self.inner.id
222    }
223
224    /// Get the client info.
225    #[must_use]
226    pub fn client_info(&self) -> Option<&ClientInfo> {
227        self.inner.client_info.as_ref()
228    }
229
230    /// Get the client capabilities.
231    #[must_use]
232    pub fn client_capabilities(&self) -> Option<&ClientCapabilities> {
233        self.inner.client_capabilities.as_ref()
234    }
235
236    /// Complete initialization (transition to Ready state).
237    ///
238    /// This is called after the initialize response is received (client)
239    /// or sent (server).
240    pub fn complete(
241        mut self,
242        server_info: ServerInfo,
243        server_capabilities: ServerCapabilities,
244    ) -> Connection<Ready> {
245        self.inner.server_info = Some(server_info);
246        self.inner.server_capabilities = Some(server_capabilities);
247        self.inner.touch();
248
249        Connection {
250            inner: self.inner,
251            _state: PhantomData,
252        }
253    }
254
255    /// Abort initialization (transition back to Disconnected).
256    #[must_use]
257    pub fn abort(self) -> Connection<Disconnected> {
258        Connection {
259            inner: ConnectionInner::new(),
260            _state: PhantomData,
261        }
262    }
263}
264
265impl Connection<Ready> {
266    /// Get the connection ID.
267    #[must_use]
268    pub fn id(&self) -> &str {
269        &self.inner.id
270    }
271
272    /// Get when the connection was established.
273    #[must_use]
274    pub fn connected_at(&self) -> Option<Instant> {
275        self.inner.connected_at
276    }
277
278    /// Get how long the connection has been active.
279    #[must_use]
280    pub fn uptime(&self) -> Duration {
281        self.inner
282            .connected_at
283            .map(|t| t.elapsed())
284            .unwrap_or_default()
285    }
286
287    /// Get the last activity timestamp.
288    #[must_use]
289    pub fn last_activity(&self) -> Option<Instant> {
290        self.inner.last_activity
291    }
292
293    /// Get the client info.
294    ///
295    /// # Panics
296    ///
297    /// This should never panic if the connection was properly initialized,
298    /// as the typestate pattern ensures this is only callable in Ready state.
299    /// Use `try_client_info()` for a fallible version.
300    #[must_use]
301    pub fn client_info(&self) -> &ClientInfo {
302        self.inner.client_info.as_ref().expect("client_info should be set in Ready state")
303    }
304
305    /// Try to get the client info.
306    ///
307    /// Returns `None` if the client info was not set (should not happen in normal use).
308    #[must_use]
309    pub fn try_client_info(&self) -> Option<&ClientInfo> {
310        self.inner.client_info.as_ref()
311    }
312
313    /// Get the server info.
314    ///
315    /// # Panics
316    ///
317    /// This should never panic if the connection was properly initialized,
318    /// as the typestate pattern ensures this is only callable in Ready state.
319    /// Use `try_server_info()` for a fallible version.
320    #[must_use]
321    pub fn server_info(&self) -> &ServerInfo {
322        self.inner.server_info.as_ref().expect("server_info should be set in Ready state")
323    }
324
325    /// Try to get the server info.
326    ///
327    /// Returns `None` if the server info was not set (should not happen in normal use).
328    #[must_use]
329    pub fn try_server_info(&self) -> Option<&ServerInfo> {
330        self.inner.server_info.as_ref()
331    }
332
333    /// Get the client capabilities.
334    ///
335    /// # Panics
336    ///
337    /// This should never panic if the connection was properly initialized,
338    /// as the typestate pattern ensures this is only callable in Ready state.
339    /// Use `try_client_capabilities()` for a fallible version.
340    #[must_use]
341    pub fn client_capabilities(&self) -> &ClientCapabilities {
342        self.inner.client_capabilities.as_ref().expect("client_capabilities should be set in Ready state")
343    }
344
345    /// Try to get the client capabilities.
346    ///
347    /// Returns `None` if the client capabilities were not set (should not happen in normal use).
348    #[must_use]
349    pub fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
350        self.inner.client_capabilities.as_ref()
351    }
352
353    /// Get the server capabilities.
354    ///
355    /// # Panics
356    ///
357    /// This should never panic if the connection was properly initialized,
358    /// as the typestate pattern ensures this is only callable in Ready state.
359    /// Use `try_server_capabilities()` for a fallible version.
360    #[must_use]
361    pub fn server_capabilities(&self) -> &ServerCapabilities {
362        self.inner.server_capabilities.as_ref().expect("server_capabilities should be set in Ready state")
363    }
364
365    /// Try to get the server capabilities.
366    ///
367    /// Returns `None` if the server capabilities were not set (should not happen in normal use).
368    #[must_use]
369    pub fn try_server_capabilities(&self) -> Option<&ServerCapabilities> {
370        self.inner.server_capabilities.as_ref()
371    }
372
373    /// Generate the next request ID.
374    pub fn next_request_id(&mut self) -> RequestId {
375        self.inner.next_request_id()
376    }
377
378    /// Update the last activity timestamp.
379    pub fn touch(&mut self) {
380        self.inner.touch();
381    }
382
383    /// Check if the connection has been idle for longer than the given duration.
384    #[must_use]
385    pub fn is_idle(&self, timeout: Duration) -> bool {
386        self.inner
387            .last_activity
388            .map(|t| t.elapsed() > timeout)
389            .unwrap_or(false)
390    }
391
392    /// Begin shutdown (transition to Closing state).
393    #[must_use]
394    pub fn shutdown(self) -> Connection<Closing> {
395        Connection {
396            inner: self.inner,
397            _state: PhantomData,
398        }
399    }
400}
401
402impl Connection<Closing> {
403    /// Get the connection ID.
404    #[must_use]
405    pub fn id(&self) -> &str {
406        &self.inner.id
407    }
408
409    /// Complete the shutdown (transition to Disconnected state).
410    #[must_use]
411    pub fn close(self) -> Connection<Disconnected> {
412        Connection {
413            inner: ConnectionInner::new(),
414            _state: PhantomData,
415        }
416    }
417}
418
419/// Builder for creating initialize results (used by servers).
420pub struct InitializeResultBuilder {
421    server_info: ServerInfo,
422    capabilities: ServerCapabilities,
423    instructions: Option<String>,
424}
425
426impl InitializeResultBuilder {
427    /// Create a new builder with server info.
428    #[must_use]
429    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
430        Self {
431            server_info: ServerInfo::new(name, version),
432            capabilities: ServerCapabilities::new(),
433            instructions: None,
434        }
435    }
436
437    /// Set the capabilities.
438    #[must_use]
439    pub fn capabilities(mut self, caps: ServerCapabilities) -> Self {
440        self.capabilities = caps;
441        self
442    }
443
444    /// Enable tool support.
445    #[must_use]
446    pub fn with_tools(mut self) -> Self {
447        self.capabilities = self.capabilities.with_tools();
448        self
449    }
450
451    /// Enable resource support.
452    #[must_use]
453    pub fn with_resources(mut self) -> Self {
454        self.capabilities = self.capabilities.with_resources();
455        self
456    }
457
458    /// Enable prompt support.
459    #[must_use]
460    pub fn with_prompts(mut self) -> Self {
461        self.capabilities = self.capabilities.with_prompts();
462        self
463    }
464
465    /// Enable task support.
466    #[must_use]
467    pub fn with_tasks(mut self) -> Self {
468        self.capabilities = self.capabilities.with_tasks();
469        self
470    }
471
472    /// Set instructions.
473    #[must_use]
474    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
475        self.instructions = Some(instructions.into());
476        self
477    }
478
479    /// Build the initialize result.
480    #[must_use]
481    pub fn build(self) -> InitializeResult {
482        let mut result = InitializeResult::new(self.server_info, self.capabilities);
483        if let Some(instructions) = self.instructions {
484            result = result.instructions(instructions);
485        }
486        result
487    }
488}
489
490/// Validate that a connection can transition to the ready state.
491pub fn validate_initialization(
492    _client_caps: &ClientCapabilities,
493    _server_caps: &ServerCapabilities,
494) -> Result<(), McpError> {
495    // For now, just return Ok. In a real implementation, you might
496    // check for required capability combinations.
497    Ok(())
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_connection_lifecycle() {
506        // Start disconnected
507        let conn: Connection<Disconnected> = Connection::new();
508        assert!(!conn.id().is_empty());
509
510        // Connect
511        let conn: Connection<Connected> = conn.connect();
512        assert!(conn.connected_at().is_some());
513
514        // Initialize
515        let client = ClientInfo::new("test", "1.0.0");
516        let caps = ClientCapabilities::new();
517        let (conn, request): (Connection<Initializing>, _) = conn.initialize(client, caps);
518        assert!(conn.client_info().is_some());
519
520        // Complete
521        let server = ServerInfo::new("server", "1.0.0");
522        let server_caps = ServerCapabilities::new().with_tools();
523        let mut conn: Connection<Ready> = conn.complete(server, server_caps);
524        assert!(conn.server_capabilities().has_tools());
525
526        // Generate request IDs
527        let id1 = conn.next_request_id();
528        let id2 = conn.next_request_id();
529        assert_ne!(id1, id2);
530
531        // Shutdown
532        let conn: Connection<Closing> = conn.shutdown();
533        let _conn: Connection<Disconnected> = conn.close();
534    }
535
536    #[test]
537    fn test_uptime() {
538        let conn = Connection::new().connect();
539        std::thread::sleep(std::time::Duration::from_millis(10));
540        assert!(conn.uptime() >= std::time::Duration::from_millis(10));
541    }
542
543    #[test]
544    fn test_idle_detection() {
545        let client = ClientInfo::new("test", "1.0.0");
546        let server = ServerInfo::new("server", "1.0.0");
547
548        let (conn, _) = Connection::new()
549            .connect()
550            .initialize(client, ClientCapabilities::new());
551
552        let conn = conn.complete(server, ServerCapabilities::new());
553
554        // Should not be idle immediately
555        assert!(!conn.is_idle(Duration::from_secs(1)));
556    }
557
558    #[test]
559    fn test_initialize_result_builder() {
560        let result = InitializeResultBuilder::new("my-server", "1.0.0")
561            .with_tools()
562            .with_resources()
563            .instructions("Use this server to access tools and resources")
564            .build();
565
566        assert_eq!(result.server_info.name, "my-server");
567        assert!(result.capabilities.has_tools());
568        assert!(result.capabilities.has_resources());
569        assert!(result.instructions.is_some());
570    }
571
572    #[test]
573    fn test_abort_initialization() {
574        let client = ClientInfo::new("test", "1.0.0");
575        let (conn, _) = Connection::new()
576            .connect()
577            .initialize(client, ClientCapabilities::new());
578
579        // Abort should return to disconnected
580        let _conn: Connection<Disconnected> = conn.abort();
581    }
582
583    #[test]
584    fn test_disconnect_from_connected() {
585        let conn = Connection::new().connect();
586        let _conn: Connection<Disconnected> = conn.disconnect();
587    }
588
589    #[test]
590    fn test_fallible_accessors() {
591        let client = ClientInfo::new("test-client", "1.0.0");
592        let server = ServerInfo::new("test-server", "2.0.0");
593        let client_caps = ClientCapabilities::new();
594        let server_caps = ServerCapabilities::new().with_tools();
595
596        let (conn, _) = Connection::new()
597            .connect()
598            .initialize(client.clone(), client_caps.clone());
599
600        let conn = conn.complete(server.clone(), server_caps.clone());
601
602        // Test fallible accessors return Some
603        assert!(conn.try_client_info().is_some());
604        assert!(conn.try_server_info().is_some());
605        assert!(conn.try_client_capabilities().is_some());
606        assert!(conn.try_server_capabilities().is_some());
607
608        // Test values are correct
609        assert_eq!(conn.try_client_info().unwrap().name, "test-client");
610        assert_eq!(conn.try_server_info().unwrap().name, "test-server");
611        assert!(conn.try_server_capabilities().unwrap().has_tools());
612    }
613}