1use crate::ca::CertificateManager;
4use crate::error::{Error, Result};
5use crate::interceptor::{InterceptorHandler, MitmRequest, MitmResponse};
6use crate::proxy::MitmConfig;
7use bytes::Bytes;
8use http::Version;
9use slinger::{Client, ClientBuilder, Request, Response};
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
14use tokio::net::{TcpListener, TcpStream};
15use tokio::sync::RwLock;
16use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
17use tokio_rustls::rustls::ServerConfig;
18use tokio_rustls::TlsAcceptor;
19
20pub struct ProxyServer {
22 config: MitmConfig,
23 cert_manager: Arc<CertificateManager>,
24 interceptor_handler: Arc<RwLock<InterceptorHandler>>,
25 client: Client,
26}
27
28#[derive(Default)]
32pub struct ProxyServerBuilder {
33 config: Option<MitmConfig>,
34 cert_manager: Option<Arc<CertificateManager>>,
35 interceptor_handler: Option<Arc<RwLock<InterceptorHandler>>>,
36 client: Option<Client>,
37 client_config: Option<Box<dyn Fn(ClientBuilder) -> ClientBuilder + Send + Sync>>,
39}
40
41impl ProxyServerBuilder {
42 pub fn from_server(server: &ProxyServer) -> Self {
44 Self {
45 config: Some(server.config.clone()),
46 cert_manager: Some(server.cert_manager.clone()),
47 interceptor_handler: Some(server.interceptor_handler.clone()),
48 client: Some(server.client.clone()),
49 client_config: None,
50 }
51 }
52
53 pub fn config(mut self, config: MitmConfig) -> Self {
55 self.config = Some(config);
56 self
57 }
58
59 pub fn cert_manager(mut self, cert_manager: Arc<CertificateManager>) -> Self {
61 self.cert_manager = Some(cert_manager);
62 self
63 }
64
65 pub fn interceptor_handler(mut self, handler: Arc<RwLock<InterceptorHandler>>) -> Self {
67 self.interceptor_handler = Some(handler);
68 self
69 }
70
71 pub fn client(mut self, client: Client) -> Self {
73 self.client = Some(client);
74 self
75 }
76
77 pub fn configure_client<F>(mut self, f: F) -> Self
80 where
81 F: Fn(ClientBuilder) -> ClientBuilder + Send + Sync + 'static,
82 {
83 self.client_config = Some(Box::new(f));
84 self
85 }
86
87 pub fn build(self) -> Result<ProxyServer> {
95 let config = self.config.unwrap_or_default();
97
98 let cert_manager = match self.cert_manager {
101 Some(c) => c,
102 None => {
103 return Err(Error::proxy_error(
104 "CertificateManager not provided; use ProxyServer::builder().build_async().await to create one automatically".to_string(),
105 ))
106 }
107 };
108
109 let interceptor_handler = self
111 .interceptor_handler
112 .unwrap_or_else(|| Arc::new(RwLock::new(InterceptorHandler::new())));
113
114 let client = if let Some(client) = self.client {
116 client
117 } else if let Some(cfg_fn) = self.client_config {
118 let builder = Client::builder();
119 let configured = cfg_fn(builder);
120 configured
121 .build()
122 .map_err(|e| Error::proxy_error(format!("Failed to build client: {}", e)))?
123 } else {
124 if let Some(proxy) = &config.upstream_proxy {
126 Client::builder()
127 .timeout(Some(Duration::from_secs(60)))
128 .keepalive(true)
129 .proxy(proxy.clone())
130 .build()
131 .map_err(|e| {
132 Error::proxy_error(format!(
133 "Failed to build client with proxy {}: {}",
134 proxy.uri(),
135 e
136 ))
137 })?
138 } else {
139 Client::builder()
140 .keepalive(true)
141 .build()
142 .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
143 }
144 };
145
146 Ok(ProxyServer {
147 config,
148 cert_manager,
149 interceptor_handler,
150 client,
151 })
152 }
153}
154
155impl ProxyServer {
156 pub fn new(
158 config: MitmConfig,
159 cert_manager: Arc<CertificateManager>,
160 interceptor_handler: Arc<RwLock<InterceptorHandler>>,
161 ) -> Result<Self> {
162 let client = if let Some(proxy) = &config.upstream_proxy {
163 Client::builder()
165 .timeout(Some(Duration::from_secs(60)))
166 .keepalive(true)
167 .proxy(proxy.clone())
168 .build()
169 .map_err(|e| {
170 Error::proxy_error(format!(
171 "Failed to build client with proxy {}: {}",
172 proxy.uri(),
173 e
174 ))
175 })?
176 } else {
177 Client::builder()
179 .keepalive(true)
180 .build()
181 .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
182 };
183 Ok(Self {
184 config,
185 cert_manager,
186 interceptor_handler,
187 client,
188 })
189 }
190
191 pub async fn run(&self, addr: &str) -> Result<()> {
193 let listener = TcpListener::bind(addr)
194 .await
195 .map_err(|e| Error::proxy_error(format!("Failed to bind to {}: {}", addr, e)))?;
196 loop {
197 match listener.accept().await {
198 Ok((stream, peer_addr)) => {
199 let config = self.config.clone();
200 let cert_manager = self.cert_manager.clone();
201 let interceptor = self.interceptor_handler.clone();
202 let client = self.client.clone();
203
204 tokio::spawn(async move {
205 if let Err(e) =
206 Self::handle_connection(stream, peer_addr, config, cert_manager, interceptor, client)
207 .await
208 {
209 tracing::error!("[MITM] Error handling connection: {}", e);
210 }
211 });
212 }
213 Err(e) => {
214 tracing::error!("[MITM] Failed to accept connection: {}", e);
215 }
216 }
217 }
218 }
219
220 async fn handle_connection(
222 mut stream: TcpStream,
223 peer_addr: SocketAddr,
224 config: MitmConfig,
225 cert_manager: Arc<CertificateManager>,
226 interceptor: Arc<RwLock<InterceptorHandler>>,
227 client: Client,
228 ) -> Result<()> {
229 use crate::socks5::Socks5Server;
230
231 let mut first_byte = [0u8; 1];
233 stream.read_exact(&mut first_byte).await?;
234
235 if first_byte[0] == 0x05 {
237 match Socks5Server::handle_handshake_with_version(&mut stream).await {
240 Ok(target_addr) => {
241 let target_host_port = target_addr.to_host_port();
242
243 let has_interceptors = interceptor.read().await.has_interceptors();
245
246 if config.enable_tcp_interception && has_interceptors {
247 Self::handle_tcp_tunnel_with_interception(
249 stream,
250 &target_host_port,
251 peer_addr,
252 interceptor,
253 )
254 .await
255 } else if config.enable_https_interception {
256 Self::handle_https_connect_socks5(
257 stream,
258 &target_host_port,
259 cert_manager,
260 interceptor,
261 client,
262 )
263 .await
264 } else {
265 Self::handle_tcp_tunnel(stream, &target_host_port).await
266 }
267 }
268 Err(e) => Err(e),
269 }
270 } else {
271 let mut request_line = vec![first_byte[0]];
272 let mut buffer = [0u8; 1];
273 loop {
274 stream.read_exact(&mut buffer).await?;
275 request_line.push(buffer[0]);
276 if buffer[0] == b'\n' {
277 break;
278 }
279 if request_line.len() > 8192 {
280 return Err(Error::invalid_request("Request line too long".to_string()));
281 }
282 }
283
284 let request_line_str = String::from_utf8_lossy(&request_line);
285 let parts: Vec<&str> = request_line_str.split_whitespace().collect();
286 if parts.len() < 3 {
287 return Err(Error::invalid_request("Invalid request line".to_string()));
288 }
289
290 let method = parts[0].to_string();
291 let uri = parts[1].to_string();
292 if method == "CONNECT" {
293 let mut reader = BufReader::new(stream);
294 const MAX_CONNECT_HEADERS: usize = 16 * 1024; let mut headers_acc = 0usize;
296 loop {
297 let mut line = String::new();
298 let n = reader.read_line(&mut line).await?;
299 if n == 0 {
301 break;
302 }
303 headers_acc += n;
304 if headers_acc > MAX_CONNECT_HEADERS {
305 return Err(Error::invalid_request(
306 "CONNECT headers size exceeds maximum allowed".to_string(),
307 ));
308 }
309 if line == "\r\n" || line == "\n" || line.is_empty() {
311 break;
312 }
313 }
314 let stream = reader.into_inner();
315 if config.enable_https_interception {
316 Self::handle_https_connect(stream, &uri, cert_manager, interceptor, client).await
317 } else {
318 Self::handle_https_tunnel(stream, &uri).await
319 }
320 } else {
321 let buf_reader = BufReader::new(stream);
322 Self::handle_http_request(&method, &uri, buf_reader, interceptor, client).await
323 }
324 }
325 }
326
327 async fn handle_https_connect(
329 client_stream: TcpStream,
330 uri: &str,
331 cert_manager: Arc<CertificateManager>,
332 interceptor: Arc<RwLock<InterceptorHandler>>,
333 slinger_client: Client,
334 ) -> Result<()> {
335 let (domain, port) = Self::parse_host_port(uri)?;
337 Self::accept_tls_and_handle(
339 client_stream,
340 &domain,
341 port,
342 true,
343 cert_manager,
344 interceptor,
345 slinger_client,
346 )
347 .await
348 }
349
350 async fn handle_https_tunnel(client_stream: TcpStream, uri: &str) -> Result<()> {
352 Self::tcp_tunnel(client_stream, uri, true).await
353 }
354
355 async fn handle_tcp_tunnel(client_stream: TcpStream, target_addr: &str) -> Result<()> {
358 Self::tcp_tunnel(client_stream, target_addr, false).await
359 }
360
361 async fn handle_https_connect_socks5(
364 client_stream: TcpStream,
365 uri: &str,
366 cert_manager: Arc<CertificateManager>,
367 interceptor: Arc<RwLock<InterceptorHandler>>,
368 slinger_client: Client,
369 ) -> Result<()> {
370 let (domain, port) = Self::parse_host_port(uri)?;
372
373 Self::accept_tls_and_handle(
375 client_stream,
376 &domain,
377 port,
378 false,
379 cert_manager,
380 interceptor,
381 slinger_client,
382 )
383 .await
384 }
385
386 async fn accept_tls_and_handle(
389 mut client_stream: TcpStream,
390 domain: &str,
391 port: u16,
392 send_response: bool,
393 cert_manager: Arc<CertificateManager>,
394 interceptor: Arc<RwLock<InterceptorHandler>>,
395 slinger_client: Client,
396 ) -> Result<()> {
397 if send_response {
398 client_stream
399 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
400 .await?;
401 client_stream
402 .flush() .await
404 .map_err(Error::Io)?;
405 }
406
407 let (cert_chain, key) = cert_manager.get_server_cert(domain).await?;
409 let tls_config = Self::create_tls_server_config(cert_chain, key)?;
411 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
412 let tls_stream = acceptor
414 .accept(client_stream)
415 .await
416 .map_err(|e| Error::tls_error(format!("TLS handshake failed: {}", e)))?;
417 let domain_with_port = format!("{}:{}", domain, port);
418 Self::handle_https_stream(tls_stream, domain_with_port, interceptor, slinger_client).await
419 }
420
421 async fn tcp_tunnel(mut client_stream: TcpStream, uri: &str, send_response: bool) -> Result<()> {
423 let (host, port) = Self::parse_host_port(uri)?;
424 let addr = format!("{}:{}", host, port);
425
426 let mut target_stream = TcpStream::connect(&addr)
428 .await
429 .map_err(|e| Error::connection_error(format!("Failed to connect to {}: {}", addr, e)))?;
430
431 if send_response {
432 client_stream
433 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
434 .await?;
435 }
436
437 let (mut client_read, mut client_write) = client_stream.split();
438 let (mut target_read, mut target_write) = target_stream.split();
439
440 let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
441 let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
442
443 tokio::select! {
444 _ = client_to_target => {},
445 _ = target_to_client => {},
446 }
447
448 Ok(())
449 }
450
451 async fn handle_tcp_tunnel_with_interception(
455 client_stream: TcpStream,
456 target_addr: &str,
457 peer_addr: SocketAddr,
458 interceptor: Arc<RwLock<InterceptorHandler>>,
459 ) -> Result<()> {
460 let (host, port) = Self::parse_host_port(target_addr)?;
461 let addr = format!("{}:{}", host, port);
462
463 let target_stream = TcpStream::connect(&addr)
465 .await
466 .map_err(|e| Error::connection_error(format!("Failed to connect to {}: {}", addr, e)))?;
467
468 let (mut client_read, mut client_write) = client_stream.into_split();
469 let (mut target_read, mut target_write) = target_stream.into_split();
470
471 let target_addr_clone = addr.clone();
472 let interceptor_clone = interceptor.clone();
473
474 let client_to_target = tokio::spawn(async move {
476 let mut buffer = vec![0u8; 8192];
477 loop {
478 match client_read.read(&mut buffer).await {
479 Ok(0) => break, Ok(n) => {
481 let data = Bytes::copy_from_slice(&buffer[..n]);
482 let request = MitmRequest::raw_tcp_with_source(peer_addr, &target_addr_clone, data);
483
484 let handler = interceptor_clone.read().await;
486 match handler.process_request(request).await {
487 Ok(Some(modified_request)) => {
488 if let Some(body) = modified_request.body() {
490 if let Err(e) = target_write.write_all(body.as_ref()).await {
491 tracing::error!("[MITM TCP] Error writing to target: {}", e);
492 break;
493 }
494 }
495 }
496 Ok(None) => {
497 tracing::debug!("[MITM TCP] Request blocked by interceptor");
499 }
500 Err(e) => {
501 tracing::error!("[MITM TCP] Error processing request: {}", e);
502 break;
503 }
504 }
505 }
506 Err(e) => {
507 tracing::error!("[MITM TCP] Error reading from client: {}", e);
508 break;
509 }
510 }
511 }
512 });
513
514 let target_to_client = tokio::spawn(async move {
516 let mut buffer = vec![0u8; 8192];
517 loop {
518 match target_read.read(&mut buffer).await {
519 Ok(0) => break, Ok(n) => {
521 let data = Bytes::copy_from_slice(&buffer[..n]);
522 let response = MitmResponse::raw_tcp_with_destination(&addr, peer_addr, data);
523
524 let handler = interceptor.read().await;
526 match handler.process_response(response).await {
527 Ok(Some(modified_response)) => {
528 if let Some(body) = modified_response.body() {
530 if let Err(e) = client_write.write_all(body.as_ref()).await {
531 tracing::error!("[MITM TCP] Error writing to client: {}", e);
532 break;
533 }
534 }
535 }
536 Ok(None) => {
537 tracing::debug!("[MITM TCP] Response blocked by interceptor");
539 }
540 Err(e) => {
541 tracing::error!("[MITM TCP] Error processing response: {}", e);
542 break;
543 }
544 }
545 }
546 Err(e) => {
547 tracing::error!("[MITM TCP] Error reading from target: {}", e);
548 break;
549 }
550 }
551 }
552 });
553
554 tokio::select! {
555 _ = client_to_target => {},
556 _ = target_to_client => {},
557 }
558
559 Ok(())
560 }
561
562 async fn forward_request_via_client(
566 interceptor: Arc<RwLock<InterceptorHandler>>,
567 client: &Client,
568 request: Request,
569 destination: &str,
570 ) -> Result<Option<Vec<u8>>> {
571 let handler = interceptor.read().await;
572 let mitm_request = MitmRequest::new(destination, request);
573 if let Some(modified_req) = handler.process_request(mitm_request).await? {
574 let inner_req = modified_req.request();
575 let uri = inner_req.uri().clone();
576 let method = inner_req.method().clone();
577 let headers = inner_req.headers().clone();
578 let body_data = if let Some(body) = inner_req.body() {
579 body.to_vec()
580 } else {
581 Vec::new()
582 };
583 let mut req_builder = client.request(method, uri);
584 for (name, value) in headers.iter() {
585 req_builder = req_builder.header(name, value);
586 }
587 req_builder = req_builder.body(body_data);
588 match req_builder.send().await {
589 Ok(response) => {
590 let mitm_response = MitmResponse::new(destination, response);
591 if let Some(final_response) = handler.process_response(mitm_response).await? {
592 let response_bytes = Self::serialize_http_response(final_response.response())?;
593 return Ok(Some(response_bytes));
594 }
595 }
596 Err(_e) => {
597 return Ok(Some(b"HTTP/1.1 502 Bad Gateway\r\n\r\n".to_vec()));
598 }
599 }
600 }
601 Ok(None)
602 }
603
604 async fn handle_https_stream<S>(
606 mut tls_stream: S,
607 domain: String,
608 interceptor: Arc<RwLock<InterceptorHandler>>,
609 client: Client,
610 ) -> Result<()>
611 where
612 S: AsyncReadExt + AsyncWriteExt + Unpin,
613 {
614 const MAX_REQUEST_SIZE: usize = 1024 * 1024; let mut buffer = Vec::new();
617 let mut temp_buf = [0u8; 8192];
618
619 loop {
620 match tls_stream.read(&mut temp_buf).await {
621 Ok(0) => break,
622 Ok(n) => {
623 buffer.extend_from_slice(&temp_buf[..n]);
624 if buffer.len() > MAX_REQUEST_SIZE {
625 return Err(Error::invalid_request(
626 "Request size exceeds maximum allowed".to_string(),
627 ));
628 }
629 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
630 break;
631 }
632 }
633 Err(e) => return Err(Error::Io(e)),
634 }
635 }
636
637 if let Ok(request) = Self::parse_http_request(&buffer, &domain) {
639 if let Some(response_bytes) =
640 Self::forward_request_via_client(interceptor, &client, request, &domain).await?
641 {
642 tls_stream.write_all(&response_bytes).await?;
643 }
644 }
645
646 Ok(())
647 }
648
649 async fn handle_http_request<R>(
651 method: &str,
652 uri: &str,
653 mut reader: BufReader<R>,
654 interceptor: Arc<RwLock<InterceptorHandler>>,
655 client: Client,
656 ) -> Result<()>
657 where
658 R: AsyncReadExt + AsyncWriteExt + Unpin,
659 {
660 const MAX_HEADERS_SIZE: usize = 64 * 1024; let mut headers_buf = Vec::new();
663 loop {
664 let mut line = String::new();
665 reader.read_line(&mut line).await?;
666 if line == "\r\n" || line == "\n" {
667 break;
668 }
669 headers_buf.extend_from_slice(line.as_bytes());
670
671 if headers_buf.len() > MAX_HEADERS_SIZE {
673 return Err(Error::invalid_request(
674 "Headers size exceeds maximum allowed".to_string(),
675 ));
676 }
677 }
678
679 let mut request_builder = http::Request::builder()
681 .method(method)
682 .uri(uri)
683 .version(Version::HTTP_11);
684
685 for line in String::from_utf8_lossy(&headers_buf).lines() {
687 if let Some(idx) = line.find(':') {
688 let (name, value) = line.split_at(idx);
689 let value = value[1..].trim();
690 request_builder = request_builder.header(name.trim(), value);
691 }
692 }
693
694 let http_request = request_builder.body(Bytes::new())?;
695 let request: Request = http_request.into();
696
697 if let Some(response_bytes) =
699 Self::forward_request_via_client(interceptor, &client, request, uri).await?
700 {
701 let mut stream = reader.into_inner();
702 stream.write_all(&response_bytes).await?;
703 }
704
705 Ok(())
706 }
707
708 fn create_tls_server_config(
710 cert_chain: Vec<CertificateDer<'static>>,
711 key: PrivateKeyDer<'static>,
712 ) -> Result<ServerConfig> {
713 let config = ServerConfig::builder()
714 .with_no_client_auth()
715 .with_single_cert(cert_chain, key)
716 .map_err(|e| Error::tls_error(format!("Failed to create TLS config: {}", e)))?;
717
718 Ok(config)
719 }
720
721 fn parse_host_port(uri: &str) -> Result<(String, u16)> {
723 let parts: Vec<&str> = uri.split(':').collect();
724 if parts.len() != 2 {
725 return Err(Error::invalid_request(format!("Invalid URI: {}", uri)));
726 }
727
728 let host = parts[0].to_string();
729 let port = parts[1]
730 .parse::<u16>()
731 .map_err(|_| Error::invalid_request(format!("Invalid port: {}", parts[1])))?;
732
733 Ok((host, port))
734 }
735
736 fn parse_http_request(buffer: &[u8], domain: &str) -> Result<Request> {
738 let request_str = String::from_utf8_lossy(buffer);
739 let mut lines = request_str.lines();
740
741 let request_line = lines
742 .next()
743 .ok_or_else(|| Error::invalid_request("Empty request".to_string()))?;
744 let parts: Vec<&str> = request_line.split_whitespace().collect();
745 if parts.len() < 3 {
746 return Err(Error::invalid_request("Invalid request line".to_string()));
747 }
748
749 let method = parts[0];
750 let path = parts[1];
751 let uri = if path.starts_with("http://") || path.starts_with("https://") {
752 path.to_string()
753 } else {
754 format!("https://{}{}", domain, path)
755 };
756
757 let mut request_builder = http::Request::builder()
758 .method(method)
759 .uri(uri)
760 .version(Version::HTTP_11);
761
762 for line in lines {
763 if line.is_empty() {
764 break;
765 }
766 if let Some(idx) = line.find(':') {
767 let (name, value) = line.split_at(idx);
768 let value = value[1..].trim();
769 request_builder = request_builder.header(name.trim(), value);
770 }
771 }
772
773 let http_request = request_builder.body(Bytes::new())?;
774 Ok(http_request.into())
775 }
776
777 fn serialize_http_response(response: &Response) -> Result<Vec<u8>> {
779 let mut buf = Vec::new();
780
781 let status = response.status_code();
783 let status_line = format!(
784 "HTTP/1.1 {} {}\r\n",
785 status.as_u16(),
786 status.canonical_reason().unwrap_or("Unknown")
787 );
788 buf.extend_from_slice(status_line.as_bytes());
789
790 for (name, value) in response.headers() {
792 buf.extend_from_slice(name.as_str().as_bytes());
793 buf.extend_from_slice(b": ");
794 buf.extend_from_slice(value.as_bytes());
795 buf.extend_from_slice(b"\r\n");
796 }
797 buf.extend_from_slice(b"\r\n");
799 if let Some(body) = response.body() {
801 buf.extend_from_slice(body.as_ref());
802 }
803 Ok(buf)
804 }
805}