1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tokio::task::{JoinHandle, JoinSet};
9use tokio_util::sync::CancellationToken;
10use uuid::Uuid;
11
12use crate::storage::ProxyStoragePort;
13use crate::types::ProxyConfig;
14
15pub type HealthMap = Arc<RwLock<HashMap<Uuid, bool>>>;
18
19#[derive(Clone)]
24pub struct HealthChecker {
25 config: ProxyConfig,
26 storage: Arc<dyn ProxyStoragePort>,
27 health_map: HealthMap,
28}
29
30impl HealthChecker {
31 pub fn health_map(&self) -> &HealthMap {
33 &self.health_map
34 }
35
36 pub fn new(
41 config: ProxyConfig,
42 storage: Arc<dyn ProxyStoragePort>,
43 health_map: HealthMap,
44 ) -> Self {
45 Self {
46 config,
47 storage,
48 health_map,
49 }
50 }
51
52 pub fn spawn(self, token: CancellationToken) -> JoinHandle<()> {
57 tokio::spawn(async move {
58 let mut interval = tokio::time::interval(self.config.health_check_interval);
59 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
60 loop {
61 tokio::select! {
62 _ = token.cancelled() => {
63 tracing::info!("health checker: shutdown requested");
64 break;
65 }
66 _ = interval.tick() => {
67 self.check_all().await;
68 }
69 }
70 }
71 tracing::info!("health checker: stopped");
72 })
73 }
74
75 pub async fn check_once(&self) {
77 self.check_all().await;
78 }
79
80 async fn check_all(&self) {
81 let records = match self.storage.list().await {
82 Ok(r) => r,
83 Err(e) => {
84 tracing::error!("health checker: storage list failed: {e}");
85 return;
86 }
87 };
88
89 let health_url = self.config.health_check_url.clone();
90 let timeout = self.config.health_check_timeout;
91
92 let mut set: JoinSet<(Uuid, Result<u64, String>)> = JoinSet::new();
93 for record in records {
94 let proxy_url = record.proxy.url.clone();
95 let username = record.proxy.username.clone();
96 let password = record.proxy.password.clone();
97 let id = record.id;
98 let check_url = health_url.clone();
99 set.spawn(async move {
100 let result = do_check(
101 &proxy_url,
102 username.as_deref(),
103 password.as_deref(),
104 &check_url,
105 timeout,
106 )
107 .await;
108 (id, result)
109 });
110 }
111
112 let mut updates: Vec<(Uuid, bool, u64)> = Vec::new();
113 while let Some(task_result) = set.join_next().await {
114 match task_result {
115 Ok((id, Ok(latency_ms))) => updates.push((id, true, latency_ms)),
116 Ok((id, Err(e))) => {
117 tracing::warn!(proxy = %id, error = %e, "health check failed");
118 updates.push((id, false, 0));
119 }
120 Err(join_err) => {
121 tracing::error!("health check task panicked: {join_err}");
122 }
123 }
124 }
125
126 let total = updates.len() as u32;
127 let healthy_count = updates.iter().filter(|(_, h, _)| *h).count() as u32;
128
129 {
130 let mut map = self.health_map.write().await;
131 for (id, healthy, _) in &updates {
132 map.insert(*id, *healthy);
133 }
134 }
135
136 for (id, success, latency) in updates {
137 if let Err(e) = self.storage.update_metrics(id, success, latency).await {
138 tracing::warn!("health checker: metrics update failed for {id}: {e}");
139 }
140 }
141
142 tracing::info!(
143 total,
144 healthy = healthy_count,
145 unhealthy = total - healthy_count,
146 "health check cycle complete"
147 );
148 }
149}
150
151async fn do_check(
154 proxy_url: &str,
155 username: Option<&str>,
156 password: Option<&str>,
157 health_url: &str,
158 timeout: std::time::Duration,
159) -> Result<u64, String> {
160 let mut proxy = reqwest::Proxy::all(proxy_url).map_err(|e| e.to_string())?;
161 if let (Some(user), Some(pass)) = (username, password) {
162 proxy = proxy.basic_auth(user, pass);
163 }
164 let client = reqwest::Client::builder()
165 .proxy(proxy)
166 .timeout(timeout)
167 .build()
168 .map_err(|e| e.to_string())?;
169
170 let start = Instant::now();
171 client
172 .get(health_url)
173 .send()
174 .await
175 .map_err(|e| e.to_string())?
176 .error_for_status()
177 .map_err(|e| e.to_string())?;
178 Ok(start.elapsed().as_millis() as u64)
179}
180
181#[cfg(test)]
186mod tests {
187 use std::time::Duration;
188
189 use wiremock::matchers::method;
190 use wiremock::{Mock, MockServer, ResponseTemplate};
191
192 use super::*;
193 use crate::storage::MemoryProxyStore;
194 use crate::types::{Proxy, ProxyType};
195
196 fn make_proxy(url: &str) -> Proxy {
197 Proxy {
198 url: url.into(),
199 proxy_type: ProxyType::Http,
200 username: None,
201 password: None,
202 weight: 1,
203 tags: vec![],
204 }
205 }
206
207 #[tokio::test]
208 async fn healthy_and_unhealthy_proxies() {
209 let server = MockServer::start().await;
212 Mock::given(method("GET"))
213 .respond_with(ResponseTemplate::new(200))
214 .mount(&server)
215 .await;
216
217 let storage = Arc::new(MemoryProxyStore::default());
218 storage.add(make_proxy(&server.uri())).await.unwrap();
220 storage
222 .add(make_proxy("http://192.0.2.1:9999"))
223 .await
224 .unwrap();
225
226 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
227 let config = ProxyConfig {
228 health_check_url: format!("{}/", server.uri()),
229 health_check_interval: Duration::from_secs(3600),
230 health_check_timeout: Duration::from_secs(2),
231 ..ProxyConfig::default()
232 };
233 let checker = HealthChecker::new(config, storage.clone(), health_map.clone());
234 checker.check_once().await;
235
236 let map = health_map.read().await;
237 let healthy = map.values().filter(|&&v| v).count();
238 let unhealthy = map.values().filter(|&&v| !v).count();
239 assert_eq!(healthy, 1, "expected 1 healthy proxy");
240 assert_eq!(unhealthy, 1, "expected 1 unhealthy proxy");
241 }
242
243 #[tokio::test]
244 async fn graceful_shutdown() {
245 let storage = Arc::new(MemoryProxyStore::default());
246 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
247 let config = ProxyConfig {
248 health_check_interval: Duration::from_secs(3600),
249 ..ProxyConfig::default()
250 };
251 let token = CancellationToken::new();
252 let checker = HealthChecker::new(config, storage, health_map);
253 let handle = checker.spawn(token.clone());
254
255 token.cancel();
256 let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
257 assert!(
258 result.is_ok(),
259 "task should exit within 1s after cancellation"
260 );
261 }
262}