1use crate::error::ProxyError;
2use crate::router::UpstreamConfig;
3use crate::router::{ProxyRouter, RouteDecision};
4use crate::upstream::connect_upstream;
5use http::Uri;
6use log::{debug, warn};
7use std::net::SocketAddr;
8use std::sync::{Arc, RwLock};
9#[cfg(feature = "capture")]
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream};
13
14#[cfg(feature = "capture")]
15use {
16 crate::capture::{CapturedSession, handle_http_sessions},
17 crate::mitm::CaConfig,
18 tokio::sync::broadcast,
19};
20
21#[derive(Clone)]
30pub struct LocalProxy {
31 inner: Arc<Inner>,
32}
33
34struct Inner {
35 router: RwLock<Arc<dyn ProxyRouter>>,
36 local_addr: SocketAddr,
37 listener: TcpListener,
38
39 #[cfg(feature = "capture")]
40 ca: RwLock<Option<Arc<CaConfig>>>,
41 #[cfg(feature = "capture")]
42 session_tx: broadcast::Sender<Arc<CapturedSession>>,
43}
44
45impl LocalProxy {
46 pub async fn bind(
48 addr: &str,
49 initial_router: Arc<dyn ProxyRouter>,
50 ) -> Result<Self, ProxyError> {
51 let listener = TcpListener::bind(addr).await?;
52 let local_addr = listener.local_addr()?;
53
54 #[cfg(feature = "capture")]
55 let (session_tx, _) = broadcast::channel(512);
56
57 Ok(Self {
58 inner: Arc::new(Inner {
59 router: RwLock::new(initial_router),
60 local_addr,
61 listener,
62 #[cfg(feature = "capture")]
63 ca: RwLock::new(None),
64 #[cfg(feature = "capture")]
65 session_tx,
66 }),
67 })
68 }
69
70 pub fn local_addr(&self) -> SocketAddr {
71 self.inner.local_addr
72 }
73
74 pub fn set_router(&self, router: Arc<dyn ProxyRouter>) {
76 *self.inner.router.write().unwrap() = router;
77 }
78
79 #[cfg(feature = "capture")]
87 pub fn set_capture_ca(&self, ca: CaConfig) {
88 *self.inner.ca.write().unwrap() = Some(Arc::new(ca));
89 }
90
91 #[cfg(feature = "capture")]
104 pub fn session_receiver(&self) -> broadcast::Receiver<Arc<CapturedSession>> {
105 self.inner.session_tx.subscribe()
106 }
107
108 pub async fn run(&self) {
110 loop {
111 match self.inner.listener.accept().await {
112 Ok((stream, peer)) => {
113 debug!("proxy: accepted {peer}");
114 let proxy = self.clone();
115 tokio::spawn(async move {
116 if let Err(e) = proxy.handle_connection(stream).await {
117 debug!("proxy: {peer} — {e}");
118 }
119 });
120 }
121 Err(e) => {
122 warn!("proxy: accept error: {e}");
123 break;
124 }
125 }
126 }
127 }
128
129 async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), ProxyError> {
132 match read_proxy_request(&mut stream).await? {
133 ProxyRequest::Connect { host, port } => {
134 let upstream_cfg = self.route_upstream(&host, port)?;
135
136 #[cfg(feature = "capture")]
137 {
138 return self
139 .handle_connect_with_capture(stream, host, port, upstream_cfg)
140 .await;
141 }
142
143 #[cfg(not(feature = "capture"))]
144 {
145 let mut up = connect_upstream(&upstream_cfg, &host, port).await?;
146 stream
147 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
148 .await?;
149 tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
150 Ok(())
151 }
152 }
153 ProxyRequest::ForwardHttp {
154 host,
155 port,
156 initial_bytes,
157 } => {
158 let upstream_cfg = self.route_upstream(&host, port)?;
159
160 #[cfg(feature = "capture")]
161 {
162 return self
163 .handle_forward_http_with_capture(
164 stream,
165 initial_bytes,
166 host,
167 port,
168 upstream_cfg,
169 )
170 .await;
171 }
172
173 #[cfg(not(feature = "capture"))]
174 {
175 let mut up = connect_upstream(&upstream_cfg, &host, port).await?;
176 up.write_all(&initial_bytes).await?;
177 tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
178 Ok(())
179 }
180 }
181 }
182 }
183
184 fn route_upstream(&self, host: &str, port: u16) -> Result<UpstreamConfig, ProxyError> {
185 let router = self.inner.router.read().unwrap();
186 match router.route(host, port)? {
187 RouteDecision::Upstream(cfg) => Ok(cfg),
188 RouteDecision::Block => Err(ProxyError::UpstreamConnect(format!(
189 "{host}:{port} blocked by policy"
190 ))),
191 }
192 }
193
194 #[cfg(feature = "capture")]
195 async fn handle_connect_with_capture(
196 &self,
197 mut stream: TcpStream,
198 host: String,
199 port: u16,
200 upstream_cfg: UpstreamConfig,
201 ) -> Result<(), ProxyError> {
202 let tx = self.inner.session_tx.clone();
203 let has_consumer = tx.receiver_count() > 0;
204 let ca = self.inner.ca.read().unwrap().clone(); if port == 443 && has_consumer && ca.is_some() {
208 let up_stream = connect_upstream(&upstream_cfg, &host, port).await?;
209 stream
210 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
211 .await?;
212
213 let (client_tls, server_tls) =
214 crate::mitm::intercept(stream, &host, up_stream, ca.as_deref().unwrap()).await?;
215
216 handle_http_sessions(host, port, true, client_tls, server_tls, tx)
217 .await
218 .map_err(ProxyError::Io)?;
219
220 return Ok(());
221 }
222
223 if port != 443 && has_consumer {
225 let up_stream = connect_upstream(&upstream_cfg, &host, port).await?;
226 stream
227 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
228 .await?;
229
230 handle_http_sessions(host, port, false, stream, up_stream, tx)
231 .await
232 .map_err(ProxyError::Io)?;
233
234 return Ok(());
235 }
236
237 let mut up = connect_upstream(&upstream_cfg, &host, port).await?;
239 stream
240 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
241 .await?;
242 tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
243 Ok(())
244 }
245
246 #[cfg(feature = "capture")]
247 async fn handle_forward_http_with_capture(
248 &self,
249 stream: TcpStream,
250 initial_bytes: Vec<u8>,
251 host: String,
252 port: u16,
253 upstream_cfg: UpstreamConfig,
254 ) -> Result<(), ProxyError> {
255 let tx = self.inner.session_tx.clone();
256 let has_consumer = tx.receiver_count() > 0;
257
258 let up_stream = connect_upstream(&upstream_cfg, &host, port).await?;
259 if has_consumer {
260 let client = PrefixedIo::new(stream, initial_bytes);
261 handle_http_sessions(host, port, false, client, up_stream, tx)
262 .await
263 .map_err(ProxyError::Io)?;
264 return Ok(());
265 }
266
267 let mut stream = stream;
268 let mut up = up_stream;
269 up.write_all(&initial_bytes).await?;
270 tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
271 Ok(())
272 }
273}
274
275enum ProxyRequest {
276 Connect {
277 host: String,
278 port: u16,
279 },
280 ForwardHttp {
281 host: String,
282 port: u16,
283 initial_bytes: Vec<u8>,
284 },
285}
286
287async fn read_proxy_request(stream: &mut TcpStream) -> Result<ProxyRequest, ProxyError> {
290 let mut buf = Vec::with_capacity(512);
291 let mut tmp = [0u8; 1];
292
293 loop {
294 stream.read_exact(&mut tmp).await?;
295 buf.push(tmp[0]);
296 if buf.ends_with(b"\r\n\r\n") {
297 break;
298 }
299 if buf.len() > 8192 {
300 return Err(ProxyError::BadRequest("CONNECT headers too large".into()));
301 }
302 }
303
304 let text =
305 std::str::from_utf8(&buf).map_err(|_| ProxyError::BadRequest("Non-UTF8 CONNECT".into()))?;
306
307 let first_line = text
308 .lines()
309 .next()
310 .ok_or_else(|| ProxyError::BadRequest("Empty request".into()))?;
311
312 let mut parts = first_line.split_whitespace();
313 let method = parts
314 .next()
315 .ok_or_else(|| ProxyError::BadRequest("No method".into()))?;
316 let target = parts
317 .next()
318 .ok_or_else(|| ProxyError::BadRequest("No request target".into()))?;
319 let version = parts
320 .next()
321 .ok_or_else(|| ProxyError::BadRequest("No HTTP version".into()))?;
322
323 if method.eq_ignore_ascii_case("CONNECT") {
324 let (host, port) = parse_authority_host_port(target, None)?;
325 return Ok(ProxyRequest::Connect { host, port });
326 }
327
328 let (host, port, upstream_target) = parse_forward_target(target, text)?;
329 let first_line_end = buf
330 .windows(2)
331 .position(|window| window == b"\r\n")
332 .ok_or_else(|| ProxyError::BadRequest("Malformed request line".into()))?;
333 let mut initial_bytes = format!("{method} {upstream_target} {version}\r\n").into_bytes();
334 initial_bytes.extend_from_slice(&buf[first_line_end + 2..]);
335 Ok(ProxyRequest::ForwardHttp {
336 host,
337 port,
338 initial_bytes,
339 })
340}
341
342fn parse_forward_target(
343 target: &str,
344 request_text: &str,
345) -> Result<(String, u16, String), ProxyError> {
346 if target.starts_with("http://") || target.starts_with("https://") {
347 let uri: Uri = target
348 .parse()
349 .map_err(|e| ProxyError::BadRequest(format!("Bad absolute-form URI: {e}")))?;
350 let scheme = uri
351 .scheme_str()
352 .ok_or_else(|| ProxyError::BadRequest("Absolute-form URI missing scheme".into()))?;
353 if scheme.eq_ignore_ascii_case("https") {
354 return Err(ProxyError::BadRequest(
355 "HTTPS absolute-form requests must use CONNECT".into(),
356 ));
357 }
358 let host = uri
359 .host()
360 .ok_or_else(|| ProxyError::BadRequest("Absolute-form URI missing host".into()))?
361 .to_string();
362 let port = uri.port_u16().unwrap_or(80);
363 let path = uri
364 .path_and_query()
365 .map(|pq| pq.as_str().to_string())
366 .unwrap_or_else(|| "/".to_string());
367 return Ok((host, port, path));
368 }
369
370 let authority = request_text
371 .lines()
372 .skip(1)
373 .find_map(|line| {
374 let (name, value) = line.split_once(':')?;
375 name.trim()
376 .eq_ignore_ascii_case("host")
377 .then(|| value.trim().to_string())
378 })
379 .ok_or_else(|| ProxyError::BadRequest("HTTP proxy request missing Host header".into()))?;
380 let (host, port) = parse_authority_host_port(&authority, Some(80))?;
381 Ok((host, port, target.to_string()))
382}
383
384fn parse_authority_host_port(
385 authority: &str,
386 default_port: Option<u16>,
387) -> Result<(String, u16), ProxyError> {
388 let authority = authority.trim();
389 if authority.is_empty() {
390 return Err(ProxyError::BadRequest("Empty authority".into()));
391 }
392 if let Some(host) = authority.strip_prefix('[')
393 && let Some((host, rest)) = host.split_once(']')
394 {
395 let port = if let Some(rest) = rest.strip_prefix(':') {
396 rest.parse()
397 .map_err(|_| ProxyError::BadRequest(format!("Bad port in '{authority}'")))?
398 } else {
399 default_port
400 .ok_or_else(|| ProxyError::BadRequest(format!("No port in '{authority}'")))?
401 };
402 return Ok((host.to_string(), port));
403 }
404
405 if let Some((host, port)) = authority.rsplit_once(':')
406 && !host.is_empty()
407 && let Ok(port) = port.parse()
408 {
409 return Ok((host.to_string(), port));
410 }
411
412 Ok((
413 authority.to_string(),
414 default_port.ok_or_else(|| ProxyError::BadRequest(format!("No port in '{authority}'")))?,
415 ))
416}
417
418#[cfg(feature = "capture")]
419struct PrefixedIo<T> {
420 prefix: Vec<u8>,
421 prefix_pos: usize,
422 inner: T,
423}
424
425#[cfg(feature = "capture")]
426impl<T> PrefixedIo<T> {
427 fn new(inner: T, prefix: Vec<u8>) -> Self {
428 Self {
429 prefix,
430 prefix_pos: 0,
431 inner,
432 }
433 }
434}
435
436#[cfg(feature = "capture")]
437impl<T> AsyncRead for PrefixedIo<T>
438where
439 T: AsyncRead + Unpin,
440{
441 fn poll_read(
442 mut self: std::pin::Pin<&mut Self>,
443 cx: &mut std::task::Context<'_>,
444 buf: &mut ReadBuf<'_>,
445 ) -> std::task::Poll<std::io::Result<()>> {
446 if self.prefix_pos < self.prefix.len() {
447 let remaining = &self.prefix[self.prefix_pos..];
448 let to_copy = remaining.len().min(buf.remaining());
449 buf.put_slice(&remaining[..to_copy]);
450 self.prefix_pos += to_copy;
451 return std::task::Poll::Ready(Ok(()));
452 }
453 std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
454 }
455}
456
457#[cfg(feature = "capture")]
458impl<T> AsyncWrite for PrefixedIo<T>
459where
460 T: AsyncWrite + Unpin,
461{
462 fn poll_write(
463 mut self: std::pin::Pin<&mut Self>,
464 cx: &mut std::task::Context<'_>,
465 buf: &[u8],
466 ) -> std::task::Poll<std::io::Result<usize>> {
467 std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
468 }
469
470 fn poll_flush(
471 mut self: std::pin::Pin<&mut Self>,
472 cx: &mut std::task::Context<'_>,
473 ) -> std::task::Poll<std::io::Result<()>> {
474 std::pin::Pin::new(&mut self.inner).poll_flush(cx)
475 }
476
477 fn poll_shutdown(
478 mut self: std::pin::Pin<&mut Self>,
479 cx: &mut std::task::Context<'_>,
480 ) -> std::task::Poll<std::io::Result<()>> {
481 std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
482 }
483}