mcp_protocol_sdk/server/
lifecycle.rs

1//! Server lifecycle management
2//!
3//! This module handles the lifecycle events and state management for MCP servers,
4//! including initialization, running state, graceful shutdown, and error recovery.
5
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{watch, Mutex, RwLock};
9use tokio::time::timeout;
10
11use crate::core::error::{McpError, McpResult};
12use crate::server::mcp_server::McpServer;
13
14/// Server lifecycle state
15#[derive(Debug, Clone, PartialEq)]
16pub enum LifecycleState {
17    /// Server is created but not yet started
18    Created,
19    /// Server is starting up
20    Starting,
21    /// Server is running and ready to handle requests
22    Running,
23    /// Server is gracefully shutting down
24    Stopping,
25    /// Server has stopped
26    Stopped,
27    /// Server encountered an error
28    Error(String),
29}
30
31/// Server lifecycle manager
32#[derive(Clone)]
33pub struct LifecycleManager {
34    /// Current lifecycle state
35    state: Arc<RwLock<LifecycleState>>,
36    /// State change broadcaster
37    state_tx: Arc<watch::Sender<LifecycleState>>,
38    /// State change receiver
39    state_rx: watch::Receiver<LifecycleState>,
40    /// Server start time
41    start_time: Arc<Mutex<Option<Instant>>>,
42    /// Shutdown signal
43    shutdown_tx: Arc<Mutex<Option<watch::Sender<()>>>>,
44}
45
46/// Lifecycle event listener
47pub trait LifecycleListener: Send + Sync {
48    /// Called when the server state changes
49    fn on_state_change(&self, old_state: LifecycleState, new_state: LifecycleState);
50
51    /// Called when the server starts
52    fn on_start(&self) {}
53
54    /// Called when the server stops
55    fn on_stop(&self) {}
56
57    /// Called when an error occurs
58    fn on_error(&self, _error: &McpError) {}
59}
60
61impl Default for LifecycleManager {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl LifecycleManager {
68    /// Create a new lifecycle manager
69    pub fn new() -> Self {
70        let (state_tx, state_rx) = watch::channel(LifecycleState::Created);
71
72        Self {
73            state: Arc::new(RwLock::new(LifecycleState::Created)),
74            state_tx: Arc::new(state_tx),
75            state_rx,
76            start_time: Arc::new(Mutex::new(None)),
77            shutdown_tx: Arc::new(Mutex::new(None)),
78        }
79    }
80
81    /// Get the current lifecycle state
82    pub async fn state(&self) -> LifecycleState {
83        let state = self.state.read().await;
84        state.clone()
85    }
86
87    /// Subscribe to state changes
88    pub fn subscribe(&self) -> watch::Receiver<LifecycleState> {
89        self.state_rx.clone()
90    }
91
92    /// Transition to a new state
93    pub async fn transition_to(&self, new_state: LifecycleState) -> McpResult<()> {
94        let _old_state = {
95            let mut state = self.state.write().await;
96            let old = state.clone();
97            *state = new_state.clone();
98            old
99        };
100
101        // Broadcast the state change
102        if self.state_tx.send(new_state.clone()).is_err() {
103            // Receiver may have been dropped, which is okay
104        }
105
106        // Handle special state transitions
107        match new_state {
108            LifecycleState::Running => {
109                let mut start_time = self.start_time.lock().await;
110                *start_time = Some(Instant::now());
111            }
112            LifecycleState::Stopped => {
113                let mut start_time = self.start_time.lock().await;
114                *start_time = None;
115            }
116            _ => {}
117        }
118
119        Ok(())
120    }
121
122    /// Check if the server is in a running state
123    pub async fn is_running(&self) -> bool {
124        let state = self.state.read().await;
125        matches!(*state, LifecycleState::Running)
126    }
127
128    /// Check if the server can be started
129    pub async fn can_start(&self) -> bool {
130        let state = self.state.read().await;
131        matches!(*state, LifecycleState::Created | LifecycleState::Stopped)
132    }
133
134    /// Check if the server can be stopped
135    pub async fn can_stop(&self) -> bool {
136        let state = self.state.read().await;
137        matches!(*state, LifecycleState::Running | LifecycleState::Starting)
138    }
139
140    /// Get server uptime
141    pub async fn uptime(&self) -> Option<Duration> {
142        let start_time = self.start_time.lock().await;
143        start_time.map(|start| start.elapsed())
144    }
145
146    /// Create a shutdown signal
147    pub async fn create_shutdown_signal(&self) -> watch::Receiver<()> {
148        let (tx, rx) = watch::channel(());
149        let mut shutdown_tx = self.shutdown_tx.lock().await;
150        *shutdown_tx = Some(tx);
151        rx
152    }
153
154    /// Trigger shutdown
155    pub async fn trigger_shutdown(&self) -> McpResult<()> {
156        let shutdown_tx = self.shutdown_tx.lock().await;
157        if let Some(tx) = shutdown_tx.as_ref() {
158            let _ = tx.send(()); // Ignore error if receiver is dropped
159        }
160        Ok(())
161    }
162}
163
164/// Server runner that manages the complete lifecycle
165pub struct ServerRunner {
166    /// The MCP server instance
167    server: Arc<Mutex<McpServer>>,
168    /// Lifecycle manager
169    lifecycle: LifecycleManager,
170    /// Lifecycle listeners
171    listeners: Arc<RwLock<Vec<Box<dyn LifecycleListener>>>>,
172}
173
174impl ServerRunner {
175    /// Create a new server runner
176    pub fn new(server: McpServer) -> Self {
177        Self {
178            server: Arc::new(Mutex::new(server)),
179            lifecycle: LifecycleManager::new(),
180            listeners: Arc::new(RwLock::new(Vec::new())),
181        }
182    }
183
184    /// Add a lifecycle listener
185    pub async fn add_listener<L>(&self, listener: L)
186    where
187        L: LifecycleListener + 'static,
188    {
189        let mut listeners = self.listeners.write().await;
190        listeners.push(Box::new(listener));
191    }
192
193    /// Get the lifecycle manager
194    pub fn lifecycle(&self) -> &LifecycleManager {
195        &self.lifecycle
196    }
197
198    /// Start the server with a transport
199    pub async fn start<T>(&self, transport: T) -> McpResult<()>
200    where
201        T: crate::transport::traits::ServerTransport + 'static,
202    {
203        // Check if we can start
204        if !self.lifecycle.can_start().await {
205            return Err(McpError::Protocol(
206                "Server cannot be started in current state".to_string(),
207            ));
208        }
209
210        // Transition to starting state
211        self.lifecycle
212            .transition_to(LifecycleState::Starting)
213            .await?;
214
215        // Notify listeners
216        self.notify_listeners(|listener| listener.on_start()).await;
217
218        // Start the server
219        let result = {
220            let mut server = self.server.lock().await;
221            server.start(transport).await
222        };
223
224        match result {
225            Ok(()) => {
226                // Transition to running state
227                self.lifecycle
228                    .transition_to(LifecycleState::Running)
229                    .await?;
230                Ok(())
231            }
232            Err(err) => {
233                // Transition to error state
234                let error_msg = err.to_string();
235                self.lifecycle
236                    .transition_to(LifecycleState::Error(error_msg.clone()))
237                    .await?;
238
239                // Notify listeners
240                self.notify_listeners(|listener| listener.on_error(&err))
241                    .await;
242
243                Err(err)
244            }
245        }
246    }
247
248    /// Stop the server gracefully
249    pub async fn stop(&self) -> McpResult<()> {
250        // Check if we can stop
251        if !self.lifecycle.can_stop().await {
252            return Err(McpError::Protocol(
253                "Server cannot be stopped in current state".to_string(),
254            ));
255        }
256
257        // Transition to stopping state
258        self.lifecycle
259            .transition_to(LifecycleState::Stopping)
260            .await?;
261
262        // Stop the server
263        let result = {
264            let server = self.server.lock().await;
265            server.stop().await
266        };
267
268        match result {
269            Ok(()) => {
270                // Transition to stopped state
271                self.lifecycle
272                    .transition_to(LifecycleState::Stopped)
273                    .await?;
274
275                // Notify listeners
276                self.notify_listeners(|listener| listener.on_stop()).await;
277
278                Ok(())
279            }
280            Err(err) => {
281                // Transition to error state
282                let error_msg = err.to_string();
283                self.lifecycle
284                    .transition_to(LifecycleState::Error(error_msg.clone()))
285                    .await?;
286
287                // Notify listeners
288                self.notify_listeners(|listener| listener.on_error(&err))
289                    .await;
290
291                Err(err)
292            }
293        }
294    }
295
296    /// Stop the server with a timeout
297    pub async fn stop_with_timeout(&self, shutdown_timeout: Duration) -> McpResult<()> {
298        match timeout(shutdown_timeout, self.stop()).await {
299            Ok(result) => result,
300            Err(_) => {
301                // Force stop if timeout exceeded
302                self.lifecycle
303                    .transition_to(LifecycleState::Error(
304                        "Shutdown timeout exceeded".to_string(),
305                    ))
306                    .await?;
307                Err(McpError::Protocol(
308                    "Server shutdown timeout exceeded".to_string(),
309                ))
310            }
311        }
312    }
313
314    /// Run the server until shutdown signal
315    pub async fn run_until_shutdown<T>(&self, transport: T) -> McpResult<()>
316    where
317        T: crate::transport::traits::ServerTransport + 'static,
318    {
319        // Start the server
320        self.start(transport).await?;
321
322        // Wait for shutdown signal
323        let mut shutdown_rx = self.lifecycle.create_shutdown_signal().await;
324        let _ = shutdown_rx.changed().await;
325
326        // Stop the server
327        self.stop().await?;
328
329        Ok(())
330    }
331
332    /// Run the server with graceful shutdown on CTRL+C
333    pub async fn run_with_signals<T>(&self, transport: T) -> McpResult<()>
334    where
335        T: crate::transport::traits::ServerTransport + 'static,
336    {
337        // Start the server
338        self.start(transport).await?;
339
340        // Set up signal handling
341        let lifecycle = self.lifecycle.clone();
342        tokio::spawn(async move {
343            tokio::signal::ctrl_c()
344                .await
345                .expect("Failed to listen for ctrl+c");
346            let _ = lifecycle.trigger_shutdown().await;
347        });
348
349        // Wait for shutdown signal
350        let mut shutdown_rx = self.lifecycle.create_shutdown_signal().await;
351        let _ = shutdown_rx.changed().await;
352
353        // Stop the server gracefully
354        let config = {
355            let server = self.server.lock().await;
356            server.config().clone()
357        };
358
359        let shutdown_timeout = Duration::from_millis(config.request_timeout_ms * 2);
360        self.stop_with_timeout(shutdown_timeout).await?;
361
362        Ok(())
363    }
364
365    /// Get the server instance (for advanced usage)
366    pub fn server(&self) -> Arc<Mutex<McpServer>> {
367        self.server.clone()
368    }
369
370    /// Check if the server is running
371    pub async fn is_running(&self) -> bool {
372        self.lifecycle.is_running().await
373    }
374
375    /// Get server uptime
376    pub async fn uptime(&self) -> Option<Duration> {
377        self.lifecycle.uptime().await
378    }
379
380    /// Restart the server
381    pub async fn restart<T>(&self, transport: T) -> McpResult<()>
382    where
383        T: crate::transport::traits::ServerTransport + 'static,
384    {
385        // Stop if running
386        if self.is_running().await {
387            self.stop().await?;
388        }
389
390        // Start with new transport
391        self.start(transport).await?;
392
393        Ok(())
394    }
395
396    /// Wait for the server to reach a specific state
397    pub async fn wait_for_state(
398        &self,
399        target_state: LifecycleState,
400        timeout_duration: Option<Duration>,
401    ) -> McpResult<()> {
402        let mut state_rx = self.lifecycle.subscribe();
403
404        // Check current state first
405        if *state_rx.borrow() == target_state {
406            return Ok(());
407        }
408
409        let wait_future = async {
410            while state_rx.changed().await.is_ok() {
411                if *state_rx.borrow() == target_state {
412                    return Ok(());
413                }
414            }
415            Err(McpError::Protocol(
416                "State change channel closed".to_string(),
417            ))
418        };
419
420        match timeout_duration {
421            Some(duration) => timeout(duration, wait_future)
422                .await
423                .map_err(|_| McpError::Protocol("Timeout waiting for state change".to_string()))?,
424            None => wait_future.await,
425        }
426    }
427
428    /// Notify all listeners with a closure
429    pub async fn notify_listeners<F>(&self, f: F)
430    where
431        F: Fn(&dyn LifecycleListener) + Send + Sync,
432    {
433        let listeners = self.listeners.read().await;
434        for listener in listeners.iter() {
435            f(listener.as_ref());
436        }
437    }
438}
439
440/// Default lifecycle listener that logs state changes
441pub struct LoggingLifecycleListener;
442
443impl LifecycleListener for LoggingLifecycleListener {
444    fn on_state_change(&self, old_state: LifecycleState, new_state: LifecycleState) {
445        tracing::info!("Server state changed: {:?} -> {:?}", old_state, new_state);
446    }
447
448    fn on_start(&self) {
449        tracing::info!("Server started");
450    }
451
452    fn on_stop(&self) {
453        tracing::info!("Server stopped");
454    }
455
456    fn on_error(&self, error: &McpError) {
457        tracing::error!("Server error: {}", error);
458    }
459}
460
461/// Health check information
462#[derive(Debug, Clone)]
463pub struct HealthStatus {
464    /// Current lifecycle state
465    pub state: LifecycleState,
466    /// Server uptime
467    pub uptime: Option<Duration>,
468    /// Whether the server is healthy
469    pub healthy: bool,
470    /// Additional health details
471    pub details: std::collections::HashMap<String, String>,
472}
473
474impl ServerRunner {
475    /// Get health status
476    pub async fn health_status(&self) -> HealthStatus {
477        let state = self.lifecycle.state().await;
478        let uptime = self.lifecycle.uptime().await;
479        let healthy = matches!(state, LifecycleState::Running);
480
481        let mut details = std::collections::HashMap::new();
482        details.insert("state".to_string(), format!("{:?}", state));
483
484        if let Some(uptime) = uptime {
485            details.insert("uptime_seconds".to_string(), uptime.as_secs().to_string());
486        }
487
488        HealthStatus {
489            state,
490            uptime,
491            healthy,
492            details,
493        }
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use crate::server::McpServer;
501
502    #[tokio::test]
503    async fn test_lifecycle_manager() {
504        let manager = LifecycleManager::new();
505
506        // Initial state should be Created
507        assert_eq!(manager.state().await, LifecycleState::Created);
508        assert!(manager.can_start().await);
509        assert!(!manager.can_stop().await);
510        assert!(!manager.is_running().await);
511
512        // Transition to Running
513        manager
514            .transition_to(LifecycleState::Running)
515            .await
516            .unwrap();
517        assert_eq!(manager.state().await, LifecycleState::Running);
518        assert!(!manager.can_start().await);
519        assert!(manager.can_stop().await);
520        assert!(manager.is_running().await);
521        assert!(manager.uptime().await.is_some());
522
523        // Transition to Stopped
524        manager
525            .transition_to(LifecycleState::Stopped)
526            .await
527            .unwrap();
528        assert_eq!(manager.state().await, LifecycleState::Stopped);
529        assert!(manager.can_start().await);
530        assert!(!manager.can_stop().await);
531        assert!(!manager.is_running().await);
532        assert!(manager.uptime().await.is_none());
533    }
534
535    #[tokio::test]
536    async fn test_server_runner() {
537        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
538        let runner = ServerRunner::new(server);
539
540        // Initial state
541        assert!(!runner.is_running().await);
542        assert_eq!(runner.lifecycle().state().await, LifecycleState::Created);
543
544        // Add a logging listener
545        runner.add_listener(LoggingLifecycleListener).await;
546
547        // Test health status
548        let health = runner.health_status().await;
549        assert_eq!(health.state, LifecycleState::Created);
550        assert!(!health.healthy);
551    }
552
553    #[tokio::test]
554    async fn test_state_subscription() {
555        let manager = LifecycleManager::new();
556        let mut state_rx = manager.subscribe();
557
558        // Initial state
559        assert_eq!(*state_rx.borrow(), LifecycleState::Created);
560
561        // Change state
562        manager
563            .transition_to(LifecycleState::Running)
564            .await
565            .unwrap();
566
567        // Wait for change
568        state_rx.changed().await.unwrap();
569        assert_eq!(*state_rx.borrow(), LifecycleState::Running);
570    }
571
572    #[tokio::test]
573    async fn test_shutdown_signal() {
574        let manager = LifecycleManager::new();
575        let mut shutdown_rx = manager.create_shutdown_signal().await;
576
577        // Trigger shutdown
578        manager.trigger_shutdown().await.unwrap();
579
580        // Wait for signal
581        shutdown_rx.changed().await.unwrap();
582    }
583
584    struct TestLifecycleListener {
585        events: Arc<Mutex<Vec<String>>>,
586    }
587
588    impl TestLifecycleListener {
589        fn new() -> (Self, Arc<Mutex<Vec<String>>>) {
590            let events = Arc::new(Mutex::new(Vec::new()));
591            let listener = Self {
592                events: events.clone(),
593            };
594            (listener, events)
595        }
596    }
597
598    impl LifecycleListener for TestLifecycleListener {
599        fn on_state_change(&self, old_state: LifecycleState, new_state: LifecycleState) {
600            // Use blocking approach for test to avoid race conditions
601            if let Ok(mut events) = self.events.try_lock() {
602                events.push(format!("state_change: {:?} -> {:?}", old_state, new_state));
603            }
604        }
605
606        fn on_start(&self) {
607            if let Ok(mut events) = self.events.try_lock() {
608                events.push("start".to_string());
609            }
610        }
611
612        fn on_stop(&self) {
613            if let Ok(mut events) = self.events.try_lock() {
614                events.push("stop".to_string());
615            }
616        }
617
618        fn on_error(&self, error: &McpError) {
619            if let Ok(mut events) = self.events.try_lock() {
620                events.push(format!("error: {}", error));
621            }
622        }
623    }
624
625    #[tokio::test]
626    async fn test_lifecycle_listeners() {
627        let server = McpServer::new("test-server".to_string(), "1.0.0".to_string());
628        let runner = ServerRunner::new(server);
629
630        let (listener, events) = TestLifecycleListener::new();
631        runner.add_listener(listener).await;
632
633        // Test calling listeners directly via the notify method
634        // Since transition_to doesn't call listeners, we need to test the actual listener functionality
635        runner
636            .notify_listeners(|listener| {
637                listener.on_state_change(LifecycleState::Created, LifecycleState::Running);
638            })
639            .await;
640
641        // Check that events were captured
642        let events = events.lock().await;
643        assert!(
644            events.len() > 0,
645            "Expected at least one event, but got: {:?}",
646            *events
647        );
648
649        // Verify the specific event was captured
650        let has_state_change = events.iter().any(|event| event.contains("state_change"));
651        assert!(
652            has_state_change,
653            "Expected state_change event, but got: {:?}",
654            *events
655        );
656    }
657}