Skip to main content

stygian_proxy/
health.rs

1//! Async background health checker for proxy liveness verification.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tokio::task::{JoinHandle, JoinSet};
9use tokio_util::sync::CancellationToken;
10use uuid::Uuid;
11
12use crate::storage::ProxyStoragePort;
13use crate::types::ProxyConfig;
14
15/// Shared health map type.
16/// `true` = proxy is currently considered healthy.
17pub type HealthMap = Arc<RwLock<HashMap<Uuid, bool>>>;
18
19/// Continuously verifies proxy liveness and updates the shared [`HealthMap`].
20///
21/// Run one check cycle with [`check_once`](HealthChecker::check_once) or launch
22/// a background task with [`spawn`](HealthChecker::spawn).
23///
24/// When the `tls-profiled` feature is enabled you can supply a
25/// [`ProfiledRequester`](crate::http_client::ProfiledRequester) via
26/// [`HealthChecker::with_profiled_client`] so health-check GET requests carry a
27/// browser TLS fingerprint.
28#[derive(Clone)]
29pub struct HealthChecker {
30    config: ProxyConfig,
31    storage: Arc<dyn ProxyStoragePort>,
32    health_map: HealthMap,
33    /// Optional TLS-profiled HTTP client.  When `None` a vanilla
34    /// `reqwest::Client` is built per check cycle.
35    #[cfg(feature = "tls-profiled")]
36    profiled: Option<crate::http_client::ProfiledRequester>,
37}
38
39impl HealthChecker {
40    /// Access the shared health map (read it to filter candidates).
41    pub const fn health_map(&self) -> &HealthMap {
42        &self.health_map
43    }
44
45    /// Create a new checker.
46    ///
47    /// `health_map` should be the **same** `Arc` held by the `ProxyManager` so
48    /// that selection decisions always see up-to-date health information.
49    pub fn new(
50        config: ProxyConfig,
51        storage: Arc<dyn ProxyStoragePort>,
52        health_map: HealthMap,
53    ) -> Self {
54        Self {
55            config,
56            storage,
57            health_map,
58            #[cfg(feature = "tls-profiled")]
59            profiled: None,
60        }
61    }
62
63    /// Attach a TLS-profiled client so that health-check requests carry a
64    /// browser fingerprint instead of a default `reqwest` TLS handshake.
65    ///
66    /// Only available when the `tls-profiled` feature is enabled.
67    ///
68    /// # Example
69    ///
70    /// ```no_run
71    /// use std::sync::Arc;
72    /// use stygian_proxy::{
73    ///     HealthChecker,
74    ///     ProxyConfig,
75    ///     http_client::{ProfiledRequestMode, ProfiledRequester},
76    /// };
77    /// use stygian_proxy::storage::MemoryProxyStore;
78    ///
79    /// # fn run() -> Result<(), Box<dyn std::error::Error>> {
80    /// let storage = Arc::new(MemoryProxyStore::default());
81    /// let health_map = stygian_proxy::health::HealthMap::default();
82    /// let requester = ProfiledRequester::chrome_mode(ProfiledRequestMode::Preset)?;
83    /// let checker = HealthChecker::new(ProxyConfig::default(), storage, health_map)
84    ///     .with_profiled_client(requester);
85    /// # Ok(())
86    /// # }
87    /// ```
88    #[cfg(feature = "tls-profiled")]
89    #[must_use]
90    pub fn with_profiled_client(
91        mut self,
92        requester: crate::http_client::ProfiledRequester,
93    ) -> Self {
94        self.profiled = Some(requester);
95        self
96    }
97
98    /// Build and attach a profile-mode-based requester.
99    ///
100    /// Uses Chrome 131 as the baseline browser identity and applies `mode`
101    /// to TLS control mapping.
102    ///
103    /// Only available when the `tls-profiled` feature is enabled.
104    ///
105    /// # Errors
106    ///
107    /// Returns [`crate::error::ProxyError::ConfigError`] if the profiled
108    /// requester cannot be constructed.
109    #[cfg(feature = "tls-profiled")]
110    pub fn with_profiled_mode(
111        self,
112        mode: crate::types::ProfiledRequestMode,
113    ) -> crate::error::ProxyResult<Self> {
114        let requester = crate::http_client::ProfiledRequester::chrome_mode(mode)
115            .map_err(|e| crate::error::ProxyError::ConfigError(e.to_string()))?;
116        Ok(self.with_profiled_client(requester))
117    }
118
119    /// Spawn an infinite background task that checks proxies on every
120    /// `config.health_check_interval` tick.
121    ///
122    /// Cancel `token` to stop the task gracefully.  Missed ticks are skipped.
123    pub fn spawn(self, token: CancellationToken) -> JoinHandle<()> {
124        tokio::spawn(async move {
125            let mut interval = tokio::time::interval(self.config.health_check_interval);
126            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
127            loop {
128                tokio::select! {
129                    () = token.cancelled() => {
130                        tracing::info!("health checker: shutdown requested");
131                        break;
132                    }
133                    _ = interval.tick() => {
134                        self.check_all().await;
135                    }
136                }
137            }
138            tracing::info!("health checker: stopped");
139        })
140    }
141
142    /// Run one full check cycle synchronously (useful for tests).
143    pub async fn check_once(&self) {
144        self.check_all().await;
145    }
146
147    async fn check_all(&self) {
148        let records = match self.storage.list().await {
149            Ok(r) => r,
150            Err(e) => {
151                tracing::error!("health checker: storage list failed: {e}");
152                return;
153            }
154        };
155
156        let health_url = self.config.health_check_url.clone();
157        let timeout = self.config.health_check_timeout;
158
159        let mut set: JoinSet<(Uuid, Result<u64, String>)> = JoinSet::new();
160        for record in records {
161            let proxy_url = record.proxy.url.clone();
162            let username = record.proxy.username.clone();
163            let password = record.proxy.password.clone();
164            let id = record.id;
165            let check_url = health_url.clone();
166
167            // When the `tls-profiled` feature is enabled, build a fresh profiled
168            // client per proxy that routes through that proxy's URL so health
169            // checks present a browser TLS fingerprint for proxy-routed requests.
170            #[cfg(feature = "tls-profiled")]
171            let routed_proxy_url =
172                proxy_url_with_auth(&proxy_url, username.as_deref(), password.as_deref());
173
174            #[cfg(feature = "tls-profiled")]
175            let preset_client: Option<reqwest::Client> = self.profiled.as_ref().and_then(|p| {
176                crate::http_client::ProfiledRequester::from_profile(
177                    p.profile(),
178                    Some(&routed_proxy_url),
179                )
180                .map(crate::http_client::ProfiledRequester::into_client)
181                .map_err(|e| {
182                    tracing::warn!(
183                        error = %e,
184                        proxy = %routed_proxy_url,
185                        "tls-profiled health-check client build failed; falling back to vanilla"
186                    );
187                })
188                .ok()
189            });
190
191            #[cfg(not(feature = "tls-profiled"))]
192            let preset_client: Option<reqwest::Client> = None;
193
194            set.spawn(async move {
195                let result = do_check(
196                    &proxy_url,
197                    username.as_deref(),
198                    password.as_deref(),
199                    &check_url,
200                    timeout,
201                    preset_client,
202                )
203                .await;
204                (id, result)
205            });
206        }
207
208        let mut updates: Vec<(Uuid, bool, u64)> = Vec::new();
209        while let Some(task_result) = set.join_next().await {
210            match task_result {
211                Ok((id, Ok(latency_ms))) => updates.push((id, true, latency_ms)),
212                Ok((id, Err(e))) => {
213                    tracing::warn!(proxy = %id, error = %e, "health check failed");
214                    updates.push((id, false, 0));
215                }
216                Err(join_err) => {
217                    tracing::error!("health check task panicked: {join_err}");
218                }
219            }
220        }
221
222        let total = u32::try_from(updates.len()).unwrap_or(u32::MAX);
223        let healthy_count =
224            u32::try_from(updates.iter().filter(|(_, h, _)| *h).count()).unwrap_or(u32::MAX);
225
226        {
227            let mut map = self.health_map.write().await;
228            for (id, healthy, _) in &updates {
229                map.insert(*id, *healthy);
230            }
231        }
232
233        for (id, success, latency) in updates {
234            if let Err(e) = self.storage.update_metrics(id, success, latency).await {
235                tracing::warn!("health checker: metrics update failed for {id}: {e}");
236            }
237        }
238
239        tracing::info!(
240            total,
241            healthy = healthy_count,
242            unhealthy = total - healthy_count,
243            "health check cycle complete"
244        );
245    }
246}
247
248#[cfg(feature = "tls-profiled")]
249fn proxy_url_with_auth(proxy_url: &str, username: Option<&str>, password: Option<&str>) -> String {
250    let (Some(user), Some(pass)) = (username, password) else {
251        return proxy_url.to_string();
252    };
253
254    let Ok(mut url) = reqwest::Url::parse(proxy_url) else {
255        return proxy_url.to_string();
256    };
257
258    if url.username().is_empty() && url.set_username(user).is_err() {
259        return proxy_url.to_string();
260    }
261    if url.password().is_none() && url.set_password(Some(pass)).is_err() {
262        return proxy_url.to_string();
263    }
264
265    url.to_string()
266}
267
268async fn do_check(
269    proxy_url: &str,
270    username: Option<&str>,
271    password: Option<&str>,
272    health_url: &str,
273    timeout: std::time::Duration,
274    preset_client: Option<reqwest::Client>,
275) -> Result<u64, String> {
276    // Use the pre-built profiled client (already includes proxy routing) when
277    // available; otherwise build a vanilla client with per-proxy routing and
278    // optional basic-auth credentials.
279    let client = if let Some(c) = preset_client {
280        c
281    } else {
282        let mut proxy = reqwest::Proxy::all(proxy_url).map_err(|e| e.to_string())?;
283        if let (Some(user), Some(pass)) = (username, password) {
284            proxy = proxy.basic_auth(user, pass);
285        }
286        reqwest::Client::builder()
287            .proxy(proxy)
288            .timeout(timeout)
289            .build()
290            .map_err(|e| e.to_string())?
291    };
292
293    let start = Instant::now();
294    client
295        .get(health_url)
296        .timeout(timeout)
297        .send()
298        .await
299        .map_err(|e| e.to_string())?
300        .error_for_status()
301        .map_err(|e| e.to_string())?;
302    Ok(start.elapsed().as_millis().try_into().unwrap_or(u64::MAX))
303}
304
305// ─────────────────────────────────────────────────────────────────────────────
306// Tests
307// ─────────────────────────────────────────────────────────────────────────────
308
309#[cfg(test)]
310mod tests {
311    use std::time::Duration;
312
313    use wiremock::matchers::method;
314    use wiremock::{Mock, MockServer, ResponseTemplate};
315
316    use super::*;
317    use crate::storage::MemoryProxyStore;
318    use crate::types::{Proxy, ProxyType};
319
320    fn make_proxy(url: &str) -> Proxy {
321        Proxy {
322            url: url.into(),
323            proxy_type: ProxyType::Http,
324            username: None,
325            password: None,
326            weight: 1,
327            tags: vec![],
328        }
329    }
330
331    #[cfg(feature = "tls-profiled")]
332    #[test]
333    fn proxy_url_with_auth_injects_credentials() {
334        let proxy_url = proxy_url_with_auth(
335            "http://proxy.example.com:8080",
336            Some("alice"),
337            Some("s3cr3t"),
338        );
339        assert!(proxy_url.starts_with("http://alice:s3cr3t@proxy.example.com:8080"));
340    }
341
342    #[cfg(feature = "tls-profiled")]
343    #[test]
344    fn proxy_url_with_auth_leaves_existing_credentials_untouched() {
345        let proxy_url = proxy_url_with_auth(
346            "http://already:present@proxy.example.com:8080",
347            Some("alice"),
348            Some("s3cr3t"),
349        );
350        assert!(proxy_url.starts_with("http://already:present@proxy.example.com:8080"));
351    }
352
353    #[tokio::test]
354    async fn healthy_and_unhealthy_proxies() -> crate::error::ProxyResult<()> {
355        // Mock server acts as both the HTTP proxy and the health-check target.
356        // reqwest sends the GET in absolute-form; wiremock responds 200.
357        let server = MockServer::start().await;
358        Mock::given(method("GET"))
359            .respond_with(ResponseTemplate::new(200))
360            .mount(&server)
361            .await;
362
363        let storage = Arc::new(MemoryProxyStore::default());
364        // Proxy 1: URL points to the mock server → health check will succeed.
365        storage.add(make_proxy(&server.uri())).await?;
366        // Proxy 2: invalid address → health check will fail.
367        storage.add(make_proxy("http://192.0.2.1:9999")).await?;
368
369        let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
370        let config = ProxyConfig {
371            health_check_url: format!("{}/", server.uri()),
372            health_check_interval: Duration::from_secs(3600),
373            health_check_timeout: Duration::from_secs(2),
374            ..ProxyConfig::default()
375        };
376        let checker = HealthChecker::new(config, storage.clone(), health_map.clone());
377        checker.check_once().await;
378
379        let map = health_map.read().await;
380        let healthy = map.values().filter(|&&v| v).count();
381        let unhealthy = map.values().filter(|&&v| !v).count();
382        drop(map);
383        assert_eq!(healthy, 1, "expected 1 healthy proxy");
384        assert_eq!(unhealthy, 1, "expected 1 unhealthy proxy");
385        Ok(())
386    }
387
388    #[tokio::test]
389    async fn graceful_shutdown() {
390        let storage = Arc::new(MemoryProxyStore::default());
391        let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
392        let config = ProxyConfig {
393            health_check_interval: Duration::from_secs(3600),
394            ..ProxyConfig::default()
395        };
396        let token = CancellationToken::new();
397        let checker = HealthChecker::new(config, storage, health_map);
398        let handle = checker.spawn(token.clone());
399
400        token.cancel();
401        let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
402        assert!(
403            result.is_ok(),
404            "task should exit within 1s after cancellation"
405        );
406    }
407}