fraiseql_server/
extractors.rs1use std::future::Future;
6
7use axum::{
8 extract::{FromRequestParts, rejection::ExtensionRejection},
9 http::request::Parts,
10};
11use fraiseql_core::security::SecurityContext;
12
13use crate::middleware::AuthUser;
14
15#[derive(Debug, Clone)]
36pub struct OptionalSecurityContext(pub Option<SecurityContext>);
37
38impl<S> FromRequestParts<S> for OptionalSecurityContext
39where
40 S: Send + Sync + 'static,
41{
42 type Rejection = ExtensionRejection;
43
44 #[allow(clippy::manual_async_fn)] fn from_request_parts(
46 parts: &mut Parts,
47 _state: &S,
48 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
49 async move {
50 let auth_user: Option<AuthUser> = parts.extensions.get::<AuthUser>().cloned();
52
53 let headers = &parts.headers;
55
56 let security_context = auth_user.map(|auth_user| {
58 let authenticated_user = auth_user.0;
59 let request_id = extract_request_id(headers);
60 let ip_address = extract_ip_address(headers);
61 let tenant_id = extract_tenant_id(headers);
62
63 let mut context = SecurityContext::from_user(&authenticated_user, request_id);
64 context.ip_address = ip_address;
65 context.tenant_id = tenant_id;
66 context
67 });
68
69 Ok(OptionalSecurityContext(security_context))
70 }
71 }
72}
73
74fn extract_request_id(headers: &axum::http::HeaderMap) -> String {
76 headers
77 .get("x-request-id")
78 .and_then(|v| v.to_str().ok())
79 .map_or_else(|| format!("req-{}", uuid::Uuid::new_v4()), |s| s.to_string())
80}
81
82const fn extract_ip_address(_headers: &axum::http::HeaderMap) -> Option<String> {
91 None
95}
96
97const fn extract_tenant_id(_headers: &axum::http::HeaderMap) -> Option<String> {
106 None
110}
111
112#[cfg(test)]
113mod tests {
114 #![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)] use super::*;
125
126 #[test]
127 fn test_extract_request_id_from_header() {
128 let mut headers = axum::http::HeaderMap::new();
129 headers.insert("x-request-id", "req-12345".parse().unwrap());
130
131 let request_id = extract_request_id(&headers);
132 assert_eq!(request_id, "req-12345");
133 }
134
135 #[test]
136 fn test_extract_request_id_generates_default() {
137 let headers = axum::http::HeaderMap::new();
138 let request_id = extract_request_id(&headers);
139 assert!(request_id.starts_with("req-"));
141 assert_eq!(request_id.len(), 40);
143 }
144
145 #[test]
146 fn test_extract_ip_ignores_x_forwarded_for() {
147 let mut headers = axum::http::HeaderMap::new();
149 headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
150
151 let ip = extract_ip_address(&headers);
152 assert_eq!(ip, None, "Must not trust X-Forwarded-For header");
153 }
154
155 #[test]
156 fn test_extract_ip_ignores_x_real_ip() {
157 let mut headers = axum::http::HeaderMap::new();
159 headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
160
161 let ip = extract_ip_address(&headers);
162 assert_eq!(ip, None, "Must not trust X-Real-IP header");
163 }
164
165 #[test]
166 fn test_extract_ip_address_none_when_missing() {
167 let headers = axum::http::HeaderMap::new();
168 let ip = extract_ip_address(&headers);
169 assert_eq!(ip, None);
170 }
171
172 #[test]
173 fn test_extract_tenant_id_ignores_header() {
174 let mut headers = axum::http::HeaderMap::new();
176 headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
177
178 let tenant_id = extract_tenant_id(&headers);
179 assert_eq!(tenant_id, None, "Must not trust X-Tenant-ID header");
180 }
181
182 #[test]
183 fn test_extract_tenant_id_none_when_missing() {
184 let headers = axum::http::HeaderMap::new();
185 let tenant_id = extract_tenant_id(&headers);
186 assert_eq!(tenant_id, None);
187 }
188
189 #[test]
190 fn test_optional_security_context_creation_from_auth_user() {
191 use chrono::Utc;
192
193 let auth_user = crate::middleware::AuthUser(fraiseql_core::security::AuthenticatedUser {
195 user_id: "user123".to_string(),
196 scopes: vec!["read:user".to_string(), "write:post".to_string()],
197 expires_at: Utc::now() + chrono::Duration::hours(1),
198 });
199
200 let mut headers = axum::http::HeaderMap::new();
202 headers.insert("x-request-id", "req-test-123".parse().unwrap());
203 headers.insert("x-tenant-id", "tenant-acme".parse().unwrap());
204 headers.insert("x-forwarded-for", "192.0.2.100".parse().unwrap());
205
206 let security_context = Some(auth_user).map(|auth_user| {
208 let authenticated_user = auth_user.0;
209 let request_id = extract_request_id(&headers);
210 let ip_address = extract_ip_address(&headers);
211 let tenant_id = extract_tenant_id(&headers);
212
213 let mut context = fraiseql_core::security::SecurityContext::from_user(
214 &authenticated_user,
215 request_id,
216 );
217 context.ip_address = ip_address;
218 context.tenant_id = tenant_id;
219 context
220 });
221
222 let sec_ctx = security_context.unwrap();
224 assert_eq!(sec_ctx.user_id, "user123");
225 assert_eq!(sec_ctx.scopes, vec!["read:user".to_string(), "write:post".to_string()]);
226 assert_eq!(sec_ctx.tenant_id, None);
229 assert_eq!(sec_ctx.request_id, "req-test-123");
230 assert_eq!(sec_ctx.ip_address, None);
233 }
234}