1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tokio::task::{JoinHandle, JoinSet};
9use tokio_util::sync::CancellationToken;
10use uuid::Uuid;
11
12use crate::storage::ProxyStoragePort;
13use crate::types::ProxyConfig;
14
15pub type HealthMap = Arc<RwLock<HashMap<Uuid, bool>>>;
18
19#[derive(Clone)]
29pub struct HealthChecker {
30 config: ProxyConfig,
31 storage: Arc<dyn ProxyStoragePort>,
32 health_map: HealthMap,
33 #[cfg(feature = "tls-profiled")]
36 profiled: Option<crate::http_client::ProfiledRequester>,
37}
38
39impl HealthChecker {
40 pub const fn health_map(&self) -> &HealthMap {
42 &self.health_map
43 }
44
45 pub fn new(
50 config: ProxyConfig,
51 storage: Arc<dyn ProxyStoragePort>,
52 health_map: HealthMap,
53 ) -> Self {
54 Self {
55 config,
56 storage,
57 health_map,
58 #[cfg(feature = "tls-profiled")]
59 profiled: None,
60 }
61 }
62
63 #[cfg(feature = "tls-profiled")]
89 #[must_use]
90 pub fn with_profiled_client(
91 mut self,
92 requester: crate::http_client::ProfiledRequester,
93 ) -> Self {
94 self.profiled = Some(requester);
95 self
96 }
97
98 #[cfg(feature = "tls-profiled")]
110 pub fn with_profiled_mode(
111 self,
112 mode: crate::types::ProfiledRequestMode,
113 ) -> crate::error::ProxyResult<Self> {
114 let requester = crate::http_client::ProfiledRequester::chrome_mode(mode)
115 .map_err(|e| crate::error::ProxyError::ConfigError(e.to_string()))?;
116 Ok(self.with_profiled_client(requester))
117 }
118
119 pub fn spawn(self, token: CancellationToken) -> JoinHandle<()> {
124 tokio::spawn(async move {
125 let mut interval = tokio::time::interval(self.config.health_check_interval);
126 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
127 loop {
128 tokio::select! {
129 () = token.cancelled() => {
130 tracing::info!("health checker: shutdown requested");
131 break;
132 }
133 _ = interval.tick() => {
134 self.check_all().await;
135 }
136 }
137 }
138 tracing::info!("health checker: stopped");
139 })
140 }
141
142 pub async fn check_once(&self) {
144 self.check_all().await;
145 }
146
147 async fn check_all(&self) {
148 let records = match self.storage.list().await {
149 Ok(r) => r,
150 Err(e) => {
151 tracing::error!("health checker: storage list failed: {e}");
152 return;
153 }
154 };
155
156 let health_url = self.config.health_check_url.clone();
157 let timeout = self.config.health_check_timeout;
158
159 let mut set: JoinSet<(Uuid, Result<u64, String>)> = JoinSet::new();
160 for record in records {
161 let proxy_url = record.proxy.url.clone();
162 let username = record.proxy.username.clone();
163 let password = record.proxy.password.clone();
164 let id = record.id;
165 let check_url = health_url.clone();
166
167 #[cfg(feature = "tls-profiled")]
171 let routed_proxy_url =
172 proxy_url_with_auth(&proxy_url, username.as_deref(), password.as_deref());
173
174 #[cfg(feature = "tls-profiled")]
175 let preset_client: Option<reqwest::Client> = self.profiled.as_ref().and_then(|p| {
176 crate::http_client::ProfiledRequester::from_profile(
177 p.profile(),
178 Some(&routed_proxy_url),
179 )
180 .map(crate::http_client::ProfiledRequester::into_client)
181 .map_err(|e| {
182 tracing::warn!(
183 error = %e,
184 proxy = %routed_proxy_url,
185 "tls-profiled health-check client build failed; falling back to vanilla"
186 );
187 })
188 .ok()
189 });
190
191 #[cfg(not(feature = "tls-profiled"))]
192 let preset_client: Option<reqwest::Client> = None;
193
194 set.spawn(async move {
195 let result = do_check(
196 &proxy_url,
197 username.as_deref(),
198 password.as_deref(),
199 &check_url,
200 timeout,
201 preset_client,
202 )
203 .await;
204 (id, result)
205 });
206 }
207
208 let mut updates: Vec<(Uuid, bool, u64)> = Vec::new();
209 while let Some(task_result) = set.join_next().await {
210 match task_result {
211 Ok((id, Ok(latency_ms))) => updates.push((id, true, latency_ms)),
212 Ok((id, Err(e))) => {
213 tracing::warn!(proxy = %id, error = %e, "health check failed");
214 updates.push((id, false, 0));
215 }
216 Err(join_err) => {
217 tracing::error!("health check task panicked: {join_err}");
218 }
219 }
220 }
221
222 let total = u32::try_from(updates.len()).unwrap_or(u32::MAX);
223 let healthy_count =
224 u32::try_from(updates.iter().filter(|(_, h, _)| *h).count()).unwrap_or(u32::MAX);
225
226 {
227 let mut map = self.health_map.write().await;
228 for (id, healthy, _) in &updates {
229 map.insert(*id, *healthy);
230 }
231 }
232
233 for (id, success, latency) in updates {
234 if let Err(e) = self.storage.update_metrics(id, success, latency).await {
235 tracing::warn!("health checker: metrics update failed for {id}: {e}");
236 }
237 }
238
239 tracing::info!(
240 total,
241 healthy = healthy_count,
242 unhealthy = total - healthy_count,
243 "health check cycle complete"
244 );
245 }
246}
247
248#[cfg(feature = "tls-profiled")]
249fn proxy_url_with_auth(proxy_url: &str, username: Option<&str>, password: Option<&str>) -> String {
250 let (Some(user), Some(pass)) = (username, password) else {
251 return proxy_url.to_string();
252 };
253
254 let Ok(mut url) = reqwest::Url::parse(proxy_url) else {
255 return proxy_url.to_string();
256 };
257
258 if url.username().is_empty() && url.set_username(user).is_err() {
259 return proxy_url.to_string();
260 }
261 if url.password().is_none() && url.set_password(Some(pass)).is_err() {
262 return proxy_url.to_string();
263 }
264
265 url.to_string()
266}
267
268async fn do_check(
269 proxy_url: &str,
270 username: Option<&str>,
271 password: Option<&str>,
272 health_url: &str,
273 timeout: std::time::Duration,
274 preset_client: Option<reqwest::Client>,
275) -> Result<u64, String> {
276 let client = if let Some(c) = preset_client {
280 c
281 } else {
282 let mut proxy = reqwest::Proxy::all(proxy_url).map_err(|e| e.to_string())?;
283 if let (Some(user), Some(pass)) = (username, password) {
284 proxy = proxy.basic_auth(user, pass);
285 }
286 reqwest::Client::builder()
287 .proxy(proxy)
288 .timeout(timeout)
289 .build()
290 .map_err(|e| e.to_string())?
291 };
292
293 let start = Instant::now();
294 client
295 .get(health_url)
296 .timeout(timeout)
297 .send()
298 .await
299 .map_err(|e| e.to_string())?
300 .error_for_status()
301 .map_err(|e| e.to_string())?;
302 Ok(start.elapsed().as_millis().try_into().unwrap_or(u64::MAX))
303}
304
305#[cfg(test)]
310mod tests {
311 use std::time::Duration;
312
313 use wiremock::matchers::method;
314 use wiremock::{Mock, MockServer, ResponseTemplate};
315
316 use super::*;
317 use crate::storage::MemoryProxyStore;
318 use crate::types::{Proxy, ProxyType};
319
320 fn make_proxy(url: &str) -> Proxy {
321 Proxy {
322 url: url.into(),
323 proxy_type: ProxyType::Http,
324 username: None,
325 password: None,
326 weight: 1,
327 tags: vec![],
328 }
329 }
330
331 #[cfg(feature = "tls-profiled")]
332 #[test]
333 fn proxy_url_with_auth_injects_credentials() {
334 let proxy_url = proxy_url_with_auth(
335 "http://proxy.example.com:8080",
336 Some("alice"),
337 Some("s3cr3t"),
338 );
339 assert!(proxy_url.starts_with("http://alice:s3cr3t@proxy.example.com:8080"));
340 }
341
342 #[cfg(feature = "tls-profiled")]
343 #[test]
344 fn proxy_url_with_auth_leaves_existing_credentials_untouched() {
345 let proxy_url = proxy_url_with_auth(
346 "http://already:present@proxy.example.com:8080",
347 Some("alice"),
348 Some("s3cr3t"),
349 );
350 assert!(proxy_url.starts_with("http://already:present@proxy.example.com:8080"));
351 }
352
353 #[tokio::test]
354 async fn healthy_and_unhealthy_proxies() -> crate::error::ProxyResult<()> {
355 let server = MockServer::start().await;
358 Mock::given(method("GET"))
359 .respond_with(ResponseTemplate::new(200))
360 .mount(&server)
361 .await;
362
363 let storage = Arc::new(MemoryProxyStore::default());
364 storage.add(make_proxy(&server.uri())).await?;
366 storage.add(make_proxy("http://192.0.2.1:9999")).await?;
368
369 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
370 let config = ProxyConfig {
371 health_check_url: format!("{}/", server.uri()),
372 health_check_interval: Duration::from_secs(3600),
373 health_check_timeout: Duration::from_secs(2),
374 ..ProxyConfig::default()
375 };
376 let checker = HealthChecker::new(config, storage.clone(), health_map.clone());
377 checker.check_once().await;
378
379 let map = health_map.read().await;
380 let healthy = map.values().filter(|&&v| v).count();
381 let unhealthy = map.values().filter(|&&v| !v).count();
382 drop(map);
383 assert_eq!(healthy, 1, "expected 1 healthy proxy");
384 assert_eq!(unhealthy, 1, "expected 1 unhealthy proxy");
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn graceful_shutdown() {
390 let storage = Arc::new(MemoryProxyStore::default());
391 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
392 let config = ProxyConfig {
393 health_check_interval: Duration::from_secs(3600),
394 ..ProxyConfig::default()
395 };
396 let token = CancellationToken::new();
397 let checker = HealthChecker::new(config, storage, health_map);
398 let handle = checker.spawn(token.clone());
399
400 token.cancel();
401 let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
402 assert!(
403 result.is_ok(),
404 "task should exit within 1s after cancellation"
405 );
406 }
407}