Skip to main content

gatel_core/proxy/
upstream.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3
4use hyper_util::client::legacy::Client;
5use hyper_util::client::legacy::connect::HttpConnector;
6use hyper_util::rt::TokioExecutor;
7use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
8use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
9
10use crate::Body;
11use crate::config::ProxyConfig;
12
13/// A no-op TLS certificate verifier that accepts any certificate.
14/// Used when `tls_skip_verify` is enabled.
15#[derive(Debug)]
16struct NoVerifier;
17
18impl ServerCertVerifier for NoVerifier {
19    fn verify_server_cert(
20        &self,
21        _end_entity: &CertificateDer<'_>,
22        _intermediates: &[CertificateDer<'_>],
23        _server_name: &ServerName<'_>,
24        _ocsp_response: &[u8],
25        _now: UnixTime,
26    ) -> Result<ServerCertVerified, rustls::Error> {
27        Ok(ServerCertVerified::assertion())
28    }
29
30    fn verify_tls12_signature(
31        &self,
32        _message: &[u8],
33        _cert: &CertificateDer<'_>,
34        _dss: &rustls::DigitallySignedStruct,
35    ) -> Result<HandshakeSignatureValid, rustls::Error> {
36        Ok(HandshakeSignatureValid::assertion())
37    }
38
39    fn verify_tls13_signature(
40        &self,
41        _message: &[u8],
42        _cert: &CertificateDer<'_>,
43        _dss: &rustls::DigitallySignedStruct,
44    ) -> Result<HandshakeSignatureValid, rustls::Error> {
45        Ok(HandshakeSignatureValid::assertion())
46    }
47
48    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
49        rustls::crypto::ring::default_provider()
50            .signature_verification_algorithms
51            .supported_schemes()
52    }
53}
54
55/// A single backend server.
56#[derive(Debug, Clone)]
57pub struct Backend {
58    pub addr: String,
59    pub weight: u32,
60}
61
62/// Pool of upstream backends with a shared HTTP client, health status,
63/// and active-connection counters.
64pub struct UpstreamPool {
65    pub backends: Vec<Backend>,
66    pub client: Client<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
67    /// Per-backend health flag. `true` = healthy, `false` = unhealthy.
68    pub healthy: Vec<AtomicBool>,
69    /// Per-backend active connection count (used by LeastConn).
70    pub active_conns: Vec<AtomicUsize>,
71    /// Optional total connection limit across all backends.
72    pub max_connections: Option<usize>,
73}
74
75impl UpstreamPool {
76    pub fn from_config(config: &ProxyConfig) -> Self {
77        let backends: Vec<Backend> = config
78            .upstreams
79            .iter()
80            .map(|u| Backend {
81                addr: u.addr.clone(),
82                weight: u.weight,
83            })
84            .collect();
85
86        let n = backends.len();
87        let healthy: Vec<AtomicBool> = (0..n).map(|_| AtomicBool::new(true)).collect();
88        let active_conns: Vec<AtomicUsize> = (0..n).map(|_| AtomicUsize::new(0)).collect();
89
90        // Build the HTTPS connector — handles both HTTP and HTTPS upstreams.
91        let connector = if config.tls_skip_verify {
92            // Build a rustls ClientConfig that skips certificate verification.
93            let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
94                rustls::crypto::ring::default_provider(),
95            ))
96            .with_safe_default_protocol_versions()
97            .expect("default protocol versions are valid")
98            .dangerous()
99            .with_custom_certificate_verifier(Arc::new(NoVerifier))
100            .with_no_client_auth();
101
102            hyper_rustls::HttpsConnectorBuilder::new()
103                .with_tls_config(tls_config)
104                .https_or_http()
105                .enable_http1()
106                .enable_http2()
107                .build()
108        } else {
109            // Build a standard rustls ClientConfig with an empty root store.
110            // Upstreams on HTTPS should use certificates trusted by the system,
111            // but since we don't have native-roots or webpki-roots features enabled
112            // here, we use an empty store — suitable for internal/self-signed CAs
113            // when skip-verify is not set. For production use with public CAs,
114            // enable the webpki-roots or native-roots feature on hyper-rustls.
115            let root_store = rustls::RootCertStore::empty();
116            let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
117                rustls::crypto::ring::default_provider(),
118            ))
119            .with_safe_default_protocol_versions()
120            .expect("default protocol versions are valid")
121            .with_root_certificates(root_store)
122            .with_no_client_auth();
123
124            hyper_rustls::HttpsConnectorBuilder::new()
125                .with_tls_config(tls_config)
126                .https_or_http()
127                .enable_http1()
128                .enable_http2()
129                .build()
130        };
131
132        let mut builder = Client::builder(TokioExecutor::new());
133
134        if config.upstream_http2 {
135            builder.http2_only(true);
136        }
137
138        if let Some(timeout) = config.keepalive_timeout {
139            builder.pool_idle_timeout(timeout);
140        }
141
142        let client = builder.build(connector);
143
144        Self {
145            backends,
146            client,
147            healthy,
148            active_conns,
149            max_connections: config.max_connections,
150        }
151    }
152
153    /// Returns `true` if the backend at `idx` is currently marked healthy.
154    pub fn is_healthy(&self, idx: usize) -> bool {
155        self.healthy
156            .get(idx)
157            .map(|h| h.load(Ordering::Relaxed))
158            .unwrap_or(false)
159    }
160
161    /// Mark a backend as healthy or unhealthy.
162    pub fn set_healthy(&self, idx: usize, val: bool) {
163        if let Some(h) = self.healthy.get(idx) {
164            h.store(val, Ordering::Relaxed);
165        }
166    }
167
168    /// Increment the active connection count for a backend. Returns a guard
169    /// that decrements on drop.
170    pub fn acquire_conn(&self, idx: usize) -> ConnGuard<'_> {
171        if let Some(c) = self.active_conns.get(idx) {
172            c.fetch_add(1, Ordering::Relaxed);
173        }
174        ConnGuard {
175            active_conns: &self.active_conns,
176            idx,
177        }
178    }
179
180    /// Get the current active connection count for a backend.
181    pub fn conn_count(&self, idx: usize) -> usize {
182        self.active_conns
183            .get(idx)
184            .map(|c| c.load(Ordering::Relaxed))
185            .unwrap_or(usize::MAX)
186    }
187
188    /// Total active connections across all backends.
189    pub fn total_active_conns(&self) -> usize {
190        self.active_conns
191            .iter()
192            .map(|c| c.load(Ordering::Relaxed))
193            .sum()
194    }
195
196    /// Number of backends.
197    pub fn len(&self) -> usize {
198        self.backends.len()
199    }
200
201    /// Whether the pool has no backends.
202    pub fn is_empty(&self) -> bool {
203        self.backends.is_empty()
204    }
205}
206
207/// RAII guard that decrements the active-connection counter on drop.
208pub struct ConnGuard<'a> {
209    active_conns: &'a [AtomicUsize],
210    idx: usize,
211}
212
213impl<'a> Drop for ConnGuard<'a> {
214    fn drop(&mut self) {
215        if let Some(c) = self.active_conns.get(self.idx) {
216            c.fetch_sub(1, Ordering::Relaxed);
217        }
218    }
219}