1use crate::ca::CertificateManager;
4use crate::error::{Error, Result};
5use crate::interceptor::{InterceptorHandler, MitmRequest, MitmResponse};
6use crate::proxy::MitmConfig;
7use bytes::Bytes;
8use http::Method;
9use slinger::{Client, ClientBuilder, Request};
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
14use tokio::net::{TcpListener, TcpStream};
15use tokio::sync::RwLock;
16use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
17use tokio_rustls::rustls::ServerConfig;
18use tokio_rustls::TlsAcceptor;
19
20pub struct ProxyServer {
22 config: MitmConfig,
23 cert_manager: Arc<CertificateManager>,
24 interceptor_handler: Arc<RwLock<InterceptorHandler>>,
25 client: Client,
26}
27
28struct TunnelRouteContext {
29 peer_addr: SocketAddr,
30 cert_manager: Arc<CertificateManager>,
31 interceptor: Arc<RwLock<InterceptorHandler>>,
32 client: Client,
33 upstream_proxy: Option<slinger::Proxy>,
34 protocol_tag: &'static str,
35}
36
37struct ConnectionContext {
38 peer_addr: SocketAddr,
39 cert_manager: Arc<CertificateManager>,
40 interceptor: Arc<RwLock<InterceptorHandler>>,
41 client: Client,
42 upstream_proxy: Option<slinger::Proxy>,
43}
44
45impl ConnectionContext {
46 fn into_tunnel(self, protocol_tag: &'static str) -> TunnelRouteContext {
47 TunnelRouteContext {
48 peer_addr: self.peer_addr,
49 cert_manager: self.cert_manager,
50 interceptor: self.interceptor,
51 client: self.client,
52 upstream_proxy: self.upstream_proxy,
53 protocol_tag,
54 }
55 }
56}
57
58#[derive(Default)]
62pub struct ProxyServerBuilder {
63 config: Option<MitmConfig>,
64 cert_manager: Option<Arc<CertificateManager>>,
65 interceptor_handler: Option<Arc<RwLock<InterceptorHandler>>>,
66 client: Option<Client>,
67 client_config: Option<Box<dyn Fn(ClientBuilder) -> ClientBuilder + Send + Sync>>,
69}
70
71impl ProxyServerBuilder {
72 pub fn from_server(server: &ProxyServer) -> Self {
74 Self {
75 config: Some(server.config.clone()),
76 cert_manager: Some(server.cert_manager.clone()),
77 interceptor_handler: Some(server.interceptor_handler.clone()),
78 client: Some(server.client.clone()),
79 client_config: None,
80 }
81 }
82
83 pub fn config(mut self, config: MitmConfig) -> Self {
85 self.config = Some(config);
86 self
87 }
88
89 pub fn cert_manager(mut self, cert_manager: Arc<CertificateManager>) -> Self {
91 self.cert_manager = Some(cert_manager);
92 self
93 }
94
95 pub fn interceptor_handler(mut self, handler: Arc<RwLock<InterceptorHandler>>) -> Self {
97 self.interceptor_handler = Some(handler);
98 self
99 }
100
101 pub fn client(mut self, client: Client) -> Self {
103 self.client = Some(client);
104 self
105 }
106
107 pub fn configure_client<F>(mut self, f: F) -> Self
110 where
111 F: Fn(ClientBuilder) -> ClientBuilder + Send + Sync + 'static,
112 {
113 self.client_config = Some(Box::new(f));
114 self
115 }
116
117 pub fn build(self) -> Result<ProxyServer> {
125 let config = self.config.unwrap_or_default();
127
128 let cert_manager = match self.cert_manager {
131 Some(c) => c,
132 None => {
133 return Err(Error::proxy_error(
134 "CertificateManager not provided; use ProxyServer::builder().build_async().await to create one automatically".to_string(),
135 ))
136 }
137 };
138
139 let interceptor_handler = self.interceptor_handler.unwrap_or_else(|| {
141 Arc::new(RwLock::new(
142 InterceptorHandler::new().with_timeout(config.interceptor_timeout_secs),
143 ))
144 });
145
146 let client = if let Some(client) = self.client {
148 client
149 } else if let Some(cfg_fn) = self.client_config {
150 let builder = Client::builder();
151 let configured = cfg_fn(builder);
152 configured
153 .build()
154 .map_err(|e| Error::proxy_error(format!("Failed to build client: {}", e)))?
155 } else {
156 if let Some(proxy) = &config.upstream_proxy {
158 Client::builder()
159 .timeout(Some(Duration::from_secs(60)))
160 .keepalive(true)
161 .proxy(proxy.clone())
162 .build()
163 .map_err(|e| {
164 Error::proxy_error(format!(
165 "Failed to build client with proxy {}: {}",
166 proxy.uri(),
167 e
168 ))
169 })?
170 } else {
171 Client::builder()
172 .keepalive(true)
173 .build()
174 .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
175 }
176 };
177
178 Ok(ProxyServer {
179 config,
180 cert_manager,
181 interceptor_handler,
182 client,
183 })
184 }
185}
186
187impl ProxyServer {
188 pub fn new(
190 config: MitmConfig,
191 cert_manager: Arc<CertificateManager>,
192 interceptor_handler: Arc<RwLock<InterceptorHandler>>,
193 ) -> Result<Self> {
194 let client = if let Some(proxy) = &config.upstream_proxy {
195 Client::builder()
197 .timeout(Some(Duration::from_secs(60)))
198 .keepalive(true)
199 .proxy(proxy.clone())
200 .build()
201 .map_err(|e| {
202 Error::proxy_error(format!(
203 "Failed to build client with proxy {}: {}",
204 proxy.uri(),
205 e
206 ))
207 })?
208 } else {
209 Client::builder()
211 .keepalive(true)
212 .build()
213 .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
214 };
215 Ok(Self {
216 config,
217 cert_manager,
218 interceptor_handler,
219 client,
220 })
221 }
222
223 pub async fn run(&self, addr: &str) -> Result<()> {
225 let listener = TcpListener::bind(addr)
226 .await
227 .map_err(|e| Error::proxy_error(format!("Failed to bind to {}: {}", addr, e)))?;
228 loop {
229 match listener.accept().await {
230 Ok((stream, peer_addr)) => {
231 let cert_manager = self.cert_manager.clone();
232 let interceptor = self.interceptor_handler.clone();
233 let client = self.client.clone();
234 let upstream_proxy = self.config.upstream_proxy.clone();
235
236 tokio::spawn(async move {
237 if let Err(e) = Self::handle_connection(
238 stream,
239 peer_addr,
240 cert_manager,
241 interceptor,
242 client,
243 upstream_proxy,
244 )
245 .await
246 {
247 tracing::error!("[MITM] Error handling connection: {}", e);
248 }
249 });
250 }
251 Err(e) => {
252 tracing::error!("[MITM] Failed to accept connection: {}", e);
253 }
254 }
255 }
256 }
257
258 async fn handle_connection(
260 mut stream: TcpStream,
261 peer_addr: SocketAddr,
262 cert_manager: Arc<CertificateManager>,
263 interceptor: Arc<RwLock<InterceptorHandler>>,
264 client: Client,
265 upstream_proxy: Option<slinger::Proxy>,
266 ) -> Result<()> {
267 let mut first_byte = [0u8; 1];
269 stream.peek(&mut first_byte).await?;
270
271 let ctx = ConnectionContext {
272 peer_addr,
273 cert_manager,
274 interceptor,
275 client,
276 upstream_proxy,
277 };
278
279 if first_byte[0] == 0x05 {
281 stream.read_exact(&mut first_byte).await?;
283 return Self::handle_socks5_connection(stream, ctx).await;
284 }
285
286 Self::handle_http_connection(stream, ctx).await
287 }
288
289 async fn handle_socks5_connection(mut stream: TcpStream, ctx: ConnectionContext) -> Result<()> {
290 use crate::socks5::Socks5Server;
291
292 let target_addr = Socks5Server::handle_handshake_with_version(&mut stream).await?;
294 let target_host_port = target_addr.to_host_port();
295 Self::handle_tunnel_route(stream, &target_host_port, ctx.into_tunnel("SOCKS5")).await
296 }
297
298 async fn handle_http_connection(stream: TcpStream, ctx: ConnectionContext) -> Result<()> {
299 let mut reader = BufReader::new(stream);
300 let request = Request::from_http_reader(&mut reader).await?;
301 if request.method() == Method::CONNECT {
302 let uri = request.uri().to_string();
303 let mut stream = reader.into_inner();
304
305 stream
308 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
309 .await?;
310 stream.flush().await.map_err(Error::Io)?;
311
312 return Self::handle_tunnel_route(stream, &uri, ctx.into_tunnel("CONNECT")).await;
313 }
314
315 Self::handle_http_request(request, reader, ctx.interceptor, ctx.client).await
316 }
317
318 async fn handle_tunnel_route(
320 stream: TcpStream,
321 target_addr: &str,
322 ctx: TunnelRouteContext,
323 ) -> Result<()> {
324 if Self::peek_tls_client_hello(&stream, ctx.protocol_tag).await {
325 let (domain, port) = Self::parse_host_port(target_addr)?;
326 return Self::accept_tls_and_handle(
327 stream,
328 &domain,
329 port,
330 false,
331 ctx.cert_manager,
332 ctx.interceptor,
333 ctx.client,
334 )
335 .await;
336 }
337
338 let has_interceptors = ctx.interceptor.read().await.has_interceptors();
339 let socket = slinger::Socket::new(slinger::StreamWrapper::Tcp(stream), None, None);
340
341 if has_interceptors {
342 Self::handle_tcp_tunnel_with_interception(
343 socket,
344 target_addr,
345 ctx.peer_addr,
346 ctx.interceptor,
347 ctx.upstream_proxy,
348 )
349 .await
350 } else {
351 Self::tcp_tunnel(socket, target_addr, false, ctx.upstream_proxy).await
352 }
353 }
354
355 async fn peek_tls_client_hello(stream: &TcpStream, protocol_tag: &str) -> bool {
356 let mut peek_buf = [0u8; 5];
358 let peeked =
359 match tokio::time::timeout(Duration::from_millis(100), stream.peek(&mut peek_buf)).await {
360 Ok(Ok(n)) => n,
361 Ok(Err(e)) => {
362 tracing::debug!(
363 "[MITM {}] Peek failed, defaulting to TCP tunnel: {}",
364 protocol_tag,
365 e
366 );
367 0
368 }
369 Err(_) => {
370 tracing::debug!(
371 "[MITM {}] Peek timed out, defaulting to TCP tunnel",
372 protocol_tag
373 );
374 0
375 }
376 };
377
378 Self::is_tls_client_hello(&peek_buf[..peeked])
379 }
380
381 fn is_tls_client_hello(bytes: &[u8]) -> bool {
391 bytes.len() >= 2 && bytes[0] == 0x16 && bytes[1] == 0x03
392 }
393
394 async fn accept_tls_and_handle(
397 mut client_stream: TcpStream,
398 domain: &str,
399 port: u16,
400 send_response: bool,
401 cert_manager: Arc<CertificateManager>,
402 interceptor: Arc<RwLock<InterceptorHandler>>,
403 slinger_client: Client,
404 ) -> Result<()> {
405 if send_response {
406 client_stream
407 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
408 .await?;
409 client_stream
410 .flush() .await
412 .map_err(Error::Io)?;
413 }
414
415 let (cert_chain, key) = cert_manager.get_server_cert(domain).await?;
417 let tls_config = Self::create_tls_server_config(cert_chain, key)?;
419 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
420 let tls_stream = acceptor
422 .accept(client_stream)
423 .await
424 .map_err(|e| Error::tls_error(format!("TLS handshake failed: {}", e)))?;
425 let domain_with_port = format!("{}:{}", domain, port);
426 Self::handle_https_stream(tls_stream, domain_with_port, interceptor, slinger_client).await
427 }
428
429 async fn tcp_tunnel(
431 mut client_stream: slinger::Socket,
432 uri: &str,
433 send_response: bool,
434 upstream_proxy: Option<slinger::Proxy>,
435 ) -> Result<()> {
436 let target_socket = Self::connect_to_target(uri, upstream_proxy.as_ref()).await?;
438
439 if send_response {
440 client_stream
441 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
442 .await?;
443 }
444
445 let (mut client_read, mut client_write) = tokio::io::split(client_stream);
446 let (mut target_read, mut target_write) = tokio::io::split(target_socket);
447
448 let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
449 let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
450
451 tokio::select! {
452 _ = client_to_target => {},
453 _ = target_to_client => {},
454 }
455
456 Ok(())
457 }
458
459 async fn handle_tcp_tunnel_with_interception(
463 client_stream: slinger::Socket,
464 target_addr: &str,
465 peer_addr: SocketAddr,
466 interceptor: Arc<RwLock<InterceptorHandler>>,
467 upstream_proxy: Option<slinger::Proxy>,
468 ) -> Result<()> {
469 use uuid::Uuid;
470
471 let connection_session_id = Uuid::new_v4().as_u128();
474
475 let target_socket = Self::connect_to_target(target_addr, upstream_proxy.as_ref()).await?;
477
478 let (mut client_read, mut client_write) = tokio::io::split(client_stream);
479 let (mut target_read, mut target_write) = tokio::io::split(target_socket);
480
481 let target_addr_clone = target_addr.to_string();
482 let target_addr_clone2 = target_addr.to_string();
483 let interceptor_clone = interceptor.clone();
484
485 let client_to_target = tokio::spawn(async move {
487 let mut buffer = vec![0u8; 8192];
488 loop {
489 match client_read.read(&mut buffer).await {
490 Ok(0) => break, Ok(n) => {
492 let data = Bytes::copy_from_slice(&buffer[..n]);
493 let mut request = MitmRequest::raw_tcp_with_source(peer_addr, &target_addr_clone, data);
494 request.set_session_id(connection_session_id);
496
497 let handler = interceptor_clone.read().await;
499 match handler.process_request(request).await {
500 Ok(Some(modified_request)) => {
501 if let Some(body) = modified_request.body() {
503 if let Err(e) = target_write.write_all(body.as_ref()).await {
504 tracing::error!("[MITM TCP] Error writing to target: {}", e);
505 break;
506 }
507 }
508 }
509 Ok(None) => {
510 tracing::debug!("[MITM TCP] Request blocked by interceptor");
512 }
513 Err(e) => {
514 tracing::error!("[MITM TCP] Error processing request: {}", e);
515 break;
516 }
517 }
518 }
519 Err(e) => {
520 tracing::error!("[MITM TCP] Error reading from client: {}", e);
521 break;
522 }
523 }
524 }
525 });
526
527 let target_to_client = tokio::spawn(async move {
529 let mut buffer = vec![0u8; 8192];
530 loop {
531 match target_read.read(&mut buffer).await {
532 Ok(0) => break, Ok(n) => {
534 let data = Bytes::copy_from_slice(&buffer[..n]);
535 let response = MitmResponse::raw_tcp_with_destination(
537 connection_session_id,
538 &target_addr_clone2,
539 peer_addr,
540 data,
541 );
542
543 let handler = interceptor.read().await;
545 match handler.process_response(response).await {
546 Ok(Some(modified_response)) => {
547 if let Some(body) = modified_response.body() {
549 if let Err(e) = client_write.write_all(body.as_ref()).await {
550 tracing::error!("[MITM TCP] Error writing to client: {}", e);
551 break;
552 }
553 }
554 }
555 Ok(None) => {
556 tracing::debug!("[MITM TCP] Response blocked by interceptor");
558 }
559 Err(e) => {
560 tracing::error!("[MITM TCP] Error processing response: {}", e);
561 break;
562 }
563 }
564 }
565 Err(e) => {
566 tracing::error!("[MITM TCP] Error reading from target: {}", e);
567 break;
568 }
569 }
570 }
571 });
572
573 tokio::select! {
574 _ = client_to_target => {},
575 _ = target_to_client => {},
576 }
577
578 Ok(())
579 }
580
581 async fn connect_to_target(
586 target_addr: &str,
587 upstream_proxy: Option<&slinger::Proxy>,
588 ) -> Result<slinger::Socket> {
589 let uri = format!("http://{}", target_addr)
593 .parse::<http::Uri>()
594 .map_err(|e| {
595 Error::connection_error(format!("Invalid target address '{}': {}", target_addr, e))
596 })?;
597
598 let connector = slinger::ConnectorBuilder::default()
599 .proxy(upstream_proxy.cloned())
600 .build()
601 .map_err(|e| Error::connection_error(format!("Failed to build connector: {}", e)))?;
602
603 connector.connect_with_uri(&uri).await.map_err(Into::into)
604 }
605
606 async fn forward_request_via_client(
609 interceptor: Arc<RwLock<InterceptorHandler>>,
610 client: &Client,
611 request: Request,
612 destination: &str,
613 ) -> Result<Option<Vec<u8>>> {
614 let handler = interceptor.read().await;
615 let mitm_request = MitmRequest::new(destination, request);
616 let session_id = mitm_request.session_id();
618 if let Some(modified_req) = handler.process_request(mitm_request).await? {
619 let inner_req = modified_req.request();
620 let uri = inner_req.uri().clone();
621 let method = inner_req.method().clone();
622 let headers = inner_req.headers().clone();
623 let body_data = if let Some(body) = inner_req.body() {
624 body.to_vec()
625 } else {
626 Vec::new()
627 };
628 let mut req_builder = client.request(method, uri);
629 for (name, value) in headers.iter() {
630 req_builder = req_builder.header(name, value);
631 }
632 req_builder = req_builder.body(body_data);
633 match req_builder.send().await {
634 Ok(response) => {
635 let mitm_response = MitmResponse::new(session_id, destination, response);
637 if let Some(final_response) = handler.process_response(mitm_response).await? {
638 let response_bytes = Bytes::from(final_response.response()).to_vec();
639 return Ok(Some(response_bytes));
640 }
641 }
642 Err(_e) => {
643 return Ok(Some(b"HTTP/1.1 502 Bad Gateway\r\n\r\n".to_vec()));
644 }
645 }
646 }
647 Ok(None)
648 }
649
650 async fn handle_https_stream<S>(
652 tls_stream: S,
653 domain: String,
654 interceptor: Arc<RwLock<InterceptorHandler>>,
655 client: Client,
656 ) -> Result<()>
657 where
658 S: AsyncReadExt + AsyncWriteExt + Unpin,
659 {
660 let mut reader = BufReader::new(tls_stream);
662 let request_result = Request::from_http_reader(&mut reader).await;
663 let mut tls_stream = reader.into_inner();
664
665 let mut request = match request_result {
666 Ok(req) => req,
667 Err(e) => {
668 tracing::debug!("[MITM HTTPS] Failed to parse request: {}", e);
669 return Ok(());
670 }
671 };
672
673 if request.uri().host().is_none() {
675 let pq = request
676 .uri()
677 .path_and_query()
678 .map(|pq| pq.as_str())
679 .unwrap_or("/");
680 let absolute_uri = format!("https://{}{}", domain, pq)
681 .parse::<http::Uri>()
682 .map_err(|e| Error::invalid_request(format!("Invalid URI: {}", e)))?;
683 *request.uri_mut() = absolute_uri;
684 }
685
686 if let Some(response_bytes) =
687 Self::forward_request_via_client(interceptor, &client, request, &domain).await?
688 {
689 tls_stream.write_all(&response_bytes).await?;
690 }
691
692 Ok(())
693 }
694
695 async fn handle_http_request<R>(
697 request: Request,
698 reader: BufReader<R>,
699 interceptor: Arc<RwLock<InterceptorHandler>>,
700 client: Client,
701 ) -> Result<()>
702 where
703 R: AsyncReadExt + AsyncWriteExt + Unpin,
704 {
705 let uri = request.uri().to_string();
706
707 if let Some(response_bytes) =
709 Self::forward_request_via_client(interceptor, &client, request, &uri).await?
710 {
711 let mut stream = reader.into_inner();
712 stream.write_all(&response_bytes).await?;
713 }
714
715 Ok(())
716 }
717
718 fn create_tls_server_config(
720 cert_chain: Vec<CertificateDer<'static>>,
721 key: PrivateKeyDer<'static>,
722 ) -> Result<ServerConfig> {
723 let config = ServerConfig::builder()
724 .with_no_client_auth()
725 .with_single_cert(cert_chain, key)
726 .map_err(|e| Error::tls_error(format!("Failed to create TLS config: {}", e)))?;
727
728 Ok(config)
729 }
730
731 fn parse_host_port(uri: &str) -> Result<(String, u16)> {
733 let parts: Vec<&str> = uri.split(':').collect();
734 if parts.len() != 2 {
735 return Err(Error::invalid_request(format!("Invalid URI: {}", uri)));
736 }
737
738 let host = parts[0].to_string();
739 let port = parts[1]
740 .parse::<u16>()
741 .map_err(|_| Error::invalid_request(format!("Invalid port: {}", parts[1])))?;
742
743 Ok((host, port))
744 }
745}