Skip to main content

relay_core_lib/proxy/
outbound.rs

1use async_trait::async_trait;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4
5use crate::interceptor::HttpBody;
6use hyper::body::Incoming;
7use hyper::{Request, Response};
8use hyper_util::rt::TokioIo;
9use relay_core_api::flow::Flow;
10use relay_core_api::policy::UpstreamProxyConfig;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use url::Url;
14
15/// Error returned by an outbound connector.
16#[derive(Debug, thiserror::Error)]
17pub enum UpstreamError {
18    #[error("upstream proxy unreachable: {0}")]
19    Unreachable(String),
20    #[error("upstream proxy refused CONNECT: status {status}")]
21    ConnectRefused { status: u16 },
22    #[error("upstream proxy authentication required")]
23    AuthRequired,
24    #[error("upstream TLS error: {0}")]
25    Tls(String),
26    #[error("I/O error: {0}")]
27    Io(#[from] std::io::Error),
28}
29
30/// Outbound connector — sends a RelayCore-processed HTTP request to the target,
31/// optionally through an upstream proxy.
32#[async_trait]
33pub trait OutboundConnector: Send + Sync {
34    async fn send_request(
35        &self,
36        req: Request<HttpBody>,
37        target_host: &str,
38        target_port: u16,
39        flow: &mut Flow,
40    ) -> Result<Response<Incoming>, UpstreamError>;
41
42    /// Returns the upstream proxy URL if this connector routes through one.
43    /// Used for circuit breaker keying and status reporting.
44    fn upstream_proxy_url(&self) -> Option<&str> {
45        None
46    }
47}
48
49// ── Bypass rule parsing / matching ──────────────────────
50
51/// Pre-parsed bypass rule for CIDR / IP / glob hostname matching.
52#[derive(Debug, Clone)]
53pub enum BypassRule {
54    Cidr(ipnetwork::IpNetwork),
55    Ip(IpAddr),
56    Glob(glob::Pattern),
57}
58
59impl BypassRule {
60    /// Parse a raw bypass string.
61    ///
62    /// Supported formats:
63    /// - `cidr:10.0.0.0/8`  → CIDR network
64    /// - `127.0.0.1`        → literal IP
65    /// - `*.internal.corp`  → glob hostname
66    pub fn parse(raw: &str) -> Result<Self, String> {
67        if let Some(cidr) = raw.strip_prefix("cidr:") {
68            let net: ipnetwork::IpNetwork = cidr
69                .parse()
70                .map_err(|e| format!("invalid CIDR '{}': {}", cidr, e))?;
71            return Ok(Self::Cidr(net));
72        }
73        if let Ok(ip) = raw.parse::<IpAddr>() {
74            return Ok(Self::Ip(ip));
75        }
76        glob::Pattern::new(raw)
77            .map(Self::Glob)
78            .map_err(|e| format!("invalid glob '{}': {}", raw, e))
79    }
80
81    /// Returns true when `hostname` (a host name or IP string) should bypass upstream.
82    pub fn matches_host(&self, hostname: &str) -> bool {
83        match self {
84            Self::Cidr(net) => hostname.parse::<IpAddr>().is_ok_and(|ip| net.contains(ip)),
85            Self::Ip(ip) => hostname.parse::<IpAddr>().is_ok_and(|parsed| parsed == *ip),
86            Self::Glob(p) => p.matches(hostname),
87        }
88    }
89
90    /// Returns true when `addr` (a raw IP) should bypass upstream (transparent mode).
91    pub fn matches_ip(&self, addr: &IpAddr) -> bool {
92        match self {
93            Self::Cidr(net) => net.contains(*addr),
94            Self::Ip(ip) => ip == addr,
95            Self::Glob(_) => false,
96        }
97    }
98}
99
100/// Computes the `Proxy-Authorization` header value from upstream auth config.
101pub fn upstream_proxy_authorization(upstream: &UpstreamProxyConfig) -> Option<String> {
102    upstream.auth.as_ref().map(|a| {
103        let creds = format!(
104            "{}:{}",
105            a.username,
106            secrecy::ExposeSecret::expose_secret(&a.password)
107        );
108        format!("Basic {}", data_encoding::BASE64.encode(creds.as_bytes()))
109    })
110}
111
112/// Determines whether a given host should bypass the upstream proxy.
113pub fn should_bypass(upstream: &UpstreamProxyConfig, host: &str, ip: Option<IpAddr>) -> bool {
114    let rules: Vec<BypassRule> = upstream
115        .bypass_hosts
116        .iter()
117        .filter_map(|r| match BypassRule::parse(r) {
118            Ok(rule) => Some(rule),
119            Err(e) => {
120                tracing::warn!("invalid upstream bypass entry '{}': {}", r, e);
121                None
122            }
123        })
124        .collect();
125
126    // Match by hostname first
127    if rules.iter().any(|r| r.matches_host(host)) {
128        return true;
129    }
130
131    // Transparent mode: also match by raw IP
132    if let Some(addr) = ip
133        && rules.iter().any(|r| r.matches_ip(&addr))
134    {
135        return true;
136    }
137
138    false
139}
140
141// ── Direct Connector ────────────────────────────────────
142
143use crate::proxy::http_utils::HttpsClient;
144use hyper_rustls::ConfigBuilderExt;
145
146/// Direct outbound connector (no upstream proxy).
147pub struct DirectConnector {
148    client: Arc<HttpsClient>,
149}
150
151impl DirectConnector {
152    pub fn new(client: Arc<HttpsClient>) -> Self {
153        Self { client }
154    }
155}
156
157#[async_trait]
158impl OutboundConnector for DirectConnector {
159    async fn send_request(
160        &self,
161        req: Request<HttpBody>,
162        _target_host: &str,
163        _target_port: u16,
164        _flow: &mut Flow,
165    ) -> Result<Response<Incoming>, UpstreamError> {
166        self.client
167            .request(req)
168            .await
169            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))
170    }
171}
172
173// ── HTTP Upstream Connector ─────────────────────────────
174
175/// Connects through an HTTP upstream proxy (no TLS to the proxy itself).
176pub struct HttpUpstreamConnector {
177    proxy_url: String,
178    proxy_addr: SocketAddr,
179    proxy_authorization: Option<String>,
180    tls_client_config: Arc<rustls::ClientConfig>,
181}
182
183impl HttpUpstreamConnector {
184    pub async fn new(config: &UpstreamProxyConfig) -> Result<Self, UpstreamError> {
185        let url = Url::parse(&config.proxy_url)
186            .map_err(|e| UpstreamError::Unreachable(format!("invalid proxy URL: {}", e)))?;
187        let host = url
188            .host_str()
189            .ok_or_else(|| UpstreamError::Unreachable("proxy URL missing host".into()))?;
190        let port = url.port_or_known_default().unwrap_or(8080);
191
192        let addr = tokio::net::lookup_host((host, port))
193            .await
194            .map_err(|e| UpstreamError::Unreachable(format!("DNS resolution failed: {}", e)))?
195            .next()
196            .ok_or_else(|| UpstreamError::Unreachable("no address resolved".into()))?;
197
198        let proxy_auth = upstream_proxy_authorization(config);
199
200        let tls_config = Arc::new(
201            rustls::ClientConfig::builder()
202                .with_native_roots()
203                .map_err(|e| UpstreamError::Tls(e.to_string()))?
204                .with_no_client_auth(),
205        );
206
207        Ok(Self {
208            proxy_url: config.proxy_url.clone(),
209            proxy_addr: addr,
210            proxy_authorization: proxy_auth,
211            tls_client_config: tls_config,
212        })
213    }
214
215    /// Send a CONNECT request through an async stream and return the HTTP response status.
216    async fn send_connect_inner<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>(
217        stream: &mut S,
218        host: &str,
219        port: u16,
220        proxy_auth: Option<&str>,
221    ) -> Result<u16, UpstreamError> {
222        let mut req = format!(
223            "CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n",
224            host, port, host, port
225        );
226        if let Some(auth) = proxy_auth {
227            req.push_str(&format!("Proxy-Authorization: {}\r\n", auth));
228        }
229        req.push_str("\r\n");
230
231        stream
232            .write_all(req.as_bytes())
233            .await
234            .map_err(UpstreamError::Io)?;
235        stream.flush().await.map_err(UpstreamError::Io)?;
236
237        // Read response line (buffered, up to 512 bytes)
238        let mut buf = [0u8; 512];
239        let mut pos = 0;
240        loop {
241            if pos >= buf.len() {
242                return Err(UpstreamError::ConnectRefused { status: 0 });
243            }
244            let n = stream
245                .read(&mut buf[pos..pos + 1])
246                .await
247                .map_err(UpstreamError::Io)?;
248            if n == 0 {
249                return Err(UpstreamError::Unreachable(
250                    "connection closed during CONNECT handshake".into(),
251                ));
252            }
253            pos += 1;
254            if pos >= 2 && buf[pos - 2..pos] == [b'\r', b'\n'] {
255                break;
256            }
257        }
258
259        let status_line = String::from_utf8_lossy(&buf[..pos]).trim().to_string();
260        let parts: Vec<&str> = status_line.split_whitespace().collect();
261        if parts.len() < 2 {
262            return Err(UpstreamError::ConnectRefused { status: 0 });
263        }
264        let status: u16 = parts[1]
265            .parse()
266            .map_err(|_| UpstreamError::ConnectRefused { status: 0 })?;
267
268        // Consume remaining headers until \r\n\r\n
269        let mut header_buf = [0u8; 4096];
270        let mut total = 0;
271        loop {
272            let n = stream
273                .read(&mut header_buf[total..])
274                .await
275                .map_err(UpstreamError::Io)?;
276            if n == 0 {
277                return Err(UpstreamError::Unreachable(
278                    "connection closed during CONNECT response".into(),
279                ));
280            }
281            total += n;
282            if total >= 4 && header_buf[total - 4..total] == [b'\r', b'\n', b'\r', b'\n'] {
283                break;
284            }
285            if total >= header_buf.len() {
286                break; // headers too large, assume end
287            }
288        }
289
290        Ok(status)
291    }
292
293    /// Establish TLS to a target through the given stream.
294    async fn tls_to_target(
295        config: Arc<rustls::ClientConfig>,
296        stream: TcpStream,
297        target_host: &str,
298    ) -> Result<tokio_rustls::client::TlsStream<TcpStream>, UpstreamError> {
299        let connector = tokio_rustls::TlsConnector::from(config);
300        let server_name = rustls::pki_types::ServerName::try_from(target_host.to_string())
301            .map_err(|e| UpstreamError::Tls(format!("invalid server name: {}", e)))?;
302        connector
303            .connect(server_name, stream)
304            .await
305            .map_err(|e| UpstreamError::Tls(e.to_string()))
306    }
307}
308
309#[async_trait]
310impl OutboundConnector for HttpUpstreamConnector {
311    async fn send_request(
312        &self,
313        req: Request<HttpBody>,
314        target_host: &str,
315        target_port: u16,
316        _flow: &mut Flow,
317    ) -> Result<Response<Incoming>, UpstreamError> {
318        let uri_scheme = req.uri().scheme_str().unwrap_or("http");
319
320        if uri_scheme == "https" {
321            return self
322                .send_request_connect(req, target_host, target_port)
323                .await;
324        }
325
326        self.send_request_absolute_uri(req, target_host, target_port)
327            .await
328    }
329
330    fn upstream_proxy_url(&self) -> Option<&str> {
331        Some(&self.proxy_url)
332    }
333}
334
335impl HttpUpstreamConnector {
336    /// Send an HTTP request with absolute URI through the upstream proxy (HTTP target).
337    async fn send_request_absolute_uri(
338        &self,
339        req: Request<HttpBody>,
340        target_host: &str,
341        target_port: u16,
342    ) -> Result<Response<Incoming>, UpstreamError> {
343        let (parts, body) = req.into_parts();
344        let path = parts
345            .uri
346            .path_and_query()
347            .map(|pq| pq.as_str())
348            .unwrap_or("/");
349        let target_url = format!("http://{}:{}{}", target_host, target_port, path);
350        let mut req_builder = Request::builder()
351            .method(parts.method)
352            .uri(&target_url)
353            .version(parts.version);
354        for (name, value) in &parts.headers {
355            if crate::proxy::http_utils::is_hop_by_hop(name.as_str()) {
356                continue;
357            }
358            req_builder = req_builder.header(name, value);
359        }
360        let req = req_builder
361            .body(body)
362            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
363
364        let stream = TcpStream::connect(self.proxy_addr)
365            .await
366            .map_err(|e| UpstreamError::Unreachable(format!("TCP connect to upstream: {}", e)))?;
367        let io = TokioIo::new(stream);
368
369        let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
370            .await
371            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
372        tokio::spawn(async move {
373            if let Err(e) = conn.await {
374                tracing::debug!("upstream http1 connection error: {}", e);
375            }
376        });
377
378        let resp = sender
379            .send_request(req)
380            .await
381            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
382        Ok(resp)
383    }
384
385    /// Send an HTTPS request through the upstream proxy via CONNECT tunnel.
386    async fn send_request_connect(
387        &self,
388        req: Request<HttpBody>,
389        target_host: &str,
390        target_port: u16,
391    ) -> Result<Response<Incoming>, UpstreamError> {
392        let mut stream = TcpStream::connect(self.proxy_addr)
393            .await
394            .map_err(|e| UpstreamError::Unreachable(format!("TCP connect to upstream: {}", e)))?;
395
396        // 1. Send CONNECT to upstream proxy
397        let status = Self::send_connect_inner(
398            &mut stream,
399            target_host,
400            target_port,
401            self.proxy_authorization.as_deref(),
402        )
403        .await?;
404
405        if !(200..300).contains(&status) {
406            return Err(UpstreamError::ConnectRefused { status });
407        }
408
409        // 2. Establish TLS to target through the tunnel
410        let tls_stream =
411            Self::tls_to_target(self.tls_client_config.clone(), stream, target_host).await?;
412        let io = TokioIo::new(tls_stream);
413
414        // 3. Send the actual HTTP request over TLS
415        let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
416            .await
417            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
418        tokio::spawn(async move {
419            if let Err(e) = conn.await {
420                tracing::debug!("upstream tunnel http1 connection error: {}", e);
421            }
422        });
423
424        let resp = sender
425            .send_request(req)
426            .await
427            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
428        Ok(resp)
429    }
430}
431
432// ── HTTPS Upstream Connector ────────────────────────────
433
434/// Connects through an HTTPS upstream proxy (TLS to the proxy first).
435pub struct HttpsUpstreamConnector {
436    proxy_url: String,
437    proxy_addr: SocketAddr,
438    proxy_host: String,
439    proxy_authorization: Option<String>,
440    tls_client_config: Arc<rustls::ClientConfig>,
441}
442
443impl HttpsUpstreamConnector {
444    pub async fn new(config: &UpstreamProxyConfig) -> Result<Self, UpstreamError> {
445        let url = Url::parse(&config.proxy_url)
446            .map_err(|e| UpstreamError::Unreachable(format!("invalid proxy URL: {}", e)))?;
447        let host = url
448            .host_str()
449            .ok_or_else(|| UpstreamError::Unreachable("proxy URL missing host".into()))?
450            .to_string();
451        let port = url.port_or_known_default().unwrap_or(443);
452
453        let addr = tokio::net::lookup_host((host.as_str(), port))
454            .await
455            .map_err(|e| UpstreamError::Unreachable(format!("DNS resolution failed: {}", e)))?
456            .next()
457            .ok_or_else(|| UpstreamError::Unreachable("no address resolved".into()))?;
458
459        let proxy_auth = upstream_proxy_authorization(config);
460
461        let tls_config = Arc::new(
462            rustls::ClientConfig::builder()
463                .with_native_roots()
464                .map_err(|e| UpstreamError::Tls(e.to_string()))?
465                .with_no_client_auth(),
466        );
467
468        Ok(Self {
469            proxy_url: config.proxy_url.clone(),
470            proxy_addr: addr,
471            proxy_host: host,
472            proxy_authorization: proxy_auth,
473            tls_client_config: tls_config,
474        })
475    }
476}
477
478#[async_trait]
479impl OutboundConnector for HttpsUpstreamConnector {
480    async fn send_request(
481        &self,
482        req: Request<HttpBody>,
483        target_host: &str,
484        target_port: u16,
485        _flow: &mut Flow,
486    ) -> Result<Response<Incoming>, UpstreamError> {
487        let connector = tokio_rustls::TlsConnector::from(self.tls_client_config.clone());
488        let proxy_server_name = rustls::pki_types::ServerName::try_from(self.proxy_host.clone())
489            .map_err(|e| UpstreamError::Tls(format!("invalid proxy server name: {}", e)))?;
490
491        // 1. TCP connect to upstream proxy
492        let stream = TcpStream::connect(self.proxy_addr)
493            .await
494            .map_err(|e| UpstreamError::Unreachable(format!("TCP connect to upstream: {}", e)))?;
495
496        // 2. TLS handshake to upstream proxy
497        let mut proxy_tls = connector
498            .connect(proxy_server_name.clone(), stream)
499            .await
500            .map_err(|e| UpstreamError::Tls(e.to_string()))?;
501
502        // 3. Send CONNECT inside the TLS tunnel
503        let status = HttpUpstreamConnector::send_connect_inner(
504            &mut proxy_tls,
505            target_host,
506            target_port,
507            self.proxy_authorization.as_deref(),
508        )
509        .await?;
510
511        if !(200..300).contains(&status) {
512            return Err(UpstreamError::ConnectRefused { status });
513        }
514
515        // 4. TLS to target through the proxy's TLS tunnel (TLS-in-TLS)
516        let target_server_name =
517            rustls::pki_types::ServerName::try_from(target_host.to_string())
518                .map_err(|e| UpstreamError::Tls(format!("invalid target server name: {}", e)))?;
519        let target_tls = connector
520            .connect(target_server_name, proxy_tls)
521            .await
522            .map_err(|e| UpstreamError::Tls(e.to_string()))?;
523        let io = TokioIo::new(target_tls);
524
525        // 5. Send the actual HTTP request
526        let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
527            .await
528            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
529        tokio::spawn(async move {
530            if let Err(e) = conn.await {
531                tracing::debug!("https-upstream tunnel http1 connection error: {}", e);
532            }
533        });
534
535        let resp = sender
536            .send_request(req)
537            .await
538            .map_err(|e| UpstreamError::Io(std::io::Error::other(e)))?;
539        Ok(resp)
540    }
541
542    fn upstream_proxy_url(&self) -> Option<&str> {
543        Some(&self.proxy_url)
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use std::net::Ipv4Addr;
551
552    #[test]
553    fn bypass_rule_parse_cidr() {
554        let rule = BypassRule::parse("cidr:10.0.0.0/8").unwrap();
555        assert!(matches!(rule, BypassRule::Cidr(_)));
556    }
557
558    #[test]
559    fn bypass_rule_parse_ip_literal() {
560        let rule = BypassRule::parse("127.0.0.1").unwrap();
561        assert!(matches!(rule, BypassRule::Ip(_)));
562    }
563
564    #[test]
565    fn bypass_rule_parse_glob() {
566        let rule = BypassRule::parse("*.internal.corp").unwrap();
567        assert!(matches!(rule, BypassRule::Glob(_)));
568    }
569
570    #[test]
571    fn bypass_rule_parse_invalid() {
572        assert!(BypassRule::parse("cidr:not-a-cidr").is_err());
573    }
574
575    #[test]
576    fn bypass_rule_cidr_matches_host() {
577        let rule = BypassRule::parse("cidr:10.0.0.0/8").unwrap();
578        assert!(rule.matches_host("10.1.2.3"));
579        assert!(!rule.matches_host("192.168.1.1"));
580        assert!(!rule.matches_host("example.com"));
581    }
582
583    #[test]
584    fn bypass_rule_ip_matches_host() {
585        let rule = BypassRule::parse("127.0.0.1").unwrap();
586        assert!(rule.matches_host("127.0.0.1"));
587        assert!(!rule.matches_host("127.0.0.2"));
588    }
589
590    #[test]
591    fn bypass_rule_glob_matches_host() {
592        let rule = BypassRule::parse("*.internal.corp").unwrap();
593        assert!(rule.matches_host("svc.internal.corp"));
594        assert!(rule.matches_host("foo.bar.internal.corp"));
595        assert!(!rule.matches_host("external.corp"));
596        assert!(!rule.matches_host("10.0.0.1"));
597    }
598
599    #[test]
600    fn bypass_rule_cidr_matches_ip() {
601        let rule = BypassRule::parse("cidr:10.0.0.0/8").unwrap();
602        assert!(rule.matches_ip(&IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3))));
603        assert!(!rule.matches_ip(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
604    }
605
606    #[test]
607    fn bypass_rule_glob_never_matches_ip() {
608        let rule = BypassRule::parse("*.example.com").unwrap();
609        assert!(!rule.matches_ip(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
610    }
611
612    #[test]
613    fn upstream_error_display() {
614        let e = UpstreamError::ConnectRefused { status: 403 };
615        assert!(e.to_string().contains("403"));
616    }
617}