use bytes::Bytes;
#[cfg(feature = "tls-impersonate")]
use std::time::Duration;
use thiserror::Error;
use wafrift_transport::stealth::ImpersonateProfile;
#[cfg(feature = "tls-impersonate")]
use wafrift_transport::stealth::StealthClient;
#[derive(Debug, Error)]
pub enum UpstreamError {
#[error("upstream request failed: {0}")]
Request(String),
#[error("invalid HTTP method: {0}")]
InvalidMethod(String),
#[error("upstream response too large (cap {cap}): truncated at {got} bytes")]
BodyTooLarge { got: usize, cap: usize },
#[error(
"stealth mode requires the `tls-impersonate` cargo feature; \
rebuild wafrift-proxy with `cargo build --features \
wafrift-transport/tls-impersonate`"
)]
StealthFeatureDisabled,
}
#[derive(Debug)]
pub struct UpstreamResponse {
pub status: http::StatusCode,
pub headers: http::HeaderMap,
pub body: Bytes,
}
#[derive(Clone)]
pub enum UpstreamClient {
Reqwest(reqwest::Client),
#[cfg(feature = "tls-impersonate")]
Stealth(std::sync::Arc<StealthClient>),
#[cfg(feature = "tls-impersonate")]
StealthPool {
clients: std::sync::Arc<Vec<std::sync::Arc<StealthClient>>>,
cursor: std::sync::Arc<std::sync::atomic::AtomicUsize>,
},
}
impl UpstreamClient {
#[must_use]
pub fn from_reqwest(client: reqwest::Client) -> Self {
Self::Reqwest(client)
}
pub fn stealth(_profile: ImpersonateProfile) -> Result<Self, UpstreamError> {
#[cfg(feature = "tls-impersonate")]
{
let client = StealthClient::with_timeout(_profile, Duration::from_secs(60))
.map_err(|e| UpstreamError::Request(e.to_string()))?;
Ok(Self::Stealth(std::sync::Arc::new(client)))
}
#[cfg(not(feature = "tls-impersonate"))]
{
Err(UpstreamError::StealthFeatureDisabled)
}
}
pub fn stealth_pool(_profiles: &[ImpersonateProfile]) -> Result<Self, UpstreamError> {
#[cfg(feature = "tls-impersonate")]
{
if _profiles.is_empty() {
return Err(UpstreamError::Request(
"stealth_pool requires at least one profile".into(),
));
}
let mut clients = Vec::with_capacity(_profiles.len());
for &p in _profiles {
let c = StealthClient::with_timeout(p, Duration::from_secs(60))
.map_err(|e| UpstreamError::Request(format!("{}: {e}", p.name())))?;
clients.push(std::sync::Arc::new(c));
}
Ok(Self::StealthPool {
clients: std::sync::Arc::new(clients),
cursor: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)),
})
}
#[cfg(not(feature = "tls-impersonate"))]
{
Err(UpstreamError::StealthFeatureDisabled)
}
}
pub async fn send(
&self,
method: &str,
url: &str,
headers: &[(String, String)],
body: Option<Vec<u8>>,
max_body: usize,
) -> Result<UpstreamResponse, UpstreamError> {
match self {
Self::Reqwest(client) => {
let m = reqwest::Method::from_bytes(method.as_bytes())
.map_err(|_| UpstreamError::InvalidMethod(method.to_string()))?;
let mut req = client.request(m, url);
for (k, v) in headers {
req = req.header(k.as_str(), v.as_str());
}
if let Some(b) = body {
req = req.body(b);
}
let resp = req
.send()
.await
.map_err(|e| UpstreamError::Request(e.to_string()))?;
let status = http::StatusCode::from_u16(resp.status().as_u16())
.map_err(|e| UpstreamError::Request(e.to_string()))?;
let headers = resp.headers().clone();
let mut buf = Vec::new();
let mut stream = resp.bytes_stream();
use futures_util::StreamExt;
while let Some(chunk) = stream.next().await {
let chunk = chunk
.map_err(|e| UpstreamError::Request(e.to_string()))?;
let remaining = max_body.saturating_sub(buf.len());
if remaining == 0 {
break;
}
let take = chunk.len().min(remaining);
buf.extend_from_slice(&chunk[..take]);
if chunk.len() > remaining {
break;
}
}
Ok(UpstreamResponse {
status,
headers,
body: Bytes::from(buf),
})
}
#[cfg(feature = "tls-impersonate")]
Self::Stealth(client) => {
Self::send_via_stealth(client, method, url, headers, body, max_body).await
}
#[cfg(feature = "tls-impersonate")]
Self::StealthPool { clients, cursor } => {
let idx = cursor
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% clients.len();
let client = clients[idx].clone();
Self::send_via_stealth(&client, method, url, headers, body, max_body).await
}
}
}
#[cfg(feature = "tls-impersonate")]
async fn send_via_stealth(
client: &StealthClient,
method: &str,
url: &str,
headers: &[(String, String)],
body: Option<Vec<u8>>,
max_body: usize,
) -> Result<UpstreamResponse, UpstreamError> {
let stealth_resp = client
.send(method, url, headers, body.as_deref(), max_body)
.await
.map_err(|e| UpstreamError::Request(e.to_string()))?;
let status = http::StatusCode::from_u16(stealth_resp.status)
.map_err(|e| UpstreamError::Request(e.to_string()))?;
let mut header_map = http::HeaderMap::with_capacity(stealth_resp.headers.len());
for (k, v) in &stealth_resp.headers {
if let (Ok(name), Ok(val)) = (
http::HeaderName::from_bytes(k.as_bytes()),
http::HeaderValue::from_bytes(v.as_bytes()),
) {
header_map.append(name, val);
}
}
Ok(UpstreamResponse {
status,
headers: header_map,
body: Bytes::from(stealth_resp.body),
})
}
#[must_use]
pub fn tls_stack_name(&self) -> &'static str {
match self {
Self::Reqwest(_) => "rustls (default)",
#[cfg(feature = "tls-impersonate")]
Self::Stealth(_) => "boringssl (stealth)",
#[cfg(feature = "tls-impersonate")]
Self::StealthPool { .. } => "boringssl (stealth pool, rotating)",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_reqwest_wraps_client() {
let client = reqwest::Client::new();
let upstream = UpstreamClient::from_reqwest(client);
assert_eq!(upstream.tls_stack_name(), "rustls (default)");
}
#[test]
fn upstream_error_messages_are_actionable() {
let err = UpstreamError::InvalidMethod("FUBAR".into());
assert!(err.to_string().contains("FUBAR"));
let err = UpstreamError::BodyTooLarge {
got: 5_000_000,
cap: 1_000_000,
};
let msg = err.to_string();
assert!(msg.contains("5000000"));
assert!(msg.contains("1000000"));
let err = UpstreamError::StealthFeatureDisabled;
let msg = err.to_string();
assert!(
msg.contains("tls-impersonate") && msg.contains("cargo build"),
"feature-disabled error must name the cargo flag, got: {msg}"
);
}
#[cfg(not(feature = "tls-impersonate"))]
#[test]
fn stealth_constructor_errors_when_feature_off() {
match UpstreamClient::stealth(ImpersonateProfile::Chrome131) {
Err(UpstreamError::StealthFeatureDisabled) => {}
Err(other) => panic!("expected StealthFeatureDisabled, got {other}"),
Ok(_) => panic!("expected error, got Ok variant"),
}
}
#[cfg(feature = "tls-impersonate")]
#[test]
fn stealth_constructor_builds_when_feature_on() {
let upstream = UpstreamClient::stealth(ImpersonateProfile::Chrome131).unwrap();
assert_eq!(upstream.tls_stack_name(), "boringssl (stealth)");
}
#[cfg(feature = "tls-impersonate")]
#[test]
fn stealth_pool_rotates_round_robin() {
let pool = UpstreamClient::stealth_pool(&[
ImpersonateProfile::Chrome131,
ImpersonateProfile::Firefox133,
ImpersonateProfile::Safari18,
])
.unwrap();
assert_eq!(pool.tls_stack_name(), "boringssl (stealth pool, rotating)");
if let UpstreamClient::StealthPool { clients, cursor } = &pool {
assert_eq!(clients.len(), 3);
let first =
cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
let second =
cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
let third =
cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
let fourth =
cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
assert_eq!((first, second, third, fourth), (0, 1, 2, 0));
} else {
panic!("expected StealthPool variant");
}
}
#[cfg(feature = "tls-impersonate")]
#[test]
fn stealth_pool_rejects_empty_profiles() {
let err = UpstreamClient::stealth_pool(&[]).unwrap_err();
match err {
UpstreamError::Request(msg) => assert!(msg.contains("at least one")),
other => panic!("expected Request error, got {other:?}"),
}
}
#[cfg(not(feature = "tls-impersonate"))]
#[test]
fn stealth_pool_errors_when_feature_off() {
match UpstreamClient::stealth_pool(&[ImpersonateProfile::Chrome131]) {
Err(UpstreamError::StealthFeatureDisabled) => {}
Err(other) => panic!("expected StealthFeatureDisabled, got {other}"),
Ok(_) => panic!("expected error, got Ok variant"),
}
}
}