http-mitm-proxy 0.16.0

A HTTP proxy server library intended to be a backend of application like Burp proxy.
Documentation
use std::{
    convert::Infallible,
    sync::{Arc, atomic::AtomicU16},
};

use axum::{
    Router,
    extract::Request,
    response::{Sse, sse::Event},
    routing::get,
};
use bytes::Bytes;
use futures::stream;
use http_mitm_proxy::{DefaultClient, MitmProxy};
use hyper::{
    Uri,
    body::{Body, Incoming},
    service::{HttpService, service_fn},
};
use moka::sync::Cache;

static PORT: AtomicU16 = AtomicU16::new(3666);

#[ctor::ctor]
#[cfg(test)]
fn init_subscriber() {
    tracing_subscriber::fmt::init();
}

fn get_port() -> u16 {
    PORT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}

fn root_cert() -> rcgen::Issuer<'static, rcgen::KeyPair> {
    let mut params = rcgen::CertificateParams::default();

    params.distinguished_name = rcgen::DistinguishedName::new();
    params.distinguished_name.push(
        rcgen::DnType::CommonName,
        rcgen::DnValue::Utf8String("<http-mitm-proxy TEST CA>".to_string()),
    );
    params.key_usages = vec![
        rcgen::KeyUsagePurpose::KeyCertSign,
        rcgen::KeyUsagePurpose::CrlSign,
    ];
    params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);

    let signing_key = rcgen::KeyPair::generate().unwrap();

    rcgen::Issuer::new(params, signing_key)
}

async fn bind_app(app: Router) -> (u16, impl std::future::Future<Output = ()>) {
    let port = get_port();
    let listener = tokio::net::TcpListener::bind(("127.0.0.1", port))
        .await
        .unwrap();
    (port, async {
        axum::serve(listener, app).await.unwrap();
    })
}

fn client(proxy_port: u16) -> reqwest::Client {
    reqwest::Client::builder()
        .proxy(reqwest::Proxy::http(format!("http://127.0.0.1:{proxy_port}")).unwrap())
        .proxy(reqwest::Proxy::https(format!("http://127.0.0.1:{proxy_port}")).unwrap())
        .build()
        .unwrap()
}

fn client_tls(proxy_port: u16, cert: &rcgen::Certificate) -> reqwest::Client {
    reqwest::Client::builder()
        .proxy(reqwest::Proxy::http(format!("http://127.0.0.1:{proxy_port}")).unwrap())
        .proxy(reqwest::Proxy::https(format!("http://127.0.0.1:{proxy_port}")).unwrap())
        .add_root_certificate(reqwest::Certificate::from_der(cert.der()).unwrap())
        .build()
        .unwrap()
}

fn proxy_client() -> DefaultClient {
    DefaultClient::new()
}

async fn setup<B, E, E2, S>(app: Router, service: S) -> (u16, u16)
where
    B: Body<Data = Bytes, Error = E> + Send + Sync + 'static,
    E: std::error::Error + Send + Sync + 'static,
    E2: std::error::Error + Send + Sync + 'static,
    S: HttpService<Incoming, ResBody = B, Error = E2, Future: Send> + Send + Sync + Clone + 'static,
{
    let proxy = MitmProxy::new(Some(root_cert()), Some(Cache::new(128)));
    let proxy_port = get_port();
    let proxy = proxy
        .bind(("127.0.0.1", proxy_port), service)
        .await
        .unwrap();
    tokio::spawn(proxy);

    let (port, server) = bind_app(app).await;
    tokio::spawn(server);

    (proxy_port, port)
}

async fn setup_tls<B, E, E2, S>(
    app: Router,
    service: S,
    root_cert: Arc<rcgen::Issuer<'static, rcgen::KeyPair>>,
) -> (u16, u16)
where
    B: Body<Data = Bytes, Error = E> + Send + Sync + 'static,
    E: std::error::Error + Send + Sync + 'static,
    E2: std::error::Error + Send + Sync + 'static,
    S: HttpService<Incoming, ResBody = B, Error = E2, Future: Send> + Send + Sync + Clone + 'static,
{
    let proxy = MitmProxy::new(Some(root_cert), Some(Cache::new(128)));
    let proxy_port = get_port();
    let proxy = proxy
        .bind(("127.0.0.1", proxy_port), service)
        .await
        .unwrap();
    tokio::spawn(proxy);

    let (port, server) = bind_app(app).await;
    tokio::spawn(server);

    (proxy_port, port)
}

