slinger_mitm/
server.rs

1//! MITM Proxy server implementation
2
3use crate::ca::CertificateManager;
4use crate::error::{Error, Result};
5use crate::interceptor::InterceptorHandler;
6use crate::proxy::MitmConfig;
7use bytes::Bytes;
8use http::Version;
9use slinger::{Client, ClientBuilder, Request, Response};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
13use tokio::net::{TcpListener, TcpStream};
14use tokio::sync::RwLock;
15use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
16use tokio_rustls::rustls::ServerConfig;
17use tokio_rustls::TlsAcceptor;
18
19/// Proxy server implementation
20pub struct ProxyServer {
21  config: MitmConfig,
22  cert_manager: Arc<CertificateManager>,
23  interceptor_handler: Arc<RwLock<InterceptorHandler>>,
24  client: Client,
25}
26
27/// Builder for `ProxyServer`.
28///
29/// Allows configuring the server and the inner `slinger::Client`.
30#[derive(Default)]
31pub struct ProxyServerBuilder {
32  config: Option<MitmConfig>,
33  cert_manager: Option<Arc<CertificateManager>>,
34  interceptor_handler: Option<Arc<RwLock<InterceptorHandler>>>,
35  client: Option<Client>,
36  // Optional client configurator: takes a ClientBuilder and returns a configured ClientBuilder
37  client_config: Option<Box<dyn Fn(ClientBuilder) -> ClientBuilder + Send + Sync>>,
38}
39
40impl ProxyServerBuilder {
41  /// Start building from an existing `ProxyServer` configuration.
42  pub fn from_server(server: &ProxyServer) -> Self {
43    Self {
44      config: Some(server.config.clone()),
45      cert_manager: Some(server.cert_manager.clone()),
46      interceptor_handler: Some(server.interceptor_handler.clone()),
47      client: Some(server.client.clone()),
48      client_config: None,
49    }
50  }
51
52  /// Set the `MitmConfig` to use.
53  pub fn config(mut self, config: MitmConfig) -> Self {
54    self.config = Some(config);
55    self
56  }
57
58  /// Set the `CertificateManager` to use.
59  pub fn cert_manager(mut self, cert_manager: Arc<CertificateManager>) -> Self {
60    self.cert_manager = Some(cert_manager);
61    self
62  }
63
64  /// Set the `InterceptorHandler` to use.
65  pub fn interceptor_handler(mut self, handler: Arc<RwLock<InterceptorHandler>>) -> Self {
66    self.interceptor_handler = Some(handler);
67    self
68  }
69
70  /// Provide a fully constructed `slinger::Client` to use.
71  pub fn client(mut self, client: Client) -> Self {
72    self.client = Some(client);
73    self
74  }
75
76  /// Configure the internal `slinger::Client` using a closure that accepts a
77  /// `ClientBuilder` and returns a configured `ClientBuilder`.
78  pub fn configure_client<F>(mut self, f: F) -> Self
79  where
80    F: Fn(ClientBuilder) -> ClientBuilder + Send + Sync + 'static,
81  {
82    self.client_config = Some(Box::new(f));
83    self
84  }
85
86  /// Build the `ProxyServer`.
87  ///
88  /// Priority for creating the inner `Client`:
89  /// 1. If `client` is provided, use it.
90  /// 2. Else if `client_config` is provided, call it with `Client::builder()`.
91  /// 3. Else fall back to default behavior: honor `config.upstream_proxy` if present
92  ///    and otherwise create a default client similar to `ProxyServer::new`.
93  pub fn build(self) -> Result<ProxyServer> {
94    // Resolve config
95    let config = self.config.unwrap_or_default();
96
97    // For synchronous build we require a pre-created CertificateManager because
98    // creation is async. Callers who don't have one should use `build_async()`.
99    let cert_manager = match self.cert_manager {
100      Some(c) => c,
101      None => {
102        return Err(Error::proxy_error(
103          "CertificateManager not provided; use ProxyServer::builder().build_async().await to create one automatically".to_string(),
104        ))
105      }
106    };
107
108    // Resolve interceptor handler
109    let interceptor_handler = self
110      .interceptor_handler
111      .unwrap_or_else(|| Arc::new(RwLock::new(InterceptorHandler::new())));
112
113    // Resolve client
114    let client = if let Some(client) = self.client {
115      client
116    } else if let Some(cfg_fn) = self.client_config {
117      let builder = Client::builder();
118      let configured = cfg_fn(builder);
119      configured
120        .build()
121        .map_err(|e| Error::proxy_error(format!("Failed to build client: {}", e)))?
122    } else {
123      // Fallback: honor upstream_proxy in config similar to ProxyServer::new
124      if let Some(proxy) = &config.upstream_proxy {
125        Client::builder()
126          .timeout(Some(Duration::from_secs(60)))
127          .keepalive(true)
128          .proxy(proxy.clone())
129          .build()
130          .map_err(|e| {
131            Error::proxy_error(format!(
132              "Failed to build client with proxy {}: {}",
133              proxy.uri(),
134              e
135            ))
136          })?
137      } else {
138        Client::builder()
139          .keepalive(true)
140          .build()
141          .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
142      }
143    };
144
145    Ok(ProxyServer {
146      config,
147      cert_manager,
148      interceptor_handler,
149      client,
150    })
151  }
152}
153
154impl ProxyServer {
155  /// Create a new proxy server
156  pub fn new(
157    config: MitmConfig,
158    cert_manager: Arc<CertificateManager>,
159    interceptor_handler: Arc<RwLock<InterceptorHandler>>,
160  ) -> Result<Self> {
161    let client = if let Some(proxy) = &config.upstream_proxy {
162      // Enable HTTP keep-alive so the connector can reuse TCP connections
163      Client::builder()
164        .timeout(Some(Duration::from_secs(60)))
165        .keepalive(true)
166        .proxy(proxy.clone())
167        .build()
168        .map_err(|e| {
169          Error::proxy_error(format!(
170            "Failed to build client with proxy {}: {}",
171            proxy.uri(),
172            e
173          ))
174        })?
175    } else {
176      // Use a client configured to reuse connections (keep-alive)
177      Client::builder()
178        .keepalive(true)
179        .build()
180        .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
181    };
182    Ok(Self {
183      config,
184      cert_manager,
185      interceptor_handler,
186      client,
187    })
188  }
189
190  /// Run the proxy server
191  pub async fn run(&self, addr: &str) -> Result<()> {
192    let listener = TcpListener::bind(addr)
193      .await
194      .map_err(|e| Error::proxy_error(format!("Failed to bind to {}: {}", addr, e)))?;
195    loop {
196      match listener.accept().await {
197        Ok((stream, _peer_addr)) => {
198          let config = self.config.clone();
199          let cert_manager = self.cert_manager.clone();
200          let interceptor = self.interceptor_handler.clone();
201          let client = self.client.clone();
202
203          tokio::spawn(async move {
204            if let Err(e) =
205              Self::handle_connection(stream, config, cert_manager, interceptor, client).await
206            {
207              tracing::error!("[MITM] Error handling connection: {}", e);
208            }
209          });
210        }
211        Err(e) => {
212          tracing::error!("[MITM] Failed to accept connection: {}", e);
213        }
214      }
215    }
216  }
217
218  /// Handle a client connection
219  async fn handle_connection(
220    mut stream: TcpStream,
221    config: MitmConfig,
222    cert_manager: Arc<CertificateManager>,
223    interceptor: Arc<RwLock<InterceptorHandler>>,
224    client: Client,
225  ) -> Result<()> {
226    use crate::socks5::Socks5Server;
227
228    // Read the first byte to determine protocol
229    let mut first_byte = [0u8; 1];
230    stream.read_exact(&mut first_byte).await?;
231
232    // SOCKS5 version is 0x05, HTTP methods start with ASCII letters
233    if first_byte[0] == 0x05 {
234      // Handle as SOCKS5 - we already consumed the version byte
235      // Put it back by handling the rest of the handshake
236      match Socks5Server::handle_handshake_with_version(&mut stream).await {
237        Ok(target_addr) => {
238          let target_host_port = target_addr.to_host_port();
239          match target_addr {
240            crate::socks5::TargetAddr::Domain(_domain, _port) => {
241              if config.enable_https_interception {
242                Self::handle_https_connect_socks5(
243                  stream,
244                  &target_host_port,
245                  cert_manager,
246                  interceptor,
247                  client,
248                )
249                .await
250              } else {
251                Self::handle_tcp_tunnel(stream, &target_host_port).await
252              }
253            }
254            _ => {
255              if config.enable_https_interception {
256                Self::handle_https_connect_socks5(
257                  stream,
258                  &target_host_port,
259                  cert_manager,
260                  interceptor,
261                  client,
262                )
263                .await
264              } else {
265                Self::handle_tcp_tunnel(stream, &target_host_port).await
266              }
267            }
268          }
269        }
270        Err(e) => Err(e),
271      }
272    } else {
273      let mut request_line = vec![first_byte[0]];
274      let mut buffer = [0u8; 1];
275      loop {
276        stream.read_exact(&mut buffer).await?;
277        request_line.push(buffer[0]);
278        if buffer[0] == b'\n' {
279          break;
280        }
281        if request_line.len() > 8192 {
282          return Err(Error::invalid_request("Request line too long".to_string()));
283        }
284      }
285
286      let request_line_str = String::from_utf8_lossy(&request_line);
287      let parts: Vec<&str> = request_line_str.split_whitespace().collect();
288      if parts.len() < 3 {
289        return Err(Error::invalid_request("Invalid request line".to_string()));
290      }
291
292      let method = parts[0].to_string();
293      let uri = parts[1].to_string();
294      if method == "CONNECT" {
295        let mut reader = BufReader::new(stream);
296        const MAX_CONNECT_HEADERS: usize = 16 * 1024; // 16KB max for proxy headers
297        let mut headers_acc = 0usize;
298        loop {
299          let mut line = String::new();
300          let n = reader.read_line(&mut line).await?;
301          // n==0 indicates EOF; bail out
302          if n == 0 {
303            break;
304          }
305          headers_acc += n;
306          if headers_acc > MAX_CONNECT_HEADERS {
307            return Err(Error::invalid_request(
308              "CONNECT headers size exceeds maximum allowed".to_string(),
309            ));
310          }
311          // End of headers is an empty line (CRLF)
312          if line == "\r\n" || line == "\n" || line.is_empty() {
313            break;
314          }
315        }
316        let stream = reader.into_inner();
317        if config.enable_https_interception {
318          Self::handle_https_connect(stream, &uri, cert_manager, interceptor, client).await
319        } else {
320          Self::handle_https_tunnel(stream, &uri).await
321        }
322      } else {
323        let buf_reader = BufReader::new(stream);
324        Self::handle_http_request(&method, &uri, buf_reader, interceptor, client).await
325      }
326    }
327  }
328
329  /// Handle HTTPS CONNECT with MITM interception
330  async fn handle_https_connect(
331    client_stream: TcpStream,
332    uri: &str,
333    cert_manager: Arc<CertificateManager>,
334    interceptor: Arc<RwLock<InterceptorHandler>>,
335    slinger_client: Client,
336  ) -> Result<()> {
337    // Parse domain and port
338    let (domain, port) = Self::parse_host_port(uri)?;
339    // Perform TLS accept + MITM handling, send HTTP 200 before TLS handshake
340    Self::accept_tls_and_handle(
341      client_stream,
342      &domain,
343      port,
344      true,
345      cert_manager,
346      interceptor,
347      slinger_client,
348    )
349    .await
350  }
351
352  /// Handle HTTPS tunnel without interception (transparent proxy)
353  async fn handle_https_tunnel(client_stream: TcpStream, uri: &str) -> Result<()> {
354    Self::tcp_tunnel(client_stream, uri, true).await
355  }
356
357  /// Handle TCP tunnel without interception (for SOCKS5)
358  /// This function doesn't send any response - the SOCKS5 handshake already sent the reply
359  async fn handle_tcp_tunnel(client_stream: TcpStream, target_addr: &str) -> Result<()> {
360    Self::tcp_tunnel(client_stream, target_addr, false).await
361  }
362
363  /// Handle HTTPS CONNECT with MITM interception for SOCKS5
364  /// This function doesn't send HTTP response - the SOCKS5 handshake already sent the reply
365  async fn handle_https_connect_socks5(
366    client_stream: TcpStream,
367    uri: &str,
368    cert_manager: Arc<CertificateManager>,
369    interceptor: Arc<RwLock<InterceptorHandler>>,
370    slinger_client: Client,
371  ) -> Result<()> {
372    // Parse domain and port
373    let (domain, port) = Self::parse_host_port(uri)?;
374
375    // Perform TLS accept + MITM handling, do NOT send HTTP 200 (SOCKS5 already responded)
376    Self::accept_tls_and_handle(
377      client_stream,
378      &domain,
379      port,
380      false,
381      cert_manager,
382      interceptor,
383      slinger_client,
384    )
385    .await
386  }
387
388  /// Accept TLS on an incoming stream (using certs from CertificateManager) and handle HTTPS requests over it.
389  /// If `send_response` is true, send an HTTP/1.1 200 Connection Established before performing the TLS handshake
390  async fn accept_tls_and_handle(
391    mut client_stream: TcpStream,
392    domain: &str,
393    port: u16,
394    send_response: bool,
395    cert_manager: Arc<CertificateManager>,
396    interceptor: Arc<RwLock<InterceptorHandler>>,
397    slinger_client: Client,
398  ) -> Result<()> {
399    if send_response {
400      client_stream
401        .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
402        .await?;
403      client_stream
404        .flush() // ensure response fully sent
405        .await
406        .map_err(Error::Io)?;
407    }
408
409    // Generate server certificate for this domain
410    let (cert_chain, key) = cert_manager.get_server_cert(domain).await?;
411    // Create TLS acceptor
412    let tls_config = Self::create_tls_server_config(cert_chain, key)?;
413    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
414    // Perform TLS handshake with client
415    let tls_stream = acceptor
416      .accept(client_stream)
417      .await
418      .map_err(|e| Error::tls_error(format!("TLS handshake failed: {}", e)))?;
419    let domain_with_port = format!("{}:{}", domain, port);
420    Self::handle_https_stream(tls_stream, domain_with_port, interceptor, slinger_client).await
421  }
422
423  /// Generic TCP tunnel helper. If `send_response` is true, sends HTTP/1.1 200 before tunneling.
424  async fn tcp_tunnel(mut client_stream: TcpStream, uri: &str, send_response: bool) -> Result<()> {
425    let (host, port) = Self::parse_host_port(uri)?;
426    let addr = format!("{}:{}", host, port);
427
428    // Connect to target server
429    let mut target_stream = TcpStream::connect(&addr)
430      .await
431      .map_err(|e| Error::connection_error(format!("Failed to connect to {}: {}", addr, e)))?;
432
433    if send_response {
434      client_stream
435        .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
436        .await?;
437    }
438
439    let (mut client_read, mut client_write) = client_stream.split();
440    let (mut target_read, mut target_write) = target_stream.split();
441
442    let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
443    let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
444
445    tokio::select! {
446        _ = client_to_target => {},
447        _ = target_to_client => {},
448    }
449
450    Ok(())
451  }
452
453  /// Forward a prepared `Request` through the inner `slinger::Client` and run interceptors.
454  /// Returns Some(response_bytes) if there is a response to write back to the caller, or None if
455  /// the interceptor chain dropped the request/response.
456  async fn forward_request_via_client(
457    interceptor: Arc<RwLock<InterceptorHandler>>,
458    client: &Client,
459    request: Request,
460  ) -> Result<Option<Vec<u8>>> {
461    let handler = interceptor.read().await;
462    if let Some(modified_req) = handler.process_request(request).await? {
463      let uri = modified_req.uri().clone();
464      let method = modified_req.method().clone();
465      let headers = modified_req.headers().clone();
466      let body_data = if let Some(body) = modified_req.body() {
467        body.to_vec()
468      } else {
469        Vec::new()
470      };
471      let mut req_builder = client.request(method, uri);
472      for (name, value) in headers.iter() {
473        req_builder = req_builder.header(name, value);
474      }
475      req_builder = req_builder.body(body_data);
476      match req_builder.send().await {
477        Ok(response) => {
478          if let Some(final_response) = handler.process_response(response).await? {
479            let response_bytes = Self::serialize_http_response(&final_response)?;
480            return Ok(Some(response_bytes));
481          }
482        }
483        Err(_e) => {
484          return Ok(Some(b"HTTP/1.1 502 Bad Gateway\r\n\r\n".to_vec()));
485        }
486      }
487    }
488    Ok(None)
489  }
490
491  /// Handle HTTPS requests over TLS connection
492  async fn handle_https_stream<S>(
493    mut tls_stream: S,
494    domain: String,
495    interceptor: Arc<RwLock<InterceptorHandler>>,
496    client: Client,
497  ) -> Result<()>
498  where
499    S: AsyncReadExt + AsyncWriteExt + Unpin,
500  {
501    // Read HTTP request from TLS stream with size limit
502    const MAX_REQUEST_SIZE: usize = 1024 * 1024; // 1MB limit
503    let mut buffer = Vec::new();
504    let mut temp_buf = [0u8; 8192];
505
506    loop {
507      match tls_stream.read(&mut temp_buf).await {
508        Ok(0) => break,
509        Ok(n) => {
510          buffer.extend_from_slice(&temp_buf[..n]);
511          if buffer.len() > MAX_REQUEST_SIZE {
512            return Err(Error::invalid_request(
513              "Request size exceeds maximum allowed".to_string(),
514            ));
515          }
516          if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
517            break;
518          }
519        }
520        Err(e) => return Err(Error::Io(e)),
521      }
522    }
523
524    // Parse request
525    if let Ok(request) = Self::parse_http_request(&buffer, &domain) {
526      if let Some(response_bytes) =
527        Self::forward_request_via_client(interceptor, &client, request).await?
528      {
529        tls_stream.write_all(&response_bytes).await?;
530      }
531    }
532
533    Ok(())
534  }
535
536  /// Handle HTTP request (non-HTTPS)
537  async fn handle_http_request<R>(
538    method: &str,
539    uri: &str,
540    mut reader: BufReader<R>,
541    interceptor: Arc<RwLock<InterceptorHandler>>,
542    client: Client,
543  ) -> Result<()>
544  where
545    R: AsyncReadExt + AsyncWriteExt + Unpin,
546  {
547    // Read headers with size limit
548    const MAX_HEADERS_SIZE: usize = 64 * 1024; // 64KB limit for headers
549    let mut headers_buf = Vec::new();
550    loop {
551      let mut line = String::new();
552      reader.read_line(&mut line).await?;
553      if line == "\r\n" || line == "\n" {
554        break;
555      }
556      headers_buf.extend_from_slice(line.as_bytes());
557
558      // Check headers size limit
559      if headers_buf.len() > MAX_HEADERS_SIZE {
560        return Err(Error::invalid_request(
561          "Headers size exceeds maximum allowed".to_string(),
562        ));
563      }
564    }
565
566    // Build request using http::Request::builder, then convert to slinger::Request
567    let mut request_builder = http::Request::builder()
568      .method(method)
569      .uri(uri)
570      .version(Version::HTTP_11);
571
572    // Parse headers
573    for line in String::from_utf8_lossy(&headers_buf).lines() {
574      if let Some(idx) = line.find(':') {
575        let (name, value) = line.split_at(idx);
576        let value = value[1..].trim();
577        request_builder = request_builder.header(name.trim(), value);
578      }
579    }
580
581    let http_request = request_builder.body(Bytes::new())?;
582    let request: Request = http_request.into();
583
584    // Process through interceptors and forward
585    if let Some(response_bytes) =
586      Self::forward_request_via_client(interceptor, &client, request).await?
587    {
588      let mut stream = reader.into_inner();
589      stream.write_all(&response_bytes).await?;
590    }
591
592    Ok(())
593  }
594
595  /// Create TLS server configuration
596  fn create_tls_server_config(
597    cert_chain: Vec<CertificateDer<'static>>,
598    key: PrivateKeyDer<'static>,
599  ) -> Result<ServerConfig> {
600    let config = ServerConfig::builder()
601      .with_no_client_auth()
602      .with_single_cert(cert_chain, key)
603      .map_err(|e| Error::tls_error(format!("Failed to create TLS config: {}", e)))?;
604
605    Ok(config)
606  }
607
608  /// Parse host and port from URI
609  fn parse_host_port(uri: &str) -> Result<(String, u16)> {
610    let parts: Vec<&str> = uri.split(':').collect();
611    if parts.len() != 2 {
612      return Err(Error::invalid_request(format!("Invalid URI: {}", uri)));
613    }
614
615    let host = parts[0].to_string();
616    let port = parts[1]
617      .parse::<u16>()
618      .map_err(|_| Error::invalid_request(format!("Invalid port: {}", parts[1])))?;
619
620    Ok((host, port))
621  }
622
623  /// Parse HTTP request from bytes
624  fn parse_http_request(buffer: &[u8], domain: &str) -> Result<Request> {
625    let request_str = String::from_utf8_lossy(buffer);
626    let mut lines = request_str.lines();
627
628    let request_line = lines
629      .next()
630      .ok_or_else(|| Error::invalid_request("Empty request".to_string()))?;
631    let parts: Vec<&str> = request_line.split_whitespace().collect();
632    if parts.len() < 3 {
633      return Err(Error::invalid_request("Invalid request line".to_string()));
634    }
635
636    let method = parts[0];
637    let path = parts[1];
638    let uri = if path.starts_with("http://") || path.starts_with("https://") {
639      path.to_string()
640    } else {
641      format!("https://{}{}", domain, path)
642    };
643
644    let mut request_builder = http::Request::builder()
645      .method(method)
646      .uri(uri)
647      .version(Version::HTTP_11);
648
649    for line in lines {
650      if line.is_empty() {
651        break;
652      }
653      if let Some(idx) = line.find(':') {
654        let (name, value) = line.split_at(idx);
655        let value = value[1..].trim();
656        request_builder = request_builder.header(name.trim(), value);
657      }
658    }
659
660    let http_request = request_builder.body(Bytes::new())?;
661    Ok(http_request.into())
662  }
663
664  /// Serialize HTTP response to bytes
665  fn serialize_http_response(response: &Response) -> Result<Vec<u8>> {
666    let mut buf = Vec::new();
667
668    // Status line
669    let status = response.status_code();
670    let status_line = format!(
671      "HTTP/1.1 {} {}\r\n",
672      status.as_u16(),
673      status.canonical_reason().unwrap_or("Unknown")
674    );
675    buf.extend_from_slice(status_line.as_bytes());
676
677    // Headers
678    for (name, value) in response.headers() {
679      buf.extend_from_slice(name.as_str().as_bytes());
680      buf.extend_from_slice(b": ");
681      buf.extend_from_slice(value.as_bytes());
682      buf.extend_from_slice(b"\r\n");
683    }
684    // Empty line before body
685    buf.extend_from_slice(b"\r\n");
686    // Body
687    if let Some(body) = response.body() {
688      buf.extend_from_slice(body.as_ref());
689    }
690    Ok(buf)
691  }
692}