use std::collections::BTreeMap;
use std::sync::Arc;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use http_body_util::BodyExt;
use tower::util::ServiceExt;
use wiremock::matchers::{body_string_contains, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use bytes::Bytes;
use clawshell::config::{Config, DlpAction, DlpPattern, Provider};
use clawshell::dlp::DlpScanner;
use clawshell::keys::KeyManager;
use clawshell::proxy::ProxyClient;
use clawshell::{AppState, build_router};
fn make_app(upstream_url: &str) -> axum::Router {
let mut key_map = BTreeMap::new();
key_map.insert(
"vk-test-1".to_string(),
("sk-real-1".to_string(), Provider::Openai),
);
key_map.insert(
"vk-test-2".to_string(),
("sk-real-2".to_string(), Provider::Openai),
);
let patterns = vec![
DlpPattern {
name: "ssn".to_string(),
regex: r"\b\d{3}-\d{2}-\d{4}\b".to_string(),
action: DlpAction::Block,
},
DlpPattern {
name: "email".to_string(),
regex: r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b".to_string(),
action: DlpAction::Block,
},
DlpPattern {
name: "credit_card".to_string(),
regex: r"\b(?:\d[ -]*?){13,19}\b".to_string(),
action: DlpAction::Block,
},
];
let mut upstream_urls = BTreeMap::new();
upstream_urls.insert(Provider::Openai, upstream_url.to_string());
upstream_urls.insert(Provider::Anthropic, upstream_url.to_string());
let state = AppState {
key_manager: Arc::new(KeyManager::new(key_map)),
dlp_scanner: Arc::new(DlpScanner::new(&patterns).unwrap()),
proxy_client: Arc::new(ProxyClient::with_upstream_urls(
upstream_urls,
"2023-06-01".to_string(),
)),
};
build_router(state)
}
fn make_app_with_anthropic(upstream_url: &str) -> axum::Router {
let mut key_map = BTreeMap::new();
key_map.insert(
"vk-test-1".to_string(),
("sk-real-1".to_string(), Provider::Openai),
);
key_map.insert(
"vk-ant-1".to_string(),
("sk-ant-real-1".to_string(), Provider::Anthropic),
);
let mut upstream_urls = BTreeMap::new();
upstream_urls.insert(Provider::Openai, upstream_url.to_string());
upstream_urls.insert(Provider::Anthropic, upstream_url.to_string());
let state = AppState {
key_manager: Arc::new(KeyManager::new(key_map)),
dlp_scanner: Arc::new(DlpScanner::new(&[]).unwrap()),
proxy_client: Arc::new(ProxyClient::with_upstream_urls(
upstream_urls,
"2023-06-01".to_string(),
)),
};
build_router(state)
}
#[tokio::test]
async fn test_proxy_forward_success() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(header("authorization", "Bearer sk-real-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"id": "chatcmpl-abc",
"choices": [{"message": {"content": "Hello!"}}]
})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["id"], "chatcmpl-abc");
}
#[tokio::test]
async fn test_proxy_preserves_query_params() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.and(header("authorization", "Bearer sk-real-1"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"data": [{"id": "gpt-4"}]})),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("GET")
.uri("/v1/models?limit=10")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_proxy_forwards_upstream_errors() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
"error": {"message": "Invalid model", "type": "invalid_request_error"}
})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let body = r#"{"model":"nonexistent"}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
json["error"]["message"]
.as_str()
.unwrap()
.contains("Invalid model")
);
}
#[tokio::test]
async fn test_real_key_injected_correctly() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.and(header("authorization", "Bearer sk-real-2"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"ok": true})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("GET")
.uri("/v1/models")
.header("authorization", "Bearer vk-test-2")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_missing_auth_header() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.uri("/v1/chat/completions")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_invalid_auth_format() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.uri("/v1/chat/completions")
.header("authorization", "Basic abc123")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_unknown_virtual_key() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-unknown")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_empty_bearer_token() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.uri("/v1/models")
.header("authorization", "Bearer ")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_dlp_blocks_ssn() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let body = r#"{"messages":[{"role":"user","content":"My SSN is 123-45-6789"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_dlp_blocks_email() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let body = r#"{"messages":[{"role":"user","content":"Email me at user@example.com"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_dlp_blocks_credit_card() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let body = r#"{"messages":[{"role":"user","content":"My card is 4111111111111111"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let error_msg = json["error"].as_str().unwrap();
assert!(error_msg.contains("sensitive data detected"));
assert!(error_msg.contains("credit_card"));
assert!(!error_msg.contains("4111111111111111"));
}
#[tokio::test]
async fn test_clean_request_passes_through() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-ok"})),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let body = r#"{"messages":[{"role":"user","content":"Hello, how are you?"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_proxy_streaming_response() {
let mock_server = MockServer::start().await;
let sse_body =
"data: {\"id\":\"chatcmpl-1\"}\n\ndata: {\"id\":\"chatcmpl-2\"}\n\ndata: [DONE]\n\n";
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.append_header("content-type", "text/event-stream")
.set_body_string(sse_body),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let body = r#"{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("chatcmpl-1"));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_multiple_endpoints() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/embeddings"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"data": [{"embedding": [0.1, 0.2]}]})),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let body = r#"{"model":"text-embedding-ada-002","input":"Hello"}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/embeddings")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_delete_method() {
let mock_server = MockServer::start().await;
Mock::given(method("DELETE"))
.and(path("/v1/files/file-abc"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"id": "file-abc", "deleted": true})),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("DELETE")
.uri("/v1/files/file-abc")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_put_method() {
let mock_server = MockServer::start().await;
Mock::given(method("PUT"))
.and(path("/v1/files/file-abc"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"ok": true})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("PUT")
.uri("/v1/files/file-abc")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from("{}"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_patch_method() {
let mock_server = MockServer::start().await;
Mock::given(method("PATCH"))
.and(path("/v1/assistants/asst-abc"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"ok": true})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("PATCH")
.uri("/v1/assistants/asst-abc")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from("{}"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_head_method() {
let mock_server = MockServer::start().await;
Mock::given(method("HEAD"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("HEAD")
.uri("/v1/models")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_options_method() {
let mock_server = MockServer::start().await;
Mock::given(method("OPTIONS"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(204))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("OPTIONS")
.uri("/v1/models")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert!(resp.status().is_success() || resp.status() == StatusCode::NO_CONTENT);
}
#[tokio::test]
async fn test_config_parsing() {
let toml_str = r#"
log_level = "debug"
[server]
host = "0.0.0.0"
port = 8080
[upstream]
base_url = "https://api.openai.com"
[[keys]]
virtual_key = "vk-1"
real_key = "sk-real-1"
[[keys]]
virtual_key = "vk-2"
real_key = "sk-real-1"
[dlp]
patterns = [
{ name = "ssn", regex = '\\b\\d{3}-\\d{2}-\\d{4}\\b' },
]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.server.host, "0.0.0.0");
assert_eq!(config.server.port, 8080);
assert_eq!(config.keys.len(), 2);
assert_eq!(config.dlp.patterns.len(), 1);
assert_eq!(config.log_level, "debug");
}
#[tokio::test]
async fn test_config_defaults() {
let toml_str = r#"
[server]
[upstream]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 18790);
assert_eq!(config.upstream.base_url, "https://api.openai.com");
assert_eq!(config.log_level, "info");
assert!(config.dlp.patterns.is_empty());
}
#[tokio::test]
async fn test_config_invalid_regex() {
let toml_str = r#"
[server]
[upstream]
[dlp]
patterns = [
{ name = "bad", regex = '[invalid' },
]
"#;
let result = Config::parse(toml_str);
assert!(result.is_err());
}
#[tokio::test]
async fn test_config_key_map() {
let toml_str = r#"
[server]
[upstream]
[[keys]]
virtual_key = "vk-a"
real_key = "sk-1"
[[keys]]
virtual_key = "vk-b"
real_key = "sk-1"
"#;
let config = Config::parse(toml_str).unwrap();
let map = config.key_map();
let (key_a, prov_a) = map.get("vk-a").unwrap();
assert_eq!(key_a, "sk-1");
assert_eq!(*prov_a, Provider::Openai);
let (key_b, prov_b) = map.get("vk-b").unwrap();
assert_eq!(key_b, "sk-1");
assert_eq!(*prov_b, Provider::Openai);
}
#[tokio::test]
async fn test_config_listen_addr() {
let toml_str = r#"
[server]
host = "0.0.0.0"
port = 9090
[upstream]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.listen_addr(), "0.0.0.0:9090");
}
#[tokio::test]
async fn test_config_from_file() {
use std::io::Write;
let dir = std::env::temp_dir().join("clawshell_test");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test_config.toml");
let mut f = std::fs::File::create(&path).unwrap();
writeln!(
f,
r#"
[server]
host = "127.0.0.1"
port = 4000
[upstream]
base_url = "https://api.openai.com"
"#
)
.unwrap();
let config = Config::from_file(&path).unwrap();
assert_eq!(config.server.port, 4000);
std::fs::remove_file(&path).ok();
}
#[tokio::test]
async fn test_app_state_from_config() {
let toml_str = r#"
[server]
host = "127.0.0.1"
port = 3000
[upstream]
base_url = "https://api.openai.com"
[[keys]]
virtual_key = "vk-1"
real_key = "sk-real-1"
"#;
let config = Config::parse(toml_str).unwrap();
let state = AppState::from_config(&config);
let resolved = state.key_manager.resolve("vk-1").unwrap();
assert_eq!(resolved.real_key, "sk-real-1");
assert_eq!(resolved.provider, Provider::Openai);
assert!(state.key_manager.resolve("vk-unknown").is_none());
}
#[tokio::test]
async fn test_app_state_from_config_with_anthropic() {
let toml_str = r#"
[server]
host = "127.0.0.1"
port = 3000
[upstream]
base_url = "https://api.openai.com"
anthropic_base_url = "https://api.anthropic.com"
[[keys]]
virtual_key = "vk-oai"
real_key = "sk-oai-key"
provider = "openai"
[[keys]]
virtual_key = "vk-ant"
real_key = "sk-ant-key"
provider = "anthropic"
"#;
let config = Config::parse(toml_str).unwrap();
let state = AppState::from_config(&config);
let oai = state.key_manager.resolve("vk-oai").unwrap();
assert_eq!(oai.real_key, "sk-oai-key");
assert_eq!(oai.provider, Provider::Openai);
let ant = state.key_manager.resolve("vk-ant").unwrap();
assert_eq!(ant.real_key, "sk-ant-key");
assert_eq!(ant.provider, Provider::Anthropic);
}
#[tokio::test]
async fn test_proxy_error_on_unreachable_upstream() {
let state = AppState {
key_manager: Arc::new(KeyManager::new(
[("vk-1".to_string(), ("sk-1".to_string(), Provider::Openai))]
.into_iter()
.collect(),
)),
dlp_scanner: Arc::new(DlpScanner::new(&[]).unwrap()),
proxy_client: Arc::new(ProxyClient::with_upstream_urls(
{
let mut urls = BTreeMap::new();
urls.insert(Provider::Openai, "http://127.0.0.1:1".to_string());
urls.insert(Provider::Anthropic, "http://127.0.0.1:1".to_string());
urls
},
"2023-06-01".to_string(),
)),
};
let app = build_router(state);
let req = Request::builder()
.uri("/v1/models")
.header("authorization", "Bearer vk-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
}
#[tokio::test]
async fn test_anthropic_forward_uses_x_api_key() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.and(header("x-api-key", "sk-ant-real-1"))
.and(header("anthropic-version", "2023-06-01"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"id": "msg-abc",
"type": "message",
"content": [{"type": "text", "text": "Hello!"}]
})))
.mount(&mock_server)
.await;
let app = make_app_with_anthropic(&mock_server.uri());
let body = r#"{"model":"claude-sonnet-4-5-20250929","max_tokens":1024,"messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header("authorization", "Bearer vk-ant-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["id"], "msg-abc");
}
#[tokio::test]
async fn test_anthropic_no_bearer_header_sent_upstream() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.and(header("x-api-key", "sk-ant-real-1"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"ok": true})))
.mount(&mock_server)
.await;
let app = make_app_with_anthropic(&mock_server.uri());
let body = r#"{"model":"claude-sonnet-4-5-20250929","max_tokens":1024,"messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header("authorization", "Bearer vk-ant-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_anthropic_streaming_response() {
let mock_server = MockServer::start().await;
let sse_body = "event: message_start\ndata: {\"type\":\"message_start\"}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\nevent: message_stop\ndata: {\"type\":\"message_stop\"}\n\n";
Mock::given(method("POST"))
.and(path("/v1/messages"))
.and(header("x-api-key", "sk-ant-real-1"))
.respond_with(
ResponseTemplate::new(200)
.append_header("content-type", "text/event-stream")
.set_body_string(sse_body),
)
.mount(&mock_server)
.await;
let app = make_app_with_anthropic(&mock_server.uri());
let body = r#"{"model":"claude-sonnet-4-5-20250929","max_tokens":1024,"stream":true,"messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header("authorization", "Bearer vk-ant-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("message_start"));
assert!(body_str.contains("message_stop"));
}
#[tokio::test]
async fn test_anthropic_dlp_blocks_sensitive_data() {
let mock_server = MockServer::start().await;
let mut key_map = BTreeMap::new();
key_map.insert(
"vk-ant-dlp".to_string(),
("sk-ant-key".to_string(), Provider::Anthropic),
);
let patterns = vec![DlpPattern {
name: "ssn".to_string(),
regex: r"\b\d{3}-\d{2}-\d{4}\b".to_string(),
action: DlpAction::Block,
}];
let mut upstream_urls = BTreeMap::new();
upstream_urls.insert(Provider::Openai, mock_server.uri());
upstream_urls.insert(Provider::Anthropic, mock_server.uri());
let state = AppState {
key_manager: Arc::new(KeyManager::new(key_map)),
dlp_scanner: Arc::new(DlpScanner::new(&patterns).unwrap()),
proxy_client: Arc::new(ProxyClient::with_upstream_urls(
upstream_urls,
"2023-06-01".to_string(),
)),
};
let app = build_router(state);
let body = r#"{"messages":[{"role":"user","content":"My SSN is 123-45-6789"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header("authorization", "Bearer vk-ant-dlp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_openai_still_uses_bearer_auth() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(header("authorization", "Bearer sk-real-1"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-ok"})),
)
.mount(&mock_server)
.await;
let app = make_app_with_anthropic(&mock_server.uri());
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_config_parsing_with_anthropic() {
let toml_str = r#"
log_level = "debug"
[server]
host = "0.0.0.0"
port = 8080
[upstream]
base_url = "https://api.openai.com"
anthropic_base_url = "https://api.anthropic.com"
[[keys]]
virtual_key = "vk-oai"
real_key = "sk-oai-1"
provider = "openai"
[[keys]]
virtual_key = "vk-ant"
real_key = "sk-ant-1"
provider = "anthropic"
[dlp]
patterns = [
{ name = "ssn", regex = '\\b\\d{3}-\\d{2}-\\d{4}\\b' },
]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.keys.len(), 2);
assert_eq!(config.keys[0].provider, Provider::Openai);
assert_eq!(config.keys[1].provider, Provider::Anthropic);
assert_eq!(
config.upstream.anthropic_base_url,
Some("https://api.anthropic.com".to_string())
);
}
#[tokio::test]
async fn test_config_provider_defaults_to_openai() {
let toml_str = r#"
[server]
[upstream]
[[keys]]
virtual_key = "vk-1"
real_key = "sk-1"
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.keys[0].provider, Provider::Openai);
}
#[tokio::test]
async fn test_config_upstream_url_resolution() {
let toml_str = r#"
[server]
[upstream]
base_url = "https://custom-openai.example.com"
anthropic_base_url = "https://custom-anthropic.example.com"
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(
config.upstream_url(Provider::Openai),
"https://custom-openai.example.com"
);
assert_eq!(
config.upstream_url(Provider::Anthropic),
"https://custom-anthropic.example.com"
);
}
#[tokio::test]
async fn test_config_anthropic_url_defaults() {
let toml_str = r#"
[server]
[upstream]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(
config.upstream_url(Provider::Anthropic),
"https://api.anthropic.com"
);
}
fn make_app_with_redact(upstream_url: &str) -> axum::Router {
let mut key_map = BTreeMap::new();
key_map.insert(
"vk-test-1".to_string(),
("sk-real-1".to_string(), Provider::Openai),
);
let patterns = vec![
DlpPattern {
name: "ssn".to_string(),
regex: r"\b\d{3}-\d{2}-\d{4}\b".to_string(),
action: DlpAction::Block,
},
DlpPattern {
name: "email".to_string(),
regex: r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b".to_string(),
action: DlpAction::Redact,
},
DlpPattern {
name: "phone_number".to_string(),
regex: r"\b(?:\+?1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b".to_string(),
action: DlpAction::Redact,
},
];
let mut upstream_urls = BTreeMap::new();
upstream_urls.insert(Provider::Openai, upstream_url.to_string());
upstream_urls.insert(Provider::Anthropic, upstream_url.to_string());
let state = AppState {
key_manager: Arc::new(KeyManager::new(key_map)),
dlp_scanner: Arc::new(DlpScanner::with_response_scanning(&patterns, true).unwrap()),
proxy_client: Arc::new(ProxyClient::with_upstream_urls(
upstream_urls,
"2023-06-01".to_string(),
)),
};
build_router(state)
}
#[tokio::test]
async fn test_request_redact_email_before_forwarding() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(header("authorization", "Bearer sk-real-1"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-ok"})),
)
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"messages":[{"role":"user","content":"Contact me at user@example.com"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_request_block_ssn_still_works() {
let mock_server = MockServer::start().await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"messages":[{"role":"user","content":"My SSN is 123-45-6789"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_response_dlp_redacts_pii() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"id": "chatcmpl-abc",
"choices": [{
"message": {
"content": "Here is your info: email user@example.com, phone 555-123-4567"
}
}]
})))
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"What is my info?"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("[REDACTED:email]"));
assert!(body_str.contains("[REDACTED:phone_number]"));
assert!(!body_str.contains("user@example.com"));
assert!(!body_str.contains("555-123-4567"));
}
#[tokio::test]
async fn test_response_dlp_redacts_ssn_in_response() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": [{
"message": {
"content": "Your SSN is 123-45-6789"
}
}]
})))
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"What is my SSN?"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("[REDACTED:ssn]"));
assert!(!body_str.contains("123-45-6789"));
}
#[tokio::test]
async fn test_response_dlp_clean_response_untouched() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"id": "chatcmpl-clean",
"choices": [{"message": {"content": "Hello! How can I help?"}}]
})))
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["id"], "chatcmpl-clean");
assert!(!json.to_string().contains("REDACTED"));
}
#[tokio::test]
async fn test_response_dlp_disabled() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": [{"message": {"content": "Email: user@example.com"}}]
})))
.mount(&mock_server)
.await;
let mut key_map = BTreeMap::new();
key_map.insert(
"vk-test-1".to_string(),
("sk-real-1".to_string(), Provider::Openai),
);
let patterns = vec![DlpPattern {
name: "email".to_string(),
regex: r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b".to_string(),
action: DlpAction::Redact,
}];
let mut upstream_urls = BTreeMap::new();
upstream_urls.insert(Provider::Openai, mock_server.uri());
upstream_urls.insert(Provider::Anthropic, mock_server.uri());
let state = AppState {
key_manager: Arc::new(KeyManager::new(key_map)),
dlp_scanner: Arc::new(DlpScanner::with_response_scanning(&patterns, false).unwrap()),
proxy_client: Arc::new(ProxyClient::with_upstream_urls(
upstream_urls,
"2023-06-01".to_string(),
)),
};
let app = build_router(state);
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("user@example.com"));
assert!(!body_str.contains("REDACTED"));
}
#[tokio::test]
async fn test_config_dlp_action_parsing() {
let toml_str = r#"
[server]
[upstream]
[dlp]
scan_responses = true
patterns = [
{ name = "ssn", regex = '\\b\\d{3}-\\d{2}-\\d{4}\\b', action = "block" },
{ name = "email", regex = '\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}\\b', action = "redact" },
{ name = "phone", regex = '\\b\\d{3}-\\d{3}-\\d{4}\\b' },
]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.dlp.patterns.len(), 3);
assert_eq!(config.dlp.patterns[0].action, DlpAction::Block);
assert_eq!(config.dlp.patterns[1].action, DlpAction::Redact);
assert_eq!(config.dlp.patterns[2].action, DlpAction::Block); assert!(config.dlp.scan_responses);
}
#[tokio::test]
async fn test_config_dlp_scan_responses_default() {
let toml_str = r#"
[server]
[upstream]
"#;
let config = Config::parse(toml_str).unwrap();
assert!(config.dlp.scan_responses);
}
#[tokio::test]
async fn test_config_dlp_scan_responses_disabled() {
let toml_str = r#"
[server]
[upstream]
[dlp]
scan_responses = false
"#;
let config = Config::parse(toml_str).unwrap();
assert!(!config.dlp.scan_responses);
}
#[tokio::test]
async fn test_redacted_body_content_length_not_stale() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(header("authorization", "Bearer sk-real-1"))
.and(body_string_contains("[REDACTED:email]"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-ok"})),
)
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body =
r#"{"messages":[{"role":"user","content":"Contact me at user@example.com please"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.header("content-length", body.len().to_string())
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let resp_body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&resp_body).unwrap();
assert_eq!(json["id"], "chatcmpl-ok");
}
#[tokio::test]
async fn test_default_patterns_empty() {
let toml_str = r#"
[server]
[upstream]
"#;
let config = Config::parse(toml_str).unwrap();
assert!(config.dlp.patterns.is_empty());
}
#[tokio::test]
async fn test_unsupported_method_returns_405() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("TRACE")
.uri("/v1/models")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn test_root_route_requires_auth() {
let mock_server = MockServer::start().await;
let app = make_app(&mock_server.uri());
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_root_route_with_auth() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"ok": true})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.uri("/")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_config_anthropic_version_default() {
let toml_str = r#"
[server]
[upstream]
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.upstream.anthropic_version, "2023-06-01");
}
#[tokio::test]
async fn test_config_anthropic_version_custom() {
let toml_str = r#"
[server]
[upstream]
anthropic_version = "2024-01-01"
"#;
let config = Config::parse(toml_str).unwrap();
assert_eq!(config.upstream.anthropic_version, "2024-01-01");
}
#[tokio::test]
async fn test_streaming_response_dlp_bypass_passes_through() {
let mock_server = MockServer::start().await;
let sse_body = "data: {\"content\":\"secret-token-xyz\"}\n\ndata: [DONE]\n\n";
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.append_header("content-type", "text/event-stream")
.set_body_string(sse_body),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let body = r#"{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("secret-token-xyz"));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_empty_body_passes_dlp() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"data": []})))
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let req = Request::builder()
.method("GET")
.uri("/v1/models")
.header("authorization", "Bearer vk-test-1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_response_content_length_updated_after_redaction() {
let mock_server = MockServer::start().await;
let original_body = r#"{"content":"Contact user@example.com for info"}"#;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.append_header("content-type", "application/json")
.set_body_string(original_body),
)
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let content_length = resp
.headers()
.get("content-length")
.map(|v| v.to_str().unwrap().parse::<usize>().unwrap());
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("[REDACTED:email]"));
if let Some(cl) = content_length {
assert_eq!(
cl,
body.len(),
"content-length header must match actual body size after redaction"
);
}
}
#[tokio::test]
async fn test_non_utf8_body_passes_through() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/audio/transcriptions"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"text": "hello world"})),
)
.mount(&mock_server)
.await;
let app = make_app(&mock_server.uri());
let binary_body: Bytes = Bytes::from(vec![0xFF, 0xFE, 0x00, 0x01, 0x80, 0x81]);
let req = Request::builder()
.method("POST")
.uri("/v1/audio/transcriptions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "multipart/form-data")
.body(Body::from(binary_body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_streaming_response_with_dlp_enabled_passes_through() {
let mock_server = MockServer::start().await;
let sse_body = "data: {\"content\":\"hello world\"}\n\ndata: [DONE]\n\n";
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_body, "text/event-stream"),
)
.mount(&mock_server)
.await;
let app = make_app_with_redact(&mock_server.uri());
let body = r#"{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hi"}]}"#;
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("authorization", "Bearer vk-test-1")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
ct.contains("text/event-stream"),
"Expected text/event-stream content-type, got: {}",
ct
);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(body_str.contains("[DONE]"));
}