1use crate::error::{AgentError, Result};
4use crate::runtime::ContainerId;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::time::timeout;
8use zlayer_spec::HealthCheck;
9
10pub type HealthCallback = Arc<dyn Fn(ContainerId, bool) + Send + Sync>;
13
14pub struct HealthChecker {
16 pub check: HealthCheck,
17 target_addr: Option<std::net::IpAddr>,
20}
21
22impl HealthChecker {
23 #[must_use]
29 pub fn new(check: HealthCheck, target_addr: Option<std::net::IpAddr>) -> Self {
30 Self { check, target_addr }
31 }
32
33 pub async fn check(&self, id: &ContainerId, timeout: Duration) -> Result<()> {
38 match &self.check {
39 HealthCheck::Tcp { port } => self.check_tcp(id, *port, timeout).await,
40 HealthCheck::Http { url, expect_status } => {
41 self.check_http(id, url, *expect_status, timeout).await
42 }
43 HealthCheck::Command { command } => self.check_command(id, command, timeout).await,
44 }
45 }
46
47 async fn check_tcp(&self, id: &ContainerId, port: u16, timeout_dur: Duration) -> Result<()> {
48 let host = self
50 .target_addr
51 .map_or_else(|| "127.0.0.1".to_string(), |ip| ip.to_string());
52 let addr = format!("{host}:{port}");
53 match timeout(timeout_dur, tokio::net::TcpStream::connect(&addr)).await {
54 Ok(Ok(_)) => Ok(()),
55 Ok(Err(e)) => Err(AgentError::HealthCheckFailed {
56 id: id.to_string(),
57 reason: format!("TCP connection failed: {e}"),
58 }),
59 Err(_) => Err(AgentError::Timeout {
60 timeout: timeout_dur,
61 }),
62 }
63 }
64
65 async fn check_http(
66 &self,
67 id: &ContainerId,
68 url: &str,
69 expect_status: u16,
70 timeout_dur: Duration,
71 ) -> Result<()> {
72 let url = if let Some(ip) = self.target_addr {
75 let ip_str = ip.to_string();
76 url.replace("localhost", &ip_str)
77 .replace("127.0.0.1", &ip_str)
78 } else {
79 url.to_string()
80 };
81
82 let client = reqwest::Client::builder()
83 .timeout(Duration::from_secs(5))
84 .build()
85 .map_err(|e| AgentError::HealthCheckFailed {
86 id: id.to_string(),
87 reason: format!("failed to create HTTP client: {e}"),
88 })?;
89
90 match timeout(timeout_dur, client.get(&url).send()).await {
91 Ok(Ok(resp)) => {
92 let status = resp.status().as_u16();
93 if status == expect_status {
94 Ok(())
95 } else {
96 Err(AgentError::HealthCheckFailed {
97 id: id.to_string(),
98 reason: format!("unexpected status: {status} (expected {expect_status})"),
99 })
100 }
101 }
102 Ok(Err(e)) => Err(AgentError::HealthCheckFailed {
103 id: id.to_string(),
104 reason: format!("HTTP request failed: {e}"),
105 }),
106 Err(_) => Err(AgentError::Timeout {
107 timeout: timeout_dur,
108 }),
109 }
110 }
111
112 async fn check_command(
113 &self,
114 id: &ContainerId,
115 command: &str,
116 timeout_dur: Duration,
117 ) -> Result<()> {
118 match timeout(
119 timeout_dur,
120 tokio::process::Command::new("sh")
121 .arg("-c")
122 .arg(command)
123 .output(),
124 )
125 .await
126 {
127 Ok(Ok(output)) => {
128 if output.status.success() {
129 Ok(())
130 } else {
131 Err(AgentError::HealthCheckFailed {
132 id: id.to_string(),
133 reason: format!(
134 "command failed with code {}: {}",
135 output.status.code().unwrap_or(-1),
136 String::from_utf8_lossy(&output.stderr)
137 ),
138 })
139 }
140 }
141 Ok(Err(e)) => Err(AgentError::HealthCheckFailed {
142 id: id.to_string(),
143 reason: format!("command execution failed: {e}"),
144 }),
145 Err(_) => Err(AgentError::Timeout {
146 timeout: timeout_dur,
147 }),
148 }
149 }
150}
151
152const MAX_BACKOFF: Duration = Duration::from_secs(60);
154
155pub struct HealthMonitor {
157 id: ContainerId,
158 checker: HealthChecker,
159 interval: Duration,
160 retries: u32,
161 check_timeout: Duration,
162 start_grace: Duration,
163 state: tokio::sync::RwLock<HealthState>,
164 on_health_change: Option<HealthCallback>,
165}
166
167#[derive(Debug, Clone, PartialEq, Eq)]
168pub enum HealthState {
169 Unknown,
170 Checking,
171 Healthy,
172 Unhealthy { failures: u32, reason: String },
173}
174
175impl HealthMonitor {
176 #[must_use]
177 pub fn new(id: ContainerId, checker: HealthChecker, interval: Duration, retries: u32) -> Self {
178 Self {
179 id,
180 checker,
181 interval,
182 retries,
183 check_timeout: Duration::from_secs(5),
184 start_grace: Duration::ZERO,
185 state: tokio::sync::RwLock::new(HealthState::Unknown),
186 on_health_change: None,
187 }
188 }
189
190 #[must_use]
192 pub fn with_callback(mut self, callback: HealthCallback) -> Self {
193 self.on_health_change = Some(callback);
194 self
195 }
196
197 #[must_use]
201 pub fn with_start_grace(mut self, grace: Duration) -> Self {
202 self.start_grace = grace;
203 self
204 }
205
206 #[must_use]
208 pub fn with_check_timeout(mut self, timeout: Duration) -> Self {
209 self.check_timeout = timeout;
210 self
211 }
212
213 pub fn start(self) -> tokio::task::JoinHandle<()> {
215 tokio::spawn(async move {
216 if !self.start_grace.is_zero() {
218 tokio::time::sleep(self.start_grace).await;
219 }
220
221 let mut failures = 0u32;
222 let mut was_healthy: Option<bool> = None;
223 let mut current_interval = self.interval;
224
225 loop {
226 *self.state.write().await = HealthState::Checking;
228
229 match self.checker.check(&self.id, self.check_timeout).await {
230 Ok(()) => {
231 failures = 0;
232 current_interval = self.interval;
233 *self.state.write().await = HealthState::Healthy;
234
235 if was_healthy != Some(true) {
237 if let Some(ref callback) = self.on_health_change {
238 callback(self.id.clone(), true);
239 }
240 was_healthy = Some(true);
241 }
242 }
243 Err(e) => {
244 failures += 1;
245
246 *self.state.write().await = HealthState::Unhealthy {
247 failures,
248 reason: e.to_string(),
249 };
250
251 if was_healthy != Some(false) {
253 if let Some(ref callback) = self.on_health_change {
254 callback(self.id.clone(), false);
255 }
256 was_healthy = Some(false);
257 }
258
259 if failures >= self.retries {
262 current_interval = (current_interval * 2).min(MAX_BACKOFF);
263 }
264 }
265 }
266
267 tokio::time::sleep(current_interval).await;
268 }
269 })
270 }
271
272 pub async fn state(&self) -> HealthState {
274 self.state.read().await.clone()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_health_state() {
284 let state = HealthState::Unhealthy {
285 failures: 3,
286 reason: "connection refused".to_string(),
287 };
288 assert_eq!(
289 state,
290 HealthState::Unhealthy {
291 failures: 3,
292 reason: "connection refused".to_string()
293 }
294 );
295 }
296}