ricecoder_mcp/
lifecycle.rs1use crate::config::MCPServerConfig;
4use crate::error::{Error, Result};
5use crate::health_check::HealthChecker;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::timeout;
10use tracing::{debug, error, info};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ServerState {
15 Stopped,
16 Starting,
17 Running,
18 Stopping,
19 Failed,
20}
21
22#[derive(Debug, Clone)]
24pub struct ServerLifecycleInfo {
25 pub server_id: String,
26 pub state: ServerState,
27 pub started_at: Option<std::time::Instant>,
28 pub stopped_at: Option<std::time::Instant>,
29 pub restart_count: u32,
30 pub last_error: Option<String>,
31}
32
33impl ServerLifecycleInfo {
34 pub fn new(server_id: String) -> Self {
36 Self {
37 server_id,
38 state: ServerState::Stopped,
39 started_at: None,
40 stopped_at: None,
41 restart_count: 0,
42 last_error: None,
43 }
44 }
45
46 pub fn uptime_ms(&self) -> Option<u128> {
48 self.started_at.map(|start| start.elapsed().as_millis())
49 }
50
51 pub fn is_running(&self) -> bool {
53 self.state == ServerState::Running
54 }
55
56 pub fn has_failed(&self) -> bool {
58 self.state == ServerState::Failed
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ServerLifecycle {
65 config: Arc<MCPServerConfig>,
66 health_checker: Arc<HealthChecker>,
67 lifecycle_info: Arc<RwLock<ServerLifecycleInfo>>,
68}
69
70impl ServerLifecycle {
71 pub fn new(config: MCPServerConfig, health_checker: Arc<HealthChecker>) -> Self {
73 Self {
74 config: Arc::new(config.clone()),
75 health_checker,
76 lifecycle_info: Arc::new(RwLock::new(ServerLifecycleInfo::new(config.id.clone()))),
77 }
78 }
79
80 pub async fn start(&self, startup_timeout_ms: Option<u64>) -> Result<()> {
88 let mut info = self.lifecycle_info.write().await;
89
90 if info.state == ServerState::Running {
91 debug!("Server {} is already running", self.config.id);
92 return Ok(());
93 }
94
95 info.state = ServerState::Starting;
96 drop(info);
97
98 debug!("Starting server: {}", self.config.id);
99
100 let timeout_duration = Duration::from_millis(startup_timeout_ms.unwrap_or(self.config.timeout_ms));
101
102 match timeout(timeout_duration, self.perform_startup()).await {
103 Ok(Ok(())) => {
104 let mut info = self.lifecycle_info.write().await;
105 info.state = ServerState::Running;
106 info.started_at = Some(std::time::Instant::now());
107 info.last_error = None;
108
109 self.health_checker.register_server(&self.config.id).await;
110
111 info!("Server started successfully: {}", self.config.id);
112 Ok(())
113 }
114 Ok(Err(e)) => {
115 let mut info = self.lifecycle_info.write().await;
116 info.state = ServerState::Failed;
117 info.last_error = Some(e.to_string());
118
119 error!("Server startup failed: {}: {}", self.config.id, e);
120 Err(e)
121 }
122 Err(_) => {
123 let mut info = self.lifecycle_info.write().await;
124 info.state = ServerState::Failed;
125 let error_msg = format!("Server startup timeout after {}ms", timeout_duration.as_millis());
126 info.last_error = Some(error_msg.clone());
127
128 error!("Server startup timeout: {}", self.config.id);
129 Err(Error::TimeoutError(timeout_duration.as_millis() as u64))
130 }
131 }
132 }
133
134 async fn perform_startup(&self) -> Result<()> {
136 debug!("Performing startup for server: {}", self.config.id);
139 Ok(())
140 }
141
142 pub async fn shutdown(&self) -> Result<()> {
144 let mut info = self.lifecycle_info.write().await;
145
146 if info.state == ServerState::Stopped {
147 debug!("Server {} is already stopped", self.config.id);
148 return Ok(());
149 }
150
151 info.state = ServerState::Stopping;
152 drop(info);
153
154 debug!("Shutting down server: {}", self.config.id);
155
156 self.health_checker.unregister_server(&self.config.id).await;
158
159 self.perform_cleanup().await?;
161
162 let mut info = self.lifecycle_info.write().await;
163 info.state = ServerState::Stopped;
164 info.stopped_at = Some(std::time::Instant::now());
165
166 info!("Server shut down successfully: {}", self.config.id);
167 Ok(())
168 }
169
170 async fn perform_cleanup(&self) -> Result<()> {
172 debug!("Performing cleanup for server: {}", self.config.id);
173 Ok(())
175 }
176
177 pub async fn check_health(&self) -> Result<bool> {
179 debug!("Checking health of server: {}", self.config.id);
180
181 let info = self.lifecycle_info.read().await;
182 if info.state != ServerState::Running {
183 return Ok(false);
184 }
185 drop(info);
186
187 self.health_checker.check_health(&self.config.id).await
188 }
189
190 pub async fn is_disconnected(&self) -> bool {
192 self.health_checker.is_disconnected(&self.config.id).await
193 }
194
195 pub async fn reconnect(&self) -> Result<()> {
197 debug!("Attempting to reconnect to server: {}", self.config.id);
198
199 let mut info = self.lifecycle_info.write().await;
200 info.restart_count += 1;
201 drop(info);
202
203 let server_id = self.config.id.clone();
204 let config = self.config.clone();
205
206 self.health_checker
207 .reconnect_with_backoff(&server_id, || {
208 let config = config.clone();
209 Box::pin(async move {
210 debug!("Attempting reconnection to: {}", config.id);
211 Ok(())
212 })
213 })
214 .await?;
215
216 info!("Successfully reconnected to server: {}", self.config.id);
217 Ok(())
218 }
219
220 pub fn max_retries(&self) -> u32 {
222 self.config.max_retries
223 }
224
225 pub async fn get_state(&self) -> ServerState {
227 self.lifecycle_info.read().await.state
228 }
229
230 pub async fn get_info(&self) -> ServerLifecycleInfo {
232 self.lifecycle_info.read().await.clone()
233 }
234
235 pub async fn get_last_error(&self) -> Option<String> {
237 self.lifecycle_info.read().await.last_error.clone()
238 }
239
240 pub async fn get_restart_count(&self) -> u32 {
242 self.lifecycle_info.read().await.restart_count
243 }
244
245 pub async fn get_uptime_ms(&self) -> Option<u128> {
247 self.lifecycle_info.read().await.uptime_ms()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use std::collections::HashMap;
255
256 fn create_test_config(id: &str) -> MCPServerConfig {
257 MCPServerConfig {
258 id: id.to_string(),
259 name: format!("Test Server {}", id),
260 command: "test".to_string(),
261 args: vec![],
262 env: HashMap::new(),
263 timeout_ms: 5000,
264 auto_reconnect: true,
265 max_retries: 3,
266 }
267 }
268
269 #[tokio::test]
270 async fn test_create_lifecycle() {
271 let config = create_test_config("server1");
272 let health_checker = Arc::new(HealthChecker::new());
273 let lifecycle = ServerLifecycle::new(config, health_checker);
274
275 let info = lifecycle.get_info().await;
276 assert_eq!(info.server_id, "server1");
277 assert_eq!(info.state, ServerState::Stopped);
278 assert_eq!(info.restart_count, 0);
279 }
280
281 #[tokio::test]
282 async fn test_start_server() {
283 let config = create_test_config("server1");
284 let health_checker = Arc::new(HealthChecker::new());
285 let lifecycle = ServerLifecycle::new(config, health_checker);
286
287 let result = lifecycle.start(Some(5000)).await;
288 assert!(result.is_ok());
289
290 let info = lifecycle.get_info().await;
291 assert_eq!(info.state, ServerState::Running);
292 assert!(info.started_at.is_some());
293 }
294
295 #[tokio::test]
296 async fn test_shutdown_server() {
297 let config = create_test_config("server1");
298 let health_checker = Arc::new(HealthChecker::new());
299 let lifecycle = ServerLifecycle::new(config, health_checker);
300
301 lifecycle.start(Some(5000)).await.unwrap();
302 let result = lifecycle.shutdown().await;
303 assert!(result.is_ok());
304
305 let info = lifecycle.get_info().await;
306 assert_eq!(info.state, ServerState::Stopped);
307 assert!(info.stopped_at.is_some());
308 }
309
310 #[tokio::test]
311 async fn test_server_uptime() {
312 let config = create_test_config("server1");
313 let health_checker = Arc::new(HealthChecker::new());
314 let lifecycle = ServerLifecycle::new(config, health_checker);
315
316 lifecycle.start(Some(5000)).await.unwrap();
317
318 let uptime = lifecycle.get_uptime_ms().await;
319 assert!(uptime.is_some());
320 }
321
322 #[tokio::test]
323 async fn test_restart_count() {
324 let config = create_test_config("server1");
325 let health_checker = Arc::new(HealthChecker::new());
326 let lifecycle = ServerLifecycle::new(config, health_checker);
327
328 assert_eq!(lifecycle.get_restart_count().await, 0);
329
330 lifecycle.start(Some(5000)).await.unwrap();
331 lifecycle.reconnect().await.ok();
332
333 let restart_count = lifecycle.get_restart_count().await;
334 assert_eq!(restart_count, 1);
335 }
336
337 #[tokio::test]
338 async fn test_max_retries() {
339 let config = create_test_config("server1");
340 let health_checker = Arc::new(HealthChecker::new());
341 let lifecycle = ServerLifecycle::new(config, health_checker);
342
343 assert_eq!(lifecycle.max_retries(), 3);
344 }
345
346 #[tokio::test]
347 async fn test_is_running() {
348 let config = create_test_config("server1");
349 let health_checker = Arc::new(HealthChecker::new());
350 let lifecycle = ServerLifecycle::new(config, health_checker);
351
352 let info = lifecycle.get_info().await;
353 assert!(!info.is_running());
354
355 lifecycle.start(Some(5000)).await.unwrap();
356 let info = lifecycle.get_info().await;
357 assert!(info.is_running());
358 }
359
360 #[tokio::test]
361 async fn test_lifecycle_info_uptime() {
362 let mut info = ServerLifecycleInfo::new("server1".to_string());
363 assert!(info.uptime_ms().is_none());
364
365 info.started_at = Some(std::time::Instant::now());
366 assert!(info.uptime_ms().is_some());
367 }
368
369 #[tokio::test]
370 async fn test_double_start() {
371 let config = create_test_config("server1");
372 let health_checker = Arc::new(HealthChecker::new());
373 let lifecycle = ServerLifecycle::new(config, health_checker);
374
375 lifecycle.start(Some(5000)).await.unwrap();
376 let result = lifecycle.start(Some(5000)).await;
377 assert!(result.is_ok());
378 }
379
380 #[tokio::test]
381 async fn test_double_shutdown() {
382 let config = create_test_config("server1");
383 let health_checker = Arc::new(HealthChecker::new());
384 let lifecycle = ServerLifecycle::new(config, health_checker);
385
386 lifecycle.start(Some(5000)).await.unwrap();
387 lifecycle.shutdown().await.unwrap();
388 let result = lifecycle.shutdown().await;
389 assert!(result.is_ok());
390 }
391}