Skip to main content

zlayer_agent/
health.rs

1//! Health checking for containers
2
3use crate::error::{AgentError, Result};
4use crate::runtime::ContainerId;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::time::timeout;
8use zlayer_spec::HealthCheck;
9
10/// Callback type for health state changes.
11/// Called with (`container_id`, `is_healthy`) when health state transitions.
12pub type HealthCallback = Arc<dyn Fn(ContainerId, bool) + Send + Sync>;
13
14/// Health checker for containers
15pub struct HealthChecker {
16    pub check: HealthCheck,
17    /// Optional target IP address for health checks (e.g., container overlay IP).
18    /// When set, TCP and HTTP checks connect to this address instead of 127.0.0.1/localhost.
19    target_addr: Option<std::net::IpAddr>,
20}
21
22impl HealthChecker {
23    /// Create a new health checker
24    ///
25    /// `target_addr` is the IP address to connect to for TCP/HTTP checks.
26    /// Pass `Some(ip)` when the container has an overlay IP, or `None` to
27    /// fall back to `127.0.0.1` / localhost.
28    #[must_use]
29    pub fn new(check: HealthCheck, target_addr: Option<std::net::IpAddr>) -> Self {
30        Self { check, target_addr }
31    }
32
33    /// Perform the health check
34    ///
35    /// # Errors
36    /// Returns an error if the health check fails or times out.
37    pub async fn check(&self, id: &ContainerId, timeout: Duration) -> Result<()> {
38        match &self.check {
39            HealthCheck::Tcp { port } => self.check_tcp(id, *port, timeout).await,
40            HealthCheck::Http { url, expect_status } => {
41                self.check_http(id, url, *expect_status, timeout).await
42            }
43            HealthCheck::Command { command } => self.check_command(id, command, timeout).await,
44        }
45    }
46
47    async fn check_tcp(&self, id: &ContainerId, port: u16, timeout_dur: Duration) -> Result<()> {
48        // Connect to the target address (overlay IP if set, otherwise localhost)
49        let host = self
50            .target_addr
51            .map_or_else(|| "127.0.0.1".to_string(), |ip| ip.to_string());
52        let addr = format!("{host}:{port}");
53        match timeout(timeout_dur, tokio::net::TcpStream::connect(&addr)).await {
54            Ok(Ok(_)) => Ok(()),
55            Ok(Err(e)) => Err(AgentError::HealthCheckFailed {
56                id: id.to_string(),
57                reason: format!("TCP connection failed: {e}"),
58            }),
59            Err(_) => Err(AgentError::Timeout {
60                timeout: timeout_dur,
61            }),
62        }
63    }
64
65    async fn check_http(
66        &self,
67        id: &ContainerId,
68        url: &str,
69        expect_status: u16,
70        timeout_dur: Duration,
71    ) -> Result<()> {
72        // If a target address is set, replace localhost / 127.0.0.1 in the URL
73        // so the health check actually reaches the container's overlay IP.
74        let url = if let Some(ip) = self.target_addr {
75            let ip_str = ip.to_string();
76            url.replace("localhost", &ip_str)
77                .replace("127.0.0.1", &ip_str)
78        } else {
79            url.to_string()
80        };
81
82        let client = reqwest::Client::builder()
83            .timeout(Duration::from_secs(5))
84            .build()
85            .map_err(|e| AgentError::HealthCheckFailed {
86                id: id.to_string(),
87                reason: format!("failed to create HTTP client: {e}"),
88            })?;
89
90        match timeout(timeout_dur, client.get(&url).send()).await {
91            Ok(Ok(resp)) => {
92                let status = resp.status().as_u16();
93                if status == expect_status {
94                    Ok(())
95                } else {
96                    Err(AgentError::HealthCheckFailed {
97                        id: id.to_string(),
98                        reason: format!("unexpected status: {status} (expected {expect_status})"),
99                    })
100                }
101            }
102            Ok(Err(e)) => Err(AgentError::HealthCheckFailed {
103                id: id.to_string(),
104                reason: format!("HTTP request failed: {e}"),
105            }),
106            Err(_) => Err(AgentError::Timeout {
107                timeout: timeout_dur,
108            }),
109        }
110    }
111
112    async fn check_command(
113        &self,
114        id: &ContainerId,
115        command: &str,
116        timeout_dur: Duration,
117    ) -> Result<()> {
118        match timeout(
119            timeout_dur,
120            tokio::process::Command::new("sh")
121                .arg("-c")
122                .arg(command)
123                .output(),
124        )
125        .await
126        {
127            Ok(Ok(output)) => {
128                if output.status.success() {
129                    Ok(())
130                } else {
131                    Err(AgentError::HealthCheckFailed {
132                        id: id.to_string(),
133                        reason: format!(
134                            "command failed with code {}: {}",
135                            output.status.code().unwrap_or(-1),
136                            String::from_utf8_lossy(&output.stderr)
137                        ),
138                    })
139                }
140            }
141            Ok(Err(e)) => Err(AgentError::HealthCheckFailed {
142                id: id.to_string(),
143                reason: format!("command execution failed: {e}"),
144            }),
145            Err(_) => Err(AgentError::Timeout {
146                timeout: timeout_dur,
147            }),
148        }
149    }
150}
151
152/// Maximum backoff interval when retries are exhausted (60 seconds).
153const MAX_BACKOFF: Duration = Duration::from_secs(60);
154
155/// Continuous health monitor
156pub struct HealthMonitor {
157    id: ContainerId,
158    checker: HealthChecker,
159    interval: Duration,
160    retries: u32,
161    check_timeout: Duration,
162    start_grace: Duration,
163    state: tokio::sync::RwLock<HealthState>,
164    on_health_change: Option<HealthCallback>,
165}
166
167#[derive(Debug, Clone, PartialEq, Eq)]
168pub enum HealthState {
169    Unknown,
170    Checking,
171    Healthy,
172    Unhealthy { failures: u32, reason: String },
173}
174
175impl HealthMonitor {
176    #[must_use]
177    pub fn new(id: ContainerId, checker: HealthChecker, interval: Duration, retries: u32) -> Self {
178        Self {
179            id,
180            checker,
181            interval,
182            retries,
183            check_timeout: Duration::from_secs(5),
184            start_grace: Duration::ZERO,
185            state: tokio::sync::RwLock::new(HealthState::Unknown),
186            on_health_change: None,
187        }
188    }
189
190    /// Set a callback to be invoked when health state changes (healthy <-> unhealthy).
191    #[must_use]
192    pub fn with_callback(mut self, callback: HealthCallback) -> Self {
193        self.on_health_change = Some(callback);
194        self
195    }
196
197    /// Set a startup grace period. The monitor will sleep for this duration
198    /// before performing the first health check, giving the container time
199    /// to initialize.
200    #[must_use]
201    pub fn with_start_grace(mut self, grace: Duration) -> Self {
202        self.start_grace = grace;
203        self
204    }
205
206    /// Set the timeout applied to each individual health check. Defaults to 5 seconds.
207    #[must_use]
208    pub fn with_check_timeout(mut self, timeout: Duration) -> Self {
209        self.check_timeout = timeout;
210        self
211    }
212
213    /// Start monitoring (spawns background task)
214    pub fn start(self) -> tokio::task::JoinHandle<()> {
215        tokio::spawn(async move {
216            // Startup grace period — let the container initialize before checking
217            if !self.start_grace.is_zero() {
218                tokio::time::sleep(self.start_grace).await;
219            }
220
221            let mut failures = 0u32;
222            let mut was_healthy: Option<bool> = None;
223            let mut current_interval = self.interval;
224
225            loop {
226                // Update state to checking
227                *self.state.write().await = HealthState::Checking;
228
229                match self.checker.check(&self.id, self.check_timeout).await {
230                    Ok(()) => {
231                        failures = 0;
232                        current_interval = self.interval;
233                        *self.state.write().await = HealthState::Healthy;
234
235                        // Check for state transition to healthy
236                        if was_healthy != Some(true) {
237                            if let Some(ref callback) = self.on_health_change {
238                                callback(self.id.clone(), true);
239                            }
240                            was_healthy = Some(true);
241                        }
242                    }
243                    Err(e) => {
244                        failures += 1;
245
246                        *self.state.write().await = HealthState::Unhealthy {
247                            failures,
248                            reason: e.to_string(),
249                        };
250
251                        // Check for state transition to unhealthy
252                        if was_healthy != Some(false) {
253                            if let Some(ref callback) = self.on_health_change {
254                                callback(self.id.clone(), false);
255                            }
256                            was_healthy = Some(false);
257                        }
258
259                        // After exhausting retries, apply exponential backoff
260                        // instead of terminating the monitor
261                        if failures >= self.retries {
262                            current_interval = (current_interval * 2).min(MAX_BACKOFF);
263                        }
264                    }
265                }
266
267                tokio::time::sleep(current_interval).await;
268            }
269        })
270    }
271
272    /// Get current health state
273    pub async fn state(&self) -> HealthState {
274        self.state.read().await.clone()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_health_state() {
284        let state = HealthState::Unhealthy {
285            failures: 3,
286            reason: "connection refused".to_string(),
287        };
288        assert_eq!(
289            state,
290            HealthState::Unhealthy {
291                failures: 3,
292                reason: "connection refused".to_string()
293            }
294        );
295    }
296}