fraiseql_server/
extractors.rs1use std::future::Future;
6
7use axum::{
8 extract::{FromRequestParts, rejection::ExtensionRejection},
9 http::request::Parts,
10};
11use fraiseql_core::security::SecurityContext;
12
13use crate::middleware::AuthUser;
14
15#[derive(Debug, Clone)]
36pub struct OptionalSecurityContext(pub Option<SecurityContext>);
37
38impl<S> FromRequestParts<S> for OptionalSecurityContext
39where
40 S: Send + Sync + 'static,
41{
42 type Rejection = ExtensionRejection;
43
44 #[allow(clippy::manual_async_fn)] fn from_request_parts(
46 parts: &mut Parts,
47 _state: &S,
48 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
49 async move {
50 let auth_user: Option<AuthUser> = parts.extensions.get::<AuthUser>().cloned();
52
53 let headers = &parts.headers;
55
56 let security_context = auth_user.map(|auth_user| {
58 let authenticated_user = auth_user.0;
59 let request_id = extract_request_id(headers);
60 let ip_address = extract_ip_address(headers);
61 let tenant_id = extract_tenant_id(headers);
62
63 let mut context = SecurityContext::from_user(&authenticated_user, request_id);
64 context.ip_address = ip_address;
65 context.tenant_id = tenant_id;
66
67 for (key, value) in &authenticated_user.extra_claims {
71 context.attributes.insert(key.clone(), value.clone());
72 }
73
74 if context.tenant_id.is_none() {
78 if let Some(org_id) =
79 authenticated_user.extra_claims.get("org_id").and_then(|v| v.as_str())
80 {
81 context.tenant_id = Some(org_id.to_string());
82 }
83 }
84
85 context
86 });
87
88 Ok(OptionalSecurityContext(security_context))
89 }
90 }
91}
92
93fn extract_request_id(headers: &axum::http::HeaderMap) -> String {
95 headers
96 .get("x-request-id")
97 .and_then(|v| v.to_str().ok())
98 .map_or_else(|| format!("req-{}", uuid::Uuid::new_v4()), |s| s.to_string())
99}
100
101const fn extract_ip_address(_headers: &axum::http::HeaderMap) -> Option<String> {
110 None
114}
115
116const fn extract_tenant_id(_headers: &axum::http::HeaderMap) -> Option<String> {
125 None
129}
130
131#[cfg(test)]
132mod tests {
133 #![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)] use super::*;
144
145 #[test]
146 fn test_extract_request_id_from_header() {
147 let mut headers = axum::http::HeaderMap::new();
148 headers.insert("x-request-id", "req-12345".parse().unwrap());
149
150 let request_id = extract_request_id(&headers);
151 assert_eq!(request_id, "req-12345");
152 }
153
154 #[test]
155 fn test_extract_request_id_generates_default() {
156 let headers = axum::http::HeaderMap::new();
157 let request_id = extract_request_id(&headers);
158 assert!(request_id.starts_with("req-"));
160 assert_eq!(request_id.len(), 40);
162 }
163
164 #[test]
165 fn test_extract_ip_ignores_x_forwarded_for() {
166 let mut headers = axum::http::HeaderMap::new();
168 headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
169
170 let ip = extract_ip_address(&headers);
171 assert_eq!(ip, None, "Must not trust X-Forwarded-For header");
172 }
173
174 #[test]
175 fn test_extract_ip_ignores_x_real_ip() {
176 let mut headers = axum::http::HeaderMap::new();
178 headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
179
180 let ip = extract_ip_address(&headers);
181 assert_eq!(ip, None, "Must not trust X-Real-IP header");
182 }
183
184 #[test]
185 fn test_extract_ip_address_none_when_missing() {
186 let headers = axum::http::HeaderMap::new();
187 let ip = extract_ip_address(&headers);
188 assert_eq!(ip, None);
189 }
190
191 #[test]
192 fn test_extract_tenant_id_ignores_header() {
193 let mut headers = axum::http::HeaderMap::new();
195 headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
196
197 let tenant_id = extract_tenant_id(&headers);
198 assert_eq!(tenant_id, None, "Must not trust X-Tenant-ID header");
199 }
200
201 #[test]
202 fn test_extract_tenant_id_none_when_missing() {
203 let headers = axum::http::HeaderMap::new();
204 let tenant_id = extract_tenant_id(&headers);
205 assert_eq!(tenant_id, None);
206 }
207
208 #[test]
209 fn test_optional_security_context_creation_from_auth_user() {
210 use chrono::Utc;
211
212 let auth_user = crate::middleware::AuthUser(fraiseql_core::security::AuthenticatedUser {
214 user_id: "user123".to_string(),
215 scopes: vec!["read:user".to_string(), "write:post".to_string()],
216 expires_at: Utc::now() + chrono::Duration::hours(1),
217 extra_claims: std::collections::HashMap::new(),
218 });
219
220 let mut headers = axum::http::HeaderMap::new();
222 headers.insert("x-request-id", "req-test-123".parse().unwrap());
223 headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
224 headers.insert("x-forwarded-for", "192.0.2.100".parse().unwrap());
225
226 let security_context = Some(auth_user).map(|auth_user| {
228 let authenticated_user = auth_user.0;
229 let request_id = extract_request_id(&headers);
230 let ip_address = extract_ip_address(&headers);
231 let tenant_id = extract_tenant_id(&headers);
232
233 let mut context = fraiseql_core::security::SecurityContext::from_user(
234 &authenticated_user,
235 request_id,
236 );
237 context.ip_address = ip_address;
238 context.tenant_id = tenant_id;
239 context
240 });
241
242 let sec_ctx = security_context.unwrap();
244 assert_eq!(sec_ctx.user_id, "user123");
245 assert_eq!(sec_ctx.scopes, vec!["read:user".to_string(), "write:post".to_string()]);
246 assert_eq!(sec_ctx.tenant_id, None);
249 assert_eq!(sec_ctx.request_id, "req-test-123");
250 assert_eq!(sec_ctx.ip_address, None);
253 }
254}