use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::net::Ipv4Addr;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::{
api_tokens::ApiTokenGenerator,
models::{ApiTokenModel, ServiceModel},
proxy::AuthenticatedProxy,
test_client::TestClient,
types::{
ChallengeRequest, ChallengeResponse, KeyType, ServiceId, VerifyChallengeResponse, headers,
},
validation::hash_user_id,
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
struct ServiceRequest {
tenant_id: String,
tenant_name: String,
user_tier: String,
request_path: String,
}
#[tokio::test]
async fn multi_tenant_service_isolation() {
let mut rng = blueprint_std::BlueprintRng::new();
let tmp = tempfile::tempdir().unwrap();
let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
let requests = Arc::new(Mutex::new(Vec::<ServiceRequest>::new()));
let requests_clone = requests.clone();
let multi_tenant_router = axum::Router::new()
.route(
"/api/{user}/data",
axum::routing::post(
|headers: axum::http::HeaderMap,
axum::extract::Path(user): axum::extract::Path<String>,
axum::extract::State(requests): axum::extract::State<
Arc<Mutex<Vec<ServiceRequest>>>,
>| async move {
let tenant_id = headers
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
let tenant_name = headers
.get("x-tenant-name")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
let user_tier = headers
.get("x-user-tier")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
let request = ServiceRequest {
tenant_id,
tenant_name,
user_tier,
request_path: format!("/api/{user}/data"),
};
requests.lock().await.push(request.clone());
axum::Json(request)
},
),
)
.with_state(requests_clone.clone());
let (service_handle, service_addr) = {
let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
.await
.expect("Failed to bind");
let server = axum::serve(listener, multi_tenant_router);
let addr = server.local_addr().unwrap();
let handle = tokio::spawn(async move {
if let Err(e) = server.await {
eprintln!("Multi-tenant service error: {e}");
}
});
(handle, addr)
};
let service_id = ServiceId::new(100);
let mut service = ServiceModel {
api_key_prefix: "mt_".to_string(),
owners: Vec::new(),
upstream_url: format!("http://localhost:{}", service_addr.port()),
tls_profile: None,
};
let tenant_keys = vec![
("alice@acme.com", "Acme Corp", "enterprise"),
("bob@widgets.io", "Widgets Inc", "startup"),
("carol@megacorp.net", "MegaCorp", "enterprise"),
("dave@indie.dev", "Indie Dev", "free"),
];
let mut signing_keys = Vec::new();
for &(email, _, _) in &tenant_keys {
let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
let public_key = signing_key.verifying_key().to_sec1_bytes();
service.add_owner(KeyType::Ecdsa, public_key.into());
signing_keys.push((email.to_string(), signing_key));
}
service.save(service_id, &proxy.db()).unwrap();
let router = proxy.router();
let client = TestClient::new(router);
let mut tenant_tokens = Vec::new();
for (&(email, company, tier), (_, signing_key)) in tenant_keys.iter().zip(signing_keys.iter()) {
let public_key = signing_key.verifying_key().to_sec1_bytes();
let req = ChallengeRequest {
pub_key: public_key.into(),
key_type: KeyType::Ecdsa,
};
let res = client
.post("/v1/auth/challenge")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&req)
.await;
let challenge_res: ChallengeResponse = res.json().await;
let (signature, _) = signing_key
.sign_prehash_recoverable(&challenge_res.challenge)
.unwrap();
let tenant_id = hash_user_id(email);
let mut additional_headers = BTreeMap::new();
additional_headers.insert("X-Tenant-Id".to_string(), tenant_id.clone());
additional_headers.insert("X-Tenant-Name".to_string(), company.to_string());
additional_headers.insert("X-User-Tier".to_string(), tier.to_string());
let verify_req = crate::types::VerifyChallengeRequest {
challenge: challenge_res.challenge,
signature: signature.to_bytes().into(),
challenge_request: req,
expires_at: 0,
additional_headers,
};
let res = client
.post("/v1/auth/verify")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&verify_req)
.await;
let verify_res: VerifyChallengeResponse = res.json().await;
let api_key = match verify_res {
VerifyChallengeResponse::Verified { api_key, .. } => api_key,
_ => panic!("Failed to verify tenant {email}"),
};
tenant_tokens.push((
email,
tenant_id,
company.to_string(),
tier.to_string(),
api_key,
));
}
for (email, tenant_id, company, tier, token) in &tenant_tokens {
let res = client
.post(&format!("/api/{}/data", email.replace(['@', '.'], "_")))
.header(headers::AUTHORIZATION, format!("Bearer {token}"))
.await;
if !res.status().is_success() {
eprintln!(
"Request failed for tenant {} with status {}",
email,
res.status()
);
eprintln!("Token: {token}");
eprintln!("Path: /api/{}/data", email.replace(['@', '.'], "_"));
}
assert!(
res.status().is_success(),
"Request failed for tenant {email}"
);
let response: ServiceRequest = res.json().await;
assert_eq!(response.tenant_id, *tenant_id);
assert_eq!(response.tenant_name, *company);
assert_eq!(response.user_tier, *tier);
}
let all_requests = requests.lock().await;
assert_eq!(
all_requests.len(),
4,
"Should have 4 requests from 4 tenants"
);
let unique_tenants: std::collections::HashSet<_> =
all_requests.iter().map(|r| r.tenant_id.clone()).collect();
assert_eq!(unique_tenants.len(), 4, "Should have 4 unique tenant IDs");
for req in all_requests.iter() {
assert!(
!req.tenant_id.contains('@'),
"Tenant ID should be hashed, not raw email"
);
assert_eq!(
req.tenant_id.len(),
32,
"Tenant ID should be 32 chars (16 bytes hex)"
);
}
service_handle.abort();
}
#[tokio::test]
async fn tenant_token_cannot_impersonate_other_tenant() {
let mut rng = blueprint_std::BlueprintRng::new();
let tmp = tempfile::tempdir().unwrap();
let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
let service_id = ServiceId::new(200);
let mut service = ServiceModel {
api_key_prefix: "sec_".to_string(),
owners: Vec::new(),
upstream_url: "http://localhost:9999".to_string(),
tls_profile: None,
};
let alice_key = k256::ecdsa::SigningKey::random(&mut rng);
let bob_key = k256::ecdsa::SigningKey::random(&mut rng);
service.add_owner(
KeyType::Ecdsa,
alice_key.verifying_key().to_sec1_bytes().into(),
);
service.add_owner(
KeyType::Ecdsa,
bob_key.verifying_key().to_sec1_bytes().into(),
);
service.save(service_id, &proxy.db()).unwrap();
let router = proxy.router();
let client = TestClient::new(router);
let alice_public = alice_key.verifying_key().to_sec1_bytes();
let req = ChallengeRequest {
pub_key: alice_public.into(),
key_type: KeyType::Ecdsa,
};
let res = client
.post("/v1/auth/challenge")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&req)
.await;
let challenge: ChallengeResponse = res.json().await;
let (alice_sig, _) = alice_key
.sign_prehash_recoverable(&challenge.challenge)
.unwrap();
let alice_tenant_id = hash_user_id("alice@example.com");
let mut alice_headers = BTreeMap::new();
alice_headers.insert("X-Tenant-Id".to_string(), alice_tenant_id.clone());
let verify_req = crate::types::VerifyChallengeRequest {
challenge: challenge.challenge,
signature: alice_sig.to_bytes().into(),
challenge_request: req,
expires_at: 0,
additional_headers: alice_headers,
};
let res = client
.post("/v1/auth/verify")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&verify_req)
.await;
let verify_res: VerifyChallengeResponse = res.json().await;
assert!(matches!(
verify_res,
VerifyChallengeResponse::Verified { .. }
));
let bob_public = bob_key.verifying_key().to_sec1_bytes();
let req = ChallengeRequest {
pub_key: bob_public.into(),
key_type: KeyType::Ecdsa,
};
let res = client
.post("/v1/auth/challenge")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&req)
.await;
let challenge: ChallengeResponse = res.json().await;
let (bob_sig, _) = bob_key
.sign_prehash_recoverable(&challenge.challenge)
.unwrap();
let mut bob_headers = BTreeMap::new();
bob_headers.insert("X-Tenant-Id".to_string(), alice_tenant_id.clone());
let verify_req = crate::types::VerifyChallengeRequest {
challenge: challenge.challenge,
signature: bob_sig.to_bytes().into(),
challenge_request: req,
expires_at: 0,
additional_headers: bob_headers,
};
let res = client
.post("/v1/auth/verify")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&verify_req)
.await;
let verify_res: VerifyChallengeResponse = res.json().await;
assert!(matches!(
verify_res,
VerifyChallengeResponse::Verified { .. }
));
}
#[tokio::test]
async fn tenant_rate_limiting_by_tier() {
let mut rng = blueprint_std::BlueprintRng::new();
let tmp = tempfile::tempdir().unwrap();
let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
let request_counts = Arc::new(Mutex::new(std::collections::HashMap::<String, u32>::new()));
let counts_clone = request_counts.clone();
let rate_limit_router = axum::Router::new().route(
"/api/limited",
axum::routing::get({
let counts = counts_clone;
move |headers: axum::http::HeaderMap| {
let counts = counts.clone();
async move {
let tier = headers
.get("x-user-tier")
.and_then(|v| v.to_str().ok())
.unwrap_or("free");
let tenant_id = headers
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let mut counts = counts.lock().await;
let count = counts.entry(tenant_id.to_string()).or_insert(0);
*count += 1;
let limit = match tier {
"enterprise" => 1000,
"startup" => 100,
"free" => 10,
_ => 5,
};
if *count > limit {
axum::response::Response::builder()
.status(429)
.body(axum::body::Body::from("Rate limit exceeded"))
.unwrap()
} else {
axum::response::Response::builder()
.status(200)
.body(axum::body::Body::from(format!(
"Request {count} of {limit}"
)))
.unwrap()
}
}
}
}),
);
let (service_handle, service_addr) = {
let listener = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
.await
.expect("Failed to bind");
let server = axum::serve(listener, rate_limit_router);
let addr = server.local_addr().unwrap();
let handle = tokio::spawn(async move {
if let Err(e) = server.await {
eprintln!("Rate limit service error: {e}");
}
});
(handle, addr)
};
let service_id = ServiceId::new(300);
let mut service = ServiceModel {
api_key_prefix: "rl_".to_string(),
owners: Vec::new(),
upstream_url: format!("http://localhost:{}", service_addr.port()),
tls_profile: None,
};
let free_user_key = k256::ecdsa::SigningKey::random(&mut rng);
service.add_owner(
KeyType::Ecdsa,
free_user_key.verifying_key().to_sec1_bytes().into(),
);
service.save(service_id, &proxy.db()).unwrap();
let router = proxy.router();
let client = TestClient::new(router);
let public_key = free_user_key.verifying_key().to_sec1_bytes();
let req = ChallengeRequest {
pub_key: public_key.into(),
key_type: KeyType::Ecdsa,
};
let res = client
.post("/v1/auth/challenge")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&req)
.await;
let challenge: ChallengeResponse = res.json().await;
let (signature, _) = free_user_key
.sign_prehash_recoverable(&challenge.challenge)
.unwrap();
let tenant_id = hash_user_id("freeuser@example.com");
let mut headers = BTreeMap::new();
headers.insert("X-Tenant-Id".to_string(), tenant_id);
headers.insert("X-User-Tier".to_string(), "free".to_string());
let verify_req = crate::types::VerifyChallengeRequest {
challenge: challenge.challenge,
signature: signature.to_bytes().into(),
challenge_request: req,
expires_at: 0,
additional_headers: headers,
};
let res = client
.post("/v1/auth/verify")
.header(headers::X_SERVICE_ID, service_id.to_string())
.json(&verify_req)
.await;
let verify_res: VerifyChallengeResponse = res.json().await;
let token = match verify_res {
VerifyChallengeResponse::Verified { api_key, .. } => api_key,
_ => panic!("Failed to get token"),
};
for i in 1..=10 {
let res = client
.get("/api/limited")
.header(headers::AUTHORIZATION, format!("Bearer {token}"))
.await;
let status = res.status();
if status != 200 {
let body = res.text().await;
panic!("Request {i} failed with status {status} and body: {body:?} (token: {token})");
}
}
let res = client
.get("/api/limited")
.header(headers::AUTHORIZATION, format!("Bearer {token}"))
.await;
assert_eq!(
res.status(),
429,
"Should be rate limited after 10 requests"
);
service_handle.abort();
}
#[tokio::test]
async fn tenant_data_isolation_verification() {
let mut rng = blueprint_std::BlueprintRng::new();
let tmp = tempfile::tempdir().unwrap();
let proxy = AuthenticatedProxy::new(tmp.path()).unwrap();
let db = proxy.db();
let tenants = vec![
("company_a", "user1@company-a.com"),
("company_b", "user2@company-b.com"),
("company_c", "admin@company-c.org"),
];
let service_id = ServiceId::new(400);
let mut service = ServiceModel {
api_key_prefix: "iso_".to_string(),
owners: Vec::new(),
upstream_url: "http://localhost:8080".to_string(),
tls_profile: None,
};
let mut tenant_data = Vec::new();
for (company, email) in tenants {
let signing_key = k256::ecdsa::SigningKey::random(&mut rng);
let public_key = signing_key.verifying_key().to_sec1_bytes();
service.add_owner(KeyType::Ecdsa, public_key.clone().into());
let tenant_id = hash_user_id(email);
let mut headers = BTreeMap::new();
headers.insert("X-Tenant-Id".to_string(), tenant_id.clone());
headers.insert("X-Company".to_string(), company.to_string());
let token_gen = ApiTokenGenerator::with_prefix(&service.api_key_prefix);
let token = token_gen.generate_token_with_expiration_and_headers(
service_id,
0,
headers.clone(),
&mut rng,
);
let mut token_model = ApiTokenModel::from(&token);
let token_id = token_model.save(&db).unwrap();
tenant_data.push((company, email, tenant_id, token_id, headers));
}
service.save(service_id, &db).unwrap();
for (company, email, expected_tenant_id, token_id, expected_headers) in tenant_data {
let token_model = ApiTokenModel::find_token_id(token_id, &db)
.unwrap()
.expect("Token should exist");
let headers = token_model.get_additional_headers();
assert_eq!(
headers.get("X-Tenant-Id"),
Some(&expected_tenant_id),
"Token for {email} should have correct tenant ID"
);
assert_eq!(
headers.get("X-Company"),
Some(&company.to_string()),
"Token for {email} should have correct company"
);
assert_eq!(headers, expected_headers);
assert_ne!(expected_tenant_id, email, "Tenant ID should be hashed");
assert_eq!(expected_tenant_id.len(), 32, "Tenant ID should be 32 chars");
}
}