1use std::net::SocketAddr;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
8use hyper::server::conn::http1;
9use hyper::service::service_fn;
10use hyper::{Method, Request, Response, StatusCode};
11use hyper_util::rt::TokioIo;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream, UnixStream};
14use tokio::sync::oneshot;
15
16use crate::error::SandboxError;
17use crate::proxy::filter::{DomainFilter, FilterDecision};
18
19pub struct HttpProxy {
21 listener: Option<TcpListener>,
22 port: u16,
23 filter: Arc<DomainFilter>,
24 mitm_socket_path: Option<String>,
25 shutdown_tx: Option<oneshot::Sender<()>>,
26}
27
28impl HttpProxy {
29 pub async fn new(
31 filter: DomainFilter,
32 mitm_socket_path: Option<String>,
33 ) -> Result<Self, SandboxError> {
34 let listener = TcpListener::bind("127.0.0.1:0").await?;
36 let port = listener.local_addr()?.port();
37
38 tracing::debug!("HTTP proxy listening on port {}", port);
39
40 Ok(Self {
41 listener: Some(listener),
42 port,
43 filter: Arc::new(filter),
44 mitm_socket_path,
45 shutdown_tx: None,
46 })
47 }
48
49 pub fn port(&self) -> u16 {
51 self.port
52 }
53
54 pub fn start(&mut self) -> Result<(), SandboxError> {
56 let listener = self
57 .listener
58 .take()
59 .ok_or_else(|| SandboxError::Proxy("Proxy already started".to_string()))?;
60
61 let filter = self.filter.clone();
62 let mitm_socket_path = self.mitm_socket_path.clone();
63 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
64 self.shutdown_tx = Some(shutdown_tx);
65
66 tokio::spawn(async move {
67 loop {
68 tokio::select! {
69 accept_result = listener.accept() => {
70 match accept_result {
71 Ok((stream, addr)) => {
72 let filter = filter.clone();
73 let mitm_socket = mitm_socket_path.clone();
74 tokio::spawn(async move {
75 if let Err(e) = handle_connection(stream, addr, filter, mitm_socket).await {
76 tracing::debug!("Connection error from {}: {}", addr, e);
77 }
78 });
79 }
80 Err(e) => {
81 tracing::error!("Accept error: {}", e);
82 }
83 }
84 }
85 _ = &mut shutdown_rx => {
86 tracing::debug!("HTTP proxy shutting down");
87 break;
88 }
89 }
90 }
91 });
92
93 Ok(())
94 }
95
96 pub fn stop(&mut self) {
98 if let Some(tx) = self.shutdown_tx.take() {
99 let _ = tx.send(());
100 }
101 }
102}
103
104async fn handle_connection(
106 stream: TcpStream,
107 _addr: SocketAddr,
108 filter: Arc<DomainFilter>,
109 mitm_socket_path: Option<String>,
110) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
111 let io = TokioIo::new(stream);
112
113 let filter_clone = filter.clone();
114 let mitm_socket_clone = mitm_socket_path.clone();
115
116 http1::Builder::new()
117 .preserve_header_case(true)
118 .title_case_headers(true)
119 .serve_connection(
120 io,
121 service_fn(move |req| {
122 let filter = filter_clone.clone();
123 let mitm_socket = mitm_socket_clone.clone();
124 async move { handle_request(req, filter, mitm_socket).await }
125 }),
126 )
127 .with_upgrades()
128 .await?;
129
130 Ok(())
131}
132
133async fn handle_request(
135 req: Request<hyper::body::Incoming>,
136 filter: Arc<DomainFilter>,
137 mitm_socket_path: Option<String>,
138) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
139 if req.method() == Method::CONNECT {
140 handle_connect(req, filter, mitm_socket_path).await
141 } else {
142 handle_http(req, filter, mitm_socket_path).await
143 }
144}
145
146async fn handle_connect(
148 req: Request<hyper::body::Incoming>,
149 filter: Arc<DomainFilter>,
150 mitm_socket_path: Option<String>,
151) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
152 let host = req.uri().host().unwrap_or_default().to_string();
153 let port = req.uri().port_u16().unwrap_or(443);
154
155 tracing::debug!("CONNECT {}:{}", host, port);
156
157 let decision = filter.check(&host, port);
159
160 match decision {
161 FilterDecision::Deny => {
162 tracing::debug!("Denied CONNECT to {}:{}", host, port);
163 return Ok(Response::builder()
164 .status(StatusCode::FORBIDDEN)
165 .body(empty_body())
166 .unwrap());
167 }
168 FilterDecision::Mitm => {
169 if let Some(socket_path) = mitm_socket_path {
171 return handle_connect_mitm(req, &socket_path, &host, port).await;
172 }
173 }
174 FilterDecision::Allow => {}
175 }
176
177 tokio::task::spawn(async move {
179 match hyper::upgrade::on(req).await {
180 Ok(upgraded) => {
181 if let Err(e) = tunnel(upgraded, &host, port).await {
182 tracing::debug!("Tunnel error: {}", e);
183 }
184 }
185 Err(e) => {
186 tracing::debug!("Upgrade error: {}", e);
187 }
188 }
189 });
190
191 Ok(Response::new(empty_body()))
192}
193
194async fn handle_connect_mitm(
196 req: Request<hyper::body::Incoming>,
197 socket_path: &str,
198 host: &str,
199 port: u16,
200) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
201 let socket_path = socket_path.to_string();
202 let host = host.to_string();
203
204 tokio::task::spawn(async move {
205 match hyper::upgrade::on(req).await {
206 Ok(upgraded) => {
207 if let Err(e) = tunnel_via_mitm(upgraded, &socket_path, &host, port).await {
208 tracing::debug!("MITM tunnel error: {}", e);
209 }
210 }
211 Err(e) => {
212 tracing::debug!("Upgrade error: {}", e);
213 }
214 }
215 });
216
217 Ok(Response::new(empty_body()))
218}
219
220async fn tunnel(
222 upgraded: hyper::upgrade::Upgraded,
223 host: &str,
224 port: u16,
225) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
226 let target = TcpStream::connect(format!("{}:{}", host, port)).await?;
227
228 let mut upgraded = TokioIo::new(upgraded);
229 let (mut target_read, mut target_write) = target.into_split();
230 let (mut client_read, mut client_write) = tokio::io::split(&mut upgraded);
231
232 let client_to_server = tokio::io::copy(&mut client_read, &mut target_write);
233 let server_to_client = tokio::io::copy(&mut target_read, &mut client_write);
234
235 tokio::try_join!(client_to_server, server_to_client)?;
236
237 Ok(())
238}
239
240async fn tunnel_via_mitm(
242 upgraded: hyper::upgrade::Upgraded,
243 socket_path: &str,
244 host: &str,
245 port: u16,
246) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
247 let mut mitm_stream = UnixStream::connect(socket_path).await?;
248
249 let connect_req = format!("CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n", host, port, host, port);
251 mitm_stream.write_all(connect_req.as_bytes()).await?;
252
253 let mut response_buf = [0u8; 1024];
255 let n = mitm_stream.read(&mut response_buf).await?;
256 let response = String::from_utf8_lossy(&response_buf[..n]);
257
258 if !response.contains("200") {
259 return Err(format!("MITM proxy returned: {}", response).into());
260 }
261
262 let mut upgraded = TokioIo::new(upgraded);
264 let (mut mitm_read, mut mitm_write) = mitm_stream.into_split();
265 let (mut client_read, mut client_write) = tokio::io::split(&mut upgraded);
266
267 let client_to_server = tokio::io::copy(&mut client_read, &mut mitm_write);
268 let server_to_client = tokio::io::copy(&mut mitm_read, &mut client_write);
269
270 tokio::try_join!(client_to_server, server_to_client)?;
271
272 Ok(())
273}
274
275async fn handle_http(
277 req: Request<hyper::body::Incoming>,
278 filter: Arc<DomainFilter>,
279 mitm_socket_path: Option<String>,
280) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
281 let host = req
282 .uri()
283 .host()
284 .or_else(|| {
285 req.headers()
286 .get("host")
287 .and_then(|h| h.to_str().ok())
288 .map(|h| h.split(':').next().unwrap_or(h))
289 })
290 .unwrap_or_default()
291 .to_string();
292
293 let port = req.uri().port_u16().unwrap_or(80);
294
295 tracing::debug!("HTTP {} {}:{}", req.method(), host, port);
296
297 let decision = filter.check(&host, port);
299
300 if matches!(decision, FilterDecision::Deny) {
301 tracing::debug!("Denied HTTP to {}:{}", host, port);
302 return Ok(Response::builder()
303 .status(StatusCode::FORBIDDEN)
304 .body(full_body("Access denied by sandbox policy"))
305 .unwrap());
306 }
307
308 if matches!(decision, FilterDecision::Mitm) {
310 if let Some(socket_path) = mitm_socket_path {
311 return forward_http_via_mitm(req, &socket_path).await;
312 }
313 }
314
315 forward_http(req).await
317}
318
319async fn forward_http(
321 req: Request<hyper::body::Incoming>,
322) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
323 let host = req
324 .uri()
325 .host()
326 .unwrap_or_default()
327 .to_string();
328 let port = req.uri().port_u16().unwrap_or(80);
329
330 let stream = match TcpStream::connect(format!("{}:{}", host, port)).await {
332 Ok(s) => s,
333 Err(e) => {
334 tracing::debug!("Failed to connect to {}:{}: {}", host, port, e);
335 return Ok(Response::builder()
336 .status(StatusCode::BAD_GATEWAY)
337 .body(full_body("Failed to connect to target"))
338 .unwrap());
339 }
340 };
341
342 let io = TokioIo::new(stream);
343
344 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
345 Ok(c) => c,
346 Err(e) => {
347 tracing::debug!("Handshake error: {}", e);
348 return Ok(Response::builder()
349 .status(StatusCode::BAD_GATEWAY)
350 .body(full_body("Handshake failed"))
351 .unwrap());
352 }
353 };
354
355 tokio::spawn(async move {
356 if let Err(e) = conn.await {
357 tracing::debug!("Connection error: {}", e);
358 }
359 });
360
361 match sender.send_request(req).await {
362 Ok(resp) => Ok(resp.map(|b| b.boxed())),
363 Err(e) => {
364 tracing::debug!("Request error: {}", e);
365 Ok(Response::builder()
366 .status(StatusCode::BAD_GATEWAY)
367 .body(full_body("Request failed"))
368 .unwrap())
369 }
370 }
371}
372
373async fn forward_http_via_mitm(
375 _req: Request<hyper::body::Incoming>,
376 _socket_path: &str,
377) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
378 Ok(Response::builder()
380 .status(StatusCode::NOT_IMPLEMENTED)
381 .body(full_body("MITM HTTP forwarding not implemented"))
382 .unwrap())
383}
384
385fn empty_body() -> BoxBody<Bytes, hyper::Error> {
386 Empty::<Bytes>::new()
387 .map_err(|never| match never {})
388 .boxed()
389}
390
391fn full_body(s: &str) -> BoxBody<Bytes, hyper::Error> {
392 Full::new(Bytes::from(s.to_string()))
393 .map_err(|never| match never {})
394 .boxed()
395}