use clap::Parser;
use futures_util::StreamExt;
use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Semaphore};
use tracing::{debug, error, info, trace, warn};
use http_body_util::{BodyExt, Full, Limited};
use hyper::body::{Bytes, Incoming};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::upgrade::Upgraded;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream};
use wafrift_proxy::hop_by_hop::{
collect_connection_header_names, collect_connection_header_names_hyper,
should_strip_proxy_header,
};
use wafrift_proxy::mitm::{CertificateAuthority, tls_server_name_from_authority};
use wafrift_proxy::rate_limit::RateLimiter;
use wafrift_proxy::scope::ScopeFilter;
use wafrift_proxy::upstream_policy::{
BogonFilteringResolver, UpstreamPolicy,
assert_forward_url_allowed,
};
use wafrift_strategy::strategy::{evade, evade_smart};
use wafrift_strategy::{EvasionConfig, HostState};
use wafrift_transport::signal::{BlockClass, ResponseProfileDb};
use wafrift_types::EvasionResult;
const MAX_PROXY_BODY_BYTES: usize = 16 * 1024 * 1024;
use std::sync::OnceLock;
static WARN_THROTTLE: OnceLock<WarnThrottle> = OnceLock::new();
#[derive(Clone)]
struct ProxyLimits {
max_upstream_response_bytes: usize,
max_evade_retries: u32,
}
const X_WAFRIFT_EVADE: &str = "x-wafrift-evade";
const X_WAFRIFT_TECHNIQUES: &str = "x-wafrift-techniques";
const X_WAFRIFT_BLOCKED: &str = "x-wafrift-blocked";
type SharedLogger = Option<Arc<RequestLogger>>;
struct RequestLogger {
#[allow(dead_code)] dir: PathBuf,
writer: tokio::sync::Mutex<std::io::BufWriter<std::fs::File>>,
}
impl RequestLogger {
fn open(dir: &std::path::Path) -> std::io::Result<Self> {
std::fs::create_dir_all(dir)?;
let now = time::OffsetDateTime::now_utc();
let ts = format!(
"{:04}{:02}{:02}T{:02}{:02}{:02}Z",
now.year(),
now.month() as u8,
now.day(),
now.hour(),
now.minute(),
now.second(),
);
let path = dir.join(format!("wafrift-proxy-{ts}.ndjson"));
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)?;
info!(path = %path.display(), "request/response log opened");
Ok(Self {
dir: dir.to_path_buf(),
writer: tokio::sync::Mutex::new(std::io::BufWriter::new(file)),
})
}
async fn log_entry(&self, entry: &serde_json::Value) {
use std::io::Write;
let mut w = self.writer.lock().await;
if let Ok(line) = serde_json::to_string(entry) {
let _ = writeln!(w, "{line}");
let _ = w.flush();
}
}
}
#[derive(Parser, Debug)]
#[command(name = "wafrift-proxy", about = "WAF Evasion Proxy")]
struct Args {
#[arg(long, default_value = "127.0.0.1:8080")]
listen: String,
#[arg(long)]
escalation: Option<String>,
#[arg(long)]
content_type_switching: bool,
#[arg(long)]
fingerprint_rotation: bool,
#[arg(long, default_value_t = false)]
insecure: bool,
#[arg(long = "write-mitm-ca-dir")]
write_mitm_ca_dir: Option<PathBuf>,
#[arg(long, default_value_t = false)]
mitm: bool,
#[arg(long = "mitm-ca-dir")]
mitm_ca_dir: Option<PathBuf>,
#[arg(long, default_value_t = false)]
allow_private_upstream: bool,
#[arg(long = "insecure-open-upstream", default_value_t = false)]
insecure_open_upstream: bool,
#[arg(long, default_value_t = 4096)]
max_concurrent_connections: usize,
#[arg(long, default_value_t = 33554432)]
max_upstream_response_bytes: usize,
#[arg(long, default_value_t = 0)]
max_evade_retries: u32,
#[arg(long, default_value = "")]
gene_bank_path: String,
#[arg(long, default_value_t = 60)]
gene_bank_flush_interval_secs: u64,
#[arg(long, num_args = 1.., value_delimiter = ',')]
only_host: Vec<String>,
#[arg(long, num_args = 1.., value_delimiter = ',')]
skip_host: Vec<String>,
#[arg(long, num_args = 1.., value_delimiter = ',')]
only_path: Vec<String>,
#[arg(long, num_args = 1.., value_delimiter = ',')]
skip_path: Vec<String>,
#[arg(long, num_args = 1.., value_delimiter = ',')]
only_method: Vec<String>,
#[arg(long, default_value_t = 0.0)]
max_rps_per_host: f64,
#[arg(long, default_value_t = 0.0)]
max_rps_per_host_burst: f64,
#[arg(long = "log-dir")]
log_dir: Option<PathBuf>,
#[arg(long = "tls-impersonate", conflicts_with = "tls_impersonate_rotate")]
tls_impersonate: Option<String>,
#[arg(long = "tls-impersonate-rotate", num_args = 1.., value_delimiter = ',')]
tls_impersonate_rotate: Vec<String>,
#[arg(long = "body-padding-bytes", default_value_t = 0)]
body_padding_bytes: usize,
#[arg(long = "no-conn-reuse", default_value_t = false)]
no_conn_reuse: bool,
#[arg(long = "tui", default_value_t = false)]
tui: bool,
#[arg(long = "mutate-url", default_value_t = false)]
mutate_url: bool,
#[arg(long = "captchaforge", default_value_t = false)]
captchaforge: bool,
}
type SharedState = Arc<Mutex<ProxyState>>;
static STEALTH_CLIENT: std::sync::OnceLock<wafrift_transport::stealth::StealthClient> =
std::sync::OnceLock::new();
static STEALTH_POOL: std::sync::OnceLock<StealthPool> = std::sync::OnceLock::new();
static BODY_PADDING_BYTES: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
static MUTATE_URL_ENABLED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static CHALLENGE_STORE: std::sync::OnceLock<wafrift_transport::challenge::ChallengeStore> =
std::sync::OnceLock::new();
fn challenge_store() -> &'static wafrift_transport::challenge::ChallengeStore {
CHALLENGE_STORE.get_or_init(wafrift_transport::challenge::ChallengeStore::new)
}
static TUI_TX: std::sync::OnceLock<tokio::sync::mpsc::Sender<wafrift_proxy::tui::Event>> =
std::sync::OnceLock::new();
static TUI_DROPPED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
#[inline]
fn emit_tui(ev: wafrift_proxy::tui::Event) {
if let Some(tx) = TUI_TX.get()
&& tx.try_send(ev).is_err()
{
TUI_DROPPED.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
struct StealthPool {
clients: Vec<wafrift_transport::stealth::StealthClient>,
cursor: std::sync::atomic::AtomicUsize,
}
impl StealthPool {
fn pick(&self) -> &wafrift_transport::stealth::StealthClient {
let i = self
.cursor
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.clients.len();
&self.clients[i]
}
}
#[inline]
fn stealth() -> Option<&'static wafrift_transport::stealth::StealthClient> {
if let Some(pool) = STEALTH_POOL.get() {
return Some(pool.pick());
}
STEALTH_CLIENT.get()
}
struct WarnThrottle {
cooldown: Duration,
last: std::sync::Mutex<HashMap<String, Instant>>,
}
impl WarnThrottle {
fn new(cooldown_secs: u64) -> Self {
Self {
cooldown: Duration::from_secs(cooldown_secs),
last: std::sync::Mutex::new(HashMap::new()),
}
}
fn should_warn(&self, key: &str) -> bool {
let mut map = match self.last.lock() {
Ok(g) => g,
Err(e) => e.into_inner(),
};
let now = Instant::now();
if let Some(last) = map.get(key)
&& now.duration_since(*last) < self.cooldown
{
return false;
}
map.insert(key.to_string(), now);
true
}
}
#[derive(Default)]
struct ProxyState {
hosts: HashMap<String, HostState>,
host_fifo: VecDeque<String>,
total_scanned: u32,
total_blocks: u32,
techniques_used: HashMap<String, u32>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
struct PersistedHostState {
proven_winners: Vec<String>,
blocklisted: Vec<String>,
waf_name: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
struct PersistedGeneBank {
schema: u32,
hosts: HashMap<String, PersistedHostState>,
}
fn default_gene_bank_path(supplied: &str) -> Option<std::path::PathBuf> {
if supplied.is_empty() {
let home = std::env::var_os("HOME")?;
let p = std::path::PathBuf::from(home)
.join(".wafrift")
.join("gene-bank.json");
Some(p)
} else if supplied == "off" || supplied == "-" {
None
} else {
Some(std::path::PathBuf::from(supplied))
}
}
fn load_gene_bank(path: &std::path::Path) -> PersistedGeneBank {
match std::fs::read_to_string(path) {
Ok(s) => {
if s.trim().is_empty() {
info!(path = %path.display(), "gene bank file is empty; starting fresh");
return PersistedGeneBank::default();
}
match serde_json::from_str::<PersistedGeneBank>(&s) {
Ok(bank) => {
if bank.schema > 1 {
warn!(
path = %path.display(),
schema = bank.schema,
"gene bank has newer schema than expected (1); data may be incomplete"
);
}
bank
}
Err(e) => {
warn!(
path = %path.display(),
error = %e,
"gene bank malformed (invalid JSON); starting fresh. Fix: inspect the file and fix the JSON syntax, or delete it to start over."
);
PersistedGeneBank::default()
}
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
info!(path = %path.display(), "gene bank not found; starting fresh");
PersistedGeneBank::default()
}
Err(e) => {
warn!(
path = %path.display(),
error = %e,
"gene bank unreadable; starting fresh. Fix: check file permissions."
);
PersistedGeneBank::default()
}
}
}
fn save_gene_bank(state: &ProxyState, path: &std::path::Path) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let mut bank = PersistedGeneBank {
schema: 1,
hosts: HashMap::new(),
};
for (host, hs) in &state.hosts {
if hs.proven_winners.is_empty()
&& hs.blocklisted.is_empty()
&& hs.waf_name.is_none()
&& hs.blocks == 0
{
continue; }
bank.hosts.insert(
host.clone(),
PersistedHostState {
proven_winners: hs.proven_winners.clone(),
blocklisted: hs.blocklisted.clone(),
waf_name: hs.waf_name.clone(),
},
);
}
let json = serde_json::to_string_pretty(&bank)?;
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_nanos());
let tmp = path.with_extension(format!("json.tmp.{pid}.{nanos}"));
{
use std::io::Write;
let mut f = std::fs::File::create(&tmp)?;
f.write_all(json.as_bytes())?;
f.sync_all()?;
}
std::fs::rename(&tmp, path)?;
if let Some(parent) = path.parent()
&& let Ok(dir) = std::fs::OpenOptions::new().read(true).open(parent)
{
let _ = dir.sync_all();
}
Ok(())
}
fn restore_gene_bank(state: &mut ProxyState, bank: PersistedGeneBank) -> usize {
let mut restored = 0usize;
for (host, persisted) in bank.hosts {
let hs = state.hosts.entry(host.clone()).or_default();
if !persisted.proven_winners.is_empty() {
hs.proven_winners = persisted.proven_winners;
hs.discovery_complete = true;
restored += 1;
}
if !persisted.blocklisted.is_empty() {
hs.blocklisted = persisted.blocklisted;
}
if persisted.waf_name.is_some() {
hs.waf_name = persisted.waf_name;
hs.waf_confirmed = true;
}
if !state.host_fifo.contains(&host) {
state.host_fifo.push_back(host);
}
}
restored
}
use wafrift_proxy::extract_host_from_header;
fn validate_args(args: &Args) -> Result<(), String> {
if args.max_concurrent_connections == 0 {
return Err("--max-concurrent-connections must be >= 1, got 0".into());
}
if args.max_upstream_response_bytes < 4096 {
return Err(format!(
"--max-upstream-response-bytes must be >= 4096 (4 KiB), got {}",
args.max_upstream_response_bytes
));
}
if args.max_rps_per_host < 0.0 {
return Err(format!(
"--max-rps-per-host must be a non-negative number, got {}",
args.max_rps_per_host
));
}
if args.max_rps_per_host_burst < 0.0 {
return Err(format!(
"--max-rps-per-host-burst must be a non-negative number, got {}",
args.max_rps_per_host_burst
));
}
if let Some(esc) = &args.escalation
&& !matches!(esc.as_str(), "light" | "medium" | "heavy")
{
return Err(format!(
"--escalation must be one of: light, medium, heavy. Got: {esc}"
));
}
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(unix)]
{
#[allow(unsafe_code)]
unsafe {
libc::signal(libc::SIGPIPE, libc::SIG_DFL);
}
}
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let mut args = Args::parse();
use tracing_subscriber::EnvFilter;
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
if args.tui {
let log_path = match &args.log_dir {
Some(dir) => {
std::fs::create_dir_all(dir).ok();
dir.join("wafrift-proxy-tui.log")
}
None => std::path::PathBuf::from("/tmp/wafrift-proxy-tui.log"),
};
match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)
{
Ok(f) => {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_writer(std::sync::Mutex::new(f))
.with_ansi(false)
.init();
eprintln!("(--tui) logs writing to {}", log_path.display());
}
Err(e) => {
eprintln!(
"(--tui) could not open log file {}: {} — disabling --tui to keep stdout logs",
log_path.display(),
e
);
args.tui = false;
tracing_subscriber::fmt().with_env_filter(env_filter).init();
}
}
} else {
tracing_subscriber::fmt().with_env_filter(env_filter).init();
}
if let Err(msg) = validate_args(&args) {
error!("{msg}");
std::process::exit(1);
}
if let Some(dir) = &args.write_mitm_ca_dir {
let ca = CertificateAuthority::generate()?;
ca.write_to_dir(dir)?;
info!(
"Wrote MITM CA to {} — install {} in your client, then run with --mitm --mitm-ca-dir ...",
dir.display(),
dir.join("wafrift-mitm-ca.pem").display()
);
println!(
"MITM CA written to:\n {}\n {}\n\nTrust the CA in your OS or browser, then:\n wafrift-proxy --mitm --mitm-ca-dir {}",
dir.join("wafrift-mitm-ca.pem").display(),
dir.join("wafrift-mitm-ca-key.pem").display(),
dir.display()
);
return Ok(());
}
if args.mitm && args.mitm_ca_dir.is_none() {
let Some(default_dir) = wafrift_proxy::mitm::default_mitm_ca_dir() else {
error!(
"cannot determine home directory for MITM CA storage \
(no $HOME / dirs::config_dir on this OS). Pass --mitm-ca-dir \
explicitly or unset --mitm."
);
std::process::exit(1);
};
info!(
"No --mitm-ca-dir specified; using default: {}",
default_dir.display()
);
args.mitm_ca_dir = Some(default_dir);
}
let mitm_ca: Option<Arc<CertificateAuthority>> = if args.mitm {
let dir = args
.mitm_ca_dir
.as_ref()
.ok_or("internal error: mitm_ca_dir was not set")?;
let ca = wafrift_proxy::mitm::ensure_ca(dir)?;
let cert_path = dir.join("wafrift-mitm-ca.pem");
match wafrift_proxy::mitm::install_ca_trust(&cert_path) {
wafrift_proxy::mitm::TrustResult::Installed { method } => {
info!("MITM CA auto-trusted via {method}");
}
wafrift_proxy::mitm::TrustResult::ManualRequired { instructions } => {
println!("\n{instructions}\n");
info!("CA generated at: {}", cert_path.display());
}
wafrift_proxy::mitm::TrustResult::Failed {
error,
instructions,
} => {
warn!("Auto-trust failed: {error}");
println!("\n{instructions}\n");
}
}
Some(Arc::new(ca))
} else {
None
};
let addr: SocketAddr = args.listen.parse().unwrap_or_else(|e| {
error!("--listen must be a valid socket address (e.g. 127.0.0.1:8080, [::1]:8080), got '{}': {}", args.listen, e);
std::process::exit(1);
});
let listener = TcpListener::bind(addr).await.unwrap_or_else(|e| {
error!("Failed to bind to {addr}: {e}");
std::process::exit(1);
});
info!("Listening on http://{}", addr);
let expose_wafrift_status = addr.ip().is_loopback();
if !expose_wafrift_status {
warn!(
"--listen is bound to a non-loopback address ({}). /_wafrift/status and /_wafrift/findings.md are disabled to prevent information leakage.",
addr
);
if args.mitm {
error!(
"REFUSING TO START: --mitm + non-loopback --listen ({}) is a CA-private-key-exposure risk. \
Anyone on the network can route HTTPS through this proxy and have it re-signed with your MITM CA. \
If you really want this (lab-only), bind to a loopback address and front-end with your own ACL'd reverse proxy.",
addr
);
std::process::exit(1);
}
}
let mut config = EvasionConfig::default();
if args.content_type_switching {
config.content_type_switching = true;
}
if args.fingerprint_rotation {
config.fingerprint_rotation = true;
}
if args.insecure {
config.insecure_tls = true;
}
let shared_state = Arc::new(Mutex::new(ProxyState::default()));
let config = Arc::new(config);
let default_escalation = args.escalation.clone();
let mitm_enabled = args.mitm;
let response_profiles = {
let next_to_binary = std::env::current_exe()
.ok()
.and_then(|p| p.parent().map(|d| d.join("rules/responses")))
.filter(|d| d.is_dir());
let cwd_dir = std::path::Path::new("rules/responses");
if let Some(dir) = next_to_binary {
ResponseProfileDb::load_dir(&dir)
} else if cwd_dir.is_dir() {
ResponseProfileDb::load_dir(cwd_dir)
} else {
info!(
"no rules/responses/ directory found — using compiled-in profiles \
(override with a rules/responses/ dir next to the binary)"
);
ResponseProfileDb::compiled_in()
}
};
let response_profiles = Arc::new(response_profiles);
let policy = Arc::new(UpstreamPolicy {
allow_private_upstream: args.allow_private_upstream,
insecure_open_upstream: args.insecure_open_upstream,
});
let _ = WARN_THROTTLE.set(WarnThrottle::new(5));
if args.insecure_open_upstream && args.allow_private_upstream {
warn!(
"--insecure-open-upstream makes --allow-private-upstream redundant; all upstream checks are disabled"
);
}
if let Some(profile_str) = &args.tls_impersonate {
use wafrift_transport::stealth::{ImpersonateProfile, StealthClient};
let profile = match ImpersonateProfile::parse(profile_str) {
Ok(p) => p,
Err(e) => {
error!("--tls-impersonate: {}", e);
std::process::exit(2);
}
};
let client = match StealthClient::with_timeout(
profile,
std::time::Duration::from_secs(wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS),
) {
Ok(c) => c,
Err(e) => {
error!(
"--tls-impersonate {}: {}\nhint: rebuild with `cargo build --features wafrift-transport/tls-impersonate` (pulls in boring-sys)",
profile.name(),
e
);
std::process::exit(2);
}
};
if STEALTH_CLIENT.set(client).is_err() {
warn!("STEALTH_CLIENT was already initialised; ignoring duplicate set");
}
info!(
"TLS impersonation active: every upstream forward will wear {}'s ClientHello",
profile.name()
);
}
if !args.tls_impersonate_rotate.is_empty() {
use wafrift_transport::stealth::{ImpersonateProfile, StealthClient};
let mut clients = Vec::with_capacity(args.tls_impersonate_rotate.len());
let mut names = Vec::with_capacity(args.tls_impersonate_rotate.len());
for raw in &args.tls_impersonate_rotate {
let raw = raw.trim();
if raw.is_empty() {
continue;
}
let profile = match ImpersonateProfile::parse(raw) {
Ok(p) => p,
Err(e) => {
error!("--tls-impersonate-rotate: {}", e);
std::process::exit(2);
}
};
let c = match StealthClient::with_timeout(
profile,
std::time::Duration::from_secs(wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS),
) {
Ok(c) => c,
Err(e) => {
error!(
"--tls-impersonate-rotate {}: {}\nhint: rebuild with `cargo build --features wafrift-transport/tls-impersonate`",
profile.name(),
e
);
std::process::exit(2);
}
};
clients.push(c);
names.push(profile.name());
}
if clients.is_empty() {
error!("--tls-impersonate-rotate: empty profile list after trimming");
std::process::exit(2);
}
let pool = StealthPool {
clients,
cursor: std::sync::atomic::AtomicUsize::new(0),
};
if STEALTH_POOL.set(pool).is_err() {
warn!("STEALTH_POOL was already initialised; ignoring duplicate set");
}
info!(
"TLS impersonation rotation active: every upstream forward picks round-robin from {:?}",
names
);
}
if args.mutate_url {
MUTATE_URL_ENABLED.store(true, std::sync::atomic::Ordering::Relaxed);
warn!(
"--mutate-url: every upstream URL's query parameter values will be aggressively \
percent-encoded. This changes routing semantics (cache keys, log entries) — \
ensure the upstream is robust to encoded query bytes."
);
}
if args.captchaforge {
#[cfg(feature = "captchaforge")]
{
if let Err(e) = wafrift_captchaforge_bridge::install_global_solver().await {
error!("--captchaforge: solver install failed: {e}");
return Err(format!("captchaforge install failed: {e}").into());
}
info!(
"--captchaforge: headless-browser solver installed into ChallengeStore. \
Cloudflare/Turnstile/hCaptcha responses will be auto-solved via captchaforge."
);
}
#[cfg(not(feature = "captchaforge"))]
{
error!(
"--captchaforge requires the binary to be built with `--features captchaforge`. \
Rebuild with `cargo build --release --features captchaforge` and retry."
);
return Err("--captchaforge requires the captchaforge feature".into());
}
}
if args.body_padding_bytes > 0 {
BODY_PADDING_BYTES.store(
args.body_padding_bytes,
std::sync::atomic::Ordering::Relaxed,
);
if args.body_padding_bytes < wafrift_evolution::body_padding::MIN_USEFUL_PAD {
warn!(
"--body-padding-bytes {} is below the {}-byte useful minimum; padding will be skipped",
args.body_padding_bytes,
wafrift_evolution::body_padding::MIN_USEFUL_PAD
);
} else {
info!(
"Body padding active: every JSON / form / multipart request body gets {} bytes of inert leading filler",
args.body_padding_bytes
);
}
}
let mut client_builder = reqwest::Client::builder()
.danger_accept_invalid_certs(config.insecure_tls)
.timeout(std::time::Duration::from_secs(
wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS,
))
.dns_resolver(Arc::new(BogonFilteringResolver {
policy: policy.clone(),
}));
if args.no_conn_reuse {
client_builder = client_builder.pool_max_idle_per_host(0);
info!(
"Connection re-use disabled: every upstream forward opens a fresh TCP connection (new source port per request)"
);
}
let global_client = client_builder.build().unwrap_or_else(|e| {
error!("reqwest client build failed: {e}");
std::process::exit(1);
});
let limits = Arc::new(ProxyLimits {
max_upstream_response_bytes: args.max_upstream_response_bytes,
max_evade_retries: args.max_evade_retries,
});
let scope = Arc::new(ScopeFilter::new(
args.only_host.clone(),
args.skip_host.clone(),
args.only_path.clone(),
args.skip_path.clone(),
args.only_method.clone(),
));
if !scope.is_empty() {
info!(
only_host = ?args.only_host,
skip_host = ?args.skip_host,
only_path = ?args.only_path,
skip_path = ?args.skip_path,
only_method = ?args.only_method,
"scope filter active — out-of-scope requests pass through unchanged"
);
}
let rate_limiter = RateLimiter::new(args.max_rps_per_host, args.max_rps_per_host_burst);
if !rate_limiter.is_unlimited() {
info!(
rps = args.max_rps_per_host,
burst = if args.max_rps_per_host_burst > 0.0 {
args.max_rps_per_host_burst
} else {
args.max_rps_per_host
},
"per-host rate limiter active"
);
}
let conn_sem = Arc::new(Semaphore::new(args.max_concurrent_connections));
let logger: SharedLogger = if let Some(dir) = &args.log_dir {
match RequestLogger::open(dir) {
Ok(l) => Some(Arc::new(l)),
Err(e) => {
error!(dir = %dir.display(), error = %e, "failed to open log directory");
std::process::exit(1);
}
}
} else {
None
};
if args.insecure_open_upstream {
warn!("--insecure-open-upstream: upstream DNS/literal policy checks are disabled");
}
if args.insecure {
warn!(
"--insecure: upstream TLS certificate verification is disabled — \
do NOT use on untrusted networks; an on-path attacker can MITM \
every HTTPS connection wafrift makes"
);
}
let gene_bank_path = default_gene_bank_path(&args.gene_bank_path);
if let Some(path) = &gene_bank_path {
let restored = {
let mut st = shared_state.lock().await;
let bank = load_gene_bank(path);
restore_gene_bank(&mut st, bank)
};
if restored > 0 {
info!(
path = %path.display(),
hosts_restored = restored,
"loaded persistent gene bank"
);
} else {
info!(path = %path.display(), "starting with empty gene bank");
}
if args.gene_bank_flush_interval_secs > 0 {
let flush_path = path.clone();
let flush_state = shared_state.clone();
let interval = args.gene_bank_flush_interval_secs;
tokio::spawn(async move {
let mut tick = tokio::time::interval(std::time::Duration::from_secs(interval));
tick.tick().await; loop {
tick.tick().await;
let st = flush_state.lock().await;
if let Err(e) = save_gene_bank(&st, &flush_path) {
warn!(error = %e, "periodic gene bank flush failed");
}
}
});
}
}
let shutdown_state = shared_state.clone();
let shutdown_path = gene_bank_path.clone();
tokio::spawn(async move {
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = match signal(SignalKind::terminate()) {
Ok(s) => s,
Err(e) => {
warn!(error = %e, "SIGTERM handler setup failed; graceful shutdown disabled");
return;
}
};
let mut sigint = match signal(SignalKind::interrupt()) {
Ok(s) => s,
Err(e) => {
warn!(error = %e, "SIGINT handler setup failed; graceful shutdown disabled");
return;
}
};
tokio::select! {
_ = sigterm.recv() => info!("received SIGTERM"),
_ = sigint.recv() => info!("received SIGINT"),
};
if let Some(path) = &shutdown_path {
let st = shutdown_state.lock().await;
match save_gene_bank(&st, path) {
Ok(()) => info!(path = %path.display(), "gene bank flushed on shutdown"),
Err(e) => {
warn!(path = %path.display(), error = %e, "gene bank flush on shutdown failed");
}
}
}
info!("shutting down");
std::process::exit(0);
});
if args.tui {
let (tx, rx) = tokio::sync::mpsc::channel(10_000);
if TUI_TX.set(tx).is_err() {
warn!("TUI_TX was already initialised; skipping TUI startup");
} else {
let tls_label = if !args.tls_impersonate_rotate.is_empty() {
format!("rotate({})", args.tls_impersonate_rotate.join(","))
} else if let Some(p) = &args.tls_impersonate {
format!("single({p})")
} else {
"off".to_string()
};
let cfg = wafrift_proxy::tui::DashboardConfig {
bind_addr: addr.to_string(),
mode: default_escalation
.clone()
.unwrap_or_else(|| "evade".to_string()),
tls_stack_label: tls_label,
body_padding_bytes: args.body_padding_bytes,
conn_reuse: !args.no_conn_reuse,
};
let (quit_tx, quit_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
if let Err(e) = wafrift_proxy::tui::run(cfg, rx, quit_tx).await {
eprintln!("TUI exited with error: {e}");
}
});
let quit_state = shared_state.clone();
let quit_path = gene_bank_path.clone();
tokio::spawn(async move {
if quit_rx.await.is_ok() {
if let Some(path) = &quit_path {
let st = quit_state.lock().await;
if let Err(e) = save_gene_bank(&st, path) {
warn!(path = %path.display(), error = %e, "gene bank flush from TUI quit failed");
}
}
std::process::exit(0);
}
});
}
}
loop {
let permit = match conn_sem.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => continue,
};
let (stream, peer) = listener.accept().await?;
let io = TokioIo::new(stream);
let shared_state = shared_state.clone();
let config = config.clone();
let default_escalation = default_escalation.clone();
let client = global_client.clone();
let mitm_ca = mitm_ca.clone();
let policy = policy.clone();
let limits = limits.clone();
let scope = scope.clone();
let rate_limiter = rate_limiter.clone();
let response_profiles = response_profiles.clone();
let expose_status_per_conn = expose_wafrift_status && peer.ip().is_loopback();
let logger = logger.clone();
tokio::task::spawn(async move {
let _permit = permit;
if let Err(err) = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(
io,
service_fn(move |req| {
proxy(
req,
shared_state.clone(),
config.clone(),
default_escalation.clone(),
client.clone(),
mitm_enabled,
mitm_ca.clone(),
policy.clone(),
limits.clone(),
scope.clone(),
rate_limiter.clone(),
expose_status_per_conn,
logger.clone(),
response_profiles.clone(),
)
}),
)
.with_upgrades()
.await
{
warn!("failed to serve connection: {:?}", err);
}
});
}
}
fn header_value_to_string(name: &str, value: &hyper::header::HeaderValue) -> String {
match String::from_utf8(value.as_bytes().to_vec()) {
Ok(s) => s,
Err(_) => {
let lossy = String::from_utf8_lossy(value.as_bytes()).to_string();
tracing::warn!(header = %name, "header value contains invalid UTF-8; using lossy conversion");
lossy
}
}
}
fn split_url_for_mutation(url: &str) -> Option<(String, String)> {
let scheme_end = url.find("://")?;
let after_scheme = &url[scheme_end + 3..];
let path_start = after_scheme.find('/')?;
let absolute_path_start = scheme_end + 3 + path_start;
Some((
url[..absolute_path_start].to_string(),
url[absolute_path_start..].to_string(),
))
}
fn error_response(status: StatusCode, message: &str) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.body(Full::new(Bytes::from(message.to_string())))
.unwrap_or_else(|_| {
let mut resp = Response::new(Full::new(Bytes::from("internal error")));
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
resp
})
}
#[allow(clippy::too_many_arguments)]
async fn forward_with_evade_retry(
wafrift_req: wafrift_types::Request,
host: String,
request_log_uri: String,
state: SharedState,
config: Arc<EvasionConfig>,
default_escalation: Option<String>,
client: &reqwest::Client,
policy: Arc<UpstreamPolicy>,
limits: Arc<ProxyLimits>,
response_profiles: Arc<ResponseProfileDb>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let max = limits.max_evade_retries;
let mut last: Option<Response<Full<Bytes>>> = None;
for attempt in 0..=max {
let resp = forward_wafrift_request(
wafrift_req.clone(),
host.clone(),
request_log_uri.clone(),
Arc::clone(&state),
Arc::clone(&config),
default_escalation.clone(),
client,
Arc::clone(&policy),
Arc::clone(&limits),
Arc::clone(&response_profiles),
)
.await?;
let status = resp.status().as_u16();
if status != 403 && status != 406 {
if attempt > 0 {
info!(
host = %host,
attempt,
status,
"evade retry landed a bypass"
);
}
return Ok(resp);
}
last = Some(resp);
}
Ok(last.unwrap_or_else(|| {
let mut r = Response::new(Full::new(Bytes::from("no attempt completed")));
*r.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
r
}))
}
#[allow(clippy::too_many_arguments)]
async fn forward_wafrift_request(
wafrift_req: wafrift_types::Request,
host: String,
request_log_uri: String,
state: SharedState,
config: Arc<EvasionConfig>,
default_escalation: Option<String>,
client: &reqwest::Client,
policy: Arc<UpstreamPolicy>,
limits: Arc<ProxyLimits>,
response_profiles: Arc<ResponseProfileDb>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
enum EvadePlan {
Replay {
replay_state: HostState,
winner_name: String,
},
Discovery {
host_state: HostState,
},
}
let plan = {
let mut st = state.lock().await;
st.total_scanned = st.total_scanned.saturating_add(1);
if st.hosts.len() >= 10_000 && !st.hosts.contains_key(&host) {
while let Some(key_to_remove) = st.host_fifo.pop_front() {
if st.hosts.remove(&key_to_remove).is_some() {
break;
}
}
}
let is_new = !st.hosts.contains_key(&host);
if is_new {
st.host_fifo.push_back(host.clone());
}
let hs = st.hosts.entry(host.clone()).or_default();
if let Some(esc) = &default_escalation {
match esc.as_str() {
"heavy" if hs.blocks < 6 => hs.blocks = 6,
"medium" if hs.blocks < 3 => hs.blocks = 3,
"light" if hs.blocks < 1 => hs.blocks = 1,
_ => {}
}
}
if hs.has_winners() {
let winner_name = hs.next_winner().unwrap_or_default();
info!(
host = %host,
technique = %winner_name,
pool_size = hs.proven_winners.len(),
"rotating proven winner"
);
let replay_state = HostState {
proven_winners: vec![winner_name.clone()],
discovery_complete: true,
..HostState::default()
};
EvadePlan::Replay {
replay_state,
winner_name,
}
} else {
if hs.discovery_complete {
info!(host = %host, "all winners pruned, re-entering discovery");
}
EvadePlan::Discovery {
host_state: hs.clone(),
}
}
};
let req_headers_pre = wafrift_req.headers.clone();
let req_body_pre_excerpt: Vec<u8> = wafrift_req
.body
.as_deref()
.map(|b| b[..b.len().min(wafrift_proxy::tui::MAX_BODY_EXCERPT)].to_vec())
.unwrap_or_default();
let (mut evasion_result, technique_keys) = match plan {
EvadePlan::Replay {
replay_state,
winner_name,
} => {
let req = wafrift_req.clone();
let req_fallback = req.clone();
let state = replay_state.clone();
let cfg = (*config).clone();
let result = tokio::task::spawn_blocking(move || evade(&req, &state, &cfg))
.await
.unwrap_or_else(|e| {
tracing::error!(error = %e, "evade task panicked");
EvasionResult::new(req_fallback, vec![], String::new())
});
let mut keys: Vec<String> = result
.techniques
.iter()
.map(std::string::ToString::to_string)
.collect();
if keys.is_empty() {
keys.push(winner_name);
}
(result, keys)
}
EvadePlan::Discovery { host_state } => {
let req = wafrift_req.clone();
let req_fallback = req.clone();
let state = host_state.clone();
let cfg = (*config).clone();
let result = tokio::task::spawn_blocking(move || evade_smart(&req, &state, &cfg))
.await
.unwrap_or_else(|e| {
tracing::error!(error = %e, "evade_smart task panicked");
EvasionResult::new(req_fallback, vec![], String::new())
});
let keys: Vec<String> = result
.techniques
.iter()
.map(std::string::ToString::to_string)
.collect();
if !result.techniques.is_empty() {
info!(
uri = %request_log_uri,
techniques = %result.description,
"discovery: evading WAF"
);
}
(result, keys)
}
};
if let Err(msg) = assert_forward_url_allowed(&evasion_result.request.url, &policy).await {
warn!(host = %host, url = %evasion_result.request.url, "{}", msg);
return Ok(error_response(StatusCode::FORBIDDEN, &msg));
}
if MUTATE_URL_ENABLED.load(std::sync::atomic::Ordering::Relaxed)
&& let Some((scheme_authority, path_and_query)) =
split_url_for_mutation(&evasion_result.request.url)
{
let cfg = wafrift_encoding::url_mutate::UrlMutateConfig::default();
let (mutated_pq, _techniques) =
wafrift_encoding::url_mutate::mutate_url(&path_and_query, &cfg);
if mutated_pq != path_and_query {
let new_url = format!("{scheme_authority}{mutated_pq}");
debug!(
host = %host,
from = %path_and_query,
to = %mutated_pq,
"url mutation applied"
);
evasion_result.request.url = new_url;
}
}
let pad_target = BODY_PADDING_BYTES.load(std::sync::atomic::Ordering::Relaxed);
if pad_target >= wafrift_evolution::body_padding::MIN_USEFUL_PAD {
let ct = evasion_result
.request
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-type")).map_or_else(|| "application/octet-stream".to_string(), |(_, v)| v.clone());
let original = evasion_result.request.body.clone().unwrap_or_default();
match wafrift_evolution::body_padding::pad(&original, &ct, pad_target) {
wafrift_evolution::body_padding::PadOutcome::Padded { bytes, added } => {
evasion_result.request.body = Some(bytes);
debug!(
host = %host,
added,
target = pad_target,
"body padding applied"
);
}
wafrift_evolution::body_padding::PadOutcome::SkippedOpaque => {
trace!(host = %host, content_type = %ct, "body padding skipped: opaque content-type");
}
wafrift_evolution::body_padding::PadOutcome::SkippedTooSmall => {
}
}
}
if wafrift_proxy::intercept::intercept_mode_enabled() {
let store = wafrift_proxy::intercept::global_store();
let path_for_intercept = evasion_result
.request
.url
.splitn(4, '/')
.nth(3).map_or_else(|| "/".into(), |s| format!("/{s}"));
let (id, rx) = store.register(
host.clone(),
evasion_result.request.method.as_str(),
path_for_intercept,
);
let decision = tokio::select! {
d = rx => d.unwrap_or(wafrift_proxy::intercept::InterceptDecision::Release),
_ = tokio::time::sleep(wafrift_proxy::intercept::INTERCEPT_TIMEOUT) => {
store.cancel(id);
warn!(
host = %host,
"intercept default-allow after {} secs (operator did not act)",
wafrift_proxy::intercept::INTERCEPT_TIMEOUT.as_secs()
);
wafrift_proxy::intercept::InterceptDecision::Release
}
};
if matches!(decision, wafrift_proxy::intercept::InterceptDecision::Kill) {
return Ok(error_response(
StatusCode::FORBIDDEN,
"killed by operator from intercept tab",
));
}
}
if let Some(clearance) = challenge_store().get(&host) {
let mut found = false;
for (k, v) in &mut evasion_result.request.headers {
if k.eq_ignore_ascii_case("cookie") {
if !v.contains(&clearance) {
if v.is_empty() {
*v = clearance.clone();
} else {
*v = format!("{v}; {clearance}");
}
}
found = true;
break;
}
}
if !found {
evasion_result
.request
.headers
.push(("Cookie".into(), clearance));
}
}
let conn_fwd = collect_connection_header_names(&evasion_result.request.headers);
let max = limits.max_upstream_response_bytes;
let upstream_start = Instant::now();
let status_code: u16;
let (mut response_builder, buf) = if let Some(sc) = stealth() {
let mut filtered_headers = Vec::with_capacity(evasion_result.request.headers.len());
for (k, v) in &evasion_result.request.headers {
if k.eq_ignore_ascii_case("host")
|| k.eq_ignore_ascii_case("content-length")
|| should_strip_proxy_header(k, &conn_fwd)
{
continue;
}
filtered_headers.push((k.clone(), v.clone()));
}
let stealth_resp = match sc
.send(
evasion_result.request.method.as_str(),
&evasion_result.request.url,
&filtered_headers,
evasion_result.request.body.as_deref(),
max,
)
.await
{
Ok(r) => r,
Err(e) => {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("forward:{host}"))
{
warn!(host = %host, error = %e, stack = "stealth", "forwarding failed");
}
return Ok(error_response(StatusCode::BAD_GATEWAY, "forwarding error"));
}
};
status_code = stealth_resp.status;
let mut response_builder = Response::builder().status(stealth_resp.status);
let conn_resp: std::collections::HashSet<String> = stealth_resp
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("connection"))
.map(|(_, v)| {
v.split(',')
.map(|t| t.trim().to_ascii_lowercase())
.filter(|t| !t.is_empty())
.collect()
})
.unwrap_or_default();
for (k, v) in &stealth_resp.headers {
if should_strip_proxy_header(k, &conn_resp) {
continue;
}
response_builder = response_builder.header(k.as_str(), v.as_str());
}
(response_builder, stealth_resp.body.clone())
} else {
let method = match reqwest::Method::from_bytes(
evasion_result.request.method.as_str().as_bytes(),
) {
Ok(m) => m,
Err(e) => {
warn!(host = %host, error = %e, method = %evasion_result.request.method.as_str(), "invalid HTTP method");
return Ok(error_response(
StatusCode::BAD_REQUEST,
"invalid HTTP method",
));
}
};
let mut builder = client.request(method, &evasion_result.request.url);
for (k, v) in &evasion_result.request.headers {
if k.eq_ignore_ascii_case("host")
|| k.eq_ignore_ascii_case("content-length")
|| should_strip_proxy_header(k, &conn_fwd)
{
continue;
}
builder = builder.header(k.as_str(), v.as_str());
}
if let Some(b) = evasion_result.request.body.clone() {
builder = builder.body(b);
}
let resp = match builder.send().await {
Ok(r) => r,
Err(e) => {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("forward:{host}"))
{
warn!(host = %host, error = %e, "forwarding failed");
}
return Ok(error_response(StatusCode::BAD_GATEWAY, "forwarding error"));
}
};
let status = resp.status();
status_code = status.as_u16();
let conn_resp = collect_connection_header_names_hyper(resp.headers());
let mut response_builder = Response::builder().status(status.as_u16());
for (k, v) in resp.headers() {
if should_strip_proxy_header(k.as_str(), &conn_resp) {
continue;
}
response_builder = response_builder.header(k, v);
}
let mut stream = resp.bytes_stream();
let mut buf = Vec::new();
while let Some(item) = stream.next().await {
let chunk = match item {
Ok(c) => c,
Err(e) => {
warn!(host = %host, error = %e, "upstream body read failed");
return Ok(error_response(
StatusCode::BAD_GATEWAY,
"upstream read error",
));
}
};
if buf.len().saturating_add(chunk.len()) > max {
return Ok(error_response(
StatusCode::PAYLOAD_TOO_LARGE,
"upstream response too large",
));
}
buf.extend_from_slice(&chunk);
}
(response_builder, buf)
};
let header_pairs: Vec<(String, String)> = response_builder
.headers_ref()
.map(|hm| {
hm.iter()
.map(|(k, v)| {
(
k.as_str().to_string(),
v.to_str().unwrap_or_default().to_string(),
)
})
.collect()
})
.unwrap_or_default();
let signal = response_profiles.classify(status_code, &header_pairs, &buf);
let is_block = signal.classification.is_blocked();
{
let store = challenge_store();
let set_cookie_values: Vec<&str> = header_pairs
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case("set-cookie"))
.map(|(_, v)| v.as_str())
.collect();
if let Some((cookie, kind)) =
wafrift_transport::challenge::extract_clearance_cookie(&set_cookie_values)
{
store.record(host.clone(), cookie, kind, None);
info!(
host = %host,
kind = %kind.label(),
"challenge clearance cookie captured"
);
}
if status_code == 503 || status_code == 403 {
let body_slice = &buf[..buf.len().min(8192)];
let kind = wafrift_transport::challenge::classify_with_status(
body_slice,
&header_pairs,
status_code,
);
if !matches!(kind, wafrift_transport::challenge::ChallengeKind::Unknown)
&& store.get(&host).is_none()
&& store.should_prompt_operator(&host)
{
warn!(
host = %host,
kind = %kind.label(),
"managed challenge detected and no clearance cookie on file — clear the \
challenge in a browser; the cookie will be captured on the next response"
);
}
}
}
let detected_waf = {
let st = state.lock().await;
st.hosts.get(&host).and_then(|h| h.waf_name.clone())
};
if detected_waf.is_none() {
if let Some(ref waf_name) = signal.matched_waf {
let mut st = state.lock().await;
if let Some(hs) = st.hosts.get_mut(&host) {
hs.confirm_waf(Some(waf_name.clone()));
info!(
host = %host,
waf = %waf_name,
source = "response_profile",
"WAF identified"
);
}
} else {
let body_slice = &buf[..buf.len().min(8192)];
let detections =
wafrift_detect::waf_detect::detect(status_code, &header_pairs, body_slice);
if let Some(top) = detections.first()
&& top.confidence >= wafrift_detect::waf_detect::ACTIONABLE_CONFIDENCE_THRESHOLD
{
let mut st = state.lock().await;
if let Some(hs) = st.hosts.get_mut(&host) {
hs.confirm_waf(Some(top.name.clone()));
info!(
host = %host,
waf = %top.name,
confidence = top.confidence,
source = "wafrift_detect",
"WAF identified"
);
}
}
}
}
{
let mut st = state.lock().await;
if let Some(hs) = st.hosts.get_mut(&host) {
hs.record_signal(
signal.classification == BlockClass::HardBlock,
signal.classification == BlockClass::SoftBlock,
signal.classification == BlockClass::RateLimit,
signal.classification == BlockClass::Challenge,
signal.matched_waf.as_deref(),
&signal.prioritize,
&signal.avoid,
signal.inspection_model.as_deref(),
&technique_keys,
);
if signal.classification == BlockClass::Pass {
if evasion_result.techniques.is_empty() {
let parsed: Vec<wafrift_types::Technique> = technique_keys
.iter()
.filter_map(|k| wafrift_types::Technique::from_pool_key(k))
.collect();
if !parsed.is_empty() {
hs.record_success_for_many(&parsed);
}
} else {
hs.record_success_for_many(&evasion_result.techniques);
}
}
if signal.classification.should_backoff() {
info!(
host = %host,
classification = ?signal.classification,
"WAF rate limit / challenge — backing off, not changing technique"
);
}
}
if is_block {
st.total_blocks = st.total_blocks.saturating_add(1);
} else {
for t in &evasion_result.techniques {
let name = t.to_string();
*st.techniques_used.entry(name).or_insert(0) += 1;
}
}
}
if !technique_keys.is_empty() {
response_builder = response_builder.header(X_WAFRIFT_TECHNIQUES, technique_keys.join(", "));
}
response_builder =
response_builder.header(X_WAFRIFT_BLOCKED, if is_block { "true" } else { "false" });
{
let path_only = request_log_uri
.split('?')
.next()
.unwrap_or(&request_log_uri)
.to_string();
let body_padded = wafrift_evolution::body_padding::looks_padded(
evasion_result.request.body.as_deref().unwrap_or(&[]),
);
let tls_profile = stealth().map(|sc| sc.profile().name().to_string());
let bypassed = !is_block && !evasion_result.techniques.is_empty();
let upstream_latency_ms =
u64::try_from(upstream_start.elapsed().as_millis()).unwrap_or(u64::MAX);
let cap = wafrift_proxy::tui::MAX_BODY_EXCERPT;
let req_body_excerpt = evasion_result
.request
.body
.as_deref()
.map(|b| b[..b.len().min(cap)].to_vec())
.unwrap_or_default();
let resp_body_excerpt = buf[..buf.len().min(cap)].to_vec();
let resp_body_total = buf.len() as u64;
let waf_name = {
let st = state.lock().await;
st.hosts.get(&host).and_then(|h| h.waf_name.clone())
};
emit_tui(wafrift_proxy::tui::Event::Request {
host: host.clone(),
method: evasion_result.request.method.as_str().to_string(),
path: path_only,
status: status_code,
bypassed,
blocked: is_block,
techniques: technique_keys.join(", "),
tls_profile,
body_padded,
upstream_latency_ms,
waf_name,
req_headers: evasion_result.request.headers.clone(),
req_body_excerpt,
req_headers_pre,
req_body_pre_excerpt,
resp_headers: header_pairs.clone(),
resp_body_excerpt,
resp_body_total,
attempts: 0,
});
}
Ok(response_builder
.body(Full::new(Bytes::from(buf)))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
)
}))
}
#[allow(clippy::too_many_arguments)]
async fn mitm_plaintext_request(
mut req: Request<Incoming>,
connect_authority: String,
state: SharedState,
config: Arc<EvasionConfig>,
default_escalation: Option<String>,
client: reqwest::Client,
policy: Arc<UpstreamPolicy>,
limits: Arc<ProxyLimits>,
scope: Arc<ScopeFilter>,
rate_limiter: Arc<RateLimiter>,
response_profiles: Arc<ResponseProfileDb>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let sni_host = tls_server_name_from_authority(&connect_authority);
if let Some(h) = req
.headers()
.get(hyper::header::HOST)
.and_then(|x| x.to_str().ok())
{
let inner = extract_host_from_header(h);
if !inner.eq_ignore_ascii_case(&sni_host) {
warn!(inner = %inner, expected = %sni_host, "mitm Host header does not match CONNECT");
return Ok(error_response(
StatusCode::BAD_REQUEST,
"Host header does not match CONNECT target",
));
}
}
let authority = connect_authority.trim();
let path_and_q = req
.uri()
.path_and_query()
.map_or("/", hyper::http::uri::PathAndQuery::as_str);
let url = format!("https://{authority}{path_and_q}");
let host = sni_host;
let limited = Limited::new(req.body_mut(), MAX_PROXY_BODY_BYTES);
let body_bytes = match limited.collect().await {
Ok(b) => b.to_bytes().to_vec(),
Err(_) => {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("body-limit:{host}"))
{
warn!(host = %host, limit = MAX_PROXY_BODY_BYTES, "request body exceeded size limit");
}
return Ok(error_response(
StatusCode::PAYLOAD_TOO_LARGE,
"request body too large",
));
}
};
let raw_headers: Vec<(String, String)> = req
.headers()
.iter()
.map(|(k, v)| {
(
k.as_str().to_string(),
header_value_to_string(k.as_str(), v),
)
})
.collect();
let conn = collect_connection_header_names(&raw_headers);
let headers: Vec<(String, String)> = raw_headers
.into_iter()
.filter(|(k, _)| !should_strip_proxy_header(k, &conn))
.collect();
let mut wafrift_req = wafrift_types::Request::with_method(
wafrift_types::Method::from(req.method().as_str()),
url,
);
wafrift_req.headers = headers;
if !body_bytes.is_empty() {
wafrift_req.body = Some(body_bytes);
}
let log_uri = wafrift_req.url.clone();
rate_limiter.acquire(&host).await;
let path_for_scope = req
.uri()
.path_and_query().map_or_else(|| "/".to_string(), |p| p.path().to_string());
if !scope.allows(&host, &path_for_scope, &wafrift_req.method) {
return forward_passthrough(wafrift_req, host, &client, policy, limits).await;
}
forward_with_evade_retry(
wafrift_req,
host,
log_uri,
state,
config,
default_escalation,
&client,
policy,
limits,
response_profiles,
)
.await
}
#[allow(clippy::too_many_arguments)]
async fn mitm_https_session(
upgraded: Upgraded,
connect_authority: String,
ca: Arc<CertificateAuthority>,
state: SharedState,
config: Arc<EvasionConfig>,
default_escalation: Option<String>,
client: reqwest::Client,
policy: Arc<UpstreamPolicy>,
limits: Arc<ProxyLimits>,
scope: Arc<ScopeFilter>,
rate_limiter: Arc<RateLimiter>,
response_profiles: Arc<ResponseProfileDb>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let tls_name = tls_server_name_from_authority(&connect_authority);
let acceptor = ca.create_tls_acceptor(&tls_name)?;
let upgraded = TokioIo::new(upgraded);
let tls_stream = acceptor.accept(upgraded).await?;
let io = TokioIo::new(tls_stream);
let svc_state = state.clone();
let svc_config = config.clone();
let svc_default_esc = default_escalation.clone();
let svc_client = client.clone();
let svc_policy = policy.clone();
let svc_limits = limits.clone();
let svc_scope = scope.clone();
let svc_rl = rate_limiter.clone();
let svc_profiles = response_profiles.clone();
let cauth = connect_authority.clone();
let service = service_fn(move |req: Request<Incoming>| {
let state = svc_state.clone();
let config = svc_config.clone();
let default_escalation = svc_default_esc.clone();
let client = svc_client.clone();
let policy = svc_policy.clone();
let limits = svc_limits.clone();
let scope = svc_scope.clone();
let rate_limiter = svc_rl.clone();
let response_profiles = svc_profiles.clone();
let connect_authority = cauth.clone();
async move {
mitm_plaintext_request(
req,
connect_authority,
state,
config,
default_escalation,
client,
policy,
limits,
scope,
rate_limiter,
response_profiles,
)
.await
}
});
http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, service)
.await?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn proxy(
mut req: Request<Incoming>,
state: SharedState,
config: Arc<EvasionConfig>,
default_escalation: Option<String>,
client: reqwest::Client,
mitm_enabled: bool,
mitm_ca: Option<Arc<CertificateAuthority>>,
policy: Arc<UpstreamPolicy>,
limits: Arc<ProxyLimits>,
scope: Arc<ScopeFilter>,
rate_limiter: Arc<RateLimiter>,
expose_wafrift_status: bool,
logger: SharedLogger,
response_profiles: Arc<ResponseProfileDb>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
if req.method() == Method::CONNECT {
if let Some(addr) = host_addr(req.uri()) {
let resolved = match wafrift_proxy::upstream_policy::resolve_connect_target_allowed(
&addr, &policy,
)
.await
{
Ok(v) => v,
Err(msg) => {
warn!("CONNECT rejected: {}", msg);
return Ok(error_response(StatusCode::FORBIDDEN, &msg));
}
};
if let (true, Some(ca)) = (mitm_enabled, mitm_ca.as_ref()) {
let ca = ca.clone();
let state = state.clone();
let config = config.clone();
let default_escalation = default_escalation.clone();
let client = client.clone();
let policy = policy.clone();
let limits = limits.clone();
let scope = scope.clone();
let rate_limiter = rate_limiter.clone();
let response_profiles = response_profiles.clone();
tokio::task::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
if let Err(e) = mitm_https_session(
upgraded,
addr,
ca,
state,
config,
default_escalation,
client,
policy,
limits,
scope,
rate_limiter,
response_profiles,
)
.await
{
warn!("mitm session error: {e:?}");
}
}
Err(e) => warn!("upgrade error: {}", e),
}
});
} else {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("connect-passthrough:{addr}"))
{
info!(
target = %addr,
"CONNECT pass-through (no MITM): bytes are TLS-encrypted, evasion engine inactive. \
Pass `--mitm` to terminate TLS and apply evasion to HTTPS request bodies."
);
}
let resolved_for_tunnel = resolved.clone();
tokio::task::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
if let Err(e) =
tunnel(upgraded, resolved_for_tunnel).await
{
warn!("server io error: {}", e);
};
}
Err(e) => warn!("upgrade error: {}", e),
}
});
}
return Ok(Response::new(Full::new(Bytes::new())));
}
return Ok(error_response(
StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
));
}
if req.uri().path() == "/_wafrift/findings.md" {
if !expose_wafrift_status {
return Ok(error_response(StatusCode::NOT_FOUND, "not found"));
}
let st = state.lock().await;
let md = render_live_findings(&st);
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/markdown; charset=utf-8")
.body(Full::new(Bytes::from(md)))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build findings response",
)
}));
}
if req.uri().path() == "/_wafrift/status" {
if !expose_wafrift_status {
return Ok(error_response(StatusCode::NOT_FOUND, "not found"));
}
let st = state.lock().await;
let response = serde_json::json!({
"status_schema_version": 1,
"hosts_scanned": st.hosts.len(),
"total_scanned": st.total_scanned,
"total_blocks": st.total_blocks,
"techniques_used": st.techniques_used,
"hosts": st.hosts.iter().map(|(host, hs)| {
serde_json::json!({
"host": host,
"blocks": hs.blocks,
"successes": hs.successes,
"discovery_complete": hs.discovery_complete,
"proven_winners": hs.proven_winners,
"blocklisted": hs.blocklisted,
})
}).collect::<Vec<_>>(),
});
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(response.to_string())))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build status response",
)
}));
}
let host = req
.uri()
.host()
.map(std::string::ToString::to_string)
.or_else(|| {
req.headers()
.get(hyper::header::HOST)
.and_then(|h| h.to_str().ok().map(extract_host_from_header))
})
.unwrap_or_else(|| "unknown".to_string());
let limited = Limited::new(req.body_mut(), MAX_PROXY_BODY_BYTES);
let body_bytes = match limited.collect().await {
Ok(b) => b.to_bytes().to_vec(),
Err(_) => {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("body-limit:{host}"))
{
warn!(host = %host, limit = MAX_PROXY_BODY_BYTES, "request body exceeded size limit");
}
return Ok(error_response(
StatusCode::PAYLOAD_TOO_LARGE,
"request body too large",
));
}
};
let raw_headers: Vec<(String, String)> = req
.headers()
.iter()
.map(|(k, v)| {
(
k.as_str().to_string(),
header_value_to_string(k.as_str(), v),
)
})
.collect();
let conn = collect_connection_header_names(&raw_headers);
let headers: Vec<(String, String)> = raw_headers
.into_iter()
.filter(|(k, _)| !should_strip_proxy_header(k, &conn))
.collect();
let mut wafrift_req = wafrift_types::Request::with_method(
wafrift_types::Method::from(req.method().as_str()),
req.uri().to_string(),
);
wafrift_req.headers = headers;
if !body_bytes.is_empty() {
wafrift_req.body = Some(body_bytes);
}
let log_uri = req.uri().to_string();
let evade_override = wafrift_req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(X_WAFRIFT_EVADE))
.map(|(_, v)| v.to_ascii_lowercase());
wafrift_req
.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case(X_WAFRIFT_EVADE));
rate_limiter.acquire(&host).await;
let path_for_scope = req
.uri()
.path_and_query().map_or_else(|| "/".to_string(), |p| p.path().to_string());
let skip_evasion = evade_override.as_deref() == Some("off");
if skip_evasion || !scope.allows(&host, &path_for_scope, &wafrift_req.method) {
debug!(host = %host, uri = %log_uri, "evasion skipped (off/out-of-scope)");
let resp = forward_passthrough(wafrift_req, host.clone(), &client, policy, limits).await;
if let (Ok(r), Some(log)) = (&resp, &logger) {
log.log_entry(&serde_json::json!({
"ts": time::OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_default(),
"host": host,
"method": req.method().as_str(),
"url": log_uri,
"evaded": false,
"status": r.status().as_u16(),
}))
.await;
}
return resp;
}
let effective_escalation = match evade_override.as_deref() {
Some("light" | "medium" | "heavy") => evade_override,
_ => default_escalation,
};
let resp = forward_with_evade_retry(
wafrift_req,
host.clone(),
log_uri.clone(),
state,
config,
effective_escalation,
&client,
policy,
limits,
response_profiles,
)
.await;
if let (Ok(r), Some(log)) = (&resp, &logger) {
let techniques: Vec<&str> = r
.headers()
.get(X_WAFRIFT_TECHNIQUES)
.and_then(|v| v.to_str().ok())
.map(|s| s.split(", ").collect())
.unwrap_or_default();
let blocked = r
.headers()
.get(X_WAFRIFT_BLOCKED)
.and_then(|v| v.to_str().ok())
== Some("true");
log.log_entry(&serde_json::json!({
"ts": time::OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_default(),
"host": host,
"method": req.method().as_str(),
"url": log_uri,
"evaded": true,
"techniques": techniques,
"status": r.status().as_u16(),
"blocked": blocked,
}))
.await;
}
resp
}
async fn forward_passthrough(
req: wafrift_types::Request,
host: String,
client: &reqwest::Client,
policy: Arc<UpstreamPolicy>,
limits: Arc<ProxyLimits>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
if let Err(msg) = assert_forward_url_allowed(&req.url, &policy).await {
warn!(host = %host, url = %req.url, "{}", msg);
return Ok(error_response(StatusCode::FORBIDDEN, &msg));
}
let conn_fwd = collect_connection_header_names(&req.headers);
let max = limits.max_upstream_response_bytes;
let (response_builder, buf) = if let Some(sc) = stealth() {
let mut filtered_headers = Vec::with_capacity(req.headers.len());
for (k, v) in &req.headers {
if k.eq_ignore_ascii_case("host")
|| k.eq_ignore_ascii_case("content-length")
|| should_strip_proxy_header(k, &conn_fwd)
{
continue;
}
filtered_headers.push((k.clone(), v.clone()));
}
let stealth_resp = match sc
.send(
req.method.as_str(),
&req.url,
&filtered_headers,
req.body.as_deref(),
max,
)
.await
{
Ok(r) => r,
Err(e) => {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("passthrough:{host}"))
{
warn!(host = %host, error = %e, stack = "stealth", "passthrough forwarding failed");
}
return Ok(error_response(StatusCode::BAD_GATEWAY, "forwarding error"));
}
};
let mut response_builder = Response::builder().status(stealth_resp.status);
let conn_resp: std::collections::HashSet<String> = stealth_resp
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("connection"))
.map(|(_, v)| {
v.split(',')
.map(|t| t.trim().to_ascii_lowercase())
.filter(|t| !t.is_empty())
.collect()
})
.unwrap_or_default();
for (k, v) in &stealth_resp.headers {
if should_strip_proxy_header(k, &conn_resp) {
continue;
}
response_builder = response_builder.header(k.as_str(), v.as_str());
}
(response_builder, stealth_resp.body.clone())
} else {
let method = match reqwest::Method::from_bytes(req.method.as_str().as_bytes()) {
Ok(m) => m,
Err(_) => {
return Ok(error_response(
StatusCode::BAD_REQUEST,
"invalid HTTP method",
));
}
};
let mut builder = client.request(method, &req.url);
for (k, v) in &req.headers {
if k.eq_ignore_ascii_case("host")
|| k.eq_ignore_ascii_case("content-length")
|| should_strip_proxy_header(k, &conn_fwd)
{
continue;
}
builder = builder.header(k.as_str(), v.as_str());
}
if let Some(b) = req.body {
builder = builder.body(b);
}
let resp = match builder.send().await {
Ok(r) => r,
Err(e) => {
if let Some(throttle) = WARN_THROTTLE.get()
&& throttle.should_warn(&format!("passthrough:{host}"))
{
warn!(host = %host, error = %e, "passthrough forwarding failed");
}
return Ok(error_response(StatusCode::BAD_GATEWAY, "forwarding error"));
}
};
let status = resp.status();
let conn_resp = collect_connection_header_names_hyper(resp.headers());
let mut response_builder = Response::builder().status(status.as_u16());
for (k, v) in resp.headers() {
if should_strip_proxy_header(k.as_str(), &conn_resp) {
continue;
}
response_builder = response_builder.header(k, v);
}
let mut stream = resp.bytes_stream();
let mut buf = Vec::new();
while let Some(item) = stream.next().await {
let chunk = match item {
Ok(c) => c,
Err(e) => {
warn!(host = %host, error = %e, "upstream body read failed");
return Ok(error_response(
StatusCode::BAD_GATEWAY,
"upstream read error",
));
}
};
if buf.len().saturating_add(chunk.len()) > max {
return Ok(error_response(
StatusCode::PAYLOAD_TOO_LARGE,
"upstream response too large",
));
}
buf.extend_from_slice(&chunk);
}
(response_builder, buf)
};
Ok(response_builder
.body(Full::new(Bytes::from(buf)))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
)
}))
}
fn render_live_findings(state: &ProxyState) -> String {
let mut out = String::new();
out.push_str("# wafrift live findings\n\n");
out.push_str(&format!(
"Total proxied: {} · Total WAF blocks observed: {} · Hosts seen: {}\n\n",
state.total_scanned,
state.total_blocks,
state.hosts.len(),
));
if state.total_scanned == 0 {
out.push_str("No requests have been proxied yet. Send traffic through the proxy to begin evasion discovery.\n");
return out;
}
let mut hosts_with_winners: Vec<(&String, &HostState)> = state
.hosts
.iter()
.filter(|(_, hs)| !hs.proven_winners.is_empty())
.collect();
hosts_with_winners.sort_by(|a, b| a.0.cmp(b.0));
if hosts_with_winners.is_empty() {
out.push_str("_No bypasses discovered yet — keep traffic flowing through the proxy. Blocks are being recorded and will inform technique selection._\n");
return out;
}
out.push_str("## Hosts with proven bypasses\n\n");
for (host, hs) in hosts_with_winners {
let host_md = sanitize_for_markdown(host);
let waf_md = hs.waf_name.as_deref().map(sanitize_for_markdown);
out.push_str(&format!("### `{host_md}`\n\n"));
if let Some(waf) = &waf_md {
out.push_str(&format!("**Identified WAF:** {waf}\n\n"));
}
out.push_str("**Working techniques:**\n\n");
for t in &hs.proven_winners {
out.push_str(&format!("- `{}`\n", sanitize_for_markdown(t)));
}
out.push('\n');
out.push_str(&format!(
"**Reproduce:** `wafrift replay --target 'https://{host_md}/<PATH>' --param q --payload '<PAYLOAD>' --from-host '{host_md}'`\n\n",
));
}
out
}
fn sanitize_for_markdown(s: &str) -> String {
s.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_' | ':' | '/' | '+' | '@') {
c
} else {
'_'
}
})
.collect()
}
fn host_addr(uri: &hyper::Uri) -> Option<String> {
uri.authority().map(std::string::ToString::to_string)
}
const MAX_TUNNEL_BYTES_PER_DIRECTION: u64 = 2 * 1024 * 1024 * 1024;
async fn tunnel(
upgraded: Upgraded,
addrs: Vec<std::net::SocketAddr>,
) -> std::io::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut server = TcpStream::connect(addrs.as_slice()).await?;
let mut upgraded = TokioIo::new(upgraded);
let (mut up_r, mut up_w) = tokio::io::split(&mut upgraded);
let (mut sv_r, mut sv_w) = server.split();
let to_server = async {
let mut buf = vec![0u8; 16 * 1024];
let mut total: u64 = 0;
loop {
let n = up_r.read(&mut buf).await?;
if n == 0 {
break;
}
total = total.saturating_add(n as u64);
if total > MAX_TUNNEL_BYTES_PER_DIRECTION {
return Err(std::io::Error::other(
"tunnel exceeded byte cap (client→server)",
));
}
sv_w.write_all(&buf[..n]).await?;
}
Ok::<(), std::io::Error>(())
};
let to_client = async {
let mut buf = vec![0u8; 16 * 1024];
let mut total: u64 = 0;
loop {
let n = sv_r.read(&mut buf).await?;
if n == 0 {
break;
}
total = total.saturating_add(n as u64);
if total > MAX_TUNNEL_BYTES_PER_DIRECTION {
return Err(std::io::Error::other(
"tunnel exceeded byte cap (server→client)",
));
}
up_w.write_all(&buf[..n]).await?;
}
Ok::<(), std::io::Error>(())
};
tokio::try_join!(to_server, to_client)?;
Ok(())
}
#[cfg(test)]
#[path = "proxy_tests.rs"]
mod tests;