use std::net::SocketAddr;
use std::sync::Arc;
use arbiter_audit::RedactionConfig;
use arbiter_metrics::ArbiterMetrics;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode};
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::TcpListener;
use arbiter_proxy::config::MiddlewareConfig;
use arbiter_proxy::middleware::MiddlewareChain;
use arbiter_proxy::proxy::{ProxyState, handle_request};
async fn ephemeral_listener() -> (TcpListener, SocketAddr) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
(listener, addr)
}
async fn spawn_upstream() -> SocketAddr {
let (listener, addr) = ephemeral_listener().await;
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(v) => v,
Err(_) => break,
};
tokio::spawn(async move {
let io = TokioIo::new(stream);
let _ = http1::Builder::new()
.serve_connection(
io,
service_fn(|req: Request<hyper::body::Incoming>| async move {
let path = req.uri().path().to_string();
let body = format!("upstream saw {path}");
Ok::<_, hyper::Error>(Response::new(Full::new(Bytes::from(body))))
}),
)
.await;
});
}
});
addr
}
async fn spawn_proxy(
upstream_addr: SocketAddr,
mw_config: MiddlewareConfig,
) -> (SocketAddr, Arc<ArbiterMetrics>) {
let (listener, addr) = ephemeral_listener().await;
let middleware = MiddlewareChain::from_config(&mw_config);
let metrics = Arc::new(ArbiterMetrics::new().unwrap());
let metrics_clone = Arc::clone(&metrics);
let state = Arc::new(ProxyState::new(
format!("http://{upstream_addr}"),
middleware,
None, RedactionConfig::default(),
metrics,
10 * 1024 * 1024, std::time::Duration::from_secs(30), ));
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(v) => v,
Err(_) => break,
};
let state = Arc::clone(&state);
tokio::spawn(async move {
let io = TokioIo::new(stream);
let svc = service_fn(move |req| {
let state = Arc::clone(&state);
handle_request(state, req)
});
let _ = http1::Builder::new().serve_connection(io, svc).await;
});
}
});
(addr, metrics_clone)
}
async fn get(url: &str) -> (StatusCode, String) {
let client: Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>> =
Client::builder(TokioExecutor::new()).build_http();
let uri: hyper::Uri = url.parse().unwrap();
let req = Request::builder()
.uri(uri)
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.unwrap();
let status = resp.status();
let body = resp.into_body().collect().await.unwrap().to_bytes();
(status, String::from_utf8_lossy(&body).to_string())
}
#[tokio::test]
async fn health_check_returns_200() {
let upstream_addr = spawn_upstream().await;
let (proxy_addr, _metrics) = spawn_proxy(upstream_addr, MiddlewareConfig::default()).await;
let (status, body) = get(&format!("http://{proxy_addr}/health")).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "OK");
}
#[tokio::test]
async fn proxy_forwards_to_upstream() {
let upstream_addr = spawn_upstream().await;
let (proxy_addr, _metrics) = spawn_proxy(upstream_addr, MiddlewareConfig::default()).await;
let (status, body) = get(&format!("http://{proxy_addr}/hello/world")).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "upstream saw /hello/world");
}
#[tokio::test]
async fn middleware_rejects_blocked_path() {
let upstream_addr = spawn_upstream().await;
let mw = MiddlewareConfig {
blocked_paths: vec!["/admin".to_string(), "/secret".to_string()],
required_headers: vec![],
};
let (proxy_addr, _metrics) = spawn_proxy(upstream_addr, mw).await;
let (status, body) = get(&format!("http://{proxy_addr}/admin")).await;
assert_eq!(status, StatusCode::FORBIDDEN);
assert_eq!(body, "Forbidden");
let (status, body) = get(&format!("http://{proxy_addr}/ok")).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "upstream saw /ok");
}
#[tokio::test]
async fn middleware_rejects_missing_required_header() {
let upstream_addr = spawn_upstream().await;
let mw = MiddlewareConfig {
blocked_paths: vec![],
required_headers: vec!["x-api-key".to_string()],
};
let (proxy_addr, _metrics) = spawn_proxy(upstream_addr, mw).await;
let (status, body) = get(&format!("http://{proxy_addr}/api")).await;
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
!body.contains("x-api-key"),
"response must NOT leak the required header name; got: {body}"
);
assert!(body.contains("Bad Request"));
}
#[tokio::test]
async fn metrics_endpoint_returns_prometheus_format() {
let upstream_addr = spawn_upstream().await;
let (proxy_addr, _metrics) = spawn_proxy(upstream_addr, MiddlewareConfig::default()).await;
let _ = get(&format!("http://{proxy_addr}/hello")).await;
let (status, body) = get(&format!("http://{proxy_addr}/metrics")).await;
assert_eq!(status, StatusCode::OK);
assert!(body.contains("requests_total"));
assert!(body.contains("request_duration_seconds"));
}
#[tokio::test]
async fn metrics_track_requests() {
let upstream_addr = spawn_upstream().await;
let (proxy_addr, metrics) = spawn_proxy(upstream_addr, MiddlewareConfig::default()).await;
let _ = get(&format!("http://{proxy_addr}/a")).await;
let _ = get(&format!("http://{proxy_addr}/b")).await;
assert_eq!(
metrics.requests_total.with_label_values(&["allow"]).get(),
2
);
assert_eq!(metrics.tool_calls_total.with_label_values(&["/a"]).get(), 1);
assert_eq!(metrics.tool_calls_total.with_label_values(&["/b"]).get(), 1);
}
#[tokio::test]
async fn metrics_track_denied_requests() {
let upstream_addr = spawn_upstream().await;
let mw = MiddlewareConfig {
blocked_paths: vec!["/blocked".to_string()],
required_headers: vec![],
};
let (proxy_addr, metrics) = spawn_proxy(upstream_addr, mw).await;
let _ = get(&format!("http://{proxy_addr}/blocked")).await;
assert_eq!(metrics.requests_total.with_label_values(&["deny"]).get(), 1);
}