ricecoder_mcp/
lifecycle.rs

1//! Server lifecycle management for MCP servers
2
3use crate::config::MCPServerConfig;
4use crate::error::{Error, Result};
5use crate::health_check::HealthChecker;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::timeout;
10use tracing::{debug, error, info};
11
12/// Server lifecycle state
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ServerState {
15    Stopped,
16    Starting,
17    Running,
18    Stopping,
19    Failed,
20}
21
22/// Server lifecycle information
23#[derive(Debug, Clone)]
24pub struct ServerLifecycleInfo {
25    pub server_id: String,
26    pub state: ServerState,
27    pub started_at: Option<std::time::Instant>,
28    pub stopped_at: Option<std::time::Instant>,
29    pub restart_count: u32,
30    pub last_error: Option<String>,
31}
32
33impl ServerLifecycleInfo {
34    /// Creates a new server lifecycle info
35    pub fn new(server_id: String) -> Self {
36        Self {
37            server_id,
38            state: ServerState::Stopped,
39            started_at: None,
40            stopped_at: None,
41            restart_count: 0,
42            last_error: None,
43        }
44    }
45
46    /// Gets the uptime in milliseconds
47    pub fn uptime_ms(&self) -> Option<u128> {
48        self.started_at.map(|start| start.elapsed().as_millis())
49    }
50
51    /// Checks if the server is running
52    pub fn is_running(&self) -> bool {
53        self.state == ServerState::Running
54    }
55
56    /// Checks if the server has failed
57    pub fn has_failed(&self) -> bool {
58        self.state == ServerState::Failed
59    }
60}
61
62/// Server lifecycle manager
63#[derive(Debug, Clone)]
64pub struct ServerLifecycle {
65    config: Arc<MCPServerConfig>,
66    health_checker: Arc<HealthChecker>,
67    lifecycle_info: Arc<RwLock<ServerLifecycleInfo>>,
68}
69
70impl ServerLifecycle {
71    /// Creates a new server lifecycle manager
72    pub fn new(config: MCPServerConfig, health_checker: Arc<HealthChecker>) -> Self {
73        Self {
74            config: Arc::new(config.clone()),
75            health_checker,
76            lifecycle_info: Arc::new(RwLock::new(ServerLifecycleInfo::new(config.id.clone()))),
77        }
78    }
79
80    /// Starts the server with timeout handling
81    ///
82    /// # Arguments
83    /// * `startup_timeout_ms` - Timeout for server startup in milliseconds
84    ///
85    /// # Returns
86    /// Result indicating success or failure
87    pub async fn start(&self, startup_timeout_ms: Option<u64>) -> Result<()> {
88        let mut info = self.lifecycle_info.write().await;
89
90        if info.state == ServerState::Running {
91            debug!("Server {} is already running", self.config.id);
92            return Ok(());
93        }
94
95        info.state = ServerState::Starting;
96        drop(info);
97
98        debug!("Starting server: {}", self.config.id);
99
100        let timeout_duration = Duration::from_millis(startup_timeout_ms.unwrap_or(self.config.timeout_ms));
101
102        match timeout(timeout_duration, self.perform_startup()).await {
103            Ok(Ok(())) => {
104                let mut info = self.lifecycle_info.write().await;
105                info.state = ServerState::Running;
106                info.started_at = Some(std::time::Instant::now());
107                info.last_error = None;
108
109                self.health_checker.register_server(&self.config.id).await;
110
111                info!("Server started successfully: {}", self.config.id);
112                Ok(())
113            }
114            Ok(Err(e)) => {
115                let mut info = self.lifecycle_info.write().await;
116                info.state = ServerState::Failed;
117                info.last_error = Some(e.to_string());
118
119                error!("Server startup failed: {}: {}", self.config.id, e);
120                Err(e)
121            }
122            Err(_) => {
123                let mut info = self.lifecycle_info.write().await;
124                info.state = ServerState::Failed;
125                let error_msg = format!("Server startup timeout after {}ms", timeout_duration.as_millis());
126                info.last_error = Some(error_msg.clone());
127
128                error!("Server startup timeout: {}", self.config.id);
129                Err(Error::TimeoutError(timeout_duration.as_millis() as u64))
130            }
131        }
132    }
133
134    /// Performs the actual server startup
135    async fn perform_startup(&self) -> Result<()> {
136        // In a real implementation, this would spawn the server process
137        // For now, we simulate successful startup
138        debug!("Performing startup for server: {}", self.config.id);
139        Ok(())
140    }
141
142    /// Shuts down the server and performs cleanup
143    pub async fn shutdown(&self) -> Result<()> {
144        let mut info = self.lifecycle_info.write().await;
145
146        if info.state == ServerState::Stopped {
147            debug!("Server {} is already stopped", self.config.id);
148            return Ok(());
149        }
150
151        info.state = ServerState::Stopping;
152        drop(info);
153
154        debug!("Shutting down server: {}", self.config.id);
155
156        // Unregister from health checker
157        self.health_checker.unregister_server(&self.config.id).await;
158
159        // Perform cleanup
160        self.perform_cleanup().await?;
161
162        let mut info = self.lifecycle_info.write().await;
163        info.state = ServerState::Stopped;
164        info.stopped_at = Some(std::time::Instant::now());
165
166        info!("Server shut down successfully: {}", self.config.id);
167        Ok(())
168    }
169
170    /// Performs cleanup operations
171    async fn perform_cleanup(&self) -> Result<()> {
172        debug!("Performing cleanup for server: {}", self.config.id);
173        // In a real implementation, this would clean up resources
174        Ok(())
175    }
176
177    /// Performs health checking and availability detection
178    pub async fn check_health(&self) -> Result<bool> {
179        debug!("Checking health of server: {}", self.config.id);
180
181        let info = self.lifecycle_info.read().await;
182        if info.state != ServerState::Running {
183            return Ok(false);
184        }
185        drop(info);
186
187        self.health_checker.check_health(&self.config.id).await
188    }
189
190    /// Detects server disconnection
191    pub async fn is_disconnected(&self) -> bool {
192        self.health_checker.is_disconnected(&self.config.id).await
193    }
194
195    /// Implements reconnection with exponential backoff
196    pub async fn reconnect(&self) -> Result<()> {
197        debug!("Attempting to reconnect to server: {}", self.config.id);
198
199        let mut info = self.lifecycle_info.write().await;
200        info.restart_count += 1;
201        drop(info);
202
203        let server_id = self.config.id.clone();
204        let config = self.config.clone();
205
206        self.health_checker
207            .reconnect_with_backoff(&server_id, || {
208                let config = config.clone();
209                Box::pin(async move {
210                    debug!("Attempting reconnection to: {}", config.id);
211                    Ok(())
212                })
213            })
214            .await?;
215
216        info!("Successfully reconnected to server: {}", self.config.id);
217        Ok(())
218    }
219
220    /// Supports configurable max retries
221    pub fn max_retries(&self) -> u32 {
222        self.config.max_retries
223    }
224
225    /// Gets the current lifecycle state
226    pub async fn get_state(&self) -> ServerState {
227        self.lifecycle_info.read().await.state
228    }
229
230    /// Gets the lifecycle information
231    pub async fn get_info(&self) -> ServerLifecycleInfo {
232        self.lifecycle_info.read().await.clone()
233    }
234
235    /// Reports the last error
236    pub async fn get_last_error(&self) -> Option<String> {
237        self.lifecycle_info.read().await.last_error.clone()
238    }
239
240    /// Gets the restart count
241    pub async fn get_restart_count(&self) -> u32 {
242        self.lifecycle_info.read().await.restart_count
243    }
244
245    /// Gets the uptime in milliseconds
246    pub async fn get_uptime_ms(&self) -> Option<u128> {
247        self.lifecycle_info.read().await.uptime_ms()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use std::collections::HashMap;
255
256    fn create_test_config(id: &str) -> MCPServerConfig {
257        MCPServerConfig {
258            id: id.to_string(),
259            name: format!("Test Server {}", id),
260            command: "test".to_string(),
261            args: vec![],
262            env: HashMap::new(),
263            timeout_ms: 5000,
264            auto_reconnect: true,
265            max_retries: 3,
266        }
267    }
268
269    #[tokio::test]
270    async fn test_create_lifecycle() {
271        let config = create_test_config("server1");
272        let health_checker = Arc::new(HealthChecker::new());
273        let lifecycle = ServerLifecycle::new(config, health_checker);
274
275        let info = lifecycle.get_info().await;
276        assert_eq!(info.server_id, "server1");
277        assert_eq!(info.state, ServerState::Stopped);
278        assert_eq!(info.restart_count, 0);
279    }
280
281    #[tokio::test]
282    async fn test_start_server() {
283        let config = create_test_config("server1");
284        let health_checker = Arc::new(HealthChecker::new());
285        let lifecycle = ServerLifecycle::new(config, health_checker);
286
287        let result = lifecycle.start(Some(5000)).await;
288        assert!(result.is_ok());
289
290        let info = lifecycle.get_info().await;
291        assert_eq!(info.state, ServerState::Running);
292        assert!(info.started_at.is_some());
293    }
294
295    #[tokio::test]
296    async fn test_shutdown_server() {
297        let config = create_test_config("server1");
298        let health_checker = Arc::new(HealthChecker::new());
299        let lifecycle = ServerLifecycle::new(config, health_checker);
300
301        lifecycle.start(Some(5000)).await.unwrap();
302        let result = lifecycle.shutdown().await;
303        assert!(result.is_ok());
304
305        let info = lifecycle.get_info().await;
306        assert_eq!(info.state, ServerState::Stopped);
307        assert!(info.stopped_at.is_some());
308    }
309
310    #[tokio::test]
311    async fn test_server_uptime() {
312        let config = create_test_config("server1");
313        let health_checker = Arc::new(HealthChecker::new());
314        let lifecycle = ServerLifecycle::new(config, health_checker);
315
316        lifecycle.start(Some(5000)).await.unwrap();
317
318        let uptime = lifecycle.get_uptime_ms().await;
319        assert!(uptime.is_some());
320    }
321
322    #[tokio::test]
323    async fn test_restart_count() {
324        let config = create_test_config("server1");
325        let health_checker = Arc::new(HealthChecker::new());
326        let lifecycle = ServerLifecycle::new(config, health_checker);
327
328        assert_eq!(lifecycle.get_restart_count().await, 0);
329
330        lifecycle.start(Some(5000)).await.unwrap();
331        lifecycle.reconnect().await.ok();
332
333        let restart_count = lifecycle.get_restart_count().await;
334        assert_eq!(restart_count, 1);
335    }
336
337    #[tokio::test]
338    async fn test_max_retries() {
339        let config = create_test_config("server1");
340        let health_checker = Arc::new(HealthChecker::new());
341        let lifecycle = ServerLifecycle::new(config, health_checker);
342
343        assert_eq!(lifecycle.max_retries(), 3);
344    }
345
346    #[tokio::test]
347    async fn test_is_running() {
348        let config = create_test_config("server1");
349        let health_checker = Arc::new(HealthChecker::new());
350        let lifecycle = ServerLifecycle::new(config, health_checker);
351
352        let info = lifecycle.get_info().await;
353        assert!(!info.is_running());
354
355        lifecycle.start(Some(5000)).await.unwrap();
356        let info = lifecycle.get_info().await;
357        assert!(info.is_running());
358    }
359
360    #[tokio::test]
361    async fn test_lifecycle_info_uptime() {
362        let mut info = ServerLifecycleInfo::new("server1".to_string());
363        assert!(info.uptime_ms().is_none());
364
365        info.started_at = Some(std::time::Instant::now());
366        assert!(info.uptime_ms().is_some());
367    }
368
369    #[tokio::test]
370    async fn test_double_start() {
371        let config = create_test_config("server1");
372        let health_checker = Arc::new(HealthChecker::new());
373        let lifecycle = ServerLifecycle::new(config, health_checker);
374
375        lifecycle.start(Some(5000)).await.unwrap();
376        let result = lifecycle.start(Some(5000)).await;
377        assert!(result.is_ok());
378    }
379
380    #[tokio::test]
381    async fn test_double_shutdown() {
382        let config = create_test_config("server1");
383        let health_checker = Arc::new(HealthChecker::new());
384        let lifecycle = ServerLifecycle::new(config, health_checker);
385
386        lifecycle.start(Some(5000)).await.unwrap();
387        lifecycle.shutdown().await.unwrap();
388        let result = lifecycle.shutdown().await;
389        assert!(result.is_ok());
390    }
391}