replay 0.1.2

A simple HTTP request replay tool
use std::{process::exit, sync::Arc, thread, time::{SystemTime, UNIX_EPOCH}};

use http_body_util::{BodyExt, Full};
use hyper::{body::{Buf, Bytes, Incoming}, server::conn::http1, Request, Response};
use owo_colors::OwoColorize;
use serde::{Deserialize, Serialize};
use tokio::{net::TcpListener, sync::Mutex};
use hyper_util::rt::TokioIo;

use crate::{replay::start_replay_server, store::{save_logs_to_file, LogStore}};

pub const PROXY_LOG_FILE: &str = "replay_mocks.json";

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct RequestLog {
    pub timestamp: u64,
    pub method: String,
    pub path: String,
    pub query_params: Option<String>,
    pub headers: Vec<(String, String)>,
    pub body: Option<String>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ResponseLog {
    pub status: u16,
    pub headers: Vec<(String, String)>,
    pub body: Option<String>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ProxyLog {
  pub request: RequestLog,
  pub response: ResponseLog,
}

pub async fn start_server(target: &str, listen: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  let target_uri = target.parse::<hyper::Uri>()?;
  let target_authority = target_uri.authority().ok_or("Invalid target URL")?;
  let target_scheme = target_uri.scheme_str().ok_or("http")?;
  let target_host = target_authority.host();
  let target_port = target_authority.port_u16().unwrap_or(if target_scheme == "https" { 443 } else { 80 });

  let logs = Arc::new(Mutex::new(Vec::<ProxyLog>::new()));
  let logs_for_saving = logs.clone();
  tokio::spawn(async move {
      loop {
          tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
          save_logs_to_file(&logs_for_saving, PROXY_LOG_FILE).await
                .unwrap_or_else(|e| eprintln!("Error saving logs to file: {}", e));
      }
  });

  let logs_for_replay = logs.clone();
  thread::spawn(move || {
    let rt = tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap();
    rt.block_on(async {
      match start_replay_server(logs_for_replay, "127.0.0.1:6688").await {
        Ok(_) => {
          println!("Replay server stopped");
          exit(0);
        },
        Err(e) => eprintln!("Replay server error: {}", e),
      }
    });
  });

  let listener = TcpListener::bind(listen).await?;
  println!("Target URL: {}", target.magenta());
  println!("Proxy server is listening on {}", listen.green());
  println!("Replay server is running on {}", "127.0.0.1:6688".green());

  loop {
      let (stream, _) = listener.accept().await?;
      let io = TokioIo::new(stream);

      let target_host_str = target_host.to_string();
      let target_scheme = target_scheme.to_string();
      let logs_clone = logs.clone();

      tokio::task::spawn(async move {
          let service = hyper::service::service_fn(move |req: Request<Incoming>| {
              let target_host = target_host_str.clone();
              let scheme = target_scheme.clone();
              let logs = logs_clone.clone();

              async move {
                  proxy_handler(req, &target_host, target_port, &scheme, logs).await
              }
          });

          if let Err(err) = http1::Builder::new()
                .keep_alive(false)
                .max_buf_size(30 * 1024 * 1024)
              .serve_connection(io, service)
              .await
          {
              eprintln!("> Connection error: {}", err);
          }
      });
  }

}

pub async fn proxy_handler(
  req: Request<Incoming>,
  target_host: &str,
  target_port: u16,
  scheme: &str,
  logs: LogStore,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
  let timestamp = SystemTime::now()
      .duration_since(UNIX_EPOCH)
      .unwrap()
      .as_secs();

  let method = req.method().clone();
  let path = req.uri().path().to_string();
  let query = req.uri().query().map(|q| q.to_string());

  let headers: Vec<(String, String)> = req
      .headers()
      .iter()
      .map(|(name, value)| {
          (
              name.to_string(),
              value.to_str().unwrap_or("").to_string(),
          )
      })
      .collect();

  let (parts, body) = req.into_parts();
  let body_bytes = match body.collect().await {
      Ok(collected) => collected.aggregate(),
      Err(e) => {
          eprintln!("Error collecting request body: {}", e);
          return Ok(Response::builder()
              .status(500)
              .body(Full::new(Bytes::from("Internal Server Error")))
              .unwrap());
      }
  };

  let body_vec = body_bytes.chunk().to_vec();
  let body_str = String::from_utf8(body_vec.clone()).ok();

  let forward_uri = if target_port != 443 && target_port != 80 {
      format!(
          "{}://{}:{}{}{}",
          scheme,
          target_host,
          target_port,
          parts.uri.path(),
          parts.uri.query().map_or(String::new(), |q| format!("?{}", q))
      )
  } else {
      format!(
          "{}://{}{}{}",
          scheme,
          target_host,
          parts.uri.path(),
          parts.uri.query().map_or(String::new(), |q| format!("?{}", q))
      )
  };

  println!("{} {} {}", method.yellow(), path, forward_uri.magenta());

  let client = reqwest::Client::builder()
      .timeout(std::time::Duration::from_secs(30))
      .danger_accept_invalid_certs(true)
      .build()
      .unwrap_or_else(|_| reqwest::Client::new());

  let mut req_builder = match method.as_str() {
      "GET" => client.get(&forward_uri),
      "POST" => client.post(&forward_uri),
      "PUT" => client.put(&forward_uri),
      "DELETE" => client.delete(&forward_uri),
      "HEAD" => client.head(&forward_uri),
      "OPTIONS" => client.request(reqwest::Method::OPTIONS, &forward_uri),
      "PATCH" => client.patch(&forward_uri),
      _ => {
          eprintln!("Unsupported method: {}", method);
          return Ok(Response::builder()
              .status(400)
              .body(Full::new(Bytes::from("Bad Request: Unsupported Method")))
              .unwrap());
      }
  };

  for (name, value) in &headers {
      if name.to_lowercase() != "host" &&
         name.to_lowercase() != "connection" {
          if let Ok(header_name) = reqwest::header::HeaderName::from_bytes(name.as_bytes()) {
              if let Ok(header_value) = reqwest::header::HeaderValue::from_str(value) {
                  req_builder = req_builder.header(header_name, header_value);
              }
          }
      }
  }

  if !body_vec.is_empty() {
      req_builder = req_builder.body(body_vec.clone());
  }

  let resp = match req_builder.send().await {
      Ok(resp) => resp,
      Err(e) => {
          eprintln!("Error sending request: {}", e);
          return Ok(Response::builder()
              .status(502)
              .body(Full::new(Bytes::from(format!("Bad Gateway: {}", e))))
              .unwrap());
      }
  };

  let status = resp.status().as_u16();

  let resp_headers: Vec<(String, String)> = resp
      .headers()
      .iter()
      .map(|(name, value)| {
          (
              name.to_string(),
              value.to_str().unwrap_or("").to_string(),
          )
      })
      .collect();

  let resp_bytes = match resp.bytes().await {
      Ok(bytes) => bytes,
      Err(e) => {
          eprintln!("Error reading response body: {}", e);
          return Ok(Response::builder()
              .status(500)
              .body(Full::new(Bytes::from("Internal Server Error")))
              .unwrap());
      }
  };

  let resp_vec = resp_bytes.to_vec();
  let resp_str = String::from_utf8(resp_vec.clone()).ok();

  let log_entry = ProxyLog {
      request: RequestLog {
          timestamp,
          method: method.to_string(),
          path,
          query_params: query,
          headers,
          body: body_str,
      },
      response: ResponseLog {
          status,
          headers: resp_headers.clone(),
          body: resp_str.clone(),
      },
  };

  {
    let mut logs_guard = logs.lock().await;
    if !logs_guard.iter()
      .any(|log|
          log.request.method == log_entry.request.method &&
          log.request.path == log_entry.request.path &&
          log.request.query_params == log_entry.request.query_params
      ) {
        logs_guard.push(log_entry.clone());
    }
  }

  println!("Saved request/response to log store {}", PROXY_LOG_FILE.magenta());

  let mut builder = Response::builder().status(status);

  for (name, value) in resp_headers {
      if name.to_lowercase() != "connection" &&
         name.to_lowercase() != "transfer-encoding" {
          if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
              if let Ok(header_value) = hyper::header::HeaderValue::from_str(&value) {
                  builder = builder.header(header_name, header_value);
              }
          }
      }
  }

  builder = builder.header("content-length", resp_vec.len());
  builder = builder.header("connection", "close");

  Ok(builder
      .body(Full::new(Bytes::from(resp_vec)))
      .unwrap_or_else(|_| {
          Response::builder()
              .status(500)
              .body(Full::new(Bytes::from("Internal Server Error")))
              .unwrap()
      }))
}