1use async_trait::async_trait;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4
5use crate::interceptor::HttpBody;
6use hyper::body::Incoming;
7use hyper::{Request, Response};
8use hyper_util::rt::TokioIo;
9use relay_core_api::flow::Flow;
10use relay_core_api::policy::UpstreamProxyConfig;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use url::Url;
14
15#[derive(Debug, thiserror::Error)]
17pub enum UpstreamError {
18 #[error("upstream proxy unreachable: {0}")]
19 Unreachable(String),
20 #[error("upstream proxy refused CONNECT: status {status}")]
21 ConnectRefused { status: u16 },
22 #[error("upstream proxy authentication required")]
23 AuthRequired,
24 #[error("upstream TLS error: {0}")]
25 Tls(String),
26 #[error("I/O error: {0}")]
27 Io(#[from] std::io::Error),
28}
29
30#[async_trait]
33pub trait OutboundConnector: Send + Sync {
34 async fn send_request(
35 &self,
36 req: Request<HttpBody>,
37 target_host: &str,
38 target_port: u16,
39 flow: &mut Flow,
40 ) -> Result<Response<Incoming>, UpstreamError>;
41
42 fn upstream_proxy_url(&self) -> Option<&str> {
45 None
46 }
47}
48
49#[derive(Debug, Clone)]
53pub enum BypassRule {
54 Cidr(ipnetwork::IpNetwork),
55 Ip(IpAddr),
56 Glob(glob::Pattern),
57}
58
59impl BypassRule {
60 pub fn parse(raw: &str) -> Result<Self, String> {
67 if let Some(cidr) = raw.strip_prefix("cidr:") {
68 let net: ipnetwork::IpNetwork = cidr
69 .parse()
70 .map_err(|e| format!("invalid CIDR '{}': {}", cidr, e))?;
71 return Ok(Self::Cidr(net));
72 }
73 if let Ok(ip) = raw.parse::<IpAddr>() {
74 return Ok(Self::Ip(ip));
75 }
76 glob::Pattern::new(raw)
77 .map(Self::Glob)
78 .map_err(|e| format!("invalid glob '{}': {}", raw, e))
79 }
80
81 pub fn matches_host(&self, hostname: &str) -> bool {
83 match self {
84 Self::Cidr(net) => hostname.parse::<IpAddr>().is_ok_and(|ip| net.contains(ip)),
85 Self::Ip(ip) => hostname.parse::<IpAddr>().is_ok_and(|parsed| parsed == *ip),
86 Self::Glob(p) => p.matches(hostname),
87 }
88 }
89
90 pub fn matches_ip(&self, addr: &IpAddr) -> bool {
92 match self {
93 Self::Cidr(net) => net.contains(*addr),
94 Self::Ip(ip) => ip == addr,
95 Self::Glob(_) => false,
96 }
97 }
98}
99
100pub fn upstream_proxy_authorization(upstream: &UpstreamProxyConfig) -> Option<String> {
102 upstream.auth.as_ref().map(|a| {
103 let creds = format!(
104 "{}:{}",
105 a.username,
106 secrecy::ExposeSecret::expose_secret(&a.password)
107 );
108 format!("Basic {}", data_encoding::BASE64.encode(creds.as_bytes()))
109 })
110}
111
112pub fn should_bypass(upstream: &UpstreamProxyConfig, host: &str, ip: Option<IpAddr>) -> bool {
114 let rules: Vec<BypassRule> = upstream
115 .bypass_hosts
116 .iter()
117 .filter_map(|r| match BypassRule::parse(r) {
118 Ok(rule) => Some(rule),
119 Err(e) => {
120 tracing::warn!("invalid upstream bypass entry '{}': {}", r, e);
121 None
122 }
123 })
124 .collect();
125
126 if rules.iter().any(|r| r.matches_host(host)) {
128 return true;
129 }
130
131 if let Some(addr) = ip
133 && rules.iter().any(|r| r.matches_ip(&addr))
134 {
135 return true;
136 }
137
138 false
139}
140
141use crate::proxy::http_utils::HttpsClient;
144use hyper_rustls::ConfigBuilderExt;
145
146pub struct DirectConnector {
148 client: Arc<HttpsClient>,
149}
150
151impl DirectConnector {
152 pub fn new(client: Arc<HttpsClient>) -> Self {
153 Self { client }
154 }
155}
156
157#[async_trait]
158impl OutboundConnector for DirectConnector {
159 async fn send_request(
160 &self,
161 req: Request<HttpBody>,
162 _target_host: &str,
163 _target_port: u16,
164 _flow: &mut Flow,
165 ) -> Result<Response<Incoming>, UpstreamError> {
166 self.client
167 .request(req)
168 .await
169 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))
170 }
171}
172
173pub struct HttpUpstreamConnector {
177 proxy_url: String,
178 proxy_addr: SocketAddr,
179 proxy_authorization: Option<String>,
180 tls_client_config: Arc<rustls::ClientConfig>,
181}
182
183impl HttpUpstreamConnector {
184 pub async fn new(config: &UpstreamProxyConfig) -> Result<Self, UpstreamError> {
185 let url = Url::parse(&config.proxy_url)
186 .map_err(|e| UpstreamError::Unreachable(format!("invalid proxy URL: {}", e)))?;
187 let host = url
188 .host_str()
189 .ok_or_else(|| UpstreamError::Unreachable("proxy URL missing host".into()))?;
190 let port = url.port_or_known_default().unwrap_or(8080);
191
192 let addr = tokio::net::lookup_host((host, port))
193 .await
194 .map_err(|e| UpstreamError::Unreachable(format!("DNS resolution failed: {}", e)))?
195 .next()
196 .ok_or_else(|| UpstreamError::Unreachable("no address resolved".into()))?;
197
198 let proxy_auth = upstream_proxy_authorization(config);
199
200 let tls_config = Arc::new(
201 rustls::ClientConfig::builder()
202 .with_native_roots()
203 .map_err(|e| UpstreamError::Tls(e.to_string()))?
204 .with_no_client_auth(),
205 );
206
207 Ok(Self {
208 proxy_url: config.proxy_url.clone(),
209 proxy_addr: addr,
210 proxy_authorization: proxy_auth,
211 tls_client_config: tls_config,
212 })
213 }
214
215 async fn send_connect_inner<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>(
217 stream: &mut S,
218 host: &str,
219 port: u16,
220 proxy_auth: Option<&str>,
221 ) -> Result<u16, UpstreamError> {
222 let mut req = format!(
223 "CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n",
224 host, port, host, port
225 );
226 if let Some(auth) = proxy_auth {
227 req.push_str(&format!("Proxy-Authorization: {}\r\n", auth));
228 }
229 req.push_str("\r\n");
230
231 stream
232 .write_all(req.as_bytes())
233 .await
234 .map_err(UpstreamError::Io)?;
235 stream.flush().await.map_err(UpstreamError::Io)?;
236
237 let mut buf = [0u8; 512];
239 let mut pos = 0;
240 loop {
241 if pos >= buf.len() {
242 return Err(UpstreamError::ConnectRefused { status: 0 });
243 }
244 let n = stream
245 .read(&mut buf[pos..pos + 1])
246 .await
247 .map_err(UpstreamError::Io)?;
248 if n == 0 {
249 return Err(UpstreamError::Unreachable(
250 "connection closed during CONNECT handshake".into(),
251 ));
252 }
253 pos += 1;
254 if pos >= 2 && buf[pos - 2..pos] == [b'\r', b'\n'] {
255 break;
256 }
257 }
258
259 let status_line = String::from_utf8_lossy(&buf[..pos]).trim().to_string();
260 let parts: Vec<&str> = status_line.split_whitespace().collect();
261 if parts.len() < 2 {
262 return Err(UpstreamError::ConnectRefused { status: 0 });
263 }
264 let status: u16 = parts[1]
265 .parse()
266 .map_err(|_| UpstreamError::ConnectRefused { status: 0 })?;
267
268 let mut header_buf = [0u8; 4096];
270 let mut total = 0;
271 loop {
272 let n = stream
273 .read(&mut header_buf[total..])
274 .await
275 .map_err(UpstreamError::Io)?;
276 if n == 0 {
277 return Err(UpstreamError::Unreachable(
278 "connection closed during CONNECT response".into(),
279 ));
280 }
281 total += n;
282 if total >= 4 && header_buf[total - 4..total] == [b'\r', b'\n', b'\r', b'\n'] {
283 break;
284 }
285 if total >= header_buf.len() {
286 break; }
288 }
289
290 Ok(status)
291 }
292
293 async fn tls_to_target(
295 config: Arc<rustls::ClientConfig>,
296 stream: TcpStream,
297 target_host: &str,
298 ) -> Result<tokio_rustls::client::TlsStream<TcpStream>, UpstreamError> {
299 let connector = tokio_rustls::TlsConnector::from(config);
300 let server_name = rustls::pki_types::ServerName::try_from(target_host.to_string())
301 .map_err(|e| UpstreamError::Tls(format!("invalid server name: {}", e)))?;
302 connector
303 .connect(server_name, stream)
304 .await
305 .map_err(|e| UpstreamError::Tls(e.to_string()))
306 }
307}
308
309#[async_trait]
310impl OutboundConnector for HttpUpstreamConnector {
311 async fn send_request(
312 &self,
313 req: Request<HttpBody>,
314 target_host: &str,
315 target_port: u16,
316 _flow: &mut Flow,
317 ) -> Result<Response<Incoming>, UpstreamError> {
318 let uri_scheme = req.uri().scheme_str().unwrap_or("http");
319
320 if uri_scheme == "https" {
321 return self
322 .send_request_connect(req, target_host, target_port)
323 .await;
324 }
325
326 self.send_request_absolute_uri(req, target_host, target_port)
327 .await
328 }
329
330 fn upstream_proxy_url(&self) -> Option<&str> {
331 Some(&self.proxy_url)
332 }
333}
334
335impl HttpUpstreamConnector {
336 async fn send_request_absolute_uri(
338 &self,
339 req: Request<HttpBody>,
340 target_host: &str,
341 target_port: u16,
342 ) -> Result<Response<Incoming>, UpstreamError> {
343 let (parts, body) = req.into_parts();
344 let path = parts
345 .uri
346 .path_and_query()
347 .map(|pq| pq.as_str())
348 .unwrap_or("/");
349 let target_url = format!("http://{}:{}{}", target_host, target_port, path);
350 let mut req_builder = Request::builder()
351 .method(parts.method)
352 .uri(&target_url)
353 .version(parts.version);
354 for (name, value) in &parts.headers {
355 if crate::proxy::http_utils::is_hop_by_hop(name.as_str()) {
356 continue;
357 }
358 req_builder = req_builder.header(name, value);
359 }
360 let req = req_builder
361 .body(body)
362 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
363
364 let stream = TcpStream::connect(self.proxy_addr)
365 .await
366 .map_err(|e| UpstreamError::Unreachable(format!("TCP connect to upstream: {}", e)))?;
367 let io = TokioIo::new(stream);
368
369 let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
370 .await
371 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
372 tokio::spawn(async move {
373 if let Err(e) = conn.await {
374 tracing::debug!("upstream http1 connection error: {}", e);
375 }
376 });
377
378 let resp = sender
379 .send_request(req)
380 .await
381 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
382 Ok(resp)
383 }
384
385 async fn send_request_connect(
387 &self,
388 req: Request<HttpBody>,
389 target_host: &str,
390 target_port: u16,
391 ) -> Result<Response<Incoming>, UpstreamError> {
392 let mut stream = TcpStream::connect(self.proxy_addr)
393 .await
394 .map_err(|e| UpstreamError::Unreachable(format!("TCP connect to upstream: {}", e)))?;
395
396 let status = Self::send_connect_inner(
398 &mut stream,
399 target_host,
400 target_port,
401 self.proxy_authorization.as_deref(),
402 )
403 .await?;
404
405 if !(200..300).contains(&status) {
406 return Err(UpstreamError::ConnectRefused { status });
407 }
408
409 let tls_stream =
411 Self::tls_to_target(self.tls_client_config.clone(), stream, target_host).await?;
412 let io = TokioIo::new(tls_stream);
413
414 let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
416 .await
417 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
418 tokio::spawn(async move {
419 if let Err(e) = conn.await {
420 tracing::debug!("upstream tunnel http1 connection error: {}", e);
421 }
422 });
423
424 let resp = sender
425 .send_request(req)
426 .await
427 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
428 Ok(resp)
429 }
430}
431
432pub struct HttpsUpstreamConnector {
436 proxy_url: String,
437 proxy_addr: SocketAddr,
438 proxy_host: String,
439 proxy_authorization: Option<String>,
440 tls_client_config: Arc<rustls::ClientConfig>,
441}
442
443impl HttpsUpstreamConnector {
444 pub async fn new(config: &UpstreamProxyConfig) -> Result<Self, UpstreamError> {
445 let url = Url::parse(&config.proxy_url)
446 .map_err(|e| UpstreamError::Unreachable(format!("invalid proxy URL: {}", e)))?;
447 let host = url
448 .host_str()
449 .ok_or_else(|| UpstreamError::Unreachable("proxy URL missing host".into()))?
450 .to_string();
451 let port = url.port_or_known_default().unwrap_or(443);
452
453 let addr = tokio::net::lookup_host((host.as_str(), port))
454 .await
455 .map_err(|e| UpstreamError::Unreachable(format!("DNS resolution failed: {}", e)))?
456 .next()
457 .ok_or_else(|| UpstreamError::Unreachable("no address resolved".into()))?;
458
459 let proxy_auth = upstream_proxy_authorization(config);
460
461 let tls_config = Arc::new(
462 rustls::ClientConfig::builder()
463 .with_native_roots()
464 .map_err(|e| UpstreamError::Tls(e.to_string()))?
465 .with_no_client_auth(),
466 );
467
468 Ok(Self {
469 proxy_url: config.proxy_url.clone(),
470 proxy_addr: addr,
471 proxy_host: host,
472 proxy_authorization: proxy_auth,
473 tls_client_config: tls_config,
474 })
475 }
476}
477
478#[async_trait]
479impl OutboundConnector for HttpsUpstreamConnector {
480 async fn send_request(
481 &self,
482 req: Request<HttpBody>,
483 target_host: &str,
484 target_port: u16,
485 _flow: &mut Flow,
486 ) -> Result<Response<Incoming>, UpstreamError> {
487 let connector = tokio_rustls::TlsConnector::from(self.tls_client_config.clone());
488 let proxy_server_name = rustls::pki_types::ServerName::try_from(self.proxy_host.clone())
489 .map_err(|e| UpstreamError::Tls(format!("invalid proxy server name: {}", e)))?;
490
491 let stream = TcpStream::connect(self.proxy_addr)
493 .await
494 .map_err(|e| UpstreamError::Unreachable(format!("TCP connect to upstream: {}", e)))?;
495
496 let mut proxy_tls = connector
498 .connect(proxy_server_name.clone(), stream)
499 .await
500 .map_err(|e| UpstreamError::Tls(e.to_string()))?;
501
502 let status = HttpUpstreamConnector::send_connect_inner(
504 &mut proxy_tls,
505 target_host,
506 target_port,
507 self.proxy_authorization.as_deref(),
508 )
509 .await?;
510
511 if !(200..300).contains(&status) {
512 return Err(UpstreamError::ConnectRefused { status });
513 }
514
515 let target_server_name =
517 rustls::pki_types::ServerName::try_from(target_host.to_string())
518 .map_err(|e| UpstreamError::Tls(format!("invalid target server name: {}", e)))?;
519 let target_tls = connector
520 .connect(target_server_name, proxy_tls)
521 .await
522 .map_err(|e| UpstreamError::Tls(e.to_string()))?;
523 let io = TokioIo::new(target_tls);
524
525 let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
527 .await
528 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
529 tokio::spawn(async move {
530 if let Err(e) = conn.await {
531 tracing::debug!("https-upstream tunnel http1 connection error: {}", e);
532 }
533 });
534
535 let resp = sender
536 .send_request(req)
537 .await
538 .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
539 Ok(resp)
540 }
541
542 fn upstream_proxy_url(&self) -> Option<&str> {
543 Some(&self.proxy_url)
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use std::net::Ipv4Addr;
551
552 #[test]
553 fn bypass_rule_parse_cidr() {
554 let rule = BypassRule::parse("cidr:10.0.0.0/8").unwrap();
555 assert!(matches!(rule, BypassRule::Cidr(_)));
556 }
557
558 #[test]
559 fn bypass_rule_parse_ip_literal() {
560 let rule = BypassRule::parse("127.0.0.1").unwrap();
561 assert!(matches!(rule, BypassRule::Ip(_)));
562 }
563
564 #[test]
565 fn bypass_rule_parse_glob() {
566 let rule = BypassRule::parse("*.internal.corp").unwrap();
567 assert!(matches!(rule, BypassRule::Glob(_)));
568 }
569
570 #[test]
571 fn bypass_rule_parse_invalid() {
572 assert!(BypassRule::parse("cidr:not-a-cidr").is_err());
573 }
574
575 #[test]
576 fn bypass_rule_cidr_matches_host() {
577 let rule = BypassRule::parse("cidr:10.0.0.0/8").unwrap();
578 assert!(rule.matches_host("10.1.2.3"));
579 assert!(!rule.matches_host("192.168.1.1"));
580 assert!(!rule.matches_host("example.com"));
581 }
582
583 #[test]
584 fn bypass_rule_ip_matches_host() {
585 let rule = BypassRule::parse("127.0.0.1").unwrap();
586 assert!(rule.matches_host("127.0.0.1"));
587 assert!(!rule.matches_host("127.0.0.2"));
588 }
589
590 #[test]
591 fn bypass_rule_glob_matches_host() {
592 let rule = BypassRule::parse("*.internal.corp").unwrap();
593 assert!(rule.matches_host("svc.internal.corp"));
594 assert!(rule.matches_host("foo.bar.internal.corp"));
595 assert!(!rule.matches_host("external.corp"));
596 assert!(!rule.matches_host("10.0.0.1"));
597 }
598
599 #[test]
600 fn bypass_rule_cidr_matches_ip() {
601 let rule = BypassRule::parse("cidr:10.0.0.0/8").unwrap();
602 assert!(rule.matches_ip(&IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3))));
603 assert!(!rule.matches_ip(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
604 }
605
606 #[test]
607 fn bypass_rule_glob_never_matches_ip() {
608 let rule = BypassRule::parse("*.example.com").unwrap();
609 assert!(!rule.matches_ip(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
610 }
611
612 #[test]
613 fn upstream_error_display() {
614 let e = UpstreamError::ConnectRefused { status: 403 };
615 assert!(e.to_string().contains("403"));
616 }
617}