Skip to main content

fraiseql_auth/oauth/
client.rs

1//! OAuth2 and OIDC client implementations.
2
3use std::{sync::Arc, time::Duration as StdDuration};
4
5/// Timeout for all outbound OAuth2 / OIDC HTTP requests.
6const OAUTH_REQUEST_TIMEOUT: StdDuration = StdDuration::from_secs(30);
7
8use std::fmt::Write as _;
9
10use serde::{Deserialize, Serialize};
11
12use super::{
13    super::jwks::{JwksCache, JwksError},
14    pkce::PKCEChallenge,
15    types::{IdTokenClaims, TokenResponse, UserInfo},
16};
17
18/// OIDC provider configuration
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
20pub struct OIDCProviderConfig {
21    /// Provider issuer URL
22    pub issuer:                   String,
23    /// Authorization endpoint
24    pub authorization_endpoint:   String,
25    /// Token endpoint
26    pub token_endpoint:           String,
27    /// Userinfo endpoint
28    pub userinfo_endpoint:        Option<String>,
29    /// JWKS URI for public keys
30    pub jwks_uri:                 String,
31    /// Scopes supported by provider
32    pub scopes_supported:         Vec<String>,
33    /// Response types supported
34    pub response_types_supported: Vec<String>,
35}
36
37impl OIDCProviderConfig {
38    /// Create new provider configuration
39    pub fn new(
40        issuer: String,
41        authorization_endpoint: String,
42        token_endpoint: String,
43        jwks_uri: String,
44    ) -> Self {
45        Self {
46            issuer,
47            authorization_endpoint,
48            token_endpoint,
49            userinfo_endpoint: None,
50            jwks_uri,
51            scopes_supported: vec![
52                "openid".to_string(),
53                "profile".to_string(),
54                "email".to_string(),
55            ],
56            response_types_supported: vec!["code".to_string()],
57        }
58    }
59}
60
61/// Result of [`OAuth2Client::authorization_url`].
62///
63/// The caller MUST store `state` (for CSRF verification at callback), when
64/// present the PKCE `pkce.code_verifier` (for token exchange), and when
65/// present the `nonce` value (must be verified against the ID token at
66/// callback via [`OIDCClient::verify_id_token`]).
67#[derive(Debug, Clone)]
68pub struct AuthorizationRequest {
69    /// The full authorization URL to redirect the user to.
70    pub url:   String,
71    /// CSRF state value — verify this matches the `state` query param at callback.
72    pub state: String,
73    /// PKCE challenge, present only when `use_pkce = true`.
74    pub pkce:  Option<PKCEChallenge>,
75    /// OIDC nonce for replay protection.
76    ///
77    /// Present only when the authorization URL was generated by
78    /// [`OIDCClient::authorization_url`].  The caller must store this value
79    /// and pass `Some(&nonce.nonce)` to [`OIDCClient::verify_id_token`] at
80    /// callback time.
81    pub nonce: Option<super::pkce::NonceParameter>,
82}
83
84/// OAuth2 client for authorization code flow.
85#[derive(Debug, Clone)]
86pub struct OAuth2Client {
87    /// Client ID from provider.
88    pub client_id:              String,
89    /// Client secret from provider.
90    client_secret:              String,
91    /// Authorization endpoint.
92    pub authorization_endpoint: String,
93    /// Token endpoint.
94    token_endpoint:             String,
95    /// Scopes to request.
96    pub scopes:                 Vec<String>,
97    /// Use PKCE for additional security.
98    pub use_pkce:               bool,
99    /// HTTP client for token requests.
100    http_client:                reqwest::Client,
101}
102
103impl OAuth2Client {
104    /// Maximum byte size accepted from an OAuth token endpoint response.
105    ///
106    /// A well-formed token response (access_token, id_token, refresh_token) is a
107    /// few kilobytes at most.  1 MiB prevents a malicious provider from sending a
108    /// response large enough to exhaust server memory.
109    const MAX_OAUTH_RESPONSE_BYTES: usize = 1024 * 1024;
110
111    /// Create new OAuth2 client.
112    pub fn new(
113        client_id: impl Into<String>,
114        client_secret: impl Into<String>,
115        authorization_endpoint: impl Into<String>,
116        token_endpoint: impl Into<String>,
117    ) -> Self {
118        Self {
119            client_id:              client_id.into(),
120            client_secret:          client_secret.into(),
121            authorization_endpoint: authorization_endpoint.into(),
122            token_endpoint:         token_endpoint.into(),
123            scopes:                 vec![
124                "openid".to_string(),
125                "profile".to_string(),
126                "email".to_string(),
127            ],
128            use_pkce:               false,
129            http_client:            reqwest::Client::builder()
130                .timeout(OAUTH_REQUEST_TIMEOUT)
131                .build()
132                .unwrap_or_default(),
133        }
134    }
135
136    /// Set scopes for request.
137    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
138        self.scopes = scopes;
139        self
140    }
141
142    /// Enable PKCE protection.
143    pub const fn with_pkce(mut self, enabled: bool) -> Self {
144        self.use_pkce = enabled;
145        self
146    }
147
148    /// Generate authorization URL.
149    ///
150    /// Returns an [`AuthorizationRequest`] containing the URL, the CSRF state
151    /// value (must be stored and verified at callback), and an optional PKCE
152    /// challenge (when `use_pkce = true`; the `code_verifier` must be stored
153    /// and sent during token exchange).
154    pub fn authorization_url(&self, redirect_uri: &str) -> AuthorizationRequest {
155        let state = uuid::Uuid::new_v4().to_string();
156        let scope = self.scopes.join(" ");
157
158        let mut url = format!(
159            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
160            self.authorization_endpoint,
161            urlencoding::encode(&self.client_id),
162            urlencoding::encode(redirect_uri),
163            urlencoding::encode(&scope),
164            urlencoding::encode(&state),
165        );
166
167        let pkce = if self.use_pkce {
168            let challenge = PKCEChallenge::new();
169            let _ = write!(
170                url,
171                "&code_challenge={}&code_challenge_method=S256",
172                urlencoding::encode(&challenge.code_challenge),
173            );
174            Some(challenge)
175        } else {
176            None
177        };
178
179        AuthorizationRequest {
180            url,
181            state,
182            pkce,
183            nonce: None,
184        }
185    }
186
187    // 1 MiB
188
189    /// Post a form request to the token endpoint and parse the response.
190    async fn post_token_request(&self, params: &[(&str, &str)]) -> Result<TokenResponse, String> {
191        let response = self
192            .http_client
193            .post(&self.token_endpoint)
194            .form(params)
195            .send()
196            .await
197            .map_err(|e| format!("Token request failed: {e}"))?;
198
199        // Read the entire body once so we can apply a size cap regardless of
200        // whether the response is a success or an error.
201        let status = response.status();
202        let body_bytes = response
203            .bytes()
204            .await
205            .map_err(|e| format!("Failed to read token response body: {e}"))?;
206
207        if !status.is_success() {
208            let capped = &body_bytes[..body_bytes.len().min(Self::MAX_OAUTH_RESPONSE_BYTES)];
209            let body = String::from_utf8_lossy(capped);
210            return Err(format!("Token endpoint returned error: {body}"));
211        }
212
213        if body_bytes.len() > Self::MAX_OAUTH_RESPONSE_BYTES {
214            return Err(format!(
215                "Token response body too large ({} bytes, max {})",
216                body_bytes.len(),
217                Self::MAX_OAUTH_RESPONSE_BYTES
218            ));
219        }
220
221        serde_json::from_slice::<TokenResponse>(&body_bytes)
222            .map_err(|e| format!("Failed to parse token response: {e}"))
223    }
224
225    /// Exchange authorization code for tokens.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if the HTTP request to the token endpoint fails or the response
230    /// cannot be parsed as a `TokenResponse`.
231    pub async fn exchange_code(
232        &self,
233        code: &str,
234        redirect_uri: &str,
235    ) -> Result<TokenResponse, String> {
236        let params = [
237            ("grant_type", "authorization_code"),
238            ("code", code),
239            ("client_id", self.client_id.as_str()),
240            ("client_secret", self.client_secret.as_str()),
241            ("redirect_uri", redirect_uri),
242        ];
243        self.post_token_request(&params).await
244    }
245
246    /// Refresh access token using a refresh token.
247    ///
248    /// # Errors
249    ///
250    /// Propagates errors from the token endpoint request (network failure,
251    /// non-2xx HTTP status, oversized response body, or JSON parse error).
252    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, String> {
253        let params = [
254            ("grant_type", "refresh_token"),
255            ("refresh_token", refresh_token),
256            ("client_id", self.client_id.as_str()),
257            ("client_secret", self.client_secret.as_str()),
258        ];
259        self.post_token_request(&params).await
260    }
261}
262
263/// OIDC client for OpenID Connect flow.
264#[derive(Debug)]
265pub struct OIDCClient {
266    /// Provider configuration.
267    pub config:     OIDCProviderConfig,
268    /// Client ID.
269    pub client_id:  String,
270    /// Client secret — retained for token revocation and introspection endpoints.
271    #[allow(dead_code)] // Reason: retained for token revocation and introspection endpoints
272    client_secret: String,
273    /// JWKS key cache for ID token signature verification.
274    pub jwks_cache: Arc<JwksCache>,
275    /// HTTP client for userinfo requests.
276    http_client:    reqwest::Client,
277}
278
279impl OIDCClient {
280    /// Maximum byte size for a userinfo endpoint response.
281    ///
282    /// Userinfo payloads carry a small set of JWT-derived claims.
283    /// 1 `MiB` is generous while blocking allocation-bomb responses.
284    const MAX_USERINFO_RESPONSE_BYTES: usize = 1024 * 1024;
285
286    // 1 MiB
287
288    /// Create new OIDC client with JWKS caching.
289    ///
290    /// The JWKS cache TTL defaults to 1 hour.
291    ///
292    /// # Errors
293    ///
294    /// Returns [`JwksError`] if `config.jwks_uri` is not a valid HTTPS URL
295    /// (HTTP is allowed only for localhost).
296    pub fn new(
297        config: OIDCProviderConfig,
298        client_id: impl Into<String>,
299        client_secret: impl Into<String>,
300    ) -> Result<Self, JwksError> {
301        let jwks_cache = Arc::new(JwksCache::new(&config.jwks_uri, StdDuration::from_secs(3600))?);
302        Ok(Self {
303            config,
304            client_id: client_id.into(),
305            client_secret: client_secret.into(),
306            jwks_cache,
307            http_client: reqwest::Client::builder()
308                .timeout(OAUTH_REQUEST_TIMEOUT)
309                .build()
310                .unwrap_or_default(),
311        })
312    }
313
314    /// Create OIDC client with a pre-built JWKS cache (for testing).
315    pub fn with_jwks_cache(
316        config: OIDCProviderConfig,
317        client_id: impl Into<String>,
318        client_secret: impl Into<String>,
319        jwks_cache: Arc<JwksCache>,
320    ) -> Self {
321        Self {
322            config,
323            client_id: client_id.into(),
324            client_secret: client_secret.into(),
325            jwks_cache,
326            http_client: reqwest::Client::builder()
327                .timeout(OAUTH_REQUEST_TIMEOUT)
328                .build()
329                .unwrap_or_default(),
330        }
331    }
332
333    /// Generate an OIDC authorization URL with a fresh nonce for replay protection.
334    ///
335    /// This extends the standard OAuth2 flow by appending a `nonce` parameter to
336    /// the authorization URL. The returned [`AuthorizationRequest::nonce`] **must**
337    /// be stored (e.g. in the encrypted session state) and passed to
338    /// [`verify_id_token`](Self::verify_id_token) at callback time.
339    ///
340    /// PKCE is always enabled for OIDC flows started via this method.
341    pub fn authorization_url(&self, redirect_uri: &str) -> AuthorizationRequest {
342        let state = uuid::Uuid::new_v4().to_string();
343        let scope = self.config.scopes_supported.join(" ");
344        let nonce = super::pkce::NonceParameter::new();
345        let challenge = PKCEChallenge::new();
346
347        let url = format!(
348            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}\
349             &nonce={}&code_challenge={}&code_challenge_method=S256",
350            self.config.authorization_endpoint,
351            urlencoding::encode(&self.client_id),
352            urlencoding::encode(redirect_uri),
353            urlencoding::encode(&scope),
354            urlencoding::encode(&state),
355            urlencoding::encode(&nonce.nonce),
356            urlencoding::encode(&challenge.code_challenge),
357        );
358
359        AuthorizationRequest {
360            url,
361            state,
362            pkce: Some(challenge),
363            nonce: Some(nonce),
364        }
365    }
366
367    /// Verify an ID token's JWT signature and claims.
368    ///
369    /// Decodes the JWT header to extract the `kid`, fetches the matching public
370    /// key from the JWKS cache, then validates signature, issuer, audience, and
371    /// required claims.
372    ///
373    /// **Nonce**: when `expected_nonce` is `Some`, the token's `nonce` claim must
374    /// match exactly.  When it is `None` but the token *contains* a `nonce` claim,
375    /// validation still succeeds — callers that generated the authorization URL
376    /// via [`authorization_url`](Self::authorization_url) MUST pass the stored
377    /// nonce here.
378    ///
379    /// **`max_age`**: when `max_age_secs` is `Some`, the token's `auth_time` claim
380    /// is required and must be within `max_age_secs` seconds of the current time.
381    /// This prevents accepting tokens from sessions that were authenticated too
382    /// long ago (RFC 6749 §3.1.2.1 / OIDC Core §3.1.2.1).
383    ///
384    /// # Errors
385    ///
386    /// Returns an error if the token is malformed, the signature is invalid,
387    /// claims validation fails, the nonce doesn't match, or the `auth_time` /
388    /// `max_age` constraint is violated.
389    pub async fn verify_id_token(
390        &self,
391        id_token: &str,
392        expected_nonce: Option<&str>,
393        max_age_secs: Option<u64>,
394    ) -> Result<IdTokenClaims, String> {
395        // 1. Decode header to get kid
396        let header = jsonwebtoken::decode_header(id_token)
397            .map_err(|e| format!("Invalid JWT header: {e}"))?;
398        let kid = header.kid.ok_or("JWT missing 'kid' in header")?;
399
400        // 2. Get key from JWKS cache
401        let key = self
402            .jwks_cache
403            .get_key(&kid)
404            .await
405            .map_err(|e| format!("JWKS fetch error: {e}"))?
406            .ok_or_else(|| format!("No key found for kid '{kid}'"))?;
407
408        // 3. Build validation criteria
409        let mut validation = jsonwebtoken::Validation::new(header.alg);
410        validation.set_issuer(&[&self.config.issuer]);
411        validation.set_audience(&[&self.client_id]);
412        validation.set_required_spec_claims(&["exp", "iat", "iss", "aud", "sub"]);
413
414        // 4. Decode and validate
415        let token_data = jsonwebtoken::decode::<IdTokenClaims>(id_token, &key, &validation)
416            .map_err(|e| format!("ID token validation failed: {e}"))?;
417        let claims = token_data.claims;
418
419        // 5. Verify nonce using constant-time comparison (replay protection — RFC 6749 §10.12 /
420        //    OIDC Core §3.1.3.7).
421        if let Some(expected) = expected_nonce {
422            super::claims_validator::validate_nonce_claim(&claims, expected)
423                .map_err(|e| e.to_string())?;
424        }
425
426        // 6. Validate auth_time against max_age (OIDC Core §3.1.2.1).
427        if let Some(max_age) = max_age_secs {
428            let now_secs = std::time::SystemTime::now()
429                .duration_since(std::time::UNIX_EPOCH)
430                .map_or(i64::MAX, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
431            super::claims_validator::validate_auth_time_claim(&claims, max_age, now_secs)
432                .map_err(|e| e.to_string())?;
433        }
434
435        Ok(claims)
436    }
437
438    /// Fetch user information from the provider's userinfo endpoint.
439    ///
440    /// # Errors
441    ///
442    /// Returns an error if no userinfo endpoint is configured, the HTTP request
443    /// fails, or the response cannot be parsed.
444    pub async fn get_userinfo(&self, access_token: &str) -> Result<UserInfo, String> {
445        let endpoint = self
446            .config
447            .userinfo_endpoint
448            .as_ref()
449            .ok_or("No userinfo endpoint configured for this provider")?;
450
451        let response = self
452            .http_client
453            .get(endpoint)
454            .bearer_auth(access_token)
455            .send()
456            .await
457            .map_err(|e| format!("Userinfo request failed: {e}"))?;
458
459        if !response.status().is_success() {
460            return Err(format!("Userinfo endpoint returned {}", response.status()));
461        }
462
463        let body = response
464            .bytes()
465            .await
466            .map_err(|e| format!("Failed to read userinfo response: {e}"))?;
467        if body.len() > Self::MAX_USERINFO_RESPONSE_BYTES {
468            return Err(format!(
469                "Userinfo response too large ({} bytes, max {})",
470                body.len(),
471                Self::MAX_USERINFO_RESPONSE_BYTES
472            ));
473        }
474        serde_json::from_slice::<UserInfo>(&body)
475            .map_err(|e| format!("Failed to parse userinfo response: {e}"))
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    #![allow(clippy::unwrap_used)] // Reason: test code, panics acceptable
482    #![allow(missing_docs)] // Reason: test helpers
483
484    use super::*;
485
486    #[test]
487    fn oauth_response_cap_constant_is_reasonable() {
488        assert_eq!(OAuth2Client::MAX_OAUTH_RESPONSE_BYTES, 1024 * 1024);
489    }
490
491    #[test]
492    fn oauth_response_error_body_is_capped() {
493        // Simulate what post_token_request does with an oversized error body.
494        let cap = OAuth2Client::MAX_OAUTH_RESPONSE_BYTES;
495        let oversized: Vec<u8> = vec![b'e'; cap + 1_000];
496        let capped = &oversized[..oversized.len().min(cap)];
497        let text = String::from_utf8_lossy(capped).into_owned();
498        assert_eq!(text.len(), cap, "body must be capped at MAX_OAUTH_RESPONSE_BYTES");
499    }
500
501    // ── S25-H1: OAuth2/OIDC client timeout ────────────────────────────────────
502
503    #[test]
504    fn oauth_request_timeout_is_set() {
505        let secs = OAUTH_REQUEST_TIMEOUT.as_secs();
506        assert!(secs > 0 && secs <= 120, "OAuth timeout should be 1–120 s, got {secs}");
507    }
508
509    #[test]
510    fn oauth2_client_new_creates_instance() {
511        let client = OAuth2Client::new(
512            "client_id",
513            "client_secret",
514            "https://example.com/auth",
515            "https://example.com/token",
516        );
517        assert_eq!(client.client_id, "client_id");
518    }
519
520    #[test]
521    fn oidc_client_new_creates_instance() {
522        let config = OIDCProviderConfig {
523            issuer:                   "https://example.com".to_string(),
524            authorization_endpoint:   "https://example.com/auth".to_string(),
525            token_endpoint:           "https://example.com/token".to_string(),
526            userinfo_endpoint:        None,
527            jwks_uri:                 "https://example.com/.well-known/jwks.json".to_string(),
528            scopes_supported:         vec!["openid".to_string()],
529            response_types_supported: vec!["code".to_string()],
530        };
531        let client = OIDCClient::new(config, "client_id", "client_secret").unwrap();
532        assert_eq!(client.client_id, "client_id");
533    }
534
535    // ── S26: OIDCClient userinfo response size cap ────────────────────────────
536
537    #[test]
538    fn oidc_userinfo_cap_constant_is_reasonable() {
539        const { assert!(OIDCClient::MAX_USERINFO_RESPONSE_BYTES >= 64 * 1024) }
540        const { assert!(OIDCClient::MAX_USERINFO_RESPONSE_BYTES <= 100 * 1024 * 1024) }
541    }
542
543    #[tokio::test]
544    async fn oidc_userinfo_oversized_response_is_rejected() {
545        use wiremock::{
546            Mock, MockServer, ResponseTemplate,
547            matchers::{method, path},
548        };
549
550        let mock_server = MockServer::start().await;
551        let oversized = vec![b'x'; OIDCClient::MAX_USERINFO_RESPONSE_BYTES + 1];
552        Mock::given(method("GET"))
553            .and(path("/userinfo"))
554            .respond_with(ResponseTemplate::new(200).set_body_bytes(oversized))
555            .mount(&mock_server)
556            .await;
557
558        let config = OIDCProviderConfig {
559            issuer:                   mock_server.uri(),
560            authorization_endpoint:   format!("{}/auth", mock_server.uri()),
561            token_endpoint:           format!("{}/token", mock_server.uri()),
562            userinfo_endpoint:        Some(format!("{}/userinfo", mock_server.uri())),
563            jwks_uri:                 format!("{}/.well-known/jwks.json", mock_server.uri()),
564            scopes_supported:         vec!["openid".to_string()],
565            response_types_supported: vec!["code".to_string()],
566        };
567        let client = OIDCClient::new(config, "client_id", "secret").unwrap();
568
569        let result = client.get_userinfo("dummy_token").await;
570        assert!(result.is_err(), "oversized userinfo response must be rejected, got: {result:?}");
571        let msg = result.unwrap_err();
572        assert!(msg.contains("too large"), "error must mention size limit: {msg}");
573    }
574
575    #[tokio::test]
576    async fn oidc_userinfo_within_limit_proceeds_to_parse() {
577        use wiremock::{
578            Mock, MockServer, ResponseTemplate,
579            matchers::{method, path},
580        };
581
582        let mock_server = MockServer::start().await;
583        // Valid but minimal payload — will fail at JSON parse (missing fields),
584        // proving the size gate was passed.
585        Mock::given(method("GET"))
586            .and(path("/userinfo"))
587            .respond_with(ResponseTemplate::new(200).set_body_bytes(b"{}".to_vec()))
588            .mount(&mock_server)
589            .await;
590
591        let config = OIDCProviderConfig {
592            issuer:                   mock_server.uri(),
593            authorization_endpoint:   format!("{}/auth", mock_server.uri()),
594            token_endpoint:           format!("{}/token", mock_server.uri()),
595            userinfo_endpoint:        Some(format!("{}/userinfo", mock_server.uri())),
596            jwks_uri:                 format!("{}/.well-known/jwks.json", mock_server.uri()),
597            scopes_supported:         vec!["openid".to_string()],
598            response_types_supported: vec!["code".to_string()],
599        };
600        let client = OIDCClient::new(config, "client_id", "secret").unwrap();
601
602        let result = client.get_userinfo("dummy_token").await;
603        // Must fail at JSON parse (missing required fields), not at size gate
604        assert!(
605            result.is_err(),
606            "expected Err when userinfo JSON is missing required fields, got: {result:?}"
607        );
608        let msg = result.unwrap_err();
609        assert!(
610            !msg.contains("too large"),
611            "size gate must not trigger for small payload: {msg}"
612        );
613    }
614}