use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode},
middleware::Next,
response::Response,
};
use base64::{engine::general_purpose, Engine as _};
use jsonwebtoken::decode_header;
use serde::Deserialize;
use std::collections::HashMap;
use crate::db::{service_accounts, users, User};
use crate::server::auth::cookie_helpers;
use crate::server::state::AppState;
#[derive(Debug, Clone, Copy)]
struct IsServiceAccount;
fn is_rise_issued_jwt(issuer: &str, public_url: &str) -> bool {
if issuer == public_url {
return true;
}
if let Some(public_base) = public_url.strip_suffix(|c: char| c.is_ascii_digit() || c == ':') {
if issuer.starts_with(public_base) {
return true;
}
}
false
}
fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
let auth_header = headers.get("Authorization")?.to_str().ok()?;
if !auth_header.starts_with("Bearer ") {
return None;
}
Some(auth_header[7..].to_string())
}
fn extract_rise_jwt_from_cookie(headers: &HeaderMap) -> Option<String> {
cookie_helpers::extract_rise_jwt_cookie(headers)
}
#[derive(Debug, Deserialize)]
struct MinimalClaims {
iss: String,
}
async fn authenticate_service_account(
state: &AppState,
token: &str,
issuer: &str,
) -> Result<User, (StatusCode, String)> {
let service_accounts = service_accounts::find_by_issuer(&state.db_pool, issuer)
.await
.map_err(|e| {
tracing::error!("Failed to find service accounts by issuer: {:#}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Database error".to_string(),
)
})?;
if service_accounts.is_empty() {
tracing::warn!("No service accounts found for issuer: {}", issuer);
return Err((
StatusCode::UNAUTHORIZED,
"No service accounts configured for this issuer".to_string(),
));
}
let mut matching_accounts = Vec::new();
let mut validation_errors = Vec::new();
for sa in &service_accounts {
let claims: HashMap<String, String> =
serde_json::from_value(sa.claims.clone()).map_err(|e| {
tracing::error!("Failed to deserialize service account claims: {:#}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid service account configuration".to_string(),
)
})?;
match state.jwt_validator.validate(token, issuer, &claims).await {
Ok(_) => {
matching_accounts.push(sa);
}
Err(e) => {
validation_errors.push(e.to_string());
}
}
}
if matching_accounts.is_empty() {
let error_msg = if validation_errors.iter().any(|e| e.contains("'aud'")) {
"The provided JWT is missing the \"aud\" claim".to_string()
} else if validation_errors.iter().any(|e| {
e.contains("validate JWT token")
|| e.contains("signature")
|| e.contains("InvalidSignature")
}) {
"The provided JWT signature could not be validated".to_string()
} else {
"No service account matches the provided claims".to_string()
};
tracing::warn!("Service account validation failed: {}", error_msg);
return Err((StatusCode::UNAUTHORIZED, error_msg));
}
if matching_accounts.len() > 1 {
let sa_ids: Vec<String> = matching_accounts
.iter()
.map(|sa| sa.id.to_string())
.collect();
tracing::error!(
"Multiple service accounts matched JWT: {:?}. This indicates ambiguous claim configuration.",
sa_ids
);
return Err((
StatusCode::CONFLICT,
"Multiple service accounts match the provided claims".to_string(),
));
}
let sa = matching_accounts[0];
tracing::info!(
"Service account authenticated: {} for project {}",
sa.id,
sa.project_id
);
let user = users::find_by_id(&state.db_pool, sa.user_id)
.await
.map_err(|e| {
tracing::error!("Failed to find user for service account: {:#}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Database error".to_string(),
)
})?
.ok_or_else(|| {
tracing::error!("Service account user not found: {}", sa.user_id);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Service account user not found".to_string(),
)
})?;
Ok(user)
}
pub async fn auth_middleware(
State(state): State<AppState>,
headers: HeaderMap,
mut req: Request,
next: Next,
) -> Result<Response, (StatusCode, String)> {
tracing::debug!(
"Auth middleware: validating request to {}",
req.uri().path()
);
let token = if let Some(cookie_token) = extract_rise_jwt_from_cookie(&headers) {
tracing::debug!(
"Auth middleware: found Rise JWT in cookie (length={})",
cookie_token.len()
);
cookie_token
} else if let Some(bearer_token) = extract_bearer_token(&headers) {
tracing::debug!(
"Auth middleware: found Bearer token in Authorization header (length={})",
bearer_token.len()
);
bearer_token
} else {
tracing::warn!("Auth middleware: no authentication token found");
return Err((
StatusCode::UNAUTHORIZED,
"Missing authentication token (cookie or Authorization header)".to_string(),
));
};
let issuer = {
decode_header(&token).map_err(|e| {
tracing::warn!("Failed to decode JWT header: {:#}", e);
(
StatusCode::UNAUTHORIZED,
format!("Invalid token format: {}", e),
)
})?;
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err((StatusCode::UNAUTHORIZED, "Invalid JWT format".to_string()));
}
let payload = parts[1];
let decoded = general_purpose::URL_SAFE_NO_PAD
.decode(payload)
.map_err(|e| {
tracing::warn!("Failed to decode JWT payload: {:#}", e);
(
StatusCode::UNAUTHORIZED,
"Invalid token encoding".to_string(),
)
})?;
let claims: MinimalClaims = serde_json::from_slice(&decoded).map_err(|e| {
tracing::warn!("Failed to parse JWT claims: {:#}", e);
(StatusCode::UNAUTHORIZED, "Invalid token claims".to_string())
})?;
claims.iss
};
tracing::debug!(
"Auth middleware: token issuer='{}', configured issuer='{}', rise public_url='{}'",
issuer,
state.auth_settings.issuer,
state.public_url
);
let user = if is_rise_issued_jwt(&issuer, &state.public_url) {
tracing::debug!("Auth middleware: authenticating with Rise-issued JWT");
let claims = state.jwt_signer.verify_jwt_skip_aud(&token).map_err(|e| {
tracing::warn!("Auth middleware: Rise JWT validation failed: {:#}", e);
(StatusCode::UNAUTHORIZED, format!("Invalid token: {}", e))
})?;
tracing::debug!("Auth middleware: Rise JWT validation successful");
let email = &claims.email;
tracing::debug!("Rise JWT validated for user: {}", email);
let groups = claims.groups.clone();
req.extensions_mut().insert(groups);
users::find_or_create(&state.db_pool, email)
.await
.map_err(|e| {
tracing::error!("Failed to find/create user: {:#}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Database error".to_string(),
)
})?
} else {
tracing::debug!("Authenticating as service account from issuer: {}", issuer);
let user = authenticate_service_account(&state, &token, &issuer).await?;
req.extensions_mut().insert(IsServiceAccount);
user
};
tracing::debug!("User authenticated: {} ({})", user.email, user.id);
req.extensions_mut().insert(user);
Ok(next.run(req).await)
}
#[allow(dead_code)]
pub async fn optional_auth_middleware(
State(state): State<AppState>,
headers: HeaderMap,
mut req: Request,
next: Next,
) -> Response {
let token = extract_rise_jwt_from_cookie(&headers).or_else(|| extract_bearer_token(&headers));
if let Some(token) = token {
if decode_header(&token).is_ok() {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() == 3 {
if let Ok(decoded) = general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) {
if let Ok(claims) = serde_json::from_slice::<MinimalClaims>(&decoded) {
if is_rise_issued_jwt(&claims.iss, &state.public_url) {
if let Ok(rise_claims) = state.jwt_signer.verify_jwt_skip_aud(&token) {
let email = &rise_claims.email;
if let Ok(user) = users::find_or_create(&state.db_pool, email).await
{
req.extensions_mut().insert(user);
}
}
}
}
}
}
}
}
next.run(req).await
}
pub async fn platform_access_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Result<Response, (StatusCode, String)> {
use crate::server::auth::platform_access::{ConfigBasedAccessChecker, PlatformAccessChecker};
let user = req.extensions().get::<User>().ok_or_else(|| {
tracing::error!("platform_access_middleware called without user in extensions");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Authentication error".to_string(),
)
})?;
if req.extensions().get::<IsServiceAccount>().is_some() {
tracing::debug!("Skipping platform access check for service account");
return Ok(next.run(req).await);
}
let groups = req.extensions().get::<Option<Vec<String>>>();
let checker = ConfigBasedAccessChecker {
config: &state.auth_settings.platform_access,
admin_users: &state.admin_users,
};
if !checker.has_platform_access(
user,
groups.as_ref().and_then(|g| g.as_ref().map(|v| v.as_ref())),
) {
tracing::warn!(
user_id = %user.id,
user_email = %user.email,
path = %req.uri().path(),
"Platform access denied for non-platform user"
);
return Err((
StatusCode::FORBIDDEN,
"You do not have access to Rise platform features. \
Your account is configured for application access only. \
Please contact your administrator if you need platform access."
.to_string(),
));
}
Ok(next.run(req).await)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn test_extract_bearer_token_valid() {
let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_static("Bearer my-token-here"),
);
let token = extract_bearer_token(&headers);
assert_eq!(token, Some("my-token-here".to_string()));
}
#[test]
fn test_extract_bearer_token_missing_header() {
let headers = HeaderMap::new();
let result = extract_bearer_token(&headers);
assert_eq!(result, None);
}
#[test]
fn test_extract_bearer_token_invalid_format() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", HeaderValue::from_static("Basic user:pass"));
let result = extract_bearer_token(&headers);
assert_eq!(result, None);
}
}