Skip to main content

fraiseql_server/
extractors.rs

1//! Custom extractors for GraphQL handlers.
2//!
3//! Provides extractors for `SecurityContext` and other request-level data.
4
5use std::future::Future;
6
7use axum::{
8    extract::{FromRequestParts, rejection::ExtensionRejection},
9    http::request::Parts,
10};
11use fraiseql_core::security::SecurityContext;
12
13use crate::middleware::AuthUser;
14
15/// Extractor for optional `SecurityContext` from authenticated user and headers.
16///
17/// When used in a handler, automatically extracts:
18/// 1. `AuthUser` from request extensions (if present)
19/// 2. Request metadata from HTTP headers (request ID, IP, tenant ID)
20/// 3. Creates `SecurityContext` from both
21///
22/// If authentication is not present, returns `None` (optional extraction).
23///
24/// # Example
25///
26/// ```text
27/// // Requires: running Axum server with authentication middleware configured.
28/// async fn graphql_handler(
29///     State(state): State<AppState>,
30///     OptionalSecurityContext(context): OptionalSecurityContext,
31/// ) -> Result<Response> {
32///     // context is Option<SecurityContext>
33/// }
34/// ```
35#[derive(Debug, Clone)]
36pub struct OptionalSecurityContext(pub Option<SecurityContext>);
37
38impl<S> FromRequestParts<S> for OptionalSecurityContext
39where
40    S: Send + Sync + 'static,
41{
42    type Rejection = ExtensionRejection;
43
44    #[allow(clippy::manual_async_fn)] // Reason: axum's FromRequestParts requires explicit Future type in return position
45    fn from_request_parts(
46        parts: &mut Parts,
47        _state: &S,
48    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
49        async move {
50            // Try to extract AuthUser from extensions
51            let auth_user: Option<AuthUser> = parts.extensions.get::<AuthUser>().cloned();
52
53            // Extract request headers
54            let headers = &parts.headers;
55
56            // Create SecurityContext if auth user is present
57            let security_context = auth_user.map(|auth_user| {
58                let authenticated_user = auth_user.0;
59                let request_id = extract_request_id(headers);
60                let ip_address = extract_ip_address(headers);
61                let tenant_id = extract_tenant_id(headers);
62
63                let mut context = SecurityContext::from_user(&authenticated_user, request_id);
64                context.ip_address = ip_address;
65                context.tenant_id = tenant_id;
66
67                // Forward JWT extra_claims to security context attributes.
68                // This makes custom claims (org_id, roles, etc.) available to RLS policies
69                // and session variable injection.
70                for (key, value) in &authenticated_user.extra_claims {
71                    context.attributes.insert(key.clone(), value.clone());
72                }
73
74                // Set tenant_id from org_id JWT claim when not already set from headers.
75                // This is the standard multi-tenant pattern: the JWT org_id claim identifies
76                // which tenant's data the authenticated user may access.
77                if context.tenant_id.is_none() {
78                    if let Some(org_id) =
79                        authenticated_user.extra_claims.get("org_id").and_then(|v| v.as_str())
80                    {
81                        context.tenant_id = Some(org_id.to_string());
82                    }
83                }
84
85                context
86            });
87
88            Ok(OptionalSecurityContext(security_context))
89        }
90    }
91}
92
93/// Extract request ID from headers or generate a new one.
94fn extract_request_id(headers: &axum::http::HeaderMap) -> String {
95    headers
96        .get("x-request-id")
97        .and_then(|v| v.to_str().ok())
98        .map_or_else(|| format!("req-{}", uuid::Uuid::new_v4()), |s| s.to_string())
99}
100
101/// Extract client IP address.
102///
103/// # Security
104///
105/// Does NOT trust X-Forwarded-For or X-Real-IP headers from clients, as these
106/// are trivially spoofable. IP address should be set from `ConnectInfo<SocketAddr>`
107/// at the handler level, or via `ProxyConfig::extract_client_ip()` which validates
108/// the proxy chain before trusting forwarding headers.
109const fn extract_ip_address(_headers: &axum::http::HeaderMap) -> Option<String> {
110    // SECURITY: IP extraction from headers removed. User-supplied X-Forwarded-For
111    // and X-Real-IP headers are trivially spoofable and must not be trusted without
112    // proxy chain validation. Use ConnectInfo<SocketAddr> or ProxyConfig instead.
113    None
114}
115
116/// Extract tenant ID.
117///
118/// # Security
119///
120/// Does NOT trust the X-Tenant-ID header directly. An authenticated user could
121/// set an arbitrary tenant ID to access another organization's data. Tenant ID
122/// should be set from `TenantContext` (populated by the secured `tenant_middleware`
123/// which requires authentication) or from JWT claims.
124const fn extract_tenant_id(_headers: &axum::http::HeaderMap) -> Option<String> {
125    // SECURITY: Tenant ID extraction from headers removed. The X-Tenant-ID header
126    // is user-controlled and could be used for tenant isolation bypass. Tenant context
127    // should come from the authenticated tenant_middleware or JWT claims.
128    None
129}
130
131#[cfg(test)]
132mod tests {
133    #![allow(clippy::unwrap_used)] // Reason: test code, panics acceptable
134    #![allow(clippy::cast_precision_loss)] // Reason: test metrics reporting
135    #![allow(clippy::cast_sign_loss)] // Reason: test data uses small positive integers
136    #![allow(clippy::cast_possible_truncation)] // Reason: test data values are bounded
137    #![allow(clippy::cast_possible_wrap)] // Reason: test data values are bounded
138    #![allow(clippy::missing_panics_doc)] // Reason: test helpers
139    #![allow(clippy::missing_errors_doc)] // Reason: test helpers
140    #![allow(missing_docs)] // Reason: test code
141    #![allow(clippy::items_after_statements)] // Reason: test helpers defined near use site
142
143    use super::*;
144
145    #[test]
146    fn test_extract_request_id_from_header() {
147        let mut headers = axum::http::HeaderMap::new();
148        headers.insert("x-request-id", "req-12345".parse().unwrap());
149
150        let request_id = extract_request_id(&headers);
151        assert_eq!(request_id, "req-12345");
152    }
153
154    #[test]
155    fn test_extract_request_id_generates_default() {
156        let headers = axum::http::HeaderMap::new();
157        let request_id = extract_request_id(&headers);
158        // Should start with "req-"
159        assert!(request_id.starts_with("req-"));
160        // Should contain a UUID: "req-" (4) + UUID (36) = 40 chars
161        assert_eq!(request_id.len(), 40);
162    }
163
164    #[test]
165    fn test_extract_ip_ignores_x_forwarded_for() {
166        // SECURITY: X-Forwarded-For must NOT be trusted without proxy validation
167        let mut headers = axum::http::HeaderMap::new();
168        headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
169
170        let ip = extract_ip_address(&headers);
171        assert_eq!(ip, None, "Must not trust X-Forwarded-For header");
172    }
173
174    #[test]
175    fn test_extract_ip_ignores_x_real_ip() {
176        // SECURITY: X-Real-IP must NOT be trusted without proxy validation
177        let mut headers = axum::http::HeaderMap::new();
178        headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
179
180        let ip = extract_ip_address(&headers);
181        assert_eq!(ip, None, "Must not trust X-Real-IP header");
182    }
183
184    #[test]
185    fn test_extract_ip_address_none_when_missing() {
186        let headers = axum::http::HeaderMap::new();
187        let ip = extract_ip_address(&headers);
188        assert_eq!(ip, None);
189    }
190
191    #[test]
192    fn test_extract_tenant_id_ignores_header() {
193        // SECURITY: X-Tenant-ID must NOT be trusted from headers
194        let mut headers = axum::http::HeaderMap::new();
195        headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
196
197        let tenant_id = extract_tenant_id(&headers);
198        assert_eq!(tenant_id, None, "Must not trust X-Tenant-ID header");
199    }
200
201    #[test]
202    fn test_extract_tenant_id_none_when_missing() {
203        let headers = axum::http::HeaderMap::new();
204        let tenant_id = extract_tenant_id(&headers);
205        assert_eq!(tenant_id, None);
206    }
207
208    #[test]
209    fn test_optional_security_context_creation_from_auth_user() {
210        use chrono::Utc;
211
212        // Simulate an authenticated user from the OIDC middleware
213        let auth_user = crate::middleware::AuthUser(fraiseql_core::security::AuthenticatedUser {
214            user_id:      "user123".to_string(),
215            scopes:       vec!["read:user".to_string(), "write:post".to_string()],
216            expires_at:   Utc::now() + chrono::Duration::hours(1),
217            extra_claims: std::collections::HashMap::new(),
218        });
219
220        // Create headers with additional metadata
221        let mut headers = axum::http::HeaderMap::new();
222        headers.insert("x-request-id", "req-test-123".parse().unwrap());
223        headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
224        headers.insert("x-forwarded-for", "192.0.2.100".parse().unwrap());
225
226        // Create security context using extractor helper logic
227        let security_context = Some(auth_user).map(|auth_user| {
228            let authenticated_user = auth_user.0;
229            let request_id = extract_request_id(&headers);
230            let ip_address = extract_ip_address(&headers);
231            let tenant_id = extract_tenant_id(&headers);
232
233            let mut context = fraiseql_core::security::SecurityContext::from_user(
234                &authenticated_user,
235                request_id,
236            );
237            context.ip_address = ip_address;
238            context.tenant_id = tenant_id;
239            context
240        });
241
242        // Verify context was created correctly
243        let sec_ctx = security_context.unwrap();
244        assert_eq!(sec_ctx.user_id, "user123");
245        assert_eq!(sec_ctx.scopes, vec!["read:user".to_string(), "write:post".to_string()]);
246        // SECURITY: Tenant ID is no longer extracted from headers (spoofable).
247        // Should come from TenantContext (authenticated tenant_middleware) or JWT claims.
248        assert_eq!(sec_ctx.tenant_id, None);
249        assert_eq!(sec_ctx.request_id, "req-test-123");
250        // SECURITY: IP is no longer extracted from headers (spoofable).
251        // Should be set from ConnectInfo<SocketAddr> at handler level.
252        assert_eq!(sec_ctx.ip_address, None);
253    }
254}