1use std::{
2 collections::{HashMap, HashSet},
3 net::SocketAddr,
4 sync::Arc,
5};
6
7use anyhow::{Context, Result, anyhow};
8use axum::{
9 Json, Router,
10 extract::{Form, Query, State},
11 http::{HeaderMap, StatusCode, header},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14};
15use base64::{
16 Engine as _,
17 engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD},
18};
19use once_cell::sync::Lazy;
20use rand::{
21 distr::{Alphanumeric, SampleString},
22 rng,
23};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use thiserror::Error;
27use time::{Duration as TimeDuration, OffsetDateTime};
28use tokio::{
29 net::TcpListener,
30 sync::{RwLock, oneshot},
31};
32use url::Url;
33use uuid::Uuid;
34
35static DEFAULT_SCOPE: Lazy<HashSet<String>> = Lazy::new(|| {
36 ["openid", "profile", "email"]
37 .into_iter()
38 .map(|s| s.to_string())
39 .collect()
40});
41
42#[derive(Debug, Clone, Serialize)]
43struct Jwk {
44 kty: String,
45 use_: String,
46 kid: String,
47 alg: String,
48 n: String,
49 e: String,
50}
51
52#[derive(Clone)]
53struct SigningKeys {
54 encoding: jsonwebtoken::EncodingKey,
55 jwk: Jwk,
56}
57
58impl std::fmt::Debug for SigningKeys {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("SigningKeys")
61 .field("jwk", &self.jwk)
62 .finish_non_exhaustive()
63 }
64}
65
66#[derive(Debug, Clone)]
67struct ClientConfig {
68 client_id: String,
69 client_secret: String,
70 redirect_uris: HashSet<String>,
71 allowed_scopes: HashSet<String>,
72}
73
74#[derive(Debug, Clone)]
75pub struct MockUser {
76 sub: String,
77 email: String,
78 preferred_username: String,
79 groups: Vec<String>,
80}
81
82impl Default for MockUser {
83 fn default() -> Self {
84 Self {
85 sub: "user-123".to_string(),
86 email: "mock.user@example.com".to_string(),
87 preferred_username: "mock.user".to_string(),
88 groups: vec!["mockers".into(), "testers".into()],
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94struct AuthorizationCode {
95 client_id: String,
96 redirect_uri: String,
97 scope: HashSet<String>,
98 code_challenge: Option<String>,
99 nonce: Option<String>,
100 _created_at: OffsetDateTime,
101}
102
103#[derive(Debug, Clone)]
104struct RefreshTokenEntry {
105 client_id: String,
106 scope: HashSet<String>,
107 _subject: MockUser,
108 _issued_at: OffsetDateTime,
109}
110
111#[derive(Debug, Clone)]
112enum DeviceCodeStatus {
113 Pending { poll_count: u32 },
114 Approved,
115 Denied,
116 Expired,
117 Completed,
118}
119
120#[derive(Debug, Clone)]
121struct DeviceCodeEntry {
122 client_id: String,
123 scope: HashSet<String>,
124 _device_code: String,
125 user_code: String,
126 expires_at: OffsetDateTime,
127 _interval: u64,
128 status: DeviceCodeStatus,
129}
130
131#[derive(Debug)]
132struct InnerState {
133 issuer: String,
134 signing: SigningKeys,
135 clients: HashMap<String, ClientConfig>,
136 user: MockUser,
137 authorization_codes: HashMap<String, AuthorizationCode>,
138 refresh_tokens: HashMap<String, RefreshTokenEntry>,
139 access_tokens: HashSet<String>,
140 device_codes: HashMap<String, DeviceCodeEntry>,
141}
142
143impl InnerState {
144 fn generate_code(&self) -> String {
145 let mut rng = rng();
146 Alphanumeric.sample_string(&mut rng, 32)
147 }
148
149 fn client(&self, client_id: &str) -> Option<&ClientConfig> {
150 self.clients.get(client_id)
151 }
152}
153
154type SharedState = Arc<RwLock<InnerState>>;
155
156#[derive(Debug, Clone)]
158pub struct MockServerBuilder {
159 clients: HashMap<String, ClientConfig>,
160 user: MockUser,
161 issuer_suffix: Option<String>,
162}
163
164impl Default for MockServerBuilder {
165 fn default() -> Self {
166 let mut clients = HashMap::new();
167 clients.insert(
168 "mock-client".into(),
169 ClientConfig {
170 client_id: "mock-client".into(),
171 client_secret: "mock-secret".into(),
172 redirect_uris: ["https://example.com/callback".into()]
173 .into_iter()
174 .collect(),
175 allowed_scopes: DEFAULT_SCOPE.clone(),
176 },
177 );
178 Self {
179 clients,
180 user: MockUser::default(),
181 issuer_suffix: None,
182 }
183 }
184}
185
186impl MockServerBuilder {
187 pub fn with_user(mut self, user: MockUser) -> Self {
189 self.user = user;
190 self
191 }
192
193 pub fn with_client(
195 mut self,
196 client_id: impl Into<String>,
197 client_secret: impl Into<String>,
198 redirect_uris: impl IntoIterator<Item = impl Into<String>>,
199 scopes: impl IntoIterator<Item = impl Into<String>>,
200 ) -> Self {
201 let client_id = client_id.into();
202 let secret = client_secret.into();
203 let redirect_uris = redirect_uris.into_iter().map(Into::into).collect();
204 let scopes = scopes.into_iter().map(Into::into).collect();
205 self.clients.insert(
206 client_id.clone(),
207 ClientConfig {
208 client_id,
209 client_secret: secret,
210 redirect_uris,
211 allowed_scopes: scopes,
212 },
213 );
214 self
215 }
216
217 pub fn with_issuer_suffix(mut self, suffix: impl Into<String>) -> Self {
219 self.issuer_suffix = Some(suffix.into());
220 self
221 }
222
223 pub async fn spawn_on_free_port(self) -> Result<MockServer> {
225 let listener = TcpListener::bind(("127.0.0.1", 0))
226 .await
227 .context("failed to bind mock OAuth listener")?;
228 let addr = listener
229 .local_addr()
230 .context("failed to determine listener address")?;
231 self.spawn_with_listener(listener, addr).await
232 }
233
234 async fn spawn_with_listener(
235 self,
236 listener: TcpListener,
237 addr: SocketAddr,
238 ) -> Result<MockServer> {
239 let base_url = format!("http://{addr}");
240 let issuer = if let Some(suffix) = &self.issuer_suffix {
241 format!("{base_url}/{suffix}")
242 } else {
243 base_url.clone()
244 };
245
246 let signing = generate_signing_keys()?;
247 let state = Arc::new(RwLock::new(InnerState {
248 issuer: issuer.clone(),
249 signing: signing.clone(),
250 clients: self.clients.clone(),
251 user: self.user.clone(),
252 authorization_codes: HashMap::new(),
253 refresh_tokens: HashMap::new(),
254 access_tokens: HashSet::new(),
255 device_codes: HashMap::new(),
256 }));
257
258 let jwks = json!({ "keys": [serde_json::to_value(&signing.jwk)?] });
259
260 let (shutdown_tx, shutdown_rx) = oneshot::channel();
261 let app = router(state.clone());
262 let server = axum::serve(listener, app).with_graceful_shutdown(async {
263 let _ = shutdown_rx.await;
264 });
265
266 let handle = tokio::spawn(async move {
267 if let Err(err) = server.await {
268 eprintln!("oauth-mock server error: {err:?}");
269 }
270 });
271
272 Ok(MockServer {
273 base_url,
274 issuer,
275 jwks,
276 state,
277 shutdown: Some(shutdown_tx),
278 _task: handle,
279 })
280 }
281}
282
283fn router(state: SharedState) -> Router {
284 Router::new()
285 .route("/.well-known/openid-configuration", get(discovery))
286 .route("/jwks.json", get(jwks_endpoint))
287 .route("/authorize", get(authorize))
288 .route("/token", post(token))
289 .route("/device_authorization", post(device_authorization))
290 .route("/userinfo", get(userinfo))
291 .route("/introspect", post(introspect))
292 .route("/revoke", post(revoke))
293 .with_state(state)
294}
295
296pub struct MockServer {
298 base_url: String,
299 issuer: String,
300 jwks: Value,
301 state: SharedState,
302 shutdown: Option<oneshot::Sender<()>>,
303 _task: tokio::task::JoinHandle<()>,
304}
305
306impl MockServer {
307 pub fn builder() -> MockServerBuilder {
308 MockServerBuilder::default()
309 }
310
311 pub async fn spawn_on_free_port() -> Result<Self> {
313 MockServerBuilder::default().spawn_on_free_port().await
314 }
315
316 pub fn base_url(&self) -> &str {
318 &self.base_url
319 }
320
321 pub fn issuer(&self) -> &str {
323 &self.issuer
324 }
325
326 pub fn jwks(&self) -> &Value {
328 &self.jwks
329 }
330
331 pub async fn default_client(&self) -> Option<(String, String)> {
333 let state = self.state.read().await;
334 state
335 .clients
336 .values()
337 .next()
338 .map(|client| (client.client_id.clone(), client.client_secret.clone()))
339 }
340
341 pub async fn approve_device_code(&self, user_code: &str) -> Result<()> {
343 let mut state = self.state.write().await;
344 let entry = state
345 .device_codes
346 .values_mut()
347 .find(|entry| entry.user_code.eq_ignore_ascii_case(user_code))
348 .ok_or_else(|| anyhow!("device code {user_code} not found"))?;
349 entry.status = DeviceCodeStatus::Approved;
350 Ok(())
351 }
352
353 pub async fn deny_device_code(&self, user_code: &str) -> Result<()> {
355 let mut state = self.state.write().await;
356 let entry = state
357 .device_codes
358 .values_mut()
359 .find(|entry| entry.user_code.eq_ignore_ascii_case(user_code))
360 .ok_or_else(|| anyhow!("device code {user_code} not found"))?;
361 entry.status = DeviceCodeStatus::Denied;
362 Ok(())
363 }
364}
365
366impl Drop for MockServer {
367 fn drop(&mut self) {
368 if let Some(tx) = self.shutdown.take() {
369 let _ = tx.send(());
370 }
371 }
372}
373
374#[derive(Debug, Serialize)]
376struct DiscoveryDocument {
377 issuer: String,
378 authorization_endpoint: String,
379 token_endpoint: String,
380 jwks_uri: String,
381 device_authorization_endpoint: String,
382 userinfo_endpoint: String,
383 introspection_endpoint: String,
384 revocation_endpoint: String,
385 response_types_supported: Vec<String>,
386 grant_types_supported: Vec<String>,
387 code_challenge_methods_supported: Vec<String>,
388 scopes_supported: Vec<String>,
389}
390
391async fn discovery(State(state): State<SharedState>) -> impl IntoResponse {
392 let state = state.read().await;
393 let issuer = state.issuer.clone();
394 let doc = DiscoveryDocument {
395 issuer: issuer.clone(),
396 authorization_endpoint: format!("{issuer}/authorize"),
397 token_endpoint: format!("{issuer}/token"),
398 jwks_uri: format!("{issuer}/jwks.json"),
399 device_authorization_endpoint: format!("{issuer}/device_authorization"),
400 userinfo_endpoint: format!("{issuer}/userinfo"),
401 introspection_endpoint: format!("{issuer}/introspect"),
402 revocation_endpoint: format!("{issuer}/revoke"),
403 response_types_supported: vec!["code".into(), "token".into()],
404 grant_types_supported: vec![
405 "authorization_code".into(),
406 "refresh_token".into(),
407 "client_credentials".into(),
408 "urn:ietf:params:oauth:grant-type:device_code".into(),
409 "device_code".into(),
410 ],
411 code_challenge_methods_supported: vec!["S256".into()],
412 scopes_supported: state
413 .clients
414 .values()
415 .flat_map(|client| client.allowed_scopes.iter().cloned())
416 .collect(),
417 };
418 Json(doc)
419}
420
421async fn jwks_endpoint(State(state): State<SharedState>) -> impl IntoResponse {
422 let state = state.read().await;
423 Json(json!({ "keys": [serde_json::to_value(&state.signing.jwk).unwrap()] }))
424}
425
426#[derive(Debug, Deserialize)]
427struct AuthorizeQuery {
428 response_type: String,
429 client_id: String,
430 redirect_uri: String,
431 scope: Option<String>,
432 state: Option<String>,
433 code_challenge: Option<String>,
434 code_challenge_method: Option<String>,
435 nonce: Option<String>,
436}
437
438async fn authorize(
439 State(state): State<SharedState>,
440 Query(query): Query<AuthorizeQuery>,
441) -> Result<Response, MockError> {
442 if query.response_type != "code" {
443 return Err(MockError::invalid_request("unsupported response_type"));
444 }
445 let mut state_guard = state.write().await;
446 let client = state_guard
447 .client(&query.client_id)
448 .cloned()
449 .ok_or_else(|| MockError::invalid_client("unknown client"))?;
450 if !client.redirect_uris.contains(&query.redirect_uri) {
451 return Err(MockError::invalid_request("redirect_uri mismatch"));
452 }
453
454 let scope_set = parse_scope(&query.scope)?;
455 let allowed: HashSet<_> = scope_set
456 .intersection(&client.allowed_scopes)
457 .cloned()
458 .collect();
459 if allowed.is_empty() {
460 return Err(MockError::invalid_scope("no allowed scopes requested"));
461 }
462
463 #[cfg(feature = "pkce")]
464 {
465 if let Some(method) = &query.code_challenge_method {
466 if method != "S256" {
467 return Err(MockError::invalid_request("only S256 accepted"));
468 }
469 } else {
470 return Err(MockError::invalid_request("missing code_challenge_method"));
471 }
472 if query.code_challenge.is_none() {
473 return Err(MockError::invalid_request("missing code_challenge"));
474 }
475 }
476
477 let code = state_guard.generate_code();
478 state_guard.authorization_codes.insert(
479 code.clone(),
480 AuthorizationCode {
481 client_id: client.client_id.clone(),
482 redirect_uri: query.redirect_uri.clone(),
483 scope: allowed,
484 code_challenge: query.code_challenge.clone(),
485 nonce: query.nonce.clone(),
486 _created_at: OffsetDateTime::now_utc(),
487 },
488 );
489
490 let mut redirect = Url::parse(&query.redirect_uri)
491 .map_err(|_| MockError::invalid_request("invalid redirect_uri"))?;
492 {
493 let mut pairs = redirect.query_pairs_mut();
494 pairs.append_pair("code", &code);
495 if let Some(state) = &query.state {
496 pairs.append_pair("state", state);
497 }
498 }
499
500 let response = (
501 StatusCode::SEE_OTHER,
502 [(header::LOCATION, redirect.to_string())],
503 );
504 Ok(response.into_response())
505}
506
507#[derive(Debug, Deserialize)]
508struct TokenRequest {
509 grant_type: String,
510 code: Option<String>,
511 redirect_uri: Option<String>,
512 code_verifier: Option<String>,
513 refresh_token: Option<String>,
514 client_id: Option<String>,
515 client_secret: Option<String>,
516 device_code: Option<String>,
517 scope: Option<String>,
518}
519
520async fn token(
521 State(state): State<SharedState>,
522 headers: HeaderMap,
523 Form(request): Form<TokenRequest>,
524) -> Result<Json<Value>, MockError> {
525 let credentials = extract_client_credentials(&headers, &request)?;
526
527 match request.grant_type.as_str() {
528 "authorization_code" => handle_authorization_code(state, credentials, request).await,
529 "client_credentials" => handle_client_credentials(state, credentials, request).await,
530 "refresh_token" => handle_refresh_token(state, credentials, request).await,
531 "urn:ietf:params:oauth:grant-type:device_code" | "device_code" => {
532 handle_device_code(state, credentials, request).await
533 }
534 other => Err(MockError::invalid_request(format!(
535 "unsupported grant_type {other}"
536 ))),
537 }
538}
539
540async fn handle_authorization_code(
541 state: SharedState,
542 credentials: ClientCredentials,
543 request: TokenRequest,
544) -> Result<Json<Value>, MockError> {
545 let code = request
546 .code
547 .as_ref()
548 .ok_or_else(|| MockError::invalid_request("missing code"))?;
549 let redirect_uri = request
550 .redirect_uri
551 .as_ref()
552 .ok_or_else(|| MockError::invalid_request("missing redirect_uri"))?;
553
554 #[cfg(feature = "pkce")]
555 let code_verifier = request
556 .code_verifier
557 .clone()
558 .ok_or_else(|| MockError::invalid_request("PKCE enabled; code_verifier is required"))?;
559
560 let mut state_guard = state.write().await;
561 let entry = state_guard
562 .authorization_codes
563 .remove(code)
564 .ok_or_else(|| MockError::invalid_grant("invalid authorization code"))?;
565
566 if entry.client_id != credentials.client_id {
567 return Err(MockError::invalid_grant(
568 "authorization code client mismatch",
569 ));
570 }
571 if entry.redirect_uri != *redirect_uri {
572 return Err(MockError::invalid_grant("redirect_uri mismatch"));
573 }
574
575 #[cfg(feature = "pkce")]
576 {
577 let expected = entry
578 .code_challenge
579 .ok_or_else(|| MockError::invalid_grant("missing code challenge"))?;
580 let verified = verify_code_challenge(&code_verifier, &expected)?;
581 if !verified {
582 return Err(MockError::invalid_grant("code_verifier mismatch"));
583 }
584 }
585
586 let client = state_guard
587 .client(&credentials.client_id)
588 .cloned()
589 .ok_or_else(|| MockError::invalid_client("unknown client"))?;
590
591 let scope = entry.scope.clone();
592 let issued_at = OffsetDateTime::now_utc();
593 let access_token = issue_access_token(&state_guard, &client, &scope, issued_at)?;
594 let id_token = issue_id_token(&state_guard, &client, &scope, issued_at, entry.nonce)?;
595 let refresh_token = issue_refresh_token(&mut state_guard, &client, &scope, issued_at)?;
596 state_guard.access_tokens.insert(access_token.clone());
597
598 Ok(Json(json!({
599 "token_type": "Bearer",
600 "expires_in": 3600,
601 "access_token": access_token,
602 "id_token": id_token,
603 "scope": scope_to_string(&scope),
604 "refresh_token": refresh_token,
605 })))
606}
607
608async fn handle_client_credentials(
609 state: SharedState,
610 credentials: ClientCredentials,
611 request: TokenRequest,
612) -> Result<Json<Value>, MockError> {
613 let mut state_guard = state.write().await;
614 let client = state_guard
615 .client(&credentials.client_id)
616 .cloned()
617 .ok_or_else(|| MockError::invalid_client("unknown client"))?;
618
619 if client.client_secret != credentials.client_secret {
620 return Err(MockError::invalid_client("invalid client_secret"));
621 }
622
623 let requested_scope = parse_scope(&request.scope)?;
624 let scope = if requested_scope.is_empty() {
625 client.allowed_scopes.clone()
626 } else {
627 requested_scope
628 .intersection(&client.allowed_scopes)
629 .cloned()
630 .collect()
631 };
632
633 let issued_at = OffsetDateTime::now_utc();
634 let access_token = issue_access_token(&state_guard, &client, &scope, issued_at)?;
635 state_guard.access_tokens.insert(access_token.clone());
636
637 Ok(Json(json!({
638 "token_type": "Bearer",
639 "expires_in": 3600,
640 "access_token": access_token,
641 "scope": scope_to_string(&scope),
642 })))
643}
644
645async fn handle_refresh_token(
646 state: SharedState,
647 credentials: ClientCredentials,
648 request: TokenRequest,
649) -> Result<Json<Value>, MockError> {
650 let refresh_token = request
651 .refresh_token
652 .as_ref()
653 .ok_or_else(|| MockError::invalid_request("missing refresh_token"))?;
654
655 let mut state_guard = state.write().await;
656 let entry = state_guard
657 .refresh_tokens
658 .remove(refresh_token)
659 .ok_or_else(|| MockError::invalid_grant("invalid refresh token"))?;
660
661 if entry.client_id != credentials.client_id {
662 return Err(MockError::invalid_grant(
663 "client mismatch for refresh token",
664 ));
665 }
666
667 let client = state_guard
668 .client(&credentials.client_id)
669 .cloned()
670 .ok_or_else(|| MockError::invalid_client("unknown client"))?;
671
672 let scope = entry.scope.clone();
673 let issued_at = OffsetDateTime::now_utc();
674 let access_token = issue_access_token(&state_guard, &client, &entry.scope, issued_at)?;
675 let id_token = issue_id_token(&state_guard, &client, &scope, issued_at, None)?;
676 let new_refresh_token = issue_refresh_token(&mut state_guard, &client, &scope, issued_at)?;
677 state_guard.access_tokens.insert(access_token.clone());
678
679 Ok(Json(json!({
680 "token_type": "Bearer",
681 "expires_in": 3600,
682 "access_token": access_token,
683 "id_token": id_token,
684 "scope": scope_to_string(&scope),
685 "refresh_token": new_refresh_token,
686 })))
687}
688
689async fn handle_device_code(
690 state: SharedState,
691 credentials: ClientCredentials,
692 request: TokenRequest,
693) -> Result<Json<Value>, MockError> {
694 let device_code = request
695 .device_code
696 .as_ref()
697 .ok_or_else(|| MockError::invalid_request("missing device_code"))?;
698
699 let mut state_guard = state.write().await;
700 let mut entry = state_guard
701 .device_codes
702 .remove(device_code)
703 .ok_or_else(|| MockError::invalid_grant("invalid device_code"))?;
704
705 if entry.client_id != credentials.client_id {
706 state_guard.device_codes.insert(device_code.clone(), entry);
707 return Err(MockError::invalid_client("client mismatch for device_code"));
708 }
709
710 if OffsetDateTime::now_utc() > entry.expires_at {
711 entry.status = DeviceCodeStatus::Expired;
712 }
713
714 let result = match &mut entry.status {
715 DeviceCodeStatus::Pending { poll_count } => {
716 *poll_count += 1;
717 if *poll_count % 3 == 0 {
718 Err(MockError::slow_down())
719 } else {
720 Err(MockError::authorization_pending())
721 }
722 }
723 DeviceCodeStatus::Approved => {
724 let client = state_guard
725 .client(&credentials.client_id)
726 .cloned()
727 .ok_or_else(|| MockError::invalid_client("unknown client"))?;
728 let issued_at = OffsetDateTime::now_utc();
729 let scope = entry.scope.clone();
730 let access_token = issue_access_token(&state_guard, &client, &scope, issued_at)?;
731 let id_token = issue_id_token(&state_guard, &client, &scope, issued_at, None)?;
732 let refresh_token =
733 issue_refresh_token(&mut state_guard, &client, &entry.scope, issued_at)?;
734 state_guard.access_tokens.insert(access_token.clone());
735 entry.status = DeviceCodeStatus::Completed;
736 Ok(Json(json!({
737 "token_type": "Bearer",
738 "expires_in": 3600,
739 "access_token": access_token,
740 "id_token": id_token,
741 "scope": scope_to_string(&scope),
742 "refresh_token": refresh_token,
743 })))
744 }
745 DeviceCodeStatus::Denied => Err(MockError::access_denied()),
746 DeviceCodeStatus::Expired => Err(MockError::expired_token()),
747 DeviceCodeStatus::Completed => Err(MockError::invalid_grant("device_code already used")),
748 };
749
750 state_guard.device_codes.insert(device_code.clone(), entry);
751 result
752}
753
754#[derive(Debug, Deserialize)]
755struct DeviceAuthorizationRequest {
756 client_id: String,
757 scope: Option<String>,
758}
759
760async fn device_authorization(
761 State(state): State<SharedState>,
762 Form(request): Form<DeviceAuthorizationRequest>,
763) -> Result<Json<Value>, MockError> {
764 #[cfg(not(feature = "device_code"))]
765 {
766 let _ = state;
767 let _ = request;
768 return Err(MockError::invalid_request("device_code feature disabled"));
769 }
770
771 #[cfg(feature = "device_code")]
772 {
773 let mut state_guard = state.write().await;
774 let client = state_guard
775 .client(&request.client_id)
776 .cloned()
777 .ok_or_else(|| MockError::invalid_client("unknown client"))?;
778
779 let requested_scope = parse_scope(&request.scope)?;
780 let scope = if requested_scope.is_empty() {
781 client.allowed_scopes.clone()
782 } else {
783 requested_scope
784 .intersection(&client.allowed_scopes)
785 .cloned()
786 .collect()
787 };
788
789 let device_code: String = state_guard.generate_code();
790 let mut rng = rng();
791 let user_code: String = Alphanumeric
792 .sample_string(&mut rng, 8)
793 .chars()
794 .map(|ch| ch.to_ascii_uppercase())
795 .collect();
796
797 let entry = DeviceCodeEntry {
798 client_id: client.client_id.clone(),
799 scope: scope.clone(),
800 _device_code: device_code.clone(),
801 user_code: user_code.clone(),
802 expires_at: OffsetDateTime::now_utc() + TimeDuration::minutes(10),
803 _interval: 5,
804 status: DeviceCodeStatus::Pending { poll_count: 0 },
805 };
806 state_guard.device_codes.insert(device_code.clone(), entry);
807
808 Ok(Json(json!({
809 "device_code": device_code,
810 "user_code": user_code,
811 "verification_uri": format!("{}/device", state_guard.issuer),
812 "verification_uri_complete": format!("{}/device?user_code={}", state_guard.issuer, user_code),
813 "expires_in": 600,
814 "interval": 5,
815 })))
816 }
817}
818
819async fn userinfo(
820 State(state): State<SharedState>,
821 headers: HeaderMap,
822) -> Result<Json<Value>, MockError> {
823 let token = extract_bearer_token(&headers)?;
824
825 let state_guard = state.read().await;
826 if !state_guard.access_tokens.contains(token) {
827 return Err(MockError::invalid_token("unknown access token"));
828 }
829
830 let claims = json!({
831 "sub": state_guard.user.sub,
832 "email": state_guard.user.email,
833 "preferred_username": state_guard.user.preferred_username,
834 "groups": state_guard.user.groups,
835 });
836 Ok(Json(claims))
837}
838
839async fn introspect(
840 State(state): State<SharedState>,
841 headers: HeaderMap,
842 Form(body): Form<HashMap<String, String>>,
843) -> Result<Json<Value>, MockError> {
844 let _ = extract_client_credentials(
845 &headers,
846 &TokenRequest {
847 grant_type: "".into(),
848 code: None,
849 redirect_uri: None,
850 code_verifier: None,
851 refresh_token: None,
852 client_id: None,
853 client_secret: None,
854 device_code: None,
855 scope: None,
856 },
857 )?;
858
859 let token = body
860 .get("token")
861 .cloned()
862 .ok_or_else(|| MockError::invalid_request("missing token"))?;
863 let state_guard = state.read().await;
864 let active = state_guard.access_tokens.contains(&token)
865 || state_guard.refresh_tokens.contains_key(&token);
866
867 Ok(Json(json!({
868 "active": active,
869 "iss": state_guard.issuer,
870 "client_id": "mock-client",
871 "scope": scope_to_string(&DEFAULT_SCOPE),
872 "token_type": "Bearer"
873 })))
874}
875
876async fn revoke(
877 State(state): State<SharedState>,
878 headers: HeaderMap,
879 Form(body): Form<HashMap<String, String>>,
880) -> Result<StatusCode, MockError> {
881 let _ = extract_client_credentials(
882 &headers,
883 &TokenRequest {
884 grant_type: "".into(),
885 code: None,
886 redirect_uri: None,
887 code_verifier: None,
888 refresh_token: None,
889 client_id: None,
890 client_secret: None,
891 device_code: None,
892 scope: None,
893 },
894 )?;
895 let token = body
896 .get("token")
897 .cloned()
898 .ok_or_else(|| MockError::invalid_request("missing token"))?;
899 let mut state_guard = state.write().await;
900 state_guard.access_tokens.remove(&token);
901 state_guard.refresh_tokens.remove(&token);
902 Ok(StatusCode::OK)
903}
904
905#[derive(Debug, Clone)]
906struct ClientCredentials {
907 client_id: String,
908 client_secret: String,
909}
910
911fn extract_client_credentials(
912 headers: &HeaderMap,
913 request: &TokenRequest,
914) -> Result<ClientCredentials, MockError> {
915 if let Some(header_value) = headers.get(header::AUTHORIZATION) {
916 let auth = header_value
917 .to_str()
918 .map_err(|_| MockError::invalid_client("invalid Authorization header"))?;
919 if let Some(encoded) = auth.strip_prefix("Basic ") {
920 let decoded = STANDARD
921 .decode(encoded)
922 .map_err(|_| MockError::invalid_client("invalid basic auth"))?;
923 let decoded = String::from_utf8(decoded)
924 .map_err(|_| MockError::invalid_client("invalid utf8 basic auth"))?;
925 if let Some((id, secret)) = decoded.split_once(':') {
926 return Ok(ClientCredentials {
927 client_id: id.to_string(),
928 client_secret: secret.to_string(),
929 });
930 }
931 }
932 return Err(MockError::invalid_client("invalid Authorization header"));
933 }
934
935 let client_id = request
936 .client_id
937 .clone()
938 .ok_or_else(|| MockError::invalid_client("missing client_id"))?;
939 let client_secret = request
940 .client_secret
941 .clone()
942 .ok_or_else(|| MockError::invalid_client("missing client_secret"))?;
943
944 Ok(ClientCredentials {
945 client_id,
946 client_secret,
947 })
948}
949
950fn extract_bearer_token(headers: &HeaderMap) -> Result<&str, MockError> {
951 let auth = headers
952 .get(header::AUTHORIZATION)
953 .and_then(|value| value.to_str().ok())
954 .ok_or_else(|| MockError::invalid_token("missing Authorization header"))?;
955 auth.strip_prefix("Bearer ")
956 .ok_or_else(|| MockError::invalid_token("invalid bearer token header"))
957}
958
959fn issue_access_token(
960 state: &InnerState,
961 client: &ClientConfig,
962 scope: &HashSet<String>,
963 issued_at: OffsetDateTime,
964) -> Result<String, MockError> {
965 #[derive(Debug, Serialize)]
966 struct AccessClaims<'a> {
967 iss: &'a str,
968 sub: &'a str,
969 aud: &'a str,
970 exp: i64,
971 iat: i64,
972 scope: String,
973 client_id: &'a str,
974 jti: String,
975 }
976
977 let claims = AccessClaims {
978 iss: &state.issuer,
979 sub: &state.user.sub,
980 aud: &client.client_id,
981 exp: (issued_at + TimeDuration::hours(1)).unix_timestamp(),
982 iat: issued_at.unix_timestamp(),
983 scope: scope_to_string(scope),
984 client_id: &client.client_id,
985 jti: Uuid::new_v4().to_string(),
986 };
987
988 jsonwebtoken::encode(
989 &jsonwebtoken::Header {
990 alg: jsonwebtoken::Algorithm::RS256,
991 kid: Some(state.signing.jwk.kid.clone()),
992 ..jsonwebtoken::Header::default()
993 },
994 &claims,
995 &state.signing.encoding,
996 )
997 .map_err(|err| MockError::server_error(format!("encode access token: {err}")))
998}
999
1000fn issue_id_token(
1001 state: &InnerState,
1002 client: &ClientConfig,
1003 scope: &HashSet<String>,
1004 issued_at: OffsetDateTime,
1005 nonce: Option<String>,
1006) -> Result<String, MockError> {
1007 #[derive(Debug, Serialize)]
1008 struct IdClaims<'a> {
1009 iss: &'a str,
1010 sub: &'a str,
1011 aud: &'a str,
1012 exp: i64,
1013 iat: i64,
1014 email: &'a str,
1015 preferred_username: &'a str,
1016 groups: &'a [String],
1017 scope: String,
1018 #[serde(skip_serializing_if = "Option::is_none")]
1019 nonce: Option<String>,
1020 }
1021
1022 let claims = IdClaims {
1023 iss: &state.issuer,
1024 sub: &state.user.sub,
1025 aud: &client.client_id,
1026 exp: (issued_at + TimeDuration::hours(1)).unix_timestamp(),
1027 iat: issued_at.unix_timestamp(),
1028 email: &state.user.email,
1029 preferred_username: &state.user.preferred_username,
1030 groups: &state.user.groups,
1031 scope: scope_to_string(scope),
1032 nonce,
1033 };
1034
1035 jsonwebtoken::encode(
1036 &jsonwebtoken::Header {
1037 alg: jsonwebtoken::Algorithm::RS256,
1038 kid: Some(state.signing.jwk.kid.clone()),
1039 ..jsonwebtoken::Header::default()
1040 },
1041 &claims,
1042 &state.signing.encoding,
1043 )
1044 .map_err(|err| MockError::server_error(format!("encode id token: {err}")))
1045}
1046
1047fn issue_refresh_token(
1048 state: &mut InnerState,
1049 client: &ClientConfig,
1050 scope: &HashSet<String>,
1051 issued_at: OffsetDateTime,
1052) -> Result<String, MockError> {
1053 let refresh_token = state.generate_code();
1054 state.refresh_tokens.insert(
1055 refresh_token.clone(),
1056 RefreshTokenEntry {
1057 client_id: client.client_id.clone(),
1058 scope: scope.clone(),
1059 _subject: state.user.clone(),
1060 _issued_at: issued_at,
1061 },
1062 );
1063 Ok(refresh_token)
1064}
1065
1066fn parse_scope(scope: &Option<String>) -> Result<HashSet<String>, MockError> {
1067 Ok(scope
1068 .as_ref()
1069 .map(|value| {
1070 value
1071 .split_whitespace()
1072 .filter(|part| !part.is_empty())
1073 .map(|part| part.to_string())
1074 .collect()
1075 })
1076 .unwrap_or_default())
1077}
1078
1079fn scope_to_string(scope: &HashSet<String>) -> String {
1080 let mut parts: Vec<_> = scope.iter().cloned().collect();
1081 parts.sort();
1082 parts.join(" ")
1083}
1084
1085fn verify_code_challenge(verifier: &str, expected_challenge: &str) -> Result<bool, MockError> {
1086 use sha2::{Digest, Sha256};
1087 let hashed = Sha256::digest(verifier.as_bytes());
1088 let encoded = URL_SAFE_NO_PAD.encode(hashed);
1089 Ok(encoded == expected_challenge)
1090}
1091
1092fn generate_signing_keys() -> Result<SigningKeys> {
1093 use rsa::rand_core::OsRng;
1094 use rsa::traits::PublicKeyParts;
1095 use rsa::{RsaPrivateKey, pkcs1::EncodeRsaPrivateKey};
1096
1097 let mut rng = OsRng;
1098 let private_key = RsaPrivateKey::new(&mut rng, 2048).context("generate RSA key")?;
1099 let public_key = private_key.to_public_key();
1100
1101 let pem = private_key
1102 .to_pkcs1_pem(Default::default())
1103 .context("encode RSA key to PEM")?;
1104 let encoding =
1105 jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).context("create encoding key")?;
1106 let jwk = Jwk {
1107 kty: "RSA".into(),
1108 use_: "sig".into(),
1109 kid: Uuid::new_v4().to_string(),
1110 alg: "RS256".into(),
1111 n: URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be()),
1112 e: URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be()),
1113 };
1114
1115 Ok(SigningKeys { encoding, jwk })
1116}
1117
1118#[derive(Debug, Error)]
1119enum MockError {
1120 #[error("invalid_request: {0}")]
1121 InvalidRequest(String),
1122 #[error("invalid_client: {0}")]
1123 InvalidClient(String),
1124 #[error("invalid_grant: {0}")]
1125 InvalidGrant(String),
1126 #[error("invalid_scope: {0}")]
1127 InvalidScope(String),
1128 #[error("invalid_token: {0}")]
1129 InvalidToken(String),
1130 #[error("access_denied")]
1131 AccessDenied,
1132 #[error("authorization_pending")]
1133 AuthorizationPending,
1134 #[error("slow_down")]
1135 SlowDown,
1136 #[error("expired_token")]
1137 ExpiredToken,
1138 #[error("server_error: {0}")]
1139 ServerError(String),
1140}
1141
1142impl MockError {
1143 fn invalid_request<T: Into<String>>(msg: T) -> Self {
1144 Self::InvalidRequest(msg.into())
1145 }
1146 fn invalid_client<T: Into<String>>(msg: T) -> Self {
1147 Self::InvalidClient(msg.into())
1148 }
1149 fn invalid_grant<T: Into<String>>(msg: T) -> Self {
1150 Self::InvalidGrant(msg.into())
1151 }
1152 fn invalid_scope<T: Into<String>>(msg: T) -> Self {
1153 Self::InvalidScope(msg.into())
1154 }
1155 fn invalid_token<T: Into<String>>(msg: T) -> Self {
1156 Self::InvalidToken(msg.into())
1157 }
1158 fn server_error<T: Into<String>>(msg: T) -> Self {
1159 Self::ServerError(msg.into())
1160 }
1161 fn authorization_pending() -> Self {
1162 Self::AuthorizationPending
1163 }
1164 fn slow_down() -> Self {
1165 Self::SlowDown
1166 }
1167 fn access_denied() -> Self {
1168 Self::AccessDenied
1169 }
1170 fn expired_token() -> Self {
1171 Self::ExpiredToken
1172 }
1173}
1174
1175impl IntoResponse for MockError {
1176 fn into_response(self) -> Response {
1177 let (status, body) = match self {
1178 MockError::InvalidRequest(msg) => {
1179 (StatusCode::BAD_REQUEST, json_error("invalid_request", msg))
1180 }
1181 MockError::InvalidClient(msg) => {
1182 (StatusCode::UNAUTHORIZED, json_error("invalid_client", msg))
1183 }
1184 MockError::InvalidGrant(msg) => {
1185 (StatusCode::BAD_REQUEST, json_error("invalid_grant", msg))
1186 }
1187 MockError::InvalidScope(msg) => {
1188 (StatusCode::BAD_REQUEST, json_error("invalid_scope", msg))
1189 }
1190 MockError::InvalidToken(msg) => {
1191 (StatusCode::UNAUTHORIZED, json_error("invalid_token", msg))
1192 }
1193 MockError::AccessDenied => (
1194 StatusCode::BAD_REQUEST,
1195 json_error("access_denied", "user denied the request"),
1196 ),
1197 MockError::AuthorizationPending => (
1198 StatusCode::BAD_REQUEST,
1199 json_error("authorization_pending", "authorization pending"),
1200 ),
1201 MockError::SlowDown => (
1202 StatusCode::BAD_REQUEST,
1203 json_error("slow_down", "slow down"),
1204 ),
1205 MockError::ExpiredToken => (
1206 StatusCode::BAD_REQUEST,
1207 json_error("expired_token", "device code expired"),
1208 ),
1209 MockError::ServerError(msg) => (
1210 StatusCode::INTERNAL_SERVER_ERROR,
1211 json_error("server_error", msg),
1212 ),
1213 };
1214 (status, Json(body)).into_response()
1215 }
1216}
1217
1218fn json_error(code: impl Into<String>, description: impl Into<String>) -> Value {
1219 json!({
1220 "error": code.into(),
1221 "error_description": description.into(),
1222 })
1223}