#[tokio::test]
async fn test_simple_http() {
    const BODY: &str = "Hello, World!";
    let app = Router::new().route("/", get(|| async move { BODY }));

    let proxy_client = proxy_client();
    let (proxy_port, port) = setup(
        app,
        service_fn(move |req| {
            let proxy_client = proxy_client.clone();
            async move { proxy_client.send_request(req).await.map(|t| t.0) }
        }),
    )
    .await;

    let client = client(proxy_port);

    let res = client
        .get(format!("http://127.0.0.1:{port}/"))
        .send()
        .await
        .unwrap();

    assert_eq!(res.status(), 200);
    assert_eq!(res.text().await.unwrap(), BODY);
}

#[tokio::test]
async fn test_modify_http() {
    let app = Router::new().route(
        "/",
        get(|req: Request| async move {
            req.headers()
                .get("X-test")
                .map(|v| v.to_str().unwrap())
                .unwrap_or("none")
                .to_string()
        }),
    );

    let proxy_client = proxy_client();
    let (proxy_port, port) = setup(
        app,
        service_fn(move |mut req| {
            let proxy_client = proxy_client.clone();
            async move {
                req.headers_mut()
                    .insert("X-test", "modified".parse().unwrap());
                proxy_client.send_request(req).await.map(|t| t.0)
            }
        }),
    )
    .await;

    let client = client(proxy_port);

    let res = client
        .get(format!("http://127.0.0.1:{port}/"))
        .send()
        .await
        .unwrap();

    assert_eq!(res.status(), 200);
    assert_eq!(res.text().await.unwrap(), "modified");
}

#[tokio::test]
async fn test_sse_http() {
    let app = Router::new().route(
        "/sse",
        get(|| async {
            Sse::new(stream::iter(["1", "2", "3"].into_iter().map(|s| {
                Ok::<Event, Infallible>(Event::default().event("message").data(s))
            })))
        }),
    );

    let proxy_client = proxy_client();
    let (proxy_port, port) = setup(
        app,
        service_fn(move |req| {
            let proxy_client = proxy_client.clone();
            async move { proxy_client.send_request(req).await.map(|t| t.0) }
        }),
    )
    .await;

    let client = client(proxy_port);
    let res = client
        .get(format!("http://127.0.0.1:{port}/sse"))
        .send()
        .await
        .unwrap();

    assert_eq!(
        res.bytes().await.unwrap(),
        b"event: message\ndata: 1\n\nevent: message\ndata: 2\n\nevent: message\ndata: 3\n\n"[..]
    );
}

#[tokio::test]
async fn test_simple_https() {
    const BODY: &str = "Hello, World!";
    let app = Router::new().route("/", get(|| async move { BODY }));

    let mut params = rcgen::CertificateParams::default();

    params.distinguished_name = rcgen::DistinguishedName::new();
    params.distinguished_name.push(
        rcgen::DnType::CommonName,
        rcgen::DnValue::Utf8String("<http-mitm-proxy TEST CA>".to_string()),
    );
    params.key_usages = vec![
        rcgen::KeyUsagePurpose::KeyCertSign,
        rcgen::KeyUsagePurpose::CrlSign,
    ];
    params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);

    let signing_key = rcgen::KeyPair::generate().unwrap();

    let cert = params.self_signed(&signing_key).unwrap();

    let issuer = rcgen::Issuer::new(params, signing_key);
    let issuer = Arc::new(issuer);

    let proxy_client = proxy_client();
    let (proxy_port, port) = setup_tls(
        app,
        service_fn(move |mut req| {
            let proxy_client = proxy_client.clone();
            async move {
                let mut parts = req.uri().clone().into_parts();
                parts.scheme = Some(hyper::http::uri::Scheme::HTTP);

                *req.uri_mut() = Uri::from_parts(parts).unwrap();

                proxy_client.send_request(req).await.map(|t| t.0)
            }
        }),
        issuer.clone(),
    )
    .await;

    let client = client_tls(proxy_port, &cert);

    let res = client
        .get(format!("https://127.0.0.1:{port}/"))
        .send()
        .await
        .unwrap();

    assert_eq!(res.status(), 200);
    assert_eq!(res.text().await.unwrap(), BODY);
}