use std::sync::Arc;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tokio::sync::RwLock;
use tower::ServiceExt;
use wardn::config::{CredentialConfig, RateLimitConfig, TimePeriod, WardenConfig};
use wardn::proxy::rate_limit::RateLimiter;
use wardn::proxy::{self, ProxyState};
use wardn::Vault;
fn test_state() -> (Arc<ProxyState>, String) {
let mut vault = Vault::ephemeral();
vault
.set_with_config(
"OPENAI_KEY",
"sk-proj-real-key-123",
&CredentialConfig {
allowed_agents: vec!["researcher".to_string()],
allowed_domains: vec!["api.openai.com".to_string()],
rate_limit: None,
},
)
.unwrap();
let placeholder = vault
.get_placeholder("OPENAI_KEY", "researcher")
.unwrap()
.to_string();
let state = Arc::new(ProxyState {
vault: Arc::new(RwLock::new(vault)),
rate_limiter: Arc::new(tokio::sync::Mutex::new(RateLimiter::new())),
config: WardenConfig::default(),
http_client: reqwest::Client::new(),
});
(state, placeholder)
}
#[tokio::test]
async fn test_proxy_health_endpoint() {
let (state, _) = test_state();
let app = proxy::build_router(state);
let req = Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_proxy_strips_warden_agent_header() {
let (state, placeholder) = test_state();
let app = proxy::build_router(state);
let req = Request::builder()
.method("POST")
.uri("https://api.openai.com/v1/chat/completions")
.header("host", "api.openai.com")
.header("x-warden-agent", "researcher")
.header("authorization", format!("Bearer {placeholder}"))
.body(Body::from(r#"{"model": "gpt-4"}"#))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_ne!(
resp.status(),
StatusCode::FORBIDDEN,
"Warden should not block an authorized agent+domain combo"
);
}
#[tokio::test]
async fn test_proxy_blocks_unauthorized_agent() {
let (state, _) = test_state();
let vault = state.vault.read().await;
drop(vault);
let app = proxy::build_router(state);
let req = Request::builder()
.method("POST")
.uri("https://api.openai.com/v1/chat/completions")
.header("host", "api.openai.com")
.header("x-warden-agent", "hacker")
.header("authorization", "Bearer wdn_placeholder_0000000000000000")
.body(Body::from(r#"{"model": "gpt-4"}"#))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_ne!(resp.status(), StatusCode::FORBIDDEN);
assert_ne!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_proxy_blocks_unauthorized_domain() {
let (state, placeholder) = test_state();
let app = proxy::build_router(state);
let req = Request::builder()
.method("POST")
.uri("https://evil.com/steal")
.header("host", "evil.com")
.header("x-warden-agent", "researcher")
.header("authorization", format!("Bearer {placeholder}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"], "domain_not_allowed");
}
#[tokio::test]
async fn test_proxy_rate_limits() {
let mut vault = Vault::ephemeral();
vault
.set_with_config(
"KEY",
"secret-key-value-long",
&CredentialConfig {
allowed_agents: vec![],
allowed_domains: vec![],
rate_limit: Some(RateLimitConfig {
max_calls: 2,
per: TimePeriod::Hour,
}),
},
)
.unwrap();
let placeholder = vault
.get_placeholder("KEY", "bot")
.unwrap()
.to_string();
let mut rl = RateLimiter::new();
rl.configure(
"KEY",
"bot",
&RateLimitConfig {
max_calls: 2,
per: TimePeriod::Hour,
},
);
let state = Arc::new(ProxyState {
vault: Arc::new(RwLock::new(vault)),
rate_limiter: Arc::new(tokio::sync::Mutex::new(rl)),
config: WardenConfig::default(),
http_client: reqwest::Client::new(),
});
for _ in 0..2 {
let app = proxy::build_router(state.clone());
let req = Request::builder()
.method("POST")
.uri("https://api.example.com/v1")
.header("host", "api.example.com")
.header("x-warden-agent", "bot")
.header("authorization", format!("Bearer {placeholder}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_ne!(
resp.status(),
StatusCode::TOO_MANY_REQUESTS,
"should not be rate limited yet"
);
}
let app = proxy::build_router(state.clone());
let req = Request::builder()
.method("POST")
.uri("https://api.example.com/v1")
.header("host", "api.example.com")
.header("x-warden-agent", "bot")
.header("authorization", format!("Bearer {placeholder}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"], "rate_limit_exceeded");
assert!(json["retry_after_seconds"].as_u64().unwrap() > 0);
}
#[tokio::test]
async fn test_proxy_passthrough_no_placeholders() {
let (state, _) = test_state();
let app = proxy::build_router(state);
let req = Request::builder()
.method("GET")
.uri("https://api.example.com/data")
.header("host", "api.example.com")
.header("x-warden-agent", "researcher")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_ne!(resp.status(), StatusCode::FORBIDDEN);
}