1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
4use chrono::{Duration, Utc};
5use rand::Rng;
6use sha2::{Digest, Sha256};
7use tracing::instrument;
8use uuid::Uuid;
9
10use authx_core::{
11 KeyRotationStore,
12 crypto::sha256_hex,
13 error::{AuthError, Result},
14 models::{CreateAuthorizationCode, CreateDeviceCode, CreateOidcToken, OidcTokenType},
15};
16use authx_storage::ports::{
17 AuthorizationCodeRepository, DeviceCodeRepository, OidcClientRepository, OidcTokenRepository,
18 UserRepository,
19};
20
21#[derive(Clone)]
23pub struct OidcProviderConfig {
24 pub issuer: String,
25 pub key_store: KeyRotationStore,
26 pub access_token_ttl_secs: i64,
27 pub id_token_ttl_secs: i64,
28 pub refresh_token_ttl_secs: i64,
29 pub auth_code_ttl_secs: i64,
30 pub device_code_ttl_secs: i64,
32 pub device_code_interval_secs: u32,
34 pub verification_uri: String,
36}
37
38#[derive(Debug, Clone, serde::Serialize)]
40#[serde(rename_all = "snake_case")]
41pub struct OidcTokenResponse {
42 pub access_token: String,
43 pub token_type: String,
44 pub expires_in: i64,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub refresh_token: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub scope: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub id_token: Option<String>,
51}
52
53#[derive(Debug, Clone, serde::Serialize)]
55pub struct DeviceAuthorizationResponse {
56 pub device_code: String,
57 pub user_code: String,
58 pub verification_uri: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub verification_uri_complete: Option<String>,
61 pub expires_in: i64,
62 pub interval: u32,
63}
64
65#[derive(Debug, Clone)]
67pub enum DeviceCodeError {
68 AuthorizationPending,
69 SlowDown,
70 ExpiredToken,
71 AccessDenied,
72}
73
74#[derive(Debug, Clone, serde::Serialize)]
76pub struct IntrospectionResponse {
77 pub active: bool,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub scope: Option<String>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 pub client_id: Option<String>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub username: Option<String>,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub token_type: Option<String>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub exp: Option<i64>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub iat: Option<i64>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub sub: Option<String>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub iss: Option<String>,
94}
95
96impl IntrospectionResponse {
97 pub fn inactive() -> Self {
98 Self {
99 active: false,
100 scope: None,
101 client_id: None,
102 username: None,
103 token_type: None,
104 exp: None,
105 iat: None,
106 sub: None,
107 iss: None,
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy)]
114pub struct CreateAuthorizationCodeRequest<'a> {
115 pub user_id: Uuid,
116 pub client_id: &'a str,
117 pub redirect_uri: &'a str,
118 pub scope: &'a str,
119 pub state: Option<&'a str>,
120 pub nonce: Option<&'a str>,
121 pub code_challenge: Option<&'a str>,
122}
123
124pub struct OidcProviderService<S> {
126 storage: S,
127 config: OidcProviderConfig,
128}
129
130impl<S> OidcProviderService<S>
131where
132 S: OidcClientRepository
133 + AuthorizationCodeRepository
134 + OidcTokenRepository
135 + DeviceCodeRepository
136 + UserRepository
137 + Clone
138 + Send
139 + Sync
140 + 'static,
141{
142 pub fn new(storage: S, config: OidcProviderConfig) -> Self {
143 Self { storage, config }
144 }
145
146 #[instrument(skip(self))]
148 pub async fn create_authorization_code(
149 &self,
150 request: CreateAuthorizationCodeRequest<'_>,
151 ) -> Result<(String, String)> {
152 let CreateAuthorizationCodeRequest {
153 user_id,
154 client_id,
155 redirect_uri,
156 scope,
157 state,
158 nonce,
159 code_challenge,
160 } = request;
161
162 let client = OidcClientRepository::find_by_client_id(&self.storage, client_id)
163 .await?
164 .ok_or(AuthError::Internal("invalid client_id".into()))?;
165
166 if !client.redirect_uris.iter().any(|u| u == redirect_uri) {
167 return Err(AuthError::Internal("redirect_uri not allowed".into()));
168 }
169 if !client.response_types.contains(&"code".to_string()) {
170 return Err(AuthError::Internal("response_type code not allowed".into()));
171 }
172
173 let allowed: std::collections::HashSet<_> =
174 client.allowed_scopes.split_whitespace().collect();
175 for s in scope.split_whitespace() {
176 if s != "openid" && !allowed.contains(s) {
177 return Err(AuthError::Internal(format!("scope {s} not allowed")));
178 }
179 }
180
181 let raw_code: [u8; 32] = rand::thread_rng().r#gen();
183 let code = URL_SAFE_NO_PAD.encode(raw_code);
184 let code_hash = sha256_hex(code.as_bytes());
185
186 let _auth_code = AuthorizationCodeRepository::create(
187 &self.storage,
188 CreateAuthorizationCode {
189 code_hash: code_hash.clone(),
190 client_id: client_id.to_string(),
191 user_id,
192 redirect_uri: redirect_uri.to_string(),
193 scope: scope.to_string(),
194 nonce: nonce.map(str::to_string),
195 code_challenge: code_challenge.map(str::to_string),
196 expires_at: Utc::now() + Duration::seconds(self.config.auth_code_ttl_secs),
197 },
198 )
199 .await?;
200
201 let redirect = if let Some(st) = state {
202 format!("{redirect_uri}?code={code}&state={st}")
203 } else {
204 format!("{redirect_uri}?code={code}")
205 };
206 Ok((code, redirect))
207 }
208
209 #[instrument(skip(self, client_secret))]
211 pub async fn exchange_code(
212 &self,
213 code: &str,
214 client_id: &str,
215 client_secret: Option<&str>,
216 redirect_uri: &str,
217 code_verifier: Option<&str>,
218 ) -> Result<OidcTokenResponse> {
219 let code_hash = sha256_hex(code.as_bytes());
220 let auth_code = AuthorizationCodeRepository::find_by_code_hash(&self.storage, &code_hash)
221 .await?
222 .ok_or(AuthError::InvalidToken)?;
223
224 if auth_code.client_id != client_id {
225 return Err(AuthError::InvalidToken);
226 }
227 if auth_code.redirect_uri != redirect_uri {
228 return Err(AuthError::InvalidToken);
229 }
230
231 let client = OidcClientRepository::find_by_client_id(&self.storage, client_id)
232 .await?
233 .ok_or(AuthError::InvalidToken)?;
234
235 if !client.secret_hash.is_empty() {
236 let secret = client_secret.ok_or(AuthError::InvalidToken)?;
237 let hash = sha256_hex(secret.as_bytes());
238 use subtle::ConstantTimeEq;
239 if hash
240 .as_bytes()
241 .ct_eq(client.secret_hash.as_bytes())
242 .unwrap_u8()
243 == 0
244 {
245 return Err(AuthError::InvalidToken);
246 }
247 } else if let Some(challenge) = &auth_code.code_challenge {
248 let verifier = code_verifier.ok_or(AuthError::InvalidToken)?;
249 let mut hasher = Sha256::new();
250 hasher.update(verifier.as_bytes());
251 let computed = URL_SAFE_NO_PAD.encode(hasher.finalize());
252 if computed != *challenge {
253 return Err(AuthError::InvalidToken);
254 }
255 }
256
257 AuthorizationCodeRepository::mark_used(&self.storage, auth_code.id).await?;
258
259 self.issue_tokens(
260 auth_code.user_id,
261 client_id,
262 &auth_code.scope,
263 auth_code.nonce.as_deref(),
264 )
265 .await
266 }
267
268 #[instrument(skip(self, client_secret))]
270 pub async fn refresh(
271 &self,
272 refresh_token: &str,
273 client_id: &str,
274 client_secret: Option<&str>,
275 scope: Option<&str>,
276 ) -> Result<OidcTokenResponse> {
277 let token_hash = sha256_hex(refresh_token.as_bytes());
278 let token = OidcTokenRepository::find_by_token_hash(&self.storage, &token_hash)
279 .await?
280 .ok_or(AuthError::InvalidToken)?;
281
282 if token.client_id != client_id || token.token_type != OidcTokenType::Refresh {
283 return Err(AuthError::InvalidToken);
284 }
285
286 let client = OidcClientRepository::find_by_client_id(&self.storage, client_id)
287 .await?
288 .ok_or(AuthError::InvalidToken)?;
289
290 if !client.secret_hash.is_empty() {
291 let secret = client_secret.ok_or(AuthError::InvalidToken)?;
292 let hash = sha256_hex(secret.as_bytes());
293 use subtle::ConstantTimeEq;
294 if hash
295 .as_bytes()
296 .ct_eq(client.secret_hash.as_bytes())
297 .unwrap_u8()
298 == 0
299 {
300 return Err(AuthError::InvalidToken);
301 }
302 }
303
304 OidcTokenRepository::revoke(&self.storage, token.id).await?;
305
306 let token_scope = scope.unwrap_or(&token.scope);
307 self.issue_tokens(token.user_id, client_id, token_scope, None)
308 .await
309 }
310
311 async fn issue_tokens(
312 &self,
313 user_id: Uuid,
314 client_id: &str,
315 scope: &str,
316 nonce: Option<&str>,
317 ) -> Result<OidcTokenResponse> {
318 let user = UserRepository::find_by_id(&self.storage, user_id)
319 .await?
320 .ok_or(AuthError::UserNotFound)?;
321
322 let now = Utc::now();
323 let access_ttl = self.config.access_token_ttl_secs;
324 let id_ttl = self.config.id_token_ttl_secs.min(access_ttl);
325
326 let access_extra = serde_json::json!({
327 "iss": self.config.issuer,
328 "aud": client_id,
329 "scope": scope
330 });
331 let access_token = self
332 .config
333 .key_store
334 .sign(user_id, access_ttl, access_extra)?;
335
336 let id_token = if scope.split_whitespace().any(|s| s == "openid") {
337 let mut id_extra = serde_json::json!({
338 "iss": self.config.issuer,
339 "aud": client_id
340 });
341 if let Some(n) = nonce {
342 id_extra["nonce"] = serde_json::Value::String(n.to_string());
343 }
344 if scope.contains("email") {
345 id_extra["email"] = serde_json::Value::String(user.email.clone());
346 id_extra["email_verified"] = serde_json::Value::Bool(user.email_verified);
347 }
348 if scope.contains("profile") {
349 id_extra["name"] = serde_json::Value::String(user.email.clone());
350 if let Some(ref u) = user.username {
351 id_extra["preferred_username"] = serde_json::Value::String(u.clone());
352 }
353 }
354 Some(self.config.key_store.sign(user_id, id_ttl, id_extra)?)
355 } else {
356 None
357 };
358
359 let refresh_token = if scope.split_whitespace().any(|s| s == "offline_access")
360 || !scope.is_empty()
361 {
362 let raw: [u8; 32] = rand::thread_rng().r#gen();
363 let token = hex::encode(raw);
364 let token_hash = sha256_hex(token.as_bytes());
365
366 OidcTokenRepository::create(
367 &self.storage,
368 CreateOidcToken {
369 token_hash,
370 client_id: client_id.to_string(),
371 user_id,
372 scope: scope.to_string(),
373 token_type: OidcTokenType::Refresh,
374 expires_at: Some(now + Duration::seconds(self.config.refresh_token_ttl_secs)),
375 },
376 )
377 .await?;
378 Some(token)
379 } else {
380 None
381 };
382
383 Ok(OidcTokenResponse {
384 access_token,
385 token_type: "Bearer".into(),
386 expires_in: access_ttl,
387 refresh_token,
388 scope: Some(scope.to_string()),
389 id_token,
390 })
391 }
392
393 pub fn validate_access_token(&self, token: &str) -> Result<Uuid> {
395 let claims = self.config.key_store.verify(token)?;
396 Uuid::parse_str(&claims.sub).map_err(|_| AuthError::InvalidToken)
397 }
398
399 pub async fn userinfo(&self, access_token: &str) -> Result<serde_json::Value> {
401 let user_id = self.validate_access_token(access_token)?;
402 let user = UserRepository::find_by_id(&self.storage, user_id)
403 .await?
404 .ok_or(AuthError::UserNotFound)?;
405
406 let mut claims = serde_json::json!({
407 "sub": user.id.to_string(),
408 "email": user.email,
409 "email_verified": user.email_verified,
410 });
411 if let Some(ref u) = user.username {
412 claims["preferred_username"] = serde_json::Value::String(u.clone());
413 }
414 Ok(claims)
415 }
416
417 #[instrument(skip(self, token, client_secret))]
422 pub async fn revoke_token(
423 &self,
424 token: &str,
425 token_type_hint: Option<&str>,
426 client_id: &str,
427 client_secret: Option<&str>,
428 ) -> Result<()> {
429 self.authenticate_client(client_id, client_secret).await?;
430
431 let try_refresh = token_type_hint.is_none() || token_type_hint == Some("refresh_token");
433 let try_access = token_type_hint.is_none() || token_type_hint == Some("access_token");
434
435 if try_refresh {
436 let token_hash = sha256_hex(token.as_bytes());
437 if let Ok(Some(oidc_token)) =
438 OidcTokenRepository::find_by_token_hash(&self.storage, &token_hash).await
439 {
440 if oidc_token.client_id == client_id {
441 let _ = OidcTokenRepository::revoke(&self.storage, oidc_token.id).await;
442 }
443 return Ok(());
444 }
445 }
446
447 if try_access {
448 if let Ok(claims) = self.config.key_store.verify(token)
451 && let Ok(user_id) = Uuid::parse_str(&claims.sub)
452 {
453 let _ = OidcTokenRepository::revoke_all_for_user_client(
454 &self.storage,
455 user_id,
456 client_id,
457 )
458 .await;
459 }
460 }
461
462 Ok(())
464 }
465
466 #[instrument(skip(self, token, client_secret))]
471 pub async fn introspect_token(
472 &self,
473 token: &str,
474 token_type_hint: Option<&str>,
475 client_id: &str,
476 client_secret: Option<&str>,
477 ) -> Result<IntrospectionResponse> {
478 self.authenticate_client(client_id, client_secret).await?;
479
480 let try_refresh = token_type_hint.is_none() || token_type_hint == Some("refresh_token");
481 let try_access = token_type_hint.is_none() || token_type_hint == Some("access_token");
482
483 if try_refresh {
485 let token_hash = sha256_hex(token.as_bytes());
486 if let Ok(Some(oidc_token)) =
487 OidcTokenRepository::find_by_token_hash(&self.storage, &token_hash).await
488 && oidc_token.client_id == client_id
489 && !oidc_token.revoked
490 {
491 let expired = oidc_token
492 .expires_at
493 .map(|exp| exp < Utc::now())
494 .unwrap_or(false);
495 if !expired {
496 return Ok(IntrospectionResponse {
497 active: true,
498 scope: Some(oidc_token.scope),
499 client_id: Some(oidc_token.client_id),
500 username: None,
501 token_type: Some("refresh_token".into()),
502 exp: oidc_token.expires_at.map(|t| t.timestamp()),
503 iat: Some(oidc_token.created_at.timestamp()),
504 sub: Some(oidc_token.user_id.to_string()),
505 iss: Some(self.config.issuer.clone()),
506 });
507 }
508 }
509 }
510
511 if try_access && let Ok(claims) = self.config.key_store.verify(token) {
513 let extra = claims.extra;
514 return Ok(IntrospectionResponse {
515 active: true,
516 scope: extra
517 .get("scope")
518 .and_then(|v| v.as_str())
519 .map(String::from),
520 client_id: extra.get("aud").and_then(|v| v.as_str()).map(String::from),
521 username: None,
522 token_type: Some("access_token".into()),
523 exp: Some(claims.exp),
524 iat: Some(claims.iat),
525 sub: Some(claims.sub),
526 iss: extra.get("iss").and_then(|v| v.as_str()).map(String::from),
527 });
528 }
529
530 Ok(IntrospectionResponse::inactive())
531 }
532
533 async fn authenticate_client(
535 &self,
536 client_id: &str,
537 client_secret: Option<&str>,
538 ) -> Result<()> {
539 let client = OidcClientRepository::find_by_client_id(&self.storage, client_id)
540 .await?
541 .ok_or(AuthError::InvalidToken)?;
542
543 if !client.secret_hash.is_empty() {
544 let secret = client_secret.ok_or(AuthError::InvalidToken)?;
545 let hash = sha256_hex(secret.as_bytes());
546 use subtle::ConstantTimeEq;
547 if hash
548 .as_bytes()
549 .ct_eq(client.secret_hash.as_bytes())
550 .unwrap_u8()
551 == 0
552 {
553 return Err(AuthError::InvalidToken);
554 }
555 }
556 Ok(())
557 }
558
559 #[instrument(skip(self))]
563 pub async fn request_device_authorization(
564 &self,
565 client_id: &str,
566 scope: &str,
567 ) -> Result<DeviceAuthorizationResponse> {
568 let client = OidcClientRepository::find_by_client_id(&self.storage, client_id)
570 .await?
571 .ok_or(AuthError::Internal("invalid client_id".into()))?;
572
573 let allowed: std::collections::HashSet<_> =
575 client.allowed_scopes.split_whitespace().collect();
576 for s in scope.split_whitespace() {
577 if s != "openid" && !allowed.contains(s) {
578 return Err(AuthError::Internal(format!("scope {s} not allowed")));
579 }
580 }
581
582 let raw_device_code: [u8; 32] = rand::thread_rng().r#gen();
584 let device_code = URL_SAFE_NO_PAD.encode(raw_device_code);
585 let device_code_hash = sha256_hex(device_code.as_bytes());
586
587 let user_code = generate_user_code();
589 let user_code_hash = sha256_hex(user_code.replace('-', "").as_bytes());
590
591 let expires_at = Utc::now() + Duration::seconds(self.config.device_code_ttl_secs);
592
593 DeviceCodeRepository::create(
594 &self.storage,
595 CreateDeviceCode {
596 device_code_hash,
597 user_code_hash,
598 user_code: user_code.clone(),
599 client_id: client_id.to_string(),
600 scope: scope.to_string(),
601 expires_at,
602 interval_secs: self.config.device_code_interval_secs,
603 },
604 )
605 .await?;
606
607 let verification_uri_complete = Some(format!(
608 "{}?user_code={}",
609 self.config.verification_uri, user_code
610 ));
611
612 Ok(DeviceAuthorizationResponse {
613 device_code,
614 user_code,
615 verification_uri: self.config.verification_uri.clone(),
616 verification_uri_complete,
617 expires_in: self.config.device_code_ttl_secs,
618 interval: self.config.device_code_interval_secs,
619 })
620 }
621
622 #[instrument(skip(self))]
624 pub async fn verify_user_code(
625 &self,
626 user_code: &str,
627 user_id: Uuid,
628 approve: bool,
629 ) -> Result<()> {
630 let normalized = user_code.replace('-', "").to_uppercase();
631 let user_code_hash = sha256_hex(normalized.as_bytes());
632
633 let dc = DeviceCodeRepository::find_by_user_code_hash(&self.storage, &user_code_hash)
634 .await?
635 .ok_or(AuthError::Internal("invalid or expired user_code".into()))?;
636
637 if approve {
638 DeviceCodeRepository::authorize(&self.storage, dc.id, user_id).await?;
639 } else {
640 DeviceCodeRepository::deny(&self.storage, dc.id).await?;
641 }
642
643 Ok(())
644 }
645
646 #[instrument(skip(self))]
648 pub async fn poll_device_code(
649 &self,
650 device_code: &str,
651 client_id: &str,
652 ) -> std::result::Result<OidcTokenResponse, DeviceCodeError> {
653 const MAX_INTERVAL_SECS: u32 = 3600;
654
655 let device_code_hash = sha256_hex(device_code.as_bytes());
656
657 let dc = DeviceCodeRepository::find_by_device_code_hash(&self.storage, &device_code_hash)
658 .await
659 .map_err(|_| DeviceCodeError::ExpiredToken)?
660 .ok_or(DeviceCodeError::ExpiredToken)?;
661
662 if dc.client_id != client_id {
663 return Err(DeviceCodeError::ExpiredToken);
664 }
665
666 if let Some(last) = dc.last_polled_at {
668 let elapsed = (Utc::now() - last).num_seconds();
669 if elapsed < dc.interval_secs as i64 {
670 let new_interval = (dc.interval_secs + 5).min(MAX_INTERVAL_SECS);
671 DeviceCodeRepository::update_last_polled(&self.storage, dc.id, new_interval)
672 .await
673 .map_err(|_| DeviceCodeError::ExpiredToken)?;
674 return Err(DeviceCodeError::SlowDown);
675 }
676 }
677
678 DeviceCodeRepository::update_last_polled(&self.storage, dc.id, dc.interval_secs)
680 .await
681 .map_err(|_| DeviceCodeError::ExpiredToken)?;
682
683 if dc.denied {
685 return Err(DeviceCodeError::AccessDenied);
686 }
687
688 if !dc.authorized {
690 return Err(DeviceCodeError::AuthorizationPending);
691 }
692
693 let user_id = dc.user_id.ok_or(DeviceCodeError::AccessDenied)?;
695 let tokens = self
696 .issue_tokens(user_id, client_id, &dc.scope, None)
697 .await
698 .map_err(|_| DeviceCodeError::AccessDenied)?;
699
700 if let Err(e) = DeviceCodeRepository::delete(&self.storage, dc.id).await {
702 tracing::warn!(error = %e, "failed to delete device code after token exchange");
703 }
704
705 Ok(tokens)
706 }
707}
708
709fn generate_user_code() -> String {
712 const CHARSET: &[u8] = b"ABCDEFGHJKMNPQRSTUVWXYZ23456789";
713 let mut rng = rand::thread_rng();
714 let code: String = (0..8)
715 .map(|_| {
716 let idx = rng.gen_range(0..CHARSET.len());
717 CHARSET[idx] as char
718 })
719 .collect();
720 format!("{}-{}", &code[..4], &code[4..])
721}