1use crate::audit;
24use crate::error::{ProxyError, Result};
25use std::net::SocketAddr;
26use std::time::Duration;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28use tokio::net::TcpStream;
29use tokio_rustls::TlsConnector;
30use tracing::debug;
31
32const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum UpstreamScheme {
40 Http,
41 Https,
42}
43
44pub enum UpstreamStrategy<'a> {
46 Direct { resolved_addrs: &'a [SocketAddr] },
49 ExternalProxy {
54 proxy_addr: &'a str,
55 proxy_auth_header: Option<&'a str>,
56 },
57}
58
59pub struct UpstreamSpec<'a> {
61 pub scheme: UpstreamScheme,
62 pub host: &'a str,
63 pub port: u16,
64 pub strategy: UpstreamStrategy<'a>,
65 pub tls_connector: &'a TlsConnector,
69}
70
71pub struct AuditCtx<'a> {
73 pub log: Option<&'a audit::SharedAuditLog>,
74 pub mode: audit::ProxyMode,
75 pub event_ctx: audit::EventContext<'a>,
76 pub target: &'a str,
78 pub method: &'a str,
79 pub path: &'a str,
82}
83
84pub async fn forward_request<S>(
90 inbound: &mut S,
91 request_bytes: &[u8],
92 body: &[u8],
93 upstream: UpstreamSpec<'_>,
94 audit: AuditCtx<'_>,
95) -> Result<u16>
96where
97 S: AsyncRead + AsyncWrite + Unpin,
98{
99 let status = match upstream.scheme {
100 UpstreamScheme::Https => {
101 let mut tls_stream = open_https_upstream(&upstream).await?;
102 write_request(&mut tls_stream, request_bytes, body).await?;
103 stream_response(&mut tls_stream, inbound).await?
104 }
105 UpstreamScheme::Http => {
106 let mut tcp_stream = open_http_upstream(&upstream).await?;
107 write_request(&mut tcp_stream, request_bytes, body).await?;
108 stream_response(&mut tcp_stream, inbound).await?
109 }
110 };
111
112 audit::log_l7_request(
113 audit.log,
114 audit.mode,
115 &audit.event_ctx,
116 audit.target,
117 audit.method,
118 audit.path,
119 status,
120 );
121 Ok(status)
122}
123
124async fn open_https_upstream(
126 upstream: &UpstreamSpec<'_>,
127) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
128 let tcp = open_tcp_upstream(upstream).await?;
129 let server_name =
130 rustls::pki_types::ServerName::try_from(upstream.host.to_string()).map_err(|_| {
131 ProxyError::UpstreamConnect {
132 host: upstream.host.to_string(),
133 reason: "invalid server name for TLS".to_string(),
134 }
135 })?;
136 upstream
137 .tls_connector
138 .connect(server_name, tcp)
139 .await
140 .map_err(|e| ProxyError::UpstreamConnect {
141 host: upstream.host.to_string(),
142 reason: format!("TLS handshake failed: {}", e),
143 })
144}
145
146async fn open_http_upstream(upstream: &UpstreamSpec<'_>) -> Result<TcpStream> {
149 open_tcp_upstream(upstream).await
150}
151
152async fn open_tcp_upstream(upstream: &UpstreamSpec<'_>) -> Result<TcpStream> {
154 match upstream.strategy {
155 UpstreamStrategy::Direct { resolved_addrs } => {
156 if resolved_addrs.is_empty() {
157 let addr = format!("{}:{}", upstream.host, upstream.port);
158 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr))
159 .await
160 {
161 Ok(Ok(s)) => Ok(s),
162 Ok(Err(e)) => Err(ProxyError::UpstreamConnect {
163 host: upstream.host.to_string(),
164 reason: e.to_string(),
165 }),
166 Err(_) => Err(ProxyError::UpstreamConnect {
167 host: upstream.host.to_string(),
168 reason: "connection timed out".to_string(),
169 }),
170 }
171 } else {
172 connect_to_resolved(resolved_addrs, upstream.host).await
173 }
174 }
175 UpstreamStrategy::ExternalProxy {
176 proxy_addr,
177 proxy_auth_header,
178 } => crate::external::connect_via_proxy(
179 proxy_addr,
180 upstream.host,
181 upstream.port,
182 proxy_auth_header,
183 )
184 .await
185 .map_err(|e| match e {
186 ProxyError::ExternalProxy(reason) => ProxyError::UpstreamConnect {
187 host: upstream.host.to_string(),
188 reason,
189 },
190 other => other,
191 }),
192 }
193}
194
195async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
200 let mut last_err = None;
201 for addr in addrs {
202 match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
203 Ok(Ok(stream)) => return Ok(stream),
204 Ok(Err(e)) => {
205 debug!("Connect to {} failed: {}", addr, e);
206 last_err = Some(e.to_string());
207 }
208 Err(_) => {
209 debug!("Connect to {} timed out", addr);
210 last_err = Some("connection timed out".to_string());
211 }
212 }
213 }
214 Err(ProxyError::UpstreamConnect {
215 host: host.to_string(),
216 reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
217 })
218}
219
220async fn write_request<S>(stream: &mut S, request: &[u8], body: &[u8]) -> Result<()>
221where
222 S: AsyncWrite + Unpin,
223{
224 stream.write_all(request).await?;
225 if !body.is_empty() {
226 stream.write_all(body).await?;
227 }
228 stream.flush().await?;
229 Ok(())
230}
231
232async fn stream_response<U, I>(upstream: &mut U, inbound: &mut I) -> Result<u16>
238where
239 U: AsyncRead + AsyncWrite + Unpin,
240 I: AsyncWrite + Unpin,
241{
242 let mut buf = [0u8; 8192];
243 let mut status_code: u16 = 502;
244 let mut first_chunk = true;
245
246 loop {
247 let n = match upstream.read(&mut buf).await {
248 Ok(0) => break,
249 Ok(n) => n,
250 Err(e) => {
251 debug!("Upstream read error: {}", e);
252 break;
253 }
254 };
255
256 if first_chunk {
257 status_code = parse_response_status(&buf[..n]);
258 first_chunk = false;
259 }
260
261 inbound.write_all(&buf[..n]).await?;
262 inbound.flush().await?;
263 }
264
265 Ok(status_code)
266}
267
268fn parse_response_status(data: &[u8]) -> u16 {
272 let line_end = data
273 .iter()
274 .position(|&b| b == b'\r' || b == b'\n')
275 .unwrap_or(data.len());
276 let first_line = &data[..line_end.min(64)];
277
278 if let Ok(line) = std::str::from_utf8(first_line) {
279 let mut parts = line.split_whitespace();
280 if let Some(version) = parts.next() {
281 if version.starts_with("HTTP/") {
282 if let Some(code_str) = parts.next() {
283 if code_str.len() == 3 {
284 return code_str.parse().unwrap_or(502);
285 }
286 }
287 }
288 }
289 }
290 502
291}
292
293#[cfg(test)]
294#[allow(clippy::unwrap_used)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn parse_response_status_extracts_code() {
300 assert_eq!(parse_response_status(b"HTTP/1.1 200 OK\r\n"), 200);
301 assert_eq!(parse_response_status(b"HTTP/1.1 404 Not Found\r\n"), 404);
302 assert_eq!(parse_response_status(b"HTTP/1.1 502 Bad Gateway\r\n"), 502);
303 }
304
305 #[test]
306 fn parse_response_status_handles_garbage() {
307 assert_eq!(parse_response_status(b""), 502);
308 assert_eq!(parse_response_status(b"garbage"), 502);
309 assert_eq!(parse_response_status(b"NOT-HTTP 200 OK"), 502);
310 }
311}