use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::{
extract::{FromRequestParts, Request, State},
http::{HeaderMap, Method, StatusCode, request::Parts},
middleware::Next,
response::{IntoResponse, Response},
};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use tokio::sync::RwLock;
use crate::config::GatewayOidcConfig;
use crate::db::Database;
#[derive(Debug, Clone)]
pub struct UserIdentity {
pub user_id: String,
pub role: String,
pub workspace_read_scopes: Vec<String>,
}
pub fn hash_token(token: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hasher.finalize().into()
}
#[derive(Clone)]
pub struct MultiAuthState {
hashed_tokens: Vec<([u8; 32], UserIdentity)>,
display_token: Option<String>,
}
impl MultiAuthState {
pub fn single(token: String, user_id: String) -> Self {
let hash = hash_token(&token);
Self {
hashed_tokens: vec![(
hash,
UserIdentity {
user_id,
role: "admin".to_string(),
workspace_read_scopes: Vec::new(),
},
)],
display_token: Some(token),
}
}
pub fn multi(tokens: HashMap<String, UserIdentity>) -> Self {
let hashed_tokens: Vec<([u8; 32], UserIdentity)> = tokens
.into_iter()
.map(|(tok, identity)| (hash_token(&tok), identity))
.collect();
Self {
hashed_tokens,
display_token: None,
}
}
pub fn authenticate(&self, candidate: &str) -> Option<&UserIdentity> {
let candidate_hash = hash_token(candidate);
let mut matched: Option<&UserIdentity> = None;
for (stored_hash, identity) in &self.hashed_tokens {
if bool::from(candidate_hash.ct_eq(stored_hash)) {
matched = Some(identity);
}
}
matched
}
pub fn first_token(&self) -> Option<&str> {
self.display_token.as_deref()
}
pub fn first_identity(&self) -> Option<&UserIdentity> {
self.hashed_tokens.first().map(|(_, id)| id)
}
}
#[derive(Clone)]
#[allow(clippy::type_complexity)]
pub struct DbAuthenticator {
store: Arc<dyn Database>,
cache: Arc<RwLock<lru::LruCache<[u8; 32], (UserIdentity, Instant)>>>,
}
impl DbAuthenticator {
const CACHE_TTL_SECS: u64 = 60;
const MAX_CACHE_ENTRIES: NonZeroUsize = match NonZeroUsize::new(1024) {
Some(v) => v,
None => unreachable!(),
};
pub fn new(store: Arc<dyn Database>) -> Self {
Self {
store,
cache: Arc::new(RwLock::new(lru::LruCache::new(Self::MAX_CACHE_ENTRIES))),
}
}
pub async fn invalidate_user(&self, user_id: &str) {
let mut cache = self.cache.write().await;
let keys_to_remove: Vec<[u8; 32]> = cache
.iter()
.filter(|(_, (identity, _))| identity.user_id == user_id)
.map(|(k, _)| *k)
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
}
pub async fn authenticate(&self, candidate: &str) -> Result<Option<UserIdentity>, ()> {
let hash = hash_token(candidate);
{
let mut cache = self.cache.write().await;
if let Some((identity, inserted_at)) = cache.get(&hash) {
if inserted_at.elapsed().as_secs() < Self::CACHE_TTL_SECS {
return Ok(Some(identity.clone()));
}
cache.pop(&hash);
}
}
let (token_record, user_record) = match self.store.authenticate_token(&hash).await {
Ok(Some(pair)) => pair,
Ok(None) => return Ok(None),
Err(e) => {
tracing::warn!("DB auth lookup failed: {e}");
return Err(());
}
};
let identity = UserIdentity {
user_id: user_record.id.clone(),
role: user_record.role.clone(),
workspace_read_scopes: Vec::new(),
};
let store = self.store.clone();
let token_id = token_record.id;
let user_id = user_record.id;
tokio::spawn(async move {
let _ = store.record_token_usage(token_id).await;
let _ = store.record_login(&user_id).await;
});
{
let mut cache = self.cache.write().await;
cache.put(hash, (identity.clone(), Instant::now()));
}
Ok(Some(identity))
}
}
#[derive(Clone)]
pub struct CombinedAuthState {
pub env_auth: MultiAuthState,
pub db_auth: Option<DbAuthenticator>,
pub oidc: Option<OidcState>,
}
impl From<MultiAuthState> for CombinedAuthState {
fn from(env_auth: MultiAuthState) -> Self {
Self {
env_auth,
db_auth: None,
oidc: None,
}
}
}
pub struct AuthenticatedUser(pub UserIdentity);
impl<S> FromRequestParts<S> for AuthenticatedUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<UserIdentity>()
.cloned()
.map(AuthenticatedUser)
.ok_or((StatusCode::UNAUTHORIZED, "Not authenticated"))
}
}
pub struct AdminUser(pub UserIdentity);
impl<S> FromRequestParts<S> for AdminUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let identity = parts
.extensions
.get::<UserIdentity>()
.cloned()
.ok_or((StatusCode::UNAUTHORIZED, "Not authenticated"))?;
if identity.role != "admin" {
return Err((StatusCode::FORBIDDEN, "Admin role required"));
}
Ok(AdminUser(identity))
}
}
#[derive(Clone)]
struct CachedKey {
decoding_key: DecodingKey,
algorithm: Algorithm,
fetched_at: Instant,
}
#[derive(Clone)]
struct FailedFetch {
failed_at: Instant,
}
const FETCH_FAILURE_BACKOFF: Duration = Duration::from_secs(10);
#[derive(Clone)]
pub struct OidcState {
config: GatewayOidcConfig,
key_cache: Arc<RwLock<HashMap<String, CachedKey>>>,
fetch_failures: Arc<RwLock<HashMap<String, FailedFetch>>>,
http_client: reqwest::Client,
}
#[derive(Debug, thiserror::Error)]
enum OidcError {
#[error("missing `kid` in JWT header")]
MissingKid,
#[error("unsupported algorithm: {0}")]
UnsupportedAlgorithm(String),
#[error("key fetch failed: {0}")]
KeyFetch(String),
#[error("signature verification failed")]
InvalidSignature,
#[error("claim validation failed: {0}")]
InvalidClaims(String),
}
const KEY_CACHE_TTL: Duration = Duration::from_secs(3600);
const KEY_CACHE_MAX_ENTRIES: usize = 64;
impl OidcState {
pub fn from_config(oidc: &GatewayOidcConfig) -> Result<Self, String> {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| format!("failed to build OIDC HTTP client: {e}"))?;
Ok(Self {
config: oidc.clone(),
key_cache: Arc::new(RwLock::new(HashMap::new())),
fetch_failures: Arc::new(RwLock::new(HashMap::new())),
http_client,
})
}
#[cfg(test)]
pub(crate) async fn seed_key(&self, kid: &str, key: DecodingKey, algorithm: Algorithm) {
let mut cache = self.key_cache.write().await;
cache.insert(
kid.to_string(),
CachedKey {
decoding_key: key,
algorithm,
fetched_at: Instant::now(),
},
);
}
fn header_name(&self) -> &str {
&self.config.header
}
async fn fetch_single_key(&self, url: &str, alg: Algorithm) -> Result<DecodingKey, OidcError> {
let body = self.fetch_url_text(url).await?;
let trimmed = body.trim();
if trimmed.starts_with("-----BEGIN") {
match alg {
Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(trimmed.as_bytes())
.map_err(|e| OidcError::KeyFetch(format!("EC PEM parse: {e}"))),
Algorithm::EdDSA => DecodingKey::from_ed_pem(trimmed.as_bytes())
.map_err(|e| OidcError::KeyFetch(format!("EdDSA PEM parse: {e}"))),
_ => DecodingKey::from_rsa_pem(trimmed.as_bytes())
.map_err(|e| OidcError::KeyFetch(format!("RSA PEM parse: {e}"))),
}
} else {
let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_str(trimmed)
.map_err(|e| OidcError::KeyFetch(format!("JWK parse: {e}")))?;
DecodingKey::from_jwk(&jwk).map_err(|e| OidcError::KeyFetch(format!("JWK decode: {e}")))
}
}
async fn fetch_jwks_key(
&self,
url: &str,
kid: &str,
) -> Result<(DecodingKey, Algorithm), OidcError> {
let body = self.fetch_url_text(url).await?;
let jwks: jsonwebtoken::jwk::JwkSet = serde_json::from_str(&body)
.map_err(|e| OidcError::KeyFetch(format!("JWKS parse: {e}")))?;
let jwk = jwks
.find(kid)
.ok_or_else(|| OidcError::KeyFetch(format!("kid '{kid}' not found in JWKS")))?;
let alg = resolve_algorithm(jwk)?;
let key = DecodingKey::from_jwk(jwk)
.map_err(|e| OidcError::KeyFetch(format!("JWK decode: {e}")))?;
Ok((key, alg))
}
const MAX_JWKS_RESPONSE_BYTES: usize = 256 * 1024;
async fn fetch_url_text(&self, url: &str) -> Result<String, OidcError> {
let response = self
.http_client
.get(url)
.send()
.await
.map_err(|e| OidcError::KeyFetch(format!("HTTP request: {e}")))?
.error_for_status()
.map_err(|e| OidcError::KeyFetch(format!("HTTP error: {e}")))?;
if let Some(len) = response.content_length()
&& len as usize > Self::MAX_JWKS_RESPONSE_BYTES
{
return Err(OidcError::KeyFetch(format!(
"JWKS response too large ({len} bytes, max {})",
Self::MAX_JWKS_RESPONSE_BYTES
)));
}
let bytes = response
.bytes()
.await
.map_err(|e| OidcError::KeyFetch(format!("reading body: {e}")))?;
if bytes.len() > Self::MAX_JWKS_RESPONSE_BYTES {
return Err(OidcError::KeyFetch(format!(
"JWKS response too large ({} bytes, max {})",
bytes.len(),
Self::MAX_JWKS_RESPONSE_BYTES
)));
}
String::from_utf8(bytes.to_vec())
.map_err(|e| OidcError::KeyFetch(format!("response not UTF-8: {e}")))
}
async fn get_or_fetch_key(
&self,
kid: &str,
alg: Algorithm,
) -> Result<(DecodingKey, Algorithm), OidcError> {
{
let cache = self.key_cache.read().await;
if let Some(cached) = cache.get(kid)
&& cached.fetched_at.elapsed() < KEY_CACHE_TTL
{
return Ok((cached.decoding_key.clone(), cached.algorithm));
}
}
{
let failures = self.fetch_failures.read().await;
if let Some(failed) = failures.get(kid)
&& failed.failed_at.elapsed() < FETCH_FAILURE_BACKOFF
{
return Err(OidcError::KeyFetch(
"JWKS fetch recently failed, backing off".to_string(),
));
}
}
let fetch_result = if self.config.jwks_url.contains("{kid}") {
let encoded_kid: String =
url::form_urlencoded::byte_serialize(kid.as_bytes()).collect();
let url = self.config.jwks_url.replace("{kid}", &encoded_kid);
self.fetch_single_key(&url, alg).await.map(|key| (key, alg))
} else {
self.fetch_jwks_key(&self.config.jwks_url, kid).await
};
let (key, resolved_alg) = match fetch_result {
Ok(result) => {
self.fetch_failures.write().await.remove(kid);
result
}
Err(e) => {
self.fetch_failures.write().await.insert(
kid.to_string(),
FailedFetch {
failed_at: Instant::now(),
},
);
return Err(e);
}
};
let mut cache = self.key_cache.write().await;
cache.retain(|_, v| v.fetched_at.elapsed() < KEY_CACHE_TTL);
if cache.len() >= KEY_CACHE_MAX_ENTRIES {
if let Some(oldest_kid) = cache
.iter()
.min_by_key(|(_, v)| v.fetched_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_kid);
}
}
cache.insert(
kid.to_string(),
CachedKey {
decoding_key: key.clone(),
algorithm: resolved_alg,
fetched_at: Instant::now(),
},
);
Ok((key, resolved_alg))
}
}
fn resolve_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Result<Algorithm, OidcError> {
match jwk.common.key_algorithm {
Some(jsonwebtoken::jwk::KeyAlgorithm::ES256) => Ok(Algorithm::ES256),
Some(jsonwebtoken::jwk::KeyAlgorithm::ES384) => Ok(Algorithm::ES384),
Some(jsonwebtoken::jwk::KeyAlgorithm::RS256) => Ok(Algorithm::RS256),
Some(jsonwebtoken::jwk::KeyAlgorithm::RS384) => Ok(Algorithm::RS384),
Some(jsonwebtoken::jwk::KeyAlgorithm::RS512) => Ok(Algorithm::RS512),
Some(jsonwebtoken::jwk::KeyAlgorithm::PS256) => Ok(Algorithm::PS256),
Some(jsonwebtoken::jwk::KeyAlgorithm::PS384) => Ok(Algorithm::PS384),
Some(jsonwebtoken::jwk::KeyAlgorithm::PS512) => Ok(Algorithm::PS512),
Some(jsonwebtoken::jwk::KeyAlgorithm::EdDSA) => Ok(Algorithm::EdDSA),
Some(other) => Err(OidcError::UnsupportedAlgorithm(format!("{other:?}"))),
None => Err(OidcError::UnsupportedAlgorithm(
"missing alg in JWK".to_string(),
)),
}
}
fn verify_signature(
original_jwt: &str,
key: &DecodingKey,
alg: Algorithm,
) -> Result<(), OidcError> {
let parts: Vec<&str> = original_jwt.split('.').collect();
if parts.len() != 3 {
return Err(OidcError::InvalidSignature);
}
let signing_input = format!("{}.{}", parts[0], parts[1]);
let raw_sig = parts[2];
let sig_bytes = URL_SAFE_NO_PAD
.decode(raw_sig.trim_end_matches('='))
.map_err(|_| OidcError::InvalidSignature)?;
let sig_bytes = if matches!(alg, Algorithm::ES256 | Algorithm::ES384) {
match try_der_to_raw(&sig_bytes, alg) {
Some(raw) => raw,
None => sig_bytes,
}
} else {
sig_bytes
};
let sig_b64 = URL_SAFE_NO_PAD.encode(&sig_bytes);
let valid = jsonwebtoken::crypto::verify(&sig_b64, signing_input.as_bytes(), key, alg)
.map_err(|_| OidcError::InvalidSignature)?;
if valid {
Ok(())
} else {
Err(OidcError::InvalidSignature)
}
}
fn normalize_b64_segment(seg: &str) -> String {
seg.trim_end_matches('=').to_string()
}
fn normalize_jwt_for_claims(jwt: &str) -> String {
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() != 3 {
return jwt.to_string();
}
format!(
"{}.{}.{}",
normalize_b64_segment(parts[0]),
normalize_b64_segment(parts[1]),
normalize_b64_segment(parts[2]),
)
}
fn try_der_to_raw(der: &[u8], alg: Algorithm) -> Option<Vec<u8>> {
let component_len = match alg {
Algorithm::ES256 => 32,
Algorithm::ES384 => 48,
_ => return None,
};
if der.len() < 6 || der[0] != 0x30 {
return None;
}
let mut pos = 1;
let _seq_len = parse_der_length(der, &mut pos)?;
if pos >= der.len() || der[pos] != 0x02 {
return None;
}
pos += 1;
let r_len = parse_der_length(der, &mut pos)?;
if r_len > component_len + 1 {
return None;
}
let r_bytes = der.get(pos..pos + r_len)?;
pos += r_len;
if pos >= der.len() || der[pos] != 0x02 {
return None;
}
pos += 1;
let s_len = parse_der_length(der, &mut pos)?;
if s_len > component_len + 1 {
return None;
}
let s_bytes = der.get(pos..pos + s_len)?;
let r = strip_der_leading_zero(r_bytes);
let s = strip_der_leading_zero(s_bytes);
if r.len() > component_len || s.len() > component_len {
return None;
}
let mut raw = vec![0u8; component_len * 2];
raw[component_len - r.len()..component_len].copy_from_slice(r);
raw[component_len * 2 - s.len()..].copy_from_slice(s);
Some(raw)
}
fn parse_der_length(der: &[u8], pos: &mut usize) -> Option<usize> {
let b = *der.get(*pos)?;
*pos += 1;
if b < 0x80 {
Some(b as usize)
} else {
let num_bytes = (b & 0x7F) as usize;
if num_bytes == 0 || num_bytes > 2 {
return None;
}
let mut len: usize = 0;
for _ in 0..num_bytes {
len = len
.checked_mul(256)?
.checked_add(*der.get(*pos)? as usize)?;
*pos += 1;
}
Some(len)
}
}
fn strip_der_leading_zero(bytes: &[u8]) -> &[u8] {
if bytes.len() > 1 && bytes[0] == 0x00 {
&bytes[1..]
} else {
bytes
}
}
async fn validate_oidc_jwt(oidc: &OidcState, jwt: &str) -> Result<String, OidcError> {
let normalized = normalize_jwt_for_claims(jwt);
let header = jsonwebtoken::decode_header(&normalized)
.map_err(|e| OidcError::InvalidClaims(format!("malformed header: {e}")))?;
let kid = header.kid.ok_or(OidcError::MissingKid)?;
let alg = header.alg;
let (key, resolved_alg) = oidc.get_or_fetch_key(&kid, alg).await?;
verify_signature(jwt, &key, resolved_alg)?;
let mut validation = Validation::new(resolved_alg);
validation.insecure_disable_signature_validation();
if let Some(ref iss) = oidc.config.issuer {
validation.set_issuer(&[iss]);
}
if let Some(ref aud) = oidc.config.audience {
validation.set_audience(&[aud]);
} else {
validation.validate_aud = false;
}
let data = jsonwebtoken::decode::<serde_json::Value>(&normalized, &key, &validation)
.map_err(|e| OidcError::InvalidClaims(format!("{e}")))?;
let sub = data
.claims
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| OidcError::InvalidClaims("missing `sub` claim".to_string()))?;
Ok(sub)
}
fn allows_query_token_auth(request: &Request) -> bool {
if request.method() != Method::GET {
return false;
}
matches!(
request.uri().path(),
"/api/chat/events" | "/api/logs/events" | "/api/chat/ws"
)
}
fn query_token(request: &Request) -> Option<String> {
let query = request.uri().query()?;
url::form_urlencoded::parse(query.as_bytes()).find_map(|(k, v)| {
if k == "token" {
Some(v.into_owned())
} else {
None
}
})
}
fn extract_token(headers: &HeaderMap, request: &Request) -> Option<String> {
if let Some(auth_header) = headers.get("authorization")
&& let Ok(value) = auth_header.to_str()
&& value.len() > 7
&& value[..7].eq_ignore_ascii_case("Bearer ")
{
return Some(value[7..].to_string());
}
if allows_query_token_auth(request) {
return query_token(request);
}
None
}
pub async fn auth_middleware(
State(auth): State<CombinedAuthState>,
headers: HeaderMap,
mut request: Request,
next: Next,
) -> Response {
let token = extract_token(&headers, &request);
if let Some(ref tok) = token {
if let Some(identity) = auth.env_auth.authenticate(tok) {
request.extensions_mut().insert(identity.clone());
return next.run(request).await;
}
if let Some(ref db_auth) = auth.db_auth {
match db_auth.authenticate(tok).await {
Ok(Some(identity)) => {
request.extensions_mut().insert(identity);
return next.run(request).await;
}
Err(()) => {
return (StatusCode::SERVICE_UNAVAILABLE, "Database unavailable")
.into_response();
}
Ok(None) => {}
}
}
}
if let Some(ref oidc) = auth.oidc
&& let Some(jwt_header) = headers.get(oidc.header_name())
&& let Ok(jwt) = jwt_header.to_str()
{
match validate_oidc_jwt(oidc, jwt).await {
Ok(sub) => {
tracing::debug!(sub = %sub, "OIDC auth succeeded");
let identity = UserIdentity {
user_id: sub,
role: "member".to_string(),
workspace_read_scopes: Vec::new(),
};
request.extensions_mut().insert(identity);
return next.run(request).await;
}
Err(e) => {
tracing::warn!(error = %e, "OIDC auth failed");
}
}
}
(StatusCode::UNAUTHORIZED, "Invalid or missing auth token").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::credentials::TEST_AUTH_SECRET_TOKEN;
#[test]
fn test_multi_auth_state_single() {
let state = MultiAuthState::single("tok-123".to_string(), "alice".to_string());
let identity = state.authenticate("tok-123");
assert!(identity.is_some());
assert_eq!(identity.unwrap().user_id, "alice");
}
#[test]
fn test_multi_auth_state_reject_wrong_token() {
let state = MultiAuthState::single("tok-123".to_string(), "alice".to_string());
assert!(state.authenticate("wrong-token").is_none());
}
#[test]
fn test_multi_auth_state_multi_users() {
let mut tokens = HashMap::new();
tokens.insert(
"tok-alice".to_string(),
UserIdentity {
user_id: "alice".to_string(),
role: "admin".to_string(),
workspace_read_scopes: Vec::new(),
},
);
tokens.insert(
"tok-bob".to_string(),
UserIdentity {
user_id: "bob".to_string(),
role: "admin".to_string(),
workspace_read_scopes: Vec::new(),
},
);
let state = MultiAuthState::multi(tokens);
let alice = state.authenticate("tok-alice").unwrap();
assert_eq!(alice.user_id, "alice");
let bob = state.authenticate("tok-bob").unwrap();
assert_eq!(bob.user_id, "bob");
assert!(state.authenticate("tok-charlie").is_none());
}
#[test]
fn test_multi_auth_state_first_token() {
let state = MultiAuthState::single("my-token".to_string(), "user1".to_string());
assert_eq!(state.first_token(), Some("my-token"));
}
#[test]
fn test_multi_auth_state_first_identity() {
let state = MultiAuthState::single("my-token".to_string(), "user1".to_string());
let identity = state.first_identity().unwrap();
assert_eq!(identity.user_id, "user1");
}
use axum::Router;
use axum::body::Body;
use axum::middleware;
use axum::routing::{get, post};
use tower::ServiceExt;
async fn dummy_handler() -> &'static str {
"ok"
}
fn test_app(token: &str) -> Router {
let state = CombinedAuthState::from(MultiAuthState::single(
token.to_string(),
"test-user".to_string(),
));
Router::new()
.route("/api/chat/events", get(dummy_handler))
.route("/api/logs/events", get(dummy_handler))
.route("/api/chat/ws", get(dummy_handler))
.route("/api/chat/history", get(dummy_handler))
.route("/api/chat/send", post(dummy_handler))
.layer(middleware::from_fn_with_state(state, auth_middleware))
}
#[tokio::test]
async fn test_valid_bearer_token_passes() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", format!("Bearer {TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_invalid_bearer_token_rejected() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer wrong-token")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_query_token_allowed_for_chat_events() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri(format!("/api/chat/events?token={TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_query_token_allowed_for_logs_events() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri(format!("/api/logs/events?token={TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_query_token_allowed_for_ws_upgrade() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri(format!("/api/chat/ws?token={TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_query_token_url_encoded() {
let raw_token = "tok+en/with spaces";
let app = test_app(raw_token);
let req = Request::builder()
.uri("/api/chat/events?token=tok%2Ben%2Fwith%20spaces")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_query_token_url_encoded_mismatch() {
let app = test_app("real-token");
let req = Request::builder()
.uri("/api/chat/events?token=wrong%2Dtoken")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_query_token_rejected_for_non_sse_get() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri(format!("/api/chat/history?token={TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_query_token_rejected_for_post() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.method(Method::POST)
.uri(format!("/api/chat/send?token={TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_query_token_invalid_rejected() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events?token=wrong-token")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_no_auth_at_all_rejected() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_bearer_header_works_for_post() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.method(Method::POST)
.uri("/api/chat/send")
.header("Authorization", format!("Bearer {TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_bearer_prefix_case_insensitive() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", format!("bearer {TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_bearer_prefix_mixed_case() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", format!("BEARER {TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_empty_bearer_token_rejected() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer ")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_token_with_whitespace_rejected() {
let app = test_app(TEST_AUTH_SECRET_TOKEN);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", format!("Bearer {TEST_AUTH_SECRET_TOKEN}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_normalize_jwt_noop_for_rfc_compliant() {
let jwt = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.sig";
assert_eq!(normalize_jwt_for_claims(jwt), jwt);
}
#[test]
fn test_normalize_jwt_strips_padding() {
let jwt = "eyJhbGciOiJIUzI1NiJ9==.eyJzdWIiOiJ0ZXN0In0=.c2ln";
let normalized = normalize_jwt_for_claims(jwt);
assert!(!normalized.contains('='));
assert!(normalized.starts_with("eyJhbGciOiJIUzI1NiJ9."));
}
#[test]
fn test_normalize_b64_segment_no_padding() {
assert_eq!(normalize_b64_segment("abc"), "abc");
}
#[test]
fn test_normalize_b64_segment_with_padding() {
assert_eq!(normalize_b64_segment("abc=="), "abc");
}
#[test]
fn test_try_der_to_raw_non_der_passthrough() {
let raw = vec![0x01; 64];
assert!(try_der_to_raw(&raw, Algorithm::ES256).is_none());
}
#[test]
fn test_try_der_to_raw_valid_der() {
let r = vec![0x01; 32];
let s = vec![0x02; 32];
let mut der = vec![0x30, 68]; der.push(0x02);
der.push(32);
der.extend_from_slice(&r);
der.push(0x02);
der.push(32);
der.extend_from_slice(&s);
let raw = try_der_to_raw(&der, Algorithm::ES256).expect("should parse DER");
assert_eq!(raw.len(), 64);
assert_eq!(&raw[..32], &r[..]);
assert_eq!(&raw[32..], &s[..]);
}
#[test]
fn test_try_der_to_raw_with_leading_zero() {
let r = {
let mut v = vec![0x00]; v.extend_from_slice(&[0x80; 32]); v
};
let s = vec![0x01; 32];
let mut der = vec![0x30, 69]; der.push(0x02);
der.push(33); der.extend_from_slice(&r);
der.push(0x02);
der.push(32);
der.extend_from_slice(&s);
let raw = try_der_to_raw(&der, Algorithm::ES256).expect("should parse DER");
assert_eq!(raw.len(), 64);
assert_eq!(raw[0], 0x80);
}
#[test]
fn test_strip_der_leading_zero() {
assert_eq!(strip_der_leading_zero(&[0x00, 0x80, 0x01]), &[0x80, 0x01]);
assert_eq!(strip_der_leading_zero(&[0x80, 0x01]), &[0x80, 0x01]);
assert_eq!(strip_der_leading_zero(&[0x00]), &[0x00]); }
#[test]
fn test_parse_der_length_short_form() {
let data = [0x20]; let mut pos = 0;
assert_eq!(parse_der_length(&data, &mut pos), Some(32));
assert_eq!(pos, 1);
}
#[test]
fn test_parse_der_length_long_form_one_byte() {
let data = [0x81, 0x80];
let mut pos = 0;
assert_eq!(parse_der_length(&data, &mut pos), Some(128));
assert_eq!(pos, 2);
}
#[test]
fn test_parse_der_length_long_form_two_bytes() {
let data = [0x82, 0x01, 0x00];
let mut pos = 0;
assert_eq!(parse_der_length(&data, &mut pos), Some(256));
assert_eq!(pos, 3);
}
#[test]
fn test_try_der_to_raw_long_form_sequence_length() {
let r = {
let mut v = vec![0x00]; v.extend_from_slice(&[0xFF; 48]); v
};
let s = {
let mut v = vec![0x00]; v.extend_from_slice(&[0xAA; 48]);
v
};
let content_len = 2 + r.len() + 2 + s.len(); assert!(content_len < 128);
let mut der = vec![0x30, 0x81, content_len as u8];
der.push(0x02);
der.push(r.len() as u8);
der.extend_from_slice(&r);
der.push(0x02);
der.push(s.len() as u8);
der.extend_from_slice(&s);
let raw = try_der_to_raw(&der, Algorithm::ES384)
.expect("should parse DER with long-form sequence length");
assert_eq!(raw.len(), 96); assert_eq!(raw[0], 0xFF);
assert_eq!(raw[48], 0xAA);
}
#[test]
fn test_kid_url_encoded_in_jwks_url() {
let encoded: String = url::form_urlencoded::byte_serialize(b"../../evil?x=1").collect();
let url = "https://example.com/keys/{kid}".replace("{kid}", &encoded);
assert!(!url.contains("../"));
assert!(url.contains("%2F"));
}
#[test]
fn test_verify_signature_rejects_tampered_payload() {
use jsonwebtoken::{EncodingKey, Header};
let secret = b"test-secret-at-least-256-bits!!!";
let header = Header::new(Algorithm::HS256);
let claims = serde_json::json!({"sub": "alice", "exp": 9999999999u64});
let token =
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret)).unwrap();
let key = DecodingKey::from_secret(secret);
assert!(verify_signature(&token, &key, Algorithm::HS256).is_ok());
let parts: Vec<&str> = token.split('.').collect();
let tampered = format!("{}.{}.{}", parts[0], "dGFtcGVyZWQ", parts[2]);
assert!(verify_signature(&tampered, &key, Algorithm::HS256).is_err());
}
#[tokio::test]
async fn test_validate_oidc_jwt_rejects_missing_sub() {
use jsonwebtoken::{EncodingKey, Header};
let secret = b"test-secret-at-least-256-bits!!!";
let mut header = Header::new(Algorithm::HS256);
header.kid = Some("test-kid".to_string());
let claims = serde_json::json!({"exp": 9999999999u64, "name": "alice"});
let token =
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret)).unwrap();
let mut validation = Validation::new(Algorithm::HS256);
validation.insecure_disable_signature_validation();
validation.validate_aud = false;
let data = jsonwebtoken::decode::<serde_json::Value>(
&token,
&DecodingKey::from_secret(secret),
&validation,
)
.unwrap();
let result = data
.claims
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| OidcError::InvalidClaims("missing `sub` claim".to_string()));
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("sub"),
"error should mention missing sub claim"
);
}
#[test]
fn test_issuer_validation_disabled_when_not_configured() {
let mut validation = Validation::new(Algorithm::HS256);
validation.insecure_disable_signature_validation();
validation.validate_aud = false;
use jsonwebtoken::{EncodingKey, Header};
let secret = b"test-secret-at-least-256-bits!!!";
let claims =
serde_json::json!({"sub": "alice", "exp": 9999999999u64, "iss": "https://example.com"});
let token = jsonwebtoken::encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret),
)
.unwrap();
let result = jsonwebtoken::decode::<serde_json::Value>(
&token,
&DecodingKey::from_secret(secret),
&validation,
);
assert!(
result.is_ok(),
"token with any issuer should pass when issuer validation is disabled"
);
}
async fn identity_handler(AuthenticatedUser(identity): AuthenticatedUser) -> String {
identity.user_id
}
async fn scopes_handler(AuthenticatedUser(identity): AuthenticatedUser) -> String {
serde_json::to_string(&identity.workspace_read_scopes).unwrap()
}
fn multi_user_app(tokens: HashMap<String, UserIdentity>) -> Router {
let state = CombinedAuthState::from(MultiAuthState::multi(tokens));
Router::new()
.route("/api/chat/events", get(identity_handler))
.route("/api/chat/send", post(identity_handler))
.route("/api/scopes", get(scopes_handler))
.layer(middleware::from_fn_with_state(state, auth_middleware))
}
fn two_user_tokens() -> HashMap<String, UserIdentity> {
let mut tokens = HashMap::new();
tokens.insert(
"tok-alice".to_string(),
UserIdentity {
user_id: "alice".to_string(),
role: "admin".to_string(),
workspace_read_scopes: vec!["shared".to_string()],
},
);
tokens.insert(
"tok-bob".to_string(),
UserIdentity {
user_id: "bob".to_string(),
role: "admin".to_string(),
workspace_read_scopes: vec!["shared".to_string(), "alice".to_string()],
},
);
tokens
}
#[tokio::test]
async fn test_multi_user_alice_token_resolves_to_alice() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer tok-alice")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "alice");
}
#[tokio::test]
async fn test_multi_user_bob_token_resolves_to_bob() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer tok-bob")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "bob");
}
#[tokio::test]
async fn test_multi_user_sequential_tokens_resolve_independently() {
let tokens = two_user_tokens();
let app1 = multi_user_app(tokens.clone());
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer tok-alice")
.body(Body::empty())
.unwrap();
let resp = app1.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "alice");
let app2 = multi_user_app(tokens);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer tok-bob")
.body(Body::empty())
.unwrap();
let resp = app2.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "bob");
}
#[tokio::test]
async fn test_multi_user_unknown_token_rejected() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer tok-charlie")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_multi_user_workspace_read_scopes_propagated() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.uri("/api/scopes")
.header("Authorization", "Bearer tok-alice")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let scopes: Vec<String> = serde_json::from_slice(&body).unwrap();
assert_eq!(scopes, vec!["shared"]);
}
#[tokio::test]
async fn test_multi_user_bob_has_two_scopes() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.uri("/api/scopes")
.header("Authorization", "Bearer tok-bob")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let scopes: Vec<String> = serde_json::from_slice(&body).unwrap();
assert_eq!(scopes, vec!["shared", "alice"]);
}
#[tokio::test]
async fn test_multi_user_query_param_resolves_correct_identity() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.uri("/api/chat/events?token=tok-bob")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "bob");
}
#[tokio::test]
async fn test_multi_user_post_with_bearer_resolves_identity() {
let app = multi_user_app(two_user_tokens());
let req = Request::builder()
.method(Method::POST)
.uri("/api/chat/send")
.header("Authorization", "Bearer tok-alice")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "alice");
}
#[tokio::test]
async fn test_multi_user_empty_scopes_for_single_user() {
let state = CombinedAuthState::from(MultiAuthState::single(
"tok-only".to_string(),
"solo".to_string(),
));
let app = Router::new()
.route("/api/scopes", get(scopes_handler))
.layer(middleware::from_fn_with_state(state, auth_middleware));
let req = Request::builder()
.uri("/api/scopes")
.header("Authorization", "Bearer tok-only")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let scopes: Vec<String> = serde_json::from_slice(&body).unwrap();
assert!(scopes.is_empty());
}
#[tokio::test]
async fn test_prefix_and_extension_tokens_rejected() {
let state = MultiAuthState::single("long-secret-token".to_string(), "user".to_string());
assert!(state.authenticate("long-secret").is_none());
assert!(state.authenticate("long-secret-token-extra").is_none());
}
const OIDC_SECRET: &[u8] = b"test-secret-at-least-256-bits!!!";
const OIDC_KID: &str = "test-kid";
const OIDC_HEADER_NAME: &str = "x-oidc-data";
fn encode_test_jwt(claims: serde_json::Value, kid: Option<&str>) -> String {
use jsonwebtoken::{EncodingKey, Header};
let mut header = Header::new(Algorithm::HS256);
header.kid = kid.map(|s| s.to_string());
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(OIDC_SECRET)).unwrap() }
fn test_oidc_config() -> crate::config::GatewayOidcConfig {
crate::config::GatewayOidcConfig {
header: OIDC_HEADER_NAME.to_string(),
jwks_url: "https://unused.example.com/keys".to_string(),
issuer: None,
audience: None,
}
}
async fn test_oidc_state() -> OidcState {
test_oidc_state_with_config(test_oidc_config()).await
}
async fn test_oidc_state_with_config(config: crate::config::GatewayOidcConfig) -> OidcState {
let oidc = OidcState::from_config(&config).unwrap(); oidc.seed_key(
OIDC_KID,
DecodingKey::from_secret(OIDC_SECRET),
Algorithm::HS256,
)
.await;
oidc
}
async fn oidc_auth_state() -> CombinedAuthState {
CombinedAuthState {
env_auth: MultiAuthState::single(
"bearer-token-123".to_string(),
"bearer-user".to_string(),
),
db_auth: None,
oidc: Some(test_oidc_state().await),
}
}
fn oidc_test_app(state: CombinedAuthState) -> Router {
Router::new()
.route("/api/chat/events", get(identity_handler))
.route("/api/chat/send", post(identity_handler))
.layer(middleware::from_fn_with_state(state, auth_middleware))
}
fn valid_oidc_jwt(sub: &str) -> String {
encode_test_jwt(
serde_json::json!({"sub": sub, "exp": 9999999999u64}),
Some(OIDC_KID),
)
}
#[tokio::test]
async fn test_oidc_auth_inserts_user_identity_for_handler() {
let app = oidc_test_app(oidc_auth_state().await);
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, valid_oidc_jwt("oidc-alice"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
StatusCode::OK,
"OIDC auth must insert UserIdentity so AuthenticatedUser extractor succeeds"
);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "oidc-alice");
}
#[tokio::test]
async fn test_oidc_auth_user_gets_member_role() {
async fn role_handler(AuthenticatedUser(id): AuthenticatedUser) -> String {
id.role
}
let state = oidc_auth_state().await;
let app = Router::new()
.route("/api/chat/events", get(role_handler))
.layer(middleware::from_fn_with_state(state, auth_middleware));
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, valid_oidc_jwt("oidc-bob"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "member");
}
#[tokio::test]
async fn test_bearer_works_when_oidc_configured_but_header_absent() {
let app = oidc_test_app(oidc_auth_state().await);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer bearer-token-123")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "bearer-user");
}
#[tokio::test]
async fn test_bearer_takes_priority_over_oidc_when_both_present() {
let app = oidc_test_app(oidc_auth_state().await);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer bearer-token-123")
.header(OIDC_HEADER_NAME, valid_oidc_jwt("oidc-alice"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(
body, "bearer-user",
"bearer should win when both auth methods are present"
);
}
#[tokio::test]
async fn test_oidc_bad_signature_returns_401_not_500() {
let state = oidc_auth_state().await;
let app = oidc_test_app(state);
let wrong_secret = b"wrong-secret-at-least-256-bits!!";
let mut header = jsonwebtoken::Header::new(Algorithm::HS256);
header.kid = Some(OIDC_KID.to_string());
let bad_jwt = jsonwebtoken::encode(
&header,
&serde_json::json!({"sub": "attacker", "exp": 9999999999u64}),
&jsonwebtoken::EncodingKey::from_secret(wrong_secret),
)
.unwrap();
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, bad_jwt)
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
StatusCode::UNAUTHORIZED,
"bad OIDC sig should yield 401, not 500"
);
}
#[tokio::test]
async fn test_invalid_oidc_does_not_block_bearer() {
let app = oidc_test_app(oidc_auth_state().await);
let req = Request::builder()
.uri("/api/chat/events")
.header("Authorization", "Bearer bearer-token-123")
.header(OIDC_HEADER_NAME, "not.a.jwt")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "bearer-user");
}
#[tokio::test]
async fn test_no_auth_with_oidc_configured() {
let app = oidc_test_app(oidc_auth_state().await);
let req = Request::builder()
.uri("/api/chat/events")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_oidc_expired_jwt_rejected() {
let app = oidc_test_app(oidc_auth_state().await);
let jwt = encode_test_jwt(
serde_json::json!({"sub": "alice", "exp": 1000000000u64}), Some(OIDC_KID),
);
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, jwt)
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_oidc_jwt_without_kid_rejected() {
let app = oidc_test_app(oidc_auth_state().await);
let jwt = encode_test_jwt(
serde_json::json!({"sub": "alice", "exp": 9999999999u64}),
None, );
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, jwt)
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_oidc_malformed_jwt_rejected() {
let app = oidc_test_app(oidc_auth_state().await);
for malformed in ["", "abc", "a.b", "a.b.c.d", "not-base64.not-base64.sig"] {
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, malformed)
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
StatusCode::UNAUTHORIZED,
"malformed JWT '{malformed}' should be rejected"
);
}
}
#[tokio::test]
async fn test_oidc_jwt_sub_not_string_rejected() {
let oidc = test_oidc_state().await;
let jwt = encode_test_jwt(
serde_json::json!({"sub": 12345, "exp": 9999999999u64}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(
result.is_err(),
"non-string sub should be rejected: {result:?}"
);
}
#[tokio::test]
async fn test_oidc_jwt_empty_sub_passes_auth() {
let app = oidc_test_app(oidc_auth_state().await);
let jwt = encode_test_jwt(
serde_json::json!({"sub": "", "exp": 9999999999u64}),
Some(OIDC_KID),
);
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, jwt)
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(body, "");
}
#[tokio::test]
async fn test_oidc_jwt_missing_sub_rejected_through_middleware() {
let app = oidc_test_app(oidc_auth_state().await);
let jwt = encode_test_jwt(
serde_json::json!({"name": "alice", "exp": 9999999999u64}), Some(OIDC_KID),
);
let req = Request::builder()
.uri("/api/chat/events")
.header(OIDC_HEADER_NAME, jwt)
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_oidc_issuer_match_accepted() {
let mut config = test_oidc_config();
config.issuer = Some("https://idp.example.com".to_string());
let oidc = test_oidc_state_with_config(config).await;
let jwt = encode_test_jwt(
serde_json::json!({
"sub": "alice",
"iss": "https://idp.example.com",
"exp": 9999999999u64,
}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(result.is_ok(), "matching issuer should pass: {result:?}");
assert_eq!(result.unwrap(), "alice");
}
#[tokio::test]
async fn test_oidc_issuer_mismatch_rejected() {
let mut config = test_oidc_config();
config.issuer = Some("https://idp.example.com".to_string());
let oidc = test_oidc_state_with_config(config).await;
let jwt = encode_test_jwt(
serde_json::json!({
"sub": "alice",
"iss": "https://evil.example.com",
"exp": 9999999999u64,
}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(result.is_err(), "wrong issuer should be rejected");
}
#[tokio::test]
async fn test_oidc_issuer_configured_but_missing_in_jwt_passes() {
let mut config = test_oidc_config();
config.issuer = Some("https://idp.example.com".to_string());
let oidc = test_oidc_state_with_config(config).await;
let jwt = encode_test_jwt(
serde_json::json!({"sub": "alice", "exp": 9999999999u64}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(
result.is_ok(),
"missing iss is not rejected by jsonwebtoken: {result:?}"
);
}
#[tokio::test]
async fn test_oidc_audience_match_accepted() {
let mut config = test_oidc_config();
config.audience = Some("my-client-id".to_string());
let oidc = test_oidc_state_with_config(config).await;
let jwt = encode_test_jwt(
serde_json::json!({
"sub": "alice",
"aud": "my-client-id",
"exp": 9999999999u64,
}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(result.is_ok(), "matching audience should pass: {result:?}");
}
#[tokio::test]
async fn test_oidc_audience_mismatch_rejected() {
let mut config = test_oidc_config();
config.audience = Some("my-client-id".to_string());
let oidc = test_oidc_state_with_config(config).await;
let jwt = encode_test_jwt(
serde_json::json!({
"sub": "alice",
"aud": "wrong-client",
"exp": 9999999999u64,
}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(result.is_err(), "wrong audience should be rejected");
}
#[tokio::test]
async fn test_oidc_audience_configured_but_missing_in_jwt_passes() {
let mut config = test_oidc_config();
config.audience = Some("my-client-id".to_string());
let oidc = test_oidc_state_with_config(config).await;
let jwt = encode_test_jwt(
serde_json::json!({"sub": "alice", "exp": 9999999999u64}),
Some(OIDC_KID),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(
result.is_ok(),
"missing aud is not rejected by jsonwebtoken: {result:?}"
);
}
#[tokio::test]
async fn test_oidc_key_cache_evicts_expired_entries() {
let oidc = test_oidc_state().await;
{
let mut cache = oidc.key_cache.write().await;
cache.insert(
"stale-kid".to_string(),
CachedKey {
decoding_key: DecodingKey::from_secret(OIDC_SECRET),
algorithm: Algorithm::HS256,
fetched_at: Instant::now() - KEY_CACHE_TTL - Duration::from_secs(1),
},
);
}
{
let cache = oidc.key_cache.read().await;
let stale = cache.get("stale-kid").unwrap();
assert!(
stale.fetched_at.elapsed() > KEY_CACHE_TTL,
"entry should be expired"
);
}
let jwt = encode_test_jwt(
serde_json::json!({"sub": "stale-user", "exp": 9999999999u64}),
Some("stale-kid"),
);
let result = validate_oidc_jwt(&oidc, &jwt).await;
assert!(
result.is_err(),
"expired cache entry should not be served; fetch fails since URL is unreachable"
);
}
#[tokio::test]
async fn test_oidc_key_cache_max_entries_constant() {
assert_eq!(
KEY_CACHE_MAX_ENTRIES, 64,
"cache should be bounded to 64 keys"
);
let oidc = test_oidc_state().await;
for i in 0..KEY_CACHE_MAX_ENTRIES {
oidc.seed_key(
&format!("kid-{i}"),
DecodingKey::from_secret(OIDC_SECRET),
Algorithm::HS256,
)
.await;
}
let cache = oidc.key_cache.read().await;
assert!(
cache.len() <= KEY_CACHE_MAX_ENTRIES + 1,
"cache should be near capacity"
);
}
#[tokio::test]
async fn test_oidc_fetch_failure_backoff() {
let oidc = test_oidc_state().await;
{
let mut failures = oidc.fetch_failures.write().await;
failures.insert(
"bad-kid".to_string(),
FailedFetch {
failed_at: Instant::now(),
},
);
}
let result = oidc.get_or_fetch_key("bad-kid", Algorithm::HS256).await;
let err_msg = match result {
Err(e) => format!("{e}"),
Ok(_) => panic!("expected backoff error"),
};
assert!(
err_msg.contains("backing off"),
"should mention backoff: {err_msg}"
);
}
#[tokio::test]
async fn test_oidc_fetch_failure_backoff_expires() {
let oidc = test_oidc_state().await;
{
let mut failures = oidc.fetch_failures.write().await;
failures.insert(
"expired-kid".to_string(),
FailedFetch {
failed_at: Instant::now() - FETCH_FAILURE_BACKOFF - Duration::from_secs(1),
},
);
}
let result = oidc.get_or_fetch_key("expired-kid", Algorithm::HS256).await;
let err_msg = match result {
Err(e) => format!("{e}"),
Ok(_) => panic!("expected fetch error (URL unreachable), not success"),
};
assert!(
!err_msg.contains("backing off"),
"should attempt fetch, not backoff: {err_msg}"
);
}
}