Skip to main content

nono_proxy/
forward.rs

1//! Shared L7 upstream-forwarding pipeline.
2//!
3//! Used by both the reverse-proxy path ([`crate::reverse`]) and the
4//! TLS-intercept CONNECT path ([`crate::tls_intercept`]). The two callers
5//! differ in how they parse the inbound request, look up the route, and
6//! transform/inject credentials, but converge on the same wire-level
7//! upstream operation:
8//!
9//! 1. Establish an upstream byte stream — direct TCP (with optional TLS)
10//!    or chained CONNECT through an enterprise proxy (then TLS).
11//! 2. Write the pre-built HTTP/1.1 request bytes + body.
12//! 3. Stream the response back into the inbound sink.
13//! 4. Emit one L7 audit event with the response status.
14//!
15//! ## Why pre-built request bytes
16//!
17//! Each caller has its own rules for header filtering, credential
18//! injection, and path transformation. Asking this module to handle that
19//! would mean smuggling all of that policy through a parameter struct.
20//! Instead, the caller hands in finished bytes: a clean separation
21//! between "build the request" and "speak it on the wire".
22
23use crate::audit;
24use crate::error::{ProxyError, Result};
25use std::net::SocketAddr;
26use std::time::Duration;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28use tokio::net::TcpStream;
29use tokio_rustls::TlsConnector;
30use tracing::debug;
31
32/// Timeout for upstream TCP connect (matches the historical reverse-proxy value).
33const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
34
35/// Scheme of the upstream connection. `Http` is only legal for loopback
36/// targets; the caller is responsible for enforcing that invariant
37/// (`reverse.rs` does so via `validate_http_upstream_target`).
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum UpstreamScheme {
40    Http,
41    Https,
42}
43
44/// How the upstream byte stream is established.
45pub enum UpstreamStrategy<'a> {
46    /// Connect directly to one of `resolved_addrs` (DNS rebinding-safe:
47    /// the addresses must already have been validated by the host filter).
48    Direct { resolved_addrs: &'a [SocketAddr] },
49    /// Chain a CONNECT through an enterprise proxy. `proxy_addr` is the
50    /// `host:port` of the corporate proxy; `proxy_auth_header` is the literal
51    /// value to send in `Proxy-Authorization` (e.g. `"Basic …"`), or `None`
52    /// for unauthenticated proxies.
53    ExternalProxy {
54        proxy_addr: &'a str,
55        proxy_auth_header: Option<&'a str>,
56    },
57}
58
59/// Description of the upstream the caller wants to reach.
60pub struct UpstreamSpec<'a> {
61    pub scheme: UpstreamScheme,
62    pub host: &'a str,
63    pub port: u16,
64    pub strategy: UpstreamStrategy<'a>,
65    /// TLS connector to use for an `Https` scheme. Reverse-proxy callers
66    /// pass either the route's per-route connector (custom CA / mTLS) or
67    /// the shared default; intercept callers do the same.
68    pub tls_connector: &'a TlsConnector,
69}
70
71/// Audit-emission context.
72pub struct AuditCtx<'a> {
73    pub log: Option<&'a audit::SharedAuditLog>,
74    pub mode: audit::ProxyMode,
75    pub event_ctx: audit::EventContext<'a>,
76    /// Logical target string (route prefix for reverse, hostname for intercept).
77    pub target: &'a str,
78    pub method: &'a str,
79    /// Path as it should appear in the audit log (the *inbound* path before
80    /// any rewriting — e.g. `/v1/chat/completions`, not the upstream URL).
81    pub path: &'a str,
82}
83
84/// Connect to the upstream, write `request_bytes + body`, stream the
85/// response back into `inbound`, and emit the L7 audit event.
86///
87/// Returns the response status code (or 502 if the upstream sent something
88/// unparseable).
89pub async fn forward_request<S>(
90    inbound: &mut S,
91    request_bytes: &[u8],
92    body: &[u8],
93    upstream: UpstreamSpec<'_>,
94    audit: AuditCtx<'_>,
95) -> Result<u16>
96where
97    S: AsyncRead + AsyncWrite + Unpin,
98{
99    let status = match upstream.scheme {
100        UpstreamScheme::Https => {
101            let mut tls_stream = open_https_upstream(&upstream).await?;
102            write_request(&mut tls_stream, request_bytes, body).await?;
103            stream_response(&mut tls_stream, inbound).await?
104        }
105        UpstreamScheme::Http => {
106            let mut tcp_stream = open_http_upstream(&upstream).await?;
107            write_request(&mut tcp_stream, request_bytes, body).await?;
108            stream_response(&mut tcp_stream, inbound).await?
109        }
110    };
111
112    audit::log_l7_request(
113        audit.log,
114        audit.mode,
115        &audit.event_ctx,
116        audit.target,
117        audit.method,
118        audit.path,
119        status,
120    );
121    Ok(status)
122}
123
124/// Open an upstream HTTPS connection (Direct TLS or ExternalProxy + TLS).
125async fn open_https_upstream(
126    upstream: &UpstreamSpec<'_>,
127) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
128    let tcp = open_tcp_upstream(upstream).await?;
129    let server_name =
130        rustls::pki_types::ServerName::try_from(upstream.host.to_string()).map_err(|_| {
131            ProxyError::UpstreamConnect {
132                host: upstream.host.to_string(),
133                reason: "invalid server name for TLS".to_string(),
134            }
135        })?;
136    upstream
137        .tls_connector
138        .connect(server_name, tcp)
139        .await
140        .map_err(|e| ProxyError::UpstreamConnect {
141            host: upstream.host.to_string(),
142            reason: format!("TLS handshake failed: {}", e),
143        })
144}
145
146/// Open an upstream HTTP (plain) connection. Caller has already validated
147/// that this is a loopback target.
148async fn open_http_upstream(upstream: &UpstreamSpec<'_>) -> Result<TcpStream> {
149    open_tcp_upstream(upstream).await
150}
151
152/// Establish the TCP layer of the upstream connection (without TLS).
153async fn open_tcp_upstream(upstream: &UpstreamSpec<'_>) -> Result<TcpStream> {
154    match upstream.strategy {
155        UpstreamStrategy::Direct { resolved_addrs } => {
156            if resolved_addrs.is_empty() {
157                let addr = format!("{}:{}", upstream.host, upstream.port);
158                match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr))
159                    .await
160                {
161                    Ok(Ok(s)) => Ok(s),
162                    Ok(Err(e)) => Err(ProxyError::UpstreamConnect {
163                        host: upstream.host.to_string(),
164                        reason: e.to_string(),
165                    }),
166                    Err(_) => Err(ProxyError::UpstreamConnect {
167                        host: upstream.host.to_string(),
168                        reason: "connection timed out".to_string(),
169                    }),
170                }
171            } else {
172                connect_to_resolved(resolved_addrs, upstream.host).await
173            }
174        }
175        UpstreamStrategy::ExternalProxy {
176            proxy_addr,
177            proxy_auth_header,
178        } => crate::external::connect_via_proxy(
179            proxy_addr,
180            upstream.host,
181            upstream.port,
182            proxy_auth_header,
183        )
184        .await
185        .map_err(|e| match e {
186            ProxyError::ExternalProxy(reason) => ProxyError::UpstreamConnect {
187                host: upstream.host.to_string(),
188                reason,
189            },
190            other => other,
191        }),
192    }
193}
194
195/// Connect to one of the pre-resolved socket addresses with timeout.
196///
197/// Tries each address in order until one succeeds. Connecting to the IP
198/// directly (not re-resolving the hostname) prevents DNS rebinding TOCTOU.
199async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
200    let mut last_err = None;
201    for addr in addrs {
202        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
203            Ok(Ok(stream)) => return Ok(stream),
204            Ok(Err(e)) => {
205                debug!("Connect to {} failed: {}", addr, e);
206                last_err = Some(e.to_string());
207            }
208            Err(_) => {
209                debug!("Connect to {} timed out", addr);
210                last_err = Some("connection timed out".to_string());
211            }
212        }
213    }
214    Err(ProxyError::UpstreamConnect {
215        host: host.to_string(),
216        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
217    })
218}
219
220async fn write_request<S>(stream: &mut S, request: &[u8], body: &[u8]) -> Result<()>
221where
222    S: AsyncWrite + Unpin,
223{
224    stream.write_all(request).await?;
225    if !body.is_empty() {
226        stream.write_all(body).await?;
227    }
228    stream.flush().await?;
229    Ok(())
230}
231
232/// Stream the upstream response back to the inbound sink.
233///
234/// Returns the HTTP status code parsed from the first chunk. Streams
235/// chunked / SSE / HTTP-streaming bodies transparently because we never
236/// buffer the body — each upstream read is mirrored to the inbound write.
237async fn stream_response<U, I>(upstream: &mut U, inbound: &mut I) -> Result<u16>
238where
239    U: AsyncRead + AsyncWrite + Unpin,
240    I: AsyncWrite + Unpin,
241{
242    let mut buf = [0u8; 8192];
243    let mut status_code: u16 = 502;
244    let mut first_chunk = true;
245
246    loop {
247        let n = match upstream.read(&mut buf).await {
248            Ok(0) => break,
249            Ok(n) => n,
250            Err(e) => {
251                debug!("Upstream read error: {}", e);
252                break;
253            }
254        };
255
256        if first_chunk {
257            status_code = parse_response_status(&buf[..n]);
258            first_chunk = false;
259        }
260
261        inbound.write_all(&buf[..n]).await?;
262        inbound.flush().await?;
263    }
264
265    Ok(status_code)
266}
267
268/// Parse HTTP status code from the first response chunk.
269///
270/// Returns 502 when the response doesn't contain a valid status line.
271fn parse_response_status(data: &[u8]) -> u16 {
272    let line_end = data
273        .iter()
274        .position(|&b| b == b'\r' || b == b'\n')
275        .unwrap_or(data.len());
276    let first_line = &data[..line_end.min(64)];
277
278    if let Ok(line) = std::str::from_utf8(first_line) {
279        let mut parts = line.split_whitespace();
280        if let Some(version) = parts.next() {
281            if version.starts_with("HTTP/") {
282                if let Some(code_str) = parts.next() {
283                    if code_str.len() == 3 {
284                        return code_str.parse().unwrap_or(502);
285                    }
286                }
287            }
288        }
289    }
290    502
291}
292
293#[cfg(test)]
294#[allow(clippy::unwrap_used)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn parse_response_status_extracts_code() {
300        assert_eq!(parse_response_status(b"HTTP/1.1 200 OK\r\n"), 200);
301        assert_eq!(parse_response_status(b"HTTP/1.1 404 Not Found\r\n"), 404);
302        assert_eq!(parse_response_status(b"HTTP/1.1 502 Bad Gateway\r\n"), 502);
303    }
304
305    #[test]
306    fn parse_response_status_handles_garbage() {
307        assert_eq!(parse_response_status(b""), 502);
308        assert_eq!(parse_response_status(b"garbage"), 502);
309        assert_eq!(parse_response_status(b"NOT-HTTP 200 OK"), 502);
310    }
311}