wafrift-proxy 0.3.1

HTTP forward proxy with automatic WAF evasion and optional TLS interception support.
Documentation
use std::sync::Arc;

use axum::{
    Router,
    body::Bytes,
    extract::State,
    http::{Method, StatusCode, Uri},
    response::IntoResponse,
    routing::any,
};
use tokio::net::TcpListener;
use tokio::sync::Mutex;

mod common;
use common::{pick_free_port, proxy_client, start_proxy_and_wait, stop_proxy};

#[derive(Debug)]
struct SeenRequest {
    headers: Vec<(String, String)>,
    query: String,
    body: String,
}

fn has_header(headers: &[(String, String)], name: &str, value: &str) -> bool {
    headers
        .iter()
        .any(|(k, v)| k.eq_ignore_ascii_case(name) && v.eq_ignore_ascii_case(value))
}

fn has_header_prefix(headers: &[(String, String)], name: &str, value: &str) -> bool {
    headers.iter().any(|(k, v)| {
        k.eq_ignore_ascii_case(name) && v.to_lowercase().starts_with(&value.to_lowercase())
    })
}

async fn start_upstream_server() -> (
    u16,
    Arc<Mutex<Vec<SeenRequest>>>,
    tokio::task::JoinHandle<()>,
) {
    let captured = Arc::new(Mutex::new(Vec::<SeenRequest>::new()));
    let app = Router::new()
        .route("/*path", any(capture_request))
        .with_state(captured.clone());
    let listener = TcpListener::bind(("127.0.0.1", 0))
        .await
        .expect("bind upstream");
    let port = listener
        .local_addr()
        .expect("upstream listener local addr")
        .port();
    let handle = tokio::spawn(async move {
        axum::serve(listener, app)
            .await
            .expect("upstream server should serve");
    });
    (port, captured, handle)
}

async fn capture_request(
    State(state): State<Arc<Mutex<Vec<SeenRequest>>>>,
    _method: Method,
    uri: Uri,
    headers: axum::http::HeaderMap,
    body: Bytes,
) -> impl IntoResponse {
    let headers = headers
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|value| (name.to_string(), value.to_string()))
        })
        .collect();
    let query = uri.query().unwrap_or_default().to_string();
    let body = String::from_utf8_lossy(&body).to_string();
    let mut buf = state.lock().await;
    buf.push(SeenRequest {
        headers,
        query,
        body,
    });
    (
        StatusCode::OK,
        [("content-type", "text/plain")],
        "upstream-response",
    )
}

#[tokio::test]
async fn forward_path_e2e_must_apply_evasion_and_return_response_unchanged() {
    let (upstream_port, captured, upstream_handle) = start_upstream_server().await;
    let proxy_port = pick_free_port().expect("pick proxy port");
    let mut proxy = start_proxy_and_wait(
        proxy_port,
        &[
            "--allow-private-upstream",
            "--content-type-switching",
            "--escalation",
            "heavy",
        ],
    )
    .await
    .expect("start proxy");

    let client = proxy_client(proxy_port).expect("proxy client");
    let target = format!("http://127.0.0.1:{upstream_port}/search?q=1'+OR+'1");
    let response = client
        .post(target)
        .header("content-type", "application/x-www-form-urlencoded")
        .body("q=1%2BOR%2B1")
        .send()
        .await
        .expect("send through proxy");
    assert!(response.status().is_success());
    let body = response.text().await.expect("read response body");
    assert_eq!(body, "upstream-response");

    let requests = captured.lock().await;
    let request = requests.last().expect("one request");
    assert!(
        request.query.contains("q="),
        "query was forwarded to upstream"
    );
    assert!(
        !has_header(
            &request.headers,
            "content-type",
            "application/x-www-form-urlencoded"
        ),
        "content-type must be mutated by proxy"
    );
    upstream_handle.abort();
    stop_proxy(&mut proxy).await;
}

#[tokio::test]
async fn forward_path_e2e_must_not_apply_evasion_when_off() {
    let (upstream_port, captured, upstream_handle) = start_upstream_server().await;
    let proxy_port = pick_free_port().expect("pick proxy port");
    let mut proxy = start_proxy_and_wait(
        proxy_port,
        &[
            "--allow-private-upstream",
            "--content-type-switching",
            "--escalation",
            "heavy",
        ],
    )
    .await
    .expect("start proxy");

    let client = proxy_client(proxy_port).expect("proxy client");
    let target = format!("http://127.0.0.1:{upstream_port}/search?q=1'+OR+'1");
    let response = client
        .post(target)
        .header("x-wafrift-evade", "off")
        .header("content-type", "application/x-www-form-urlencoded")
        .body("q=1%2BOR%2B1")
        .send()
        .await
        .expect("send through proxy");
    assert!(response.status().is_success());
    assert_eq!(
        response.text().await.expect("read response body"),
        "upstream-response"
    );

    let requests = captured.lock().await;
    let request = requests.last().expect("one request");
    assert!(
        has_header_prefix(
            &request.headers,
            "content-type",
            "application/x-www-form-urlencoded",
        ),
        "evade off must keep original content-type"
    );
    assert_eq!(request.body, "q=1%2BOR%2B1");
    upstream_handle.abort();
    stop_proxy(&mut proxy).await;
}