use crate::config::Config;
use crate::db::Db;
use anyhow::{Context, Result};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tracing::warn;
struct RoundRobinProxyMiddleware {
clients: Vec<reqwest::Client>,
counter: Arc<AtomicUsize>,
event_target: Option<(std::path::PathBuf, String)>,
}
#[async_trait::async_trait]
impl reqwest_middleware::Middleware for RoundRobinProxyMiddleware {
async fn handle(
&self,
req: reqwest::Request,
_extensions: &mut http::Extensions,
_next: reqwest_middleware::Next<'_>,
) -> reqwest_middleware::Result<reqwest::Response> {
let n = self.clients.len();
debug_assert!(n > 0, "proxy middleware installed with no clients");
let start = self.counter.fetch_add(1, Ordering::Relaxed) % n;
if req.try_clone().is_none() {
return self.clients[start]
.execute(req)
.await
.map_err(reqwest_middleware::Error::Reqwest);
}
let mut last_err: Option<reqwest::Error> = None;
for attempt in 0..n {
let idx = (start + attempt) % n;
let req_clone = req
.try_clone()
.expect("request body became non-cloneable mid-loop (impossible)");
match self.clients[idx].execute(req_clone).await {
Ok(resp) => {
if attempt > 0 {
warn!(
proxy_index = idx,
attempts = attempt + 1,
"request succeeded after proxy failover"
);
self.emit_failover_event(idx, attempt, None);
}
return Ok(resp);
}
Err(e) if is_transient_proxy_error(&e) => {
warn!(
proxy_index = idx,
error = %e,
attempt = attempt + 1,
total_proxies = n,
"proxy attempt failed; trying next proxy"
);
last_err = Some(e);
continue;
}
Err(e) => {
return Err(reqwest_middleware::Error::Reqwest(e));
}
}
}
warn!(
total_proxies = n,
"all proxies failed for request; returning last transient error"
);
let last = last_err
.expect("loop ran at least once and last_err is set on every transient failure");
self.emit_failover_event(n - 1, n, Some(&last.to_string()));
Err(reqwest_middleware::Error::Reqwest(last))
}
}
impl RoundRobinProxyMiddleware {
fn emit_failover_event(&self, failed_proxy_index: usize, attempt: usize, error: Option<&str>) {
let Some((ref db_path, ref pid)) = self.event_target else {
return;
};
let db = Db::new_file(db_path.clone());
let payload = serde_json::json!({
"failed_proxy_index": failed_proxy_index,
"attempt": attempt,
"error": error.unwrap_or("failover — subsequent proxy succeeded"),
});
if let Err(e) = db.add_event(Some(pid), None, "proxy_failover", payload) {
warn!(error = %e, "failed to write proxy_failover event to db");
}
}
}
fn is_transient_proxy_error(err: &reqwest::Error) -> bool {
err.is_connect() || err.is_timeout() || err.status().is_none()
}
pub fn build_client(
config: &Config,
db: Option<&Db>,
project_id: Option<&str>,
) -> Result<ClientWithMiddleware> {
if config.core.proxies.is_empty() {
return Ok(ClientBuilder::new(reqwest::Client::new()).build());
}
tracing::info!(
count = config.core.proxies.len(),
"outbound HTTP client: enabling round-robin proxy pool"
);
let proxy_clients = config
.core
.proxies
.iter()
.enumerate()
.map(|(i, url)| {
let proxy = reqwest::Proxy::all(url)
.with_context(|| format!("invalid proxy URL at index {i}"))?;
reqwest::Client::builder()
.proxy(proxy)
.build()
.with_context(|| format!("failed to build client for proxy at index {i}"))
})
.collect::<Result<Vec<_>>>()?;
let event_target: Option<(std::path::PathBuf, String)> = db
.zip(project_id)
.map(|(d, pid)| (d.path.clone(), pid.to_owned()));
let middleware = RoundRobinProxyMiddleware {
clients: proxy_clients,
counter: Arc::new(AtomicUsize::new(0)),
event_target,
};
Ok(ClientBuilder::new(reqwest::Client::new())
.with(middleware)
.build())
}
pub fn build_reqwest_client(config: &Config) -> Result<reqwest::Client> {
let mut builder = reqwest::Client::builder();
if let Some(proxy_url) = config.core.proxies.first() {
let proxy = reqwest::Proxy::all(proxy_url).context("invalid proxy URL at index 0")?;
builder = builder.proxy(proxy);
}
builder.build().context("failed to build reqwest client")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
#[test]
fn build_client_no_proxies_succeeds() {
let config = Config::default();
assert!(config.core.proxies.is_empty());
let client = build_client(&config, None, None).expect("build_client with no proxies");
drop(client);
}
#[test]
fn build_reqwest_client_no_proxies_succeeds() {
let config = Config::default();
let client = build_reqwest_client(&config).expect("build_reqwest_client with no proxies");
drop(client);
}
#[test]
fn build_client_with_valid_proxies_succeeds() {
let mut config = Config::default();
config.core.proxies = vec![
"http://proxy1.example.com:8080".to_string(),
"socks5://proxy2.example.com:1080".to_string(),
];
let client = build_client(&config, None, None).expect("build_client with valid proxy URLs");
drop(client);
}
#[test]
fn is_transient_proxy_error_classifies_correctly() {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_millis(500))
.build()
.unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let err = rt.block_on(async {
client
.get("http://198.51.100.1:1")
.send()
.await
.expect_err("connection to TEST-NET-2 must fail")
});
assert!(
is_transient_proxy_error(&err),
"connect error to unroutable host must be classified as transient: {err}"
);
}
#[tokio::test]
async fn round_robin_failover_skips_dead_proxies() {
use axum::{Router, response::IntoResponse, routing::get};
use std::net::SocketAddr;
use tokio::net::TcpListener;
async fn ok_handler() -> impl IntoResponse {
"ok-from-live-proxy"
}
let app = Router::new().route("/", get(ok_handler));
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap();
let live_addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let mut config = Config::default();
config.core.proxies = vec![
"http://198.51.100.1:1".to_string(),
"http://198.51.100.2:2".to_string(),
format!("http://{live_addr}"),
];
let client = build_client(&config, None, None).expect("build client with mixed proxies");
let resp = client
.get(format!("http://{live_addr}/"))
.send()
.await
.expect("request must succeed via 3rd proxy after 2 failover steps");
let status = resp.status();
let body = resp.text().await.unwrap();
assert!(status.is_success(), "got {status}: {body}");
assert_eq!(body, "ok-from-live-proxy");
server_handle.abort();
}
#[tokio::test]
async fn failover_writes_proxy_failover_event_to_db() {
use axum::{Router, response::IntoResponse, routing::get};
use std::net::SocketAddr;
use tempfile::tempdir;
use tokio::net::TcpListener;
async fn ok_handler() -> impl IntoResponse {
"ok"
}
let app = Router::new().route("/", get(ok_handler));
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap();
let live_addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let tmp = tempdir().unwrap();
let db = crate::db::Db::new(tmp.path());
db.ensure_schema().unwrap();
let mut config = Config::default();
config.core.proxies = vec![
"http://198.51.100.1:1".to_string(),
format!("http://{live_addr}"),
];
let project_id = "test-proj-failover";
let client =
build_client(&config, Some(&db), Some(project_id)).expect("build client for test");
let resp = client
.get(format!("http://{live_addr}/"))
.send()
.await
.expect("request must succeed via 2nd proxy");
assert!(resp.status().is_success());
let failover_event = db
.most_recent_event_of_type(project_id, "proxy_failover")
.unwrap();
assert!(
failover_event.is_some(),
"expected a proxy_failover event in db after failover"
);
server_handle.abort();
}
}