1use std::{sync::Arc, time::Duration as StdDuration};
4
5const OAUTH_REQUEST_TIMEOUT: StdDuration = StdDuration::from_secs(30);
7
8use std::fmt::Write as _;
9
10use serde::{Deserialize, Serialize};
11
12use super::{
13 super::jwks::{JwksCache, JwksError},
14 pkce::PKCEChallenge,
15 types::{IdTokenClaims, TokenResponse, UserInfo},
16};
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
20pub struct OIDCProviderConfig {
21 pub issuer: String,
23 pub authorization_endpoint: String,
25 pub token_endpoint: String,
27 pub userinfo_endpoint: Option<String>,
29 pub jwks_uri: String,
31 pub scopes_supported: Vec<String>,
33 pub response_types_supported: Vec<String>,
35}
36
37impl OIDCProviderConfig {
38 pub fn new(
40 issuer: String,
41 authorization_endpoint: String,
42 token_endpoint: String,
43 jwks_uri: String,
44 ) -> Self {
45 Self {
46 issuer,
47 authorization_endpoint,
48 token_endpoint,
49 userinfo_endpoint: None,
50 jwks_uri,
51 scopes_supported: vec![
52 "openid".to_string(),
53 "profile".to_string(),
54 "email".to_string(),
55 ],
56 response_types_supported: vec!["code".to_string()],
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
68pub struct AuthorizationRequest {
69 pub url: String,
71 pub state: String,
73 pub pkce: Option<PKCEChallenge>,
75 pub nonce: Option<super::pkce::NonceParameter>,
82}
83
84#[derive(Debug, Clone)]
86pub struct OAuth2Client {
87 pub client_id: String,
89 client_secret: String,
91 pub authorization_endpoint: String,
93 token_endpoint: String,
95 pub scopes: Vec<String>,
97 pub use_pkce: bool,
99 http_client: reqwest::Client,
101}
102
103impl OAuth2Client {
104 const MAX_OAUTH_RESPONSE_BYTES: usize = 1024 * 1024;
110
111 pub fn new(
113 client_id: impl Into<String>,
114 client_secret: impl Into<String>,
115 authorization_endpoint: impl Into<String>,
116 token_endpoint: impl Into<String>,
117 ) -> Self {
118 Self {
119 client_id: client_id.into(),
120 client_secret: client_secret.into(),
121 authorization_endpoint: authorization_endpoint.into(),
122 token_endpoint: token_endpoint.into(),
123 scopes: vec![
124 "openid".to_string(),
125 "profile".to_string(),
126 "email".to_string(),
127 ],
128 use_pkce: false,
129 http_client: reqwest::Client::builder()
130 .timeout(OAUTH_REQUEST_TIMEOUT)
131 .build()
132 .unwrap_or_default(),
133 }
134 }
135
136 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
138 self.scopes = scopes;
139 self
140 }
141
142 pub const fn with_pkce(mut self, enabled: bool) -> Self {
144 self.use_pkce = enabled;
145 self
146 }
147
148 pub fn authorization_url(&self, redirect_uri: &str) -> AuthorizationRequest {
155 let state = uuid::Uuid::new_v4().to_string();
156 let scope = self.scopes.join(" ");
157
158 let mut url = format!(
159 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
160 self.authorization_endpoint,
161 urlencoding::encode(&self.client_id),
162 urlencoding::encode(redirect_uri),
163 urlencoding::encode(&scope),
164 urlencoding::encode(&state),
165 );
166
167 let pkce = if self.use_pkce {
168 let challenge = PKCEChallenge::new();
169 let _ = write!(
170 url,
171 "&code_challenge={}&code_challenge_method=S256",
172 urlencoding::encode(&challenge.code_challenge),
173 );
174 Some(challenge)
175 } else {
176 None
177 };
178
179 AuthorizationRequest {
180 url,
181 state,
182 pkce,
183 nonce: None,
184 }
185 }
186
187 async fn post_token_request(&self, params: &[(&str, &str)]) -> Result<TokenResponse, String> {
191 let response = self
192 .http_client
193 .post(&self.token_endpoint)
194 .form(params)
195 .send()
196 .await
197 .map_err(|e| format!("Token request failed: {e}"))?;
198
199 let status = response.status();
202 let body_bytes = response
203 .bytes()
204 .await
205 .map_err(|e| format!("Failed to read token response body: {e}"))?;
206
207 if !status.is_success() {
208 let capped = &body_bytes[..body_bytes.len().min(Self::MAX_OAUTH_RESPONSE_BYTES)];
209 let body = String::from_utf8_lossy(capped);
210 return Err(format!("Token endpoint returned error: {body}"));
211 }
212
213 if body_bytes.len() > Self::MAX_OAUTH_RESPONSE_BYTES {
214 return Err(format!(
215 "Token response body too large ({} bytes, max {})",
216 body_bytes.len(),
217 Self::MAX_OAUTH_RESPONSE_BYTES
218 ));
219 }
220
221 serde_json::from_slice::<TokenResponse>(&body_bytes)
222 .map_err(|e| format!("Failed to parse token response: {e}"))
223 }
224
225 pub async fn exchange_code(
232 &self,
233 code: &str,
234 redirect_uri: &str,
235 ) -> Result<TokenResponse, String> {
236 let params = [
237 ("grant_type", "authorization_code"),
238 ("code", code),
239 ("client_id", self.client_id.as_str()),
240 ("client_secret", self.client_secret.as_str()),
241 ("redirect_uri", redirect_uri),
242 ];
243 self.post_token_request(¶ms).await
244 }
245
246 pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, String> {
253 let params = [
254 ("grant_type", "refresh_token"),
255 ("refresh_token", refresh_token),
256 ("client_id", self.client_id.as_str()),
257 ("client_secret", self.client_secret.as_str()),
258 ];
259 self.post_token_request(¶ms).await
260 }
261}
262
263#[derive(Debug)]
265pub struct OIDCClient {
266 pub config: OIDCProviderConfig,
268 pub client_id: String,
270 #[allow(dead_code)] client_secret: String,
273 pub jwks_cache: Arc<JwksCache>,
275 http_client: reqwest::Client,
277}
278
279impl OIDCClient {
280 const MAX_USERINFO_RESPONSE_BYTES: usize = 1024 * 1024;
285
286 pub fn new(
297 config: OIDCProviderConfig,
298 client_id: impl Into<String>,
299 client_secret: impl Into<String>,
300 ) -> Result<Self, JwksError> {
301 let jwks_cache = Arc::new(JwksCache::new(&config.jwks_uri, StdDuration::from_secs(3600))?);
302 Ok(Self {
303 config,
304 client_id: client_id.into(),
305 client_secret: client_secret.into(),
306 jwks_cache,
307 http_client: reqwest::Client::builder()
308 .timeout(OAUTH_REQUEST_TIMEOUT)
309 .build()
310 .unwrap_or_default(),
311 })
312 }
313
314 pub fn with_jwks_cache(
316 config: OIDCProviderConfig,
317 client_id: impl Into<String>,
318 client_secret: impl Into<String>,
319 jwks_cache: Arc<JwksCache>,
320 ) -> Self {
321 Self {
322 config,
323 client_id: client_id.into(),
324 client_secret: client_secret.into(),
325 jwks_cache,
326 http_client: reqwest::Client::builder()
327 .timeout(OAUTH_REQUEST_TIMEOUT)
328 .build()
329 .unwrap_or_default(),
330 }
331 }
332
333 pub fn authorization_url(&self, redirect_uri: &str) -> AuthorizationRequest {
342 let state = uuid::Uuid::new_v4().to_string();
343 let scope = self.config.scopes_supported.join(" ");
344 let nonce = super::pkce::NonceParameter::new();
345 let challenge = PKCEChallenge::new();
346
347 let url = format!(
348 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}\
349 &nonce={}&code_challenge={}&code_challenge_method=S256",
350 self.config.authorization_endpoint,
351 urlencoding::encode(&self.client_id),
352 urlencoding::encode(redirect_uri),
353 urlencoding::encode(&scope),
354 urlencoding::encode(&state),
355 urlencoding::encode(&nonce.nonce),
356 urlencoding::encode(&challenge.code_challenge),
357 );
358
359 AuthorizationRequest {
360 url,
361 state,
362 pkce: Some(challenge),
363 nonce: Some(nonce),
364 }
365 }
366
367 pub async fn verify_id_token(
390 &self,
391 id_token: &str,
392 expected_nonce: Option<&str>,
393 max_age_secs: Option<u64>,
394 ) -> Result<IdTokenClaims, String> {
395 let header = jsonwebtoken::decode_header(id_token)
397 .map_err(|e| format!("Invalid JWT header: {e}"))?;
398 let kid = header.kid.ok_or("JWT missing 'kid' in header")?;
399
400 let key = self
402 .jwks_cache
403 .get_key(&kid)
404 .await
405 .map_err(|e| format!("JWKS fetch error: {e}"))?
406 .ok_or_else(|| format!("No key found for kid '{kid}'"))?;
407
408 let mut validation = jsonwebtoken::Validation::new(header.alg);
410 validation.set_issuer(&[&self.config.issuer]);
411 validation.set_audience(&[&self.client_id]);
412 validation.set_required_spec_claims(&["exp", "iat", "iss", "aud", "sub"]);
413
414 let token_data = jsonwebtoken::decode::<IdTokenClaims>(id_token, &key, &validation)
416 .map_err(|e| format!("ID token validation failed: {e}"))?;
417 let claims = token_data.claims;
418
419 if let Some(expected) = expected_nonce {
422 super::claims_validator::validate_nonce_claim(&claims, expected)
423 .map_err(|e| e.to_string())?;
424 }
425
426 if let Some(max_age) = max_age_secs {
428 let now_secs = std::time::SystemTime::now()
429 .duration_since(std::time::UNIX_EPOCH)
430 .map_or(i64::MAX, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
431 super::claims_validator::validate_auth_time_claim(&claims, max_age, now_secs)
432 .map_err(|e| e.to_string())?;
433 }
434
435 Ok(claims)
436 }
437
438 pub async fn get_userinfo(&self, access_token: &str) -> Result<UserInfo, String> {
445 let endpoint = self
446 .config
447 .userinfo_endpoint
448 .as_ref()
449 .ok_or("No userinfo endpoint configured for this provider")?;
450
451 let response = self
452 .http_client
453 .get(endpoint)
454 .bearer_auth(access_token)
455 .send()
456 .await
457 .map_err(|e| format!("Userinfo request failed: {e}"))?;
458
459 if !response.status().is_success() {
460 return Err(format!("Userinfo endpoint returned {}", response.status()));
461 }
462
463 let body = response
464 .bytes()
465 .await
466 .map_err(|e| format!("Failed to read userinfo response: {e}"))?;
467 if body.len() > Self::MAX_USERINFO_RESPONSE_BYTES {
468 return Err(format!(
469 "Userinfo response too large ({} bytes, max {})",
470 body.len(),
471 Self::MAX_USERINFO_RESPONSE_BYTES
472 ));
473 }
474 serde_json::from_slice::<UserInfo>(&body)
475 .map_err(|e| format!("Failed to parse userinfo response: {e}"))
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 #![allow(clippy::unwrap_used)] #![allow(missing_docs)] use super::*;
485
486 #[test]
487 fn oauth_response_cap_constant_is_reasonable() {
488 assert_eq!(OAuth2Client::MAX_OAUTH_RESPONSE_BYTES, 1024 * 1024);
489 }
490
491 #[test]
492 fn oauth_response_error_body_is_capped() {
493 let cap = OAuth2Client::MAX_OAUTH_RESPONSE_BYTES;
495 let oversized: Vec<u8> = vec![b'e'; cap + 1_000];
496 let capped = &oversized[..oversized.len().min(cap)];
497 let text = String::from_utf8_lossy(capped).into_owned();
498 assert_eq!(text.len(), cap, "body must be capped at MAX_OAUTH_RESPONSE_BYTES");
499 }
500
501 #[test]
504 fn oauth_request_timeout_is_set() {
505 let secs = OAUTH_REQUEST_TIMEOUT.as_secs();
506 assert!(secs > 0 && secs <= 120, "OAuth timeout should be 1–120 s, got {secs}");
507 }
508
509 #[test]
510 fn oauth2_client_new_creates_instance() {
511 let client = OAuth2Client::new(
512 "client_id",
513 "client_secret",
514 "https://example.com/auth",
515 "https://example.com/token",
516 );
517 assert_eq!(client.client_id, "client_id");
518 }
519
520 #[test]
521 fn oidc_client_new_creates_instance() {
522 let config = OIDCProviderConfig {
523 issuer: "https://example.com".to_string(),
524 authorization_endpoint: "https://example.com/auth".to_string(),
525 token_endpoint: "https://example.com/token".to_string(),
526 userinfo_endpoint: None,
527 jwks_uri: "https://example.com/.well-known/jwks.json".to_string(),
528 scopes_supported: vec!["openid".to_string()],
529 response_types_supported: vec!["code".to_string()],
530 };
531 let client = OIDCClient::new(config, "client_id", "client_secret").unwrap();
532 assert_eq!(client.client_id, "client_id");
533 }
534
535 #[test]
538 fn oidc_userinfo_cap_constant_is_reasonable() {
539 const { assert!(OIDCClient::MAX_USERINFO_RESPONSE_BYTES >= 64 * 1024) }
540 const { assert!(OIDCClient::MAX_USERINFO_RESPONSE_BYTES <= 100 * 1024 * 1024) }
541 }
542
543 #[tokio::test]
544 async fn oidc_userinfo_oversized_response_is_rejected() {
545 use wiremock::{
546 Mock, MockServer, ResponseTemplate,
547 matchers::{method, path},
548 };
549
550 let mock_server = MockServer::start().await;
551 let oversized = vec![b'x'; OIDCClient::MAX_USERINFO_RESPONSE_BYTES + 1];
552 Mock::given(method("GET"))
553 .and(path("/userinfo"))
554 .respond_with(ResponseTemplate::new(200).set_body_bytes(oversized))
555 .mount(&mock_server)
556 .await;
557
558 let config = OIDCProviderConfig {
559 issuer: mock_server.uri(),
560 authorization_endpoint: format!("{}/auth", mock_server.uri()),
561 token_endpoint: format!("{}/token", mock_server.uri()),
562 userinfo_endpoint: Some(format!("{}/userinfo", mock_server.uri())),
563 jwks_uri: format!("{}/.well-known/jwks.json", mock_server.uri()),
564 scopes_supported: vec!["openid".to_string()],
565 response_types_supported: vec!["code".to_string()],
566 };
567 let client = OIDCClient::new(config, "client_id", "secret").unwrap();
568
569 let result = client.get_userinfo("dummy_token").await;
570 assert!(result.is_err(), "oversized userinfo response must be rejected, got: {result:?}");
571 let msg = result.unwrap_err();
572 assert!(msg.contains("too large"), "error must mention size limit: {msg}");
573 }
574
575 #[tokio::test]
576 async fn oidc_userinfo_within_limit_proceeds_to_parse() {
577 use wiremock::{
578 Mock, MockServer, ResponseTemplate,
579 matchers::{method, path},
580 };
581
582 let mock_server = MockServer::start().await;
583 Mock::given(method("GET"))
586 .and(path("/userinfo"))
587 .respond_with(ResponseTemplate::new(200).set_body_bytes(b"{}".to_vec()))
588 .mount(&mock_server)
589 .await;
590
591 let config = OIDCProviderConfig {
592 issuer: mock_server.uri(),
593 authorization_endpoint: format!("{}/auth", mock_server.uri()),
594 token_endpoint: format!("{}/token", mock_server.uri()),
595 userinfo_endpoint: Some(format!("{}/userinfo", mock_server.uri())),
596 jwks_uri: format!("{}/.well-known/jwks.json", mock_server.uri()),
597 scopes_supported: vec!["openid".to_string()],
598 response_types_supported: vec!["code".to_string()],
599 };
600 let client = OIDCClient::new(config, "client_id", "secret").unwrap();
601
602 let result = client.get_userinfo("dummy_token").await;
603 assert!(
605 result.is_err(),
606 "expected Err when userinfo JSON is missing required fields, got: {result:?}"
607 );
608 let msg = result.unwrap_err();
609 assert!(
610 !msg.contains("too large"),
611 "size gate must not trigger for small payload: {msg}"
612 );
613 }
614}