mcpkit_core/
state.rs

1//! Typestate pattern for connection lifecycle management.
2//!
3//! This module implements the typestate pattern to enforce correct
4//! connection state transitions at compile time. This prevents runtime
5//! errors from calling methods on connections in invalid states.
6//!
7//! # Connection States
8//!
9//! ```text
10//! Disconnected -> Connected -> Initializing -> Ready -> Closing -> Disconnected
11//! ```
12//!
13//! # Example
14//!
15//! ```rust
16//! use mcpkit_core::state::{Connection, Disconnected, Connected};
17//!
18//! // Connection starts in Disconnected state
19//! let conn: Connection<Disconnected> = Connection::new();
20//!
21//! // Each state has appropriate methods
22//! let id = conn.id();
23//! assert!(!id.is_empty());
24//!
25//! // Connect to transition to Connected state
26//! let connected: Connection<Connected> = conn.connect();
27//! assert!(connected.connected_at().is_some());
28//! ```
29
30use std::marker::PhantomData;
31use std::time::{Duration, Instant};
32
33use crate::capability::{
34    ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult, ServerCapabilities,
35    ServerInfo,
36};
37use crate::error::McpError;
38use crate::protocol::RequestId;
39
40/// Marker type for disconnected state.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct Disconnected;
43
44/// Marker type for connected state (transport established).
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub struct Connected;
47
48/// Marker type for initializing state (handshake in progress).
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct Initializing;
51
52/// Marker type for ready state (fully operational).
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct Ready;
55
56/// Marker type for closing state (shutdown in progress).
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct Closing;
59
60/// Internal connection data shared across states.
61#[derive(Debug)]
62pub struct ConnectionInner {
63    /// Unique connection identifier.
64    pub id: String,
65    /// When the connection was established.
66    pub connected_at: Option<Instant>,
67    /// Last activity timestamp.
68    pub last_activity: Option<Instant>,
69    /// Request counter for generating IDs.
70    pub request_counter: u64,
71    /// Client info (available after initialization).
72    pub client_info: Option<ClientInfo>,
73    /// Server info (available after initialization).
74    pub server_info: Option<ServerInfo>,
75    /// Client capabilities (available after initialization).
76    pub client_capabilities: Option<ClientCapabilities>,
77    /// Server capabilities (available after initialization).
78    pub server_capabilities: Option<ServerCapabilities>,
79}
80
81impl ConnectionInner {
82    /// Create new connection inner data.
83    fn new() -> Self {
84        Self {
85            id: uuid::Uuid::new_v4().to_string(),
86            connected_at: None,
87            last_activity: None,
88            request_counter: 0,
89            client_info: None,
90            server_info: None,
91            client_capabilities: None,
92            server_capabilities: None,
93        }
94    }
95
96    /// Generate the next request ID.
97    fn next_request_id(&mut self) -> RequestId {
98        self.request_counter += 1;
99        RequestId::Number(self.request_counter)
100    }
101
102    /// Update last activity timestamp.
103    fn touch(&mut self) {
104        self.last_activity = Some(Instant::now());
105    }
106}
107
108impl Default for ConnectionInner {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114/// A connection in a specific state.
115///
116/// The type parameter `S` represents the current state of the connection.
117/// Different methods are available depending on the state.
118#[derive(Debug)]
119pub struct Connection<S> {
120    inner: ConnectionInner,
121    _state: PhantomData<S>,
122}
123
124impl Connection<Disconnected> {
125    /// Create a new disconnected connection.
126    #[must_use]
127    pub fn new() -> Self {
128        Self {
129            inner: ConnectionInner::new(),
130            _state: PhantomData,
131        }
132    }
133
134    /// Get the connection ID.
135    #[must_use]
136    pub fn id(&self) -> &str {
137        &self.inner.id
138    }
139
140    /// Establish the connection (transition to Connected state).
141    ///
142    /// In a real implementation, this would take a transport and
143    /// establish the connection. Here we just transition the state.
144    #[must_use]
145    pub fn connect(mut self) -> Connection<Connected> {
146        self.inner.connected_at = Some(Instant::now());
147        self.inner.touch();
148        Connection {
149            inner: self.inner,
150            _state: PhantomData,
151        }
152    }
153}
154
155impl Default for Connection<Disconnected> {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl Connection<Connected> {
162    /// Get the connection ID.
163    #[must_use]
164    pub fn id(&self) -> &str {
165        &self.inner.id
166    }
167
168    /// Get when the connection was established.
169    #[must_use]
170    pub const fn connected_at(&self) -> Option<Instant> {
171        self.inner.connected_at
172    }
173
174    /// Get how long the connection has been active.
175    #[must_use]
176    pub fn uptime(&self) -> Duration {
177        self.inner
178            .connected_at
179            .map(|t| t.elapsed())
180            .unwrap_or_default()
181    }
182
183    /// Begin initialization (transition to Initializing state).
184    ///
185    /// For clients: Send initialize request with client info and capabilities.
186    /// For servers: This is called when receiving an initialize request.
187    #[must_use]
188    pub fn initialize(
189        mut self,
190        client_info: ClientInfo,
191        client_capabilities: ClientCapabilities,
192    ) -> (Connection<Initializing>, InitializeRequest) {
193        self.inner.client_info = Some(client_info.clone());
194        self.inner.client_capabilities = Some(client_capabilities.clone());
195        self.inner.touch();
196
197        let request = InitializeRequest::new(client_info, client_capabilities);
198
199        (
200            Connection {
201                inner: self.inner,
202                _state: PhantomData,
203            },
204            request,
205        )
206    }
207
208    /// Disconnect (transition back to Disconnected state).
209    #[must_use]
210    pub fn disconnect(self) -> Connection<Disconnected> {
211        Connection {
212            inner: ConnectionInner::new(),
213            _state: PhantomData,
214        }
215    }
216}
217
218impl Connection<Initializing> {
219    /// Get the connection ID.
220    #[must_use]
221    pub fn id(&self) -> &str {
222        &self.inner.id
223    }
224
225    /// Get the client info.
226    #[must_use]
227    pub const fn client_info(&self) -> Option<&ClientInfo> {
228        self.inner.client_info.as_ref()
229    }
230
231    /// Get the client capabilities.
232    #[must_use]
233    pub const fn client_capabilities(&self) -> Option<&ClientCapabilities> {
234        self.inner.client_capabilities.as_ref()
235    }
236
237    /// Complete initialization (transition to Ready state).
238    ///
239    /// This is called after the initialize response is received (client)
240    /// or sent (server).
241    #[must_use]
242    pub fn complete(
243        mut self,
244        server_info: ServerInfo,
245        server_capabilities: ServerCapabilities,
246    ) -> Connection<Ready> {
247        self.inner.server_info = Some(server_info);
248        self.inner.server_capabilities = Some(server_capabilities);
249        self.inner.touch();
250
251        Connection {
252            inner: self.inner,
253            _state: PhantomData,
254        }
255    }
256
257    /// Abort initialization (transition back to Disconnected).
258    #[must_use]
259    pub fn abort(self) -> Connection<Disconnected> {
260        Connection {
261            inner: ConnectionInner::new(),
262            _state: PhantomData,
263        }
264    }
265}
266
267impl Connection<Ready> {
268    /// Get the connection ID.
269    #[must_use]
270    pub fn id(&self) -> &str {
271        &self.inner.id
272    }
273
274    /// Get when the connection was established.
275    #[must_use]
276    pub const fn connected_at(&self) -> Option<Instant> {
277        self.inner.connected_at
278    }
279
280    /// Get how long the connection has been active.
281    #[must_use]
282    pub fn uptime(&self) -> Duration {
283        self.inner
284            .connected_at
285            .map(|t| t.elapsed())
286            .unwrap_or_default()
287    }
288
289    /// Get the last activity timestamp.
290    #[must_use]
291    pub const fn last_activity(&self) -> Option<Instant> {
292        self.inner.last_activity
293    }
294
295    /// Get the client info.
296    ///
297    /// # Panics
298    ///
299    /// This should never panic if the connection was properly initialized,
300    /// as the typestate pattern ensures this is only callable in Ready state.
301    /// Use `try_client_info()` for a fallible version.
302    #[must_use]
303    pub fn client_info(&self) -> &ClientInfo {
304        self.inner
305            .client_info
306            .as_ref()
307            .expect("client_info should be set in Ready state")
308    }
309
310    /// Try to get the client info.
311    ///
312    /// Returns `None` if the client info was not set (should not happen in normal use).
313    #[must_use]
314    pub const fn try_client_info(&self) -> Option<&ClientInfo> {
315        self.inner.client_info.as_ref()
316    }
317
318    /// Get the server info.
319    ///
320    /// # Panics
321    ///
322    /// This should never panic if the connection was properly initialized,
323    /// as the typestate pattern ensures this is only callable in Ready state.
324    /// Use `try_server_info()` for a fallible version.
325    #[must_use]
326    pub fn server_info(&self) -> &ServerInfo {
327        self.inner
328            .server_info
329            .as_ref()
330            .expect("server_info should be set in Ready state")
331    }
332
333    /// Try to get the server info.
334    ///
335    /// Returns `None` if the server info was not set (should not happen in normal use).
336    #[must_use]
337    pub const fn try_server_info(&self) -> Option<&ServerInfo> {
338        self.inner.server_info.as_ref()
339    }
340
341    /// Get the client capabilities.
342    ///
343    /// # Panics
344    ///
345    /// This should never panic if the connection was properly initialized,
346    /// as the typestate pattern ensures this is only callable in Ready state.
347    /// Use `try_client_capabilities()` for a fallible version.
348    #[must_use]
349    pub fn client_capabilities(&self) -> &ClientCapabilities {
350        self.inner
351            .client_capabilities
352            .as_ref()
353            .expect("client_capabilities should be set in Ready state")
354    }
355
356    /// Try to get the client capabilities.
357    ///
358    /// Returns `None` if the client capabilities were not set (should not happen in normal use).
359    #[must_use]
360    pub const fn try_client_capabilities(&self) -> Option<&ClientCapabilities> {
361        self.inner.client_capabilities.as_ref()
362    }
363
364    /// Get the server capabilities.
365    ///
366    /// # Panics
367    ///
368    /// This should never panic if the connection was properly initialized,
369    /// as the typestate pattern ensures this is only callable in Ready state.
370    /// Use `try_server_capabilities()` for a fallible version.
371    #[must_use]
372    pub fn server_capabilities(&self) -> &ServerCapabilities {
373        self.inner
374            .server_capabilities
375            .as_ref()
376            .expect("server_capabilities should be set in Ready state")
377    }
378
379    /// Try to get the server capabilities.
380    ///
381    /// Returns `None` if the server capabilities were not set (should not happen in normal use).
382    #[must_use]
383    pub const fn try_server_capabilities(&self) -> Option<&ServerCapabilities> {
384        self.inner.server_capabilities.as_ref()
385    }
386
387    /// Generate the next request ID.
388    pub fn next_request_id(&mut self) -> RequestId {
389        self.inner.next_request_id()
390    }
391
392    /// Update the last activity timestamp.
393    pub fn touch(&mut self) {
394        self.inner.touch();
395    }
396
397    /// Check if the connection has been idle for longer than the given duration.
398    #[must_use]
399    pub fn is_idle(&self, timeout: Duration) -> bool {
400        self.inner
401            .last_activity
402            .is_some_and(|t| t.elapsed() > timeout)
403    }
404
405    /// Begin shutdown (transition to Closing state).
406    #[must_use]
407    pub fn shutdown(self) -> Connection<Closing> {
408        Connection {
409            inner: self.inner,
410            _state: PhantomData,
411        }
412    }
413}
414
415impl Connection<Closing> {
416    /// Get the connection ID.
417    #[must_use]
418    pub fn id(&self) -> &str {
419        &self.inner.id
420    }
421
422    /// Complete the shutdown (transition to Disconnected state).
423    #[must_use]
424    pub fn close(self) -> Connection<Disconnected> {
425        Connection {
426            inner: ConnectionInner::new(),
427            _state: PhantomData,
428        }
429    }
430}
431
432/// Builder for creating initialize results (used by servers).
433pub struct InitializeResultBuilder {
434    server_info: ServerInfo,
435    capabilities: ServerCapabilities,
436    instructions: Option<String>,
437}
438
439impl InitializeResultBuilder {
440    /// Create a new builder with server info.
441    #[must_use]
442    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
443        Self {
444            server_info: ServerInfo::new(name, version),
445            capabilities: ServerCapabilities::new(),
446            instructions: None,
447        }
448    }
449
450    /// Set the capabilities.
451    #[must_use]
452    pub fn capabilities(mut self, caps: ServerCapabilities) -> Self {
453        self.capabilities = caps;
454        self
455    }
456
457    /// Enable tool support.
458    #[must_use]
459    pub fn with_tools(mut self) -> Self {
460        self.capabilities = self.capabilities.with_tools();
461        self
462    }
463
464    /// Enable resource support.
465    #[must_use]
466    pub fn with_resources(mut self) -> Self {
467        self.capabilities = self.capabilities.with_resources();
468        self
469    }
470
471    /// Enable prompt support.
472    #[must_use]
473    pub fn with_prompts(mut self) -> Self {
474        self.capabilities = self.capabilities.with_prompts();
475        self
476    }
477
478    /// Enable task support.
479    #[must_use]
480    pub fn with_tasks(mut self) -> Self {
481        self.capabilities = self.capabilities.with_tasks();
482        self
483    }
484
485    /// Set instructions.
486    #[must_use]
487    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
488        self.instructions = Some(instructions.into());
489        self
490    }
491
492    /// Build the initialize result.
493    #[must_use]
494    pub fn build(self) -> InitializeResult {
495        let mut result = InitializeResult::new(self.server_info, self.capabilities);
496        if let Some(instructions) = self.instructions {
497            result = result.instructions(instructions);
498        }
499        result
500    }
501}
502
503/// Validate that a connection can transition to the ready state.
504pub const fn validate_initialization(
505    _client_caps: &ClientCapabilities,
506    _server_caps: &ServerCapabilities,
507) -> Result<(), McpError> {
508    // For now, just return Ok. In a real implementation, you might
509    // check for required capability combinations.
510    Ok(())
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_connection_lifecycle() {
519        // Start disconnected
520        let conn: Connection<Disconnected> = Connection::new();
521        assert!(!conn.id().is_empty());
522
523        // Connect
524        let conn: Connection<Connected> = conn.connect();
525        assert!(conn.connected_at().is_some());
526
527        // Initialize
528        let client = ClientInfo::new("test", "1.0.0");
529        let caps = ClientCapabilities::new();
530        let (conn, _request): (Connection<Initializing>, _) = conn.initialize(client, caps);
531        assert!(conn.client_info().is_some());
532
533        // Complete
534        let server = ServerInfo::new("server", "1.0.0");
535        let server_caps = ServerCapabilities::new().with_tools();
536        let mut conn: Connection<Ready> = conn.complete(server, server_caps);
537        assert!(conn.server_capabilities().has_tools());
538
539        // Generate request IDs
540        let id1 = conn.next_request_id();
541        let id2 = conn.next_request_id();
542        assert_ne!(id1, id2);
543
544        // Shutdown
545        let conn: Connection<Closing> = conn.shutdown();
546        let _conn: Connection<Disconnected> = conn.close();
547    }
548
549    #[test]
550    fn test_uptime() {
551        let conn = Connection::new().connect();
552        std::thread::sleep(std::time::Duration::from_millis(10));
553        assert!(conn.uptime() >= std::time::Duration::from_millis(10));
554    }
555
556    #[test]
557    fn test_idle_detection() {
558        let client = ClientInfo::new("test", "1.0.0");
559        let server = ServerInfo::new("server", "1.0.0");
560
561        let (conn, _) = Connection::new()
562            .connect()
563            .initialize(client, ClientCapabilities::new());
564
565        let conn = conn.complete(server, ServerCapabilities::new());
566
567        // Should not be idle immediately
568        assert!(!conn.is_idle(Duration::from_secs(1)));
569    }
570
571    #[test]
572    fn test_initialize_result_builder() {
573        let result = InitializeResultBuilder::new("my-server", "1.0.0")
574            .with_tools()
575            .with_resources()
576            .instructions("Use this server to access tools and resources")
577            .build();
578
579        assert_eq!(result.server_info.name, "my-server");
580        assert!(result.capabilities.has_tools());
581        assert!(result.capabilities.has_resources());
582        assert!(result.instructions.is_some());
583    }
584
585    #[test]
586    fn test_abort_initialization() {
587        let client = ClientInfo::new("test", "1.0.0");
588        let (conn, _) = Connection::new()
589            .connect()
590            .initialize(client, ClientCapabilities::new());
591
592        // Abort should return to disconnected
593        let _conn: Connection<Disconnected> = conn.abort();
594    }
595
596    #[test]
597    fn test_disconnect_from_connected() {
598        let conn = Connection::new().connect();
599        let _conn: Connection<Disconnected> = conn.disconnect();
600    }
601
602    #[test]
603    fn test_fallible_accessors() -> Result<(), Box<dyn std::error::Error>> {
604        let client = ClientInfo::new("test-client", "1.0.0");
605        let server = ServerInfo::new("test-server", "2.0.0");
606        let client_caps = ClientCapabilities::new();
607        let server_caps = ServerCapabilities::new().with_tools();
608
609        let (conn, _) = Connection::new().connect().initialize(client, client_caps);
610
611        let conn = conn.complete(server, server_caps);
612
613        // Test fallible accessors return Some
614        assert!(conn.try_client_info().is_some());
615        assert!(conn.try_server_info().is_some());
616        assert!(conn.try_client_capabilities().is_some());
617        assert!(conn.try_server_capabilities().is_some());
618
619        // Test values are correct
620        assert_eq!(
621            conn.try_client_info().ok_or("Expected client info")?.name,
622            "test-client"
623        );
624        assert_eq!(
625            conn.try_server_info().ok_or("Expected server info")?.name,
626            "test-server"
627        );
628        assert!(
629            conn.try_server_capabilities()
630                .ok_or("Expected server capabilities")?
631                .has_tools()
632        );
633        Ok(())
634    }
635}