gatel_core/proxy/
upstream.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3
4use hyper_util::client::legacy::Client;
5use hyper_util::client::legacy::connect::HttpConnector;
6use hyper_util::rt::TokioExecutor;
7use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
8use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
9
10use crate::Body;
11use crate::config::ProxyConfig;
12
13#[derive(Debug)]
16struct NoVerifier;
17
18impl ServerCertVerifier for NoVerifier {
19 fn verify_server_cert(
20 &self,
21 _end_entity: &CertificateDer<'_>,
22 _intermediates: &[CertificateDer<'_>],
23 _server_name: &ServerName<'_>,
24 _ocsp_response: &[u8],
25 _now: UnixTime,
26 ) -> Result<ServerCertVerified, rustls::Error> {
27 Ok(ServerCertVerified::assertion())
28 }
29
30 fn verify_tls12_signature(
31 &self,
32 _message: &[u8],
33 _cert: &CertificateDer<'_>,
34 _dss: &rustls::DigitallySignedStruct,
35 ) -> Result<HandshakeSignatureValid, rustls::Error> {
36 Ok(HandshakeSignatureValid::assertion())
37 }
38
39 fn verify_tls13_signature(
40 &self,
41 _message: &[u8],
42 _cert: &CertificateDer<'_>,
43 _dss: &rustls::DigitallySignedStruct,
44 ) -> Result<HandshakeSignatureValid, rustls::Error> {
45 Ok(HandshakeSignatureValid::assertion())
46 }
47
48 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
49 rustls::crypto::ring::default_provider()
50 .signature_verification_algorithms
51 .supported_schemes()
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct Backend {
58 pub addr: String,
59 pub weight: u32,
60}
61
62pub struct UpstreamPool {
65 pub backends: Vec<Backend>,
66 pub client: Client<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
67 pub healthy: Vec<AtomicBool>,
69 pub active_conns: Vec<AtomicUsize>,
71 pub max_connections: Option<usize>,
73}
74
75impl UpstreamPool {
76 pub fn from_config(config: &ProxyConfig) -> Self {
77 let backends: Vec<Backend> = config
78 .upstreams
79 .iter()
80 .map(|u| Backend {
81 addr: u.addr.clone(),
82 weight: u.weight,
83 })
84 .collect();
85
86 let n = backends.len();
87 let healthy: Vec<AtomicBool> = (0..n).map(|_| AtomicBool::new(true)).collect();
88 let active_conns: Vec<AtomicUsize> = (0..n).map(|_| AtomicUsize::new(0)).collect();
89
90 let connector = if config.tls_skip_verify {
92 let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
94 rustls::crypto::ring::default_provider(),
95 ))
96 .with_safe_default_protocol_versions()
97 .expect("default protocol versions are valid")
98 .dangerous()
99 .with_custom_certificate_verifier(Arc::new(NoVerifier))
100 .with_no_client_auth();
101
102 hyper_rustls::HttpsConnectorBuilder::new()
103 .with_tls_config(tls_config)
104 .https_or_http()
105 .enable_http1()
106 .enable_http2()
107 .build()
108 } else {
109 let root_store = rustls::RootCertStore::empty();
116 let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
117 rustls::crypto::ring::default_provider(),
118 ))
119 .with_safe_default_protocol_versions()
120 .expect("default protocol versions are valid")
121 .with_root_certificates(root_store)
122 .with_no_client_auth();
123
124 hyper_rustls::HttpsConnectorBuilder::new()
125 .with_tls_config(tls_config)
126 .https_or_http()
127 .enable_http1()
128 .enable_http2()
129 .build()
130 };
131
132 let mut builder = Client::builder(TokioExecutor::new());
133
134 if config.upstream_http2 {
135 builder.http2_only(true);
136 }
137
138 if let Some(timeout) = config.keepalive_timeout {
139 builder.pool_idle_timeout(timeout);
140 }
141
142 let client = builder.build(connector);
143
144 Self {
145 backends,
146 client,
147 healthy,
148 active_conns,
149 max_connections: config.max_connections,
150 }
151 }
152
153 pub fn is_healthy(&self, idx: usize) -> bool {
155 self.healthy
156 .get(idx)
157 .map(|h| h.load(Ordering::Relaxed))
158 .unwrap_or(false)
159 }
160
161 pub fn set_healthy(&self, idx: usize, val: bool) {
163 if let Some(h) = self.healthy.get(idx) {
164 h.store(val, Ordering::Relaxed);
165 }
166 }
167
168 pub fn acquire_conn(&self, idx: usize) -> ConnGuard<'_> {
171 if let Some(c) = self.active_conns.get(idx) {
172 c.fetch_add(1, Ordering::Relaxed);
173 }
174 ConnGuard {
175 active_conns: &self.active_conns,
176 idx,
177 }
178 }
179
180 pub fn conn_count(&self, idx: usize) -> usize {
182 self.active_conns
183 .get(idx)
184 .map(|c| c.load(Ordering::Relaxed))
185 .unwrap_or(usize::MAX)
186 }
187
188 pub fn total_active_conns(&self) -> usize {
190 self.active_conns
191 .iter()
192 .map(|c| c.load(Ordering::Relaxed))
193 .sum()
194 }
195
196 pub fn len(&self) -> usize {
198 self.backends.len()
199 }
200
201 pub fn is_empty(&self) -> bool {
203 self.backends.is_empty()
204 }
205}
206
207pub struct ConnGuard<'a> {
209 active_conns: &'a [AtomicUsize],
210 idx: usize,
211}
212
213impl<'a> Drop for ConnGuard<'a> {
214 fn drop(&mut self) {
215 if let Some(c) = self.active_conns.get(self.idx) {
216 c.fetch_sub(1, Ordering::Relaxed);
217 }
218 }
219}