use std::future::Future;
use axum::{
extract::{FromRequestParts, rejection::ExtensionRejection},
http::request::Parts,
};
use fraiseql_core::security::SecurityContext;
use crate::middleware::AuthUser;
#[derive(Debug, Clone)]
pub struct OptionalSecurityContext(pub Option<SecurityContext>);
impl<S> FromRequestParts<S> for OptionalSecurityContext
where
S: Send + Sync + 'static,
{
type Rejection = ExtensionRejection;
#[allow(clippy::manual_async_fn)] fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
async move {
let auth_user: Option<AuthUser> = parts.extensions.get::<AuthUser>().cloned();
let headers = &parts.headers;
let security_context = auth_user.map(|auth_user| {
let authenticated_user = auth_user.0;
let request_id = extract_request_id(headers);
let ip_address = extract_ip_address(headers);
let tenant_id = extract_tenant_id(headers);
let mut context = SecurityContext::from_user(&authenticated_user, request_id);
context.ip_address = ip_address;
context.tenant_id = tenant_id;
for (key, value) in &authenticated_user.extra_claims {
context.attributes.insert(key.clone(), value.clone());
}
if context.tenant_id.is_none() {
if let Some(org_id) =
authenticated_user.extra_claims.get("org_id").and_then(|v| v.as_str())
{
context.tenant_id = Some(org_id.to_string());
}
}
context
});
Ok(OptionalSecurityContext(security_context))
}
}
}
fn extract_request_id(headers: &axum::http::HeaderMap) -> String {
headers
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map_or_else(|| format!("req-{}", uuid::Uuid::new_v4()), |s| s.to_string())
}
const fn extract_ip_address(_headers: &axum::http::HeaderMap) -> Option<String> {
None
}
const fn extract_tenant_id(_headers: &axum::http::HeaderMap) -> Option<String> {
None
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)]
use super::*;
#[test]
fn test_extract_request_id_from_header() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-request-id", "req-12345".parse().unwrap());
let request_id = extract_request_id(&headers);
assert_eq!(request_id, "req-12345");
}
#[test]
fn test_extract_request_id_generates_default() {
let headers = axum::http::HeaderMap::new();
let request_id = extract_request_id(&headers);
assert!(request_id.starts_with("req-"));
assert_eq!(request_id.len(), 40);
}
#[test]
fn test_extract_ip_ignores_x_forwarded_for() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
let ip = extract_ip_address(&headers);
assert_eq!(ip, None, "Must not trust X-Forwarded-For header");
}
#[test]
fn test_extract_ip_ignores_x_real_ip() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
let ip = extract_ip_address(&headers);
assert_eq!(ip, None, "Must not trust X-Real-IP header");
}
#[test]
fn test_extract_ip_address_none_when_missing() {
let headers = axum::http::HeaderMap::new();
let ip = extract_ip_address(&headers);
assert_eq!(ip, None);
}
#[test]
fn test_extract_tenant_id_ignores_header() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
let tenant_id = extract_tenant_id(&headers);
assert_eq!(tenant_id, None, "Must not trust X-Tenant-ID header");
}
#[test]
fn test_extract_tenant_id_none_when_missing() {
let headers = axum::http::HeaderMap::new();
let tenant_id = extract_tenant_id(&headers);
assert_eq!(tenant_id, None);
}
#[test]
fn test_optional_security_context_creation_from_auth_user() {
use chrono::Utc;
let auth_user = crate::middleware::AuthUser(fraiseql_core::security::AuthenticatedUser {
user_id: "user123".to_string(),
scopes: vec!["read:user".to_string(), "write:post".to_string()],
expires_at: Utc::now() + chrono::Duration::hours(1),
extra_claims: std::collections::HashMap::new(),
});
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-request-id", "req-test-123".parse().unwrap());
headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
headers.insert("x-forwarded-for", "192.0.2.100".parse().unwrap());
let security_context = Some(auth_user).map(|auth_user| {
let authenticated_user = auth_user.0;
let request_id = extract_request_id(&headers);
let ip_address = extract_ip_address(&headers);
let tenant_id = extract_tenant_id(&headers);
let mut context = fraiseql_core::security::SecurityContext::from_user(
&authenticated_user,
request_id,
);
context.ip_address = ip_address;
context.tenant_id = tenant_id;
context
});
let sec_ctx = security_context.unwrap();
assert_eq!(sec_ctx.user_id, "user123");
assert_eq!(sec_ctx.scopes, vec!["read:user".to_string(), "write:post".to_string()]);
assert_eq!(sec_ctx.tenant_id, None);
assert_eq!(sec_ctx.request_id, "req-test-123");
assert_eq!(sec_ctx.ip_address, None);
}
}