use ati::core::jwt::{self, JwtConfig};
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
const SECRET: &[u8] = b"per-provider-test-secret-32bytes";
fn multi_aud_config() -> JwtConfig {
jwt::config_from_secret(
SECRET,
None,
vec!["ati-proxy".into(), "parcha-custom-tools".into()],
)
}
fn single_aud_config() -> JwtConfig {
jwt::config_from_secret(SECRET, None, vec!["ati-proxy".into()])
}
fn issue_test_token(aud: &str, scope: &str) -> String {
use ati::core::jwt::{AtiNamespace, TokenClaims};
use std::collections::HashMap;
let now = jwt::now_secs();
let claims = TokenClaims {
iss: None,
sub: "sandbox:e2e-121".into(),
aud: aud.into(),
iat: now,
exp: now + 300,
jti: Some(uuid::Uuid::new_v4().to_string()),
scope: scope.into(),
ati: Some(AtiNamespace {
v: 1,
rate: HashMap::new(),
}),
job_id: None,
sandbox_id: None,
};
let config = jwt::config_from_secret(SECRET, None, vec![aud.into()]);
jwt::issue(&claims, &config).expect("issue token")
}
fn build_app(jwt_config: JwtConfig) -> axum::Router {
use ati::core::auth_generator::AuthCache;
use ati::core::keyring::Keyring;
use ati::core::manifest::ManifestRegistry;
use ati::core::skill::SkillRegistry;
use ati::proxy::server::{build_router, ProxyState};
use std::sync::Arc;
let dir = tempfile::tempdir().expect("tempdir");
let manifests_dir = dir.path().join("manifests");
std::fs::create_dir_all(&manifests_dir).expect("manifests dir");
std::fs::write(
manifests_dir.join("echo.toml"),
r#"
[provider]
name = "echo"
description = "Test echo provider"
base_url = "http://127.0.0.1:9"
[[tools]]
name = "ping"
description = "ping"
endpoint = "/ping"
method = "GET"
"#,
)
.expect("write manifest");
let registry = ManifestRegistry::load(&manifests_dir).expect("load manifest");
drop(dir);
let skill_registry = SkillRegistry::load(std::path::Path::new("/nonexistent")).unwrap();
let state = Arc::new(ProxyState {
registry,
skill_registry,
keyring: Keyring::empty(),
jwt_config: Some(jwt_config),
jwks_json: None,
auth_cache: AuthCache::new(),
upstream_url_allowlists: std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashMap::new(),
)),
});
build_router(state)
}
async fn call_with_token(app: axum::Router, token: &str) -> StatusCode {
let body = serde_json::json!({"tool_name": "ping", "args": {}});
let req = Request::builder()
.method("POST")
.uri("/call")
.header("content-type", "application/json")
.header("authorization", format!("Bearer {token}"))
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
app.oneshot(req).await.expect("oneshot").status()
}
#[tokio::test]
async fn per_provider_audience_jwt_accepted_when_in_allowlist() {
let app = build_app(multi_aud_config());
let token = issue_test_token("parcha-custom-tools", "tool:ping");
let status = call_with_token(app, &token).await;
assert_ne!(
status,
StatusCode::UNAUTHORIZED,
"alt-audience JWT must pass the auth gate when in allowlist"
);
assert_ne!(
status,
StatusCode::FORBIDDEN,
"alt-audience JWT with tool:ping scope must pass the scope gate"
);
}
#[tokio::test]
async fn per_provider_audience_jwt_rejected_when_not_in_allowlist() {
let app = build_app(single_aud_config());
let token = issue_test_token("parcha-custom-tools", "tool:ping");
let status = call_with_token(app, &token).await;
assert_eq!(
status,
StatusCode::UNAUTHORIZED,
"JWT with aud=parcha-custom-tools must be rejected by a proxy whose \
accepted_audiences = [\"ati-proxy\"]"
);
}
#[tokio::test]
async fn default_audience_jwt_still_accepted_in_multi_aud_mode() {
let app = build_app(multi_aud_config());
let token = issue_test_token("ati-proxy", "tool:ping");
let status = call_with_token(app, &token).await;
assert_ne!(
status,
StatusCode::UNAUTHORIZED,
"ati-proxy-aud JWT must still pass when allowlist contains it"
);
assert_ne!(
status,
StatusCode::FORBIDDEN,
"ati-proxy-aud JWT with tool:ping scope must pass the scope gate"
);
}
#[tokio::test]
async fn unknown_audience_rejected_under_multi_aud() {
let app = build_app(multi_aud_config());
let token = issue_test_token("evil-aud", "tool:ping");
let status = call_with_token(app, &token).await;
assert_eq!(
status,
StatusCode::UNAUTHORIZED,
"aud=evil-aud must be rejected; multi-audience mode is an allowlist, not a bypass"
);
}
#[test]
fn manifest_auth_session_token_env_field_round_trip() {
let toml = r#"
[provider]
name = "parcha_custom_tools"
description = "MCP that needs a per-audience JWT"
base_url = ""
handler = "mcp"
auth_type = "bearer"
auth_session_token_env = "PARCHA_TOOLS_SESSION_TOKEN"
[[tools]]
name = "parcha_custom_tools:recall"
description = "stub"
endpoint = "/recall"
method = "GET"
"#;
let dir = tempfile::tempdir().expect("tempdir");
let manifests_dir = dir.path().join("manifests");
std::fs::create_dir_all(&manifests_dir).expect("manifests dir");
std::fs::write(manifests_dir.join("p.toml"), toml).expect("write");
let registry =
ati::core::manifest::ManifestRegistry::load(&manifests_dir).expect("load manifest");
let (provider, _tool) = registry
.get_tool("parcha_custom_tools:recall")
.expect("tool resolves under provider:tool form");
assert_eq!(
provider.auth_session_token_env.as_deref(),
Some("PARCHA_TOOLS_SESSION_TOKEN"),
"manifest field must survive TOML → Provider deserialization"
);
}
#[test]
fn manifest_without_auth_session_token_env_defaults_to_none() {
let toml = r#"
[provider]
name = "echo"
description = "no per-provider token"
base_url = "http://example.test"
[[tools]]
name = "ping"
description = "stub"
endpoint = "/ping"
method = "GET"
"#;
let dir = tempfile::tempdir().expect("tempdir");
let manifests_dir = dir.path().join("manifests");
std::fs::create_dir_all(&manifests_dir).expect("manifests dir");
std::fs::write(manifests_dir.join("p.toml"), toml).expect("write");
let registry =
ati::core::manifest::ManifestRegistry::load(&manifests_dir).expect("load manifest");
let (provider, _) = registry.get_tool("ping").expect("tool resolves");
assert_eq!(
provider.auth_session_token_env, None,
"absent field must default to None — backwards compat guarantee"
);
}