oauth_mock/
lib.rs

1use std::{
2    collections::{HashMap, HashSet},
3    net::SocketAddr,
4    sync::Arc,
5};
6
7use anyhow::{Context, Result, anyhow};
8use axum::{
9    Json, Router,
10    extract::{Form, Query, State},
11    http::{HeaderMap, StatusCode, header},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14};
15use base64::{
16    Engine as _,
17    engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD},
18};
19use once_cell::sync::Lazy;
20use rand::{
21    distr::{Alphanumeric, SampleString},
22    rng,
23};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use thiserror::Error;
27use time::{Duration as TimeDuration, OffsetDateTime};
28use tokio::{
29    net::TcpListener,
30    sync::{RwLock, oneshot},
31};
32use url::Url;
33use uuid::Uuid;
34
35static DEFAULT_SCOPE: Lazy<HashSet<String>> = Lazy::new(|| {
36    ["openid", "profile", "email"]
37        .into_iter()
38        .map(|s| s.to_string())
39        .collect()
40});
41
42#[derive(Debug, Clone, Serialize)]
43struct Jwk {
44    kty: String,
45    use_: String,
46    kid: String,
47    alg: String,
48    n: String,
49    e: String,
50}
51
52#[derive(Clone)]
53struct SigningKeys {
54    encoding: jsonwebtoken::EncodingKey,
55    jwk: Jwk,
56}
57
58impl std::fmt::Debug for SigningKeys {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("SigningKeys")
61            .field("jwk", &self.jwk)
62            .finish_non_exhaustive()
63    }
64}
65
66#[derive(Debug, Clone)]
67struct ClientConfig {
68    client_id: String,
69    client_secret: String,
70    redirect_uris: HashSet<String>,
71    allowed_scopes: HashSet<String>,
72}
73
74#[derive(Debug, Clone)]
75pub struct MockUser {
76    sub: String,
77    email: String,
78    preferred_username: String,
79    groups: Vec<String>,
80}
81
82impl Default for MockUser {
83    fn default() -> Self {
84        Self {
85            sub: "user-123".to_string(),
86            email: "mock.user@example.com".to_string(),
87            preferred_username: "mock.user".to_string(),
88            groups: vec!["mockers".into(), "testers".into()],
89        }
90    }
91}
92
93#[derive(Debug, Clone)]
94struct AuthorizationCode {
95    client_id: String,
96    redirect_uri: String,
97    scope: HashSet<String>,
98    code_challenge: Option<String>,
99    nonce: Option<String>,
100    _created_at: OffsetDateTime,
101}
102
103#[derive(Debug, Clone)]
104struct RefreshTokenEntry {
105    client_id: String,
106    scope: HashSet<String>,
107    _subject: MockUser,
108    _issued_at: OffsetDateTime,
109}
110
111#[derive(Debug, Clone)]
112enum DeviceCodeStatus {
113    Pending { poll_count: u32 },
114    Approved,
115    Denied,
116    Expired,
117    Completed,
118}
119
120#[derive(Debug, Clone)]
121struct DeviceCodeEntry {
122    client_id: String,
123    scope: HashSet<String>,
124    _device_code: String,
125    user_code: String,
126    expires_at: OffsetDateTime,
127    _interval: u64,
128    status: DeviceCodeStatus,
129}
130
131#[derive(Debug)]
132struct InnerState {
133    issuer: String,
134    signing: SigningKeys,
135    clients: HashMap<String, ClientConfig>,
136    user: MockUser,
137    authorization_codes: HashMap<String, AuthorizationCode>,
138    refresh_tokens: HashMap<String, RefreshTokenEntry>,
139    access_tokens: HashSet<String>,
140    device_codes: HashMap<String, DeviceCodeEntry>,
141}
142
143impl InnerState {
144    fn generate_code(&self) -> String {
145        let mut rng = rng();
146        Alphanumeric.sample_string(&mut rng, 32)
147    }
148
149    fn client(&self, client_id: &str) -> Option<&ClientConfig> {
150        self.clients.get(client_id)
151    }
152}
153
154type SharedState = Arc<RwLock<InnerState>>;
155
156/// Builder for configuring the mock server.
157#[derive(Debug, Clone)]
158pub struct MockServerBuilder {
159    clients: HashMap<String, ClientConfig>,
160    user: MockUser,
161    issuer_suffix: Option<String>,
162}
163
164impl Default for MockServerBuilder {
165    fn default() -> Self {
166        let mut clients = HashMap::new();
167        clients.insert(
168            "mock-client".into(),
169            ClientConfig {
170                client_id: "mock-client".into(),
171                client_secret: "mock-secret".into(),
172                redirect_uris: ["https://example.com/callback".into()]
173                    .into_iter()
174                    .collect(),
175                allowed_scopes: DEFAULT_SCOPE.clone(),
176            },
177        );
178        Self {
179            clients,
180            user: MockUser::default(),
181            issuer_suffix: None,
182        }
183    }
184}
185
186impl MockServerBuilder {
187    /// Overrides the default mock user.
188    pub fn with_user(mut self, user: MockUser) -> Self {
189        self.user = user;
190        self
191    }
192
193    /// Adds or replaces a client configuration.
194    pub fn with_client(
195        mut self,
196        client_id: impl Into<String>,
197        client_secret: impl Into<String>,
198        redirect_uris: impl IntoIterator<Item = impl Into<String>>,
199        scopes: impl IntoIterator<Item = impl Into<String>>,
200    ) -> Self {
201        let client_id = client_id.into();
202        let secret = client_secret.into();
203        let redirect_uris = redirect_uris.into_iter().map(Into::into).collect();
204        let scopes = scopes.into_iter().map(Into::into).collect();
205        self.clients.insert(
206            client_id.clone(),
207            ClientConfig {
208                client_id,
209                client_secret: secret,
210                redirect_uris,
211                allowed_scopes: scopes,
212            },
213        );
214        self
215    }
216
217    /// Customises the issuer suffix (useful when sharing base URLs across tests).
218    pub fn with_issuer_suffix(mut self, suffix: impl Into<String>) -> Self {
219        self.issuer_suffix = Some(suffix.into());
220        self
221    }
222
223    /// Spawns the server using a random free port.
224    pub async fn spawn_on_free_port(self) -> Result<MockServer> {
225        let listener = TcpListener::bind(("127.0.0.1", 0))
226            .await
227            .context("failed to bind mock OAuth listener")?;
228        let addr = listener
229            .local_addr()
230            .context("failed to determine listener address")?;
231        self.spawn_with_listener(listener, addr).await
232    }
233
234    async fn spawn_with_listener(
235        self,
236        listener: TcpListener,
237        addr: SocketAddr,
238    ) -> Result<MockServer> {
239        let base_url = format!("http://{addr}");
240        let issuer = if let Some(suffix) = &self.issuer_suffix {
241            format!("{base_url}/{suffix}")
242        } else {
243            base_url.clone()
244        };
245
246        let signing = generate_signing_keys()?;
247        let state = Arc::new(RwLock::new(InnerState {
248            issuer: issuer.clone(),
249            signing: signing.clone(),
250            clients: self.clients.clone(),
251            user: self.user.clone(),
252            authorization_codes: HashMap::new(),
253            refresh_tokens: HashMap::new(),
254            access_tokens: HashSet::new(),
255            device_codes: HashMap::new(),
256        }));
257
258        let jwks = json!({ "keys": [serde_json::to_value(&signing.jwk)?] });
259
260        let (shutdown_tx, shutdown_rx) = oneshot::channel();
261        let app = router(state.clone());
262        let server = axum::serve(listener, app).with_graceful_shutdown(async {
263            let _ = shutdown_rx.await;
264        });
265
266        let handle = tokio::spawn(async move {
267            if let Err(err) = server.await {
268                eprintln!("oauth-mock server error: {err:?}");
269            }
270        });
271
272        Ok(MockServer {
273            base_url,
274            issuer,
275            jwks,
276            state,
277            shutdown: Some(shutdown_tx),
278            _task: handle,
279        })
280    }
281}
282
283fn router(state: SharedState) -> Router {
284    Router::new()
285        .route("/.well-known/openid-configuration", get(discovery))
286        .route("/jwks.json", get(jwks_endpoint))
287        .route("/authorize", get(authorize))
288        .route("/token", post(token))
289        .route("/device_authorization", post(device_authorization))
290        .route("/userinfo", get(userinfo))
291        .route("/introspect", post(introspect))
292        .route("/revoke", post(revoke))
293        .with_state(state)
294}
295
296/// Running instance of the mock server.
297pub struct MockServer {
298    base_url: String,
299    issuer: String,
300    jwks: Value,
301    state: SharedState,
302    shutdown: Option<oneshot::Sender<()>>,
303    _task: tokio::task::JoinHandle<()>,
304}
305
306impl MockServer {
307    pub fn builder() -> MockServerBuilder {
308        MockServerBuilder::default()
309    }
310
311    /// Convenience helper to spawn using defaults.
312    pub async fn spawn_on_free_port() -> Result<Self> {
313        MockServerBuilder::default().spawn_on_free_port().await
314    }
315
316    /// Returns the base URL (http://host:port).
317    pub fn base_url(&self) -> &str {
318        &self.base_url
319    }
320
321    /// Returns the configured issuer URL.
322    pub fn issuer(&self) -> &str {
323        &self.issuer
324    }
325
326    /// Returns the JWKS document.
327    pub fn jwks(&self) -> &Value {
328        &self.jwks
329    }
330
331    /// Retrieves the default client credentials (first configured client).
332    pub async fn default_client(&self) -> Option<(String, String)> {
333        let state = self.state.read().await;
334        state
335            .clients
336            .values()
337            .next()
338            .map(|client| (client.client_id.clone(), client.client_secret.clone()))
339    }
340
341    /// Marks a device user code as approved.
342    pub async fn approve_device_code(&self, user_code: &str) -> Result<()> {
343        let mut state = self.state.write().await;
344        let entry = state
345            .device_codes
346            .values_mut()
347            .find(|entry| entry.user_code.eq_ignore_ascii_case(user_code))
348            .ok_or_else(|| anyhow!("device code {user_code} not found"))?;
349        entry.status = DeviceCodeStatus::Approved;
350        Ok(())
351    }
352
353    /// Denies a device code request.
354    pub async fn deny_device_code(&self, user_code: &str) -> Result<()> {
355        let mut state = self.state.write().await;
356        let entry = state
357            .device_codes
358            .values_mut()
359            .find(|entry| entry.user_code.eq_ignore_ascii_case(user_code))
360            .ok_or_else(|| anyhow!("device code {user_code} not found"))?;
361        entry.status = DeviceCodeStatus::Denied;
362        Ok(())
363    }
364}
365
366impl Drop for MockServer {
367    fn drop(&mut self) {
368        if let Some(tx) = self.shutdown.take() {
369            let _ = tx.send(());
370        }
371    }
372}
373
374/// Metadata returned by the discovery endpoint.
375#[derive(Debug, Serialize)]
376struct DiscoveryDocument {
377    issuer: String,
378    authorization_endpoint: String,
379    token_endpoint: String,
380    jwks_uri: String,
381    device_authorization_endpoint: String,
382    userinfo_endpoint: String,
383    introspection_endpoint: String,
384    revocation_endpoint: String,
385    response_types_supported: Vec<String>,
386    grant_types_supported: Vec<String>,
387    code_challenge_methods_supported: Vec<String>,
388    scopes_supported: Vec<String>,
389}
390
391async fn discovery(State(state): State<SharedState>) -> impl IntoResponse {
392    let state = state.read().await;
393    let issuer = state.issuer.clone();
394    let doc = DiscoveryDocument {
395        issuer: issuer.clone(),
396        authorization_endpoint: format!("{issuer}/authorize"),
397        token_endpoint: format!("{issuer}/token"),
398        jwks_uri: format!("{issuer}/jwks.json"),
399        device_authorization_endpoint: format!("{issuer}/device_authorization"),
400        userinfo_endpoint: format!("{issuer}/userinfo"),
401        introspection_endpoint: format!("{issuer}/introspect"),
402        revocation_endpoint: format!("{issuer}/revoke"),
403        response_types_supported: vec!["code".into(), "token".into()],
404        grant_types_supported: vec![
405            "authorization_code".into(),
406            "refresh_token".into(),
407            "client_credentials".into(),
408            "urn:ietf:params:oauth:grant-type:device_code".into(),
409            "device_code".into(),
410        ],
411        code_challenge_methods_supported: vec!["S256".into()],
412        scopes_supported: state
413            .clients
414            .values()
415            .flat_map(|client| client.allowed_scopes.iter().cloned())
416            .collect(),
417    };
418    Json(doc)
419}
420
421async fn jwks_endpoint(State(state): State<SharedState>) -> impl IntoResponse {
422    let state = state.read().await;
423    Json(json!({ "keys": [serde_json::to_value(&state.signing.jwk).unwrap()] }))
424}
425
426#[derive(Debug, Deserialize)]
427struct AuthorizeQuery {
428    response_type: String,
429    client_id: String,
430    redirect_uri: String,
431    scope: Option<String>,
432    state: Option<String>,
433    code_challenge: Option<String>,
434    code_challenge_method: Option<String>,
435    nonce: Option<String>,
436}
437
438async fn authorize(
439    State(state): State<SharedState>,
440    Query(query): Query<AuthorizeQuery>,
441) -> Result<Response, MockError> {
442    if query.response_type != "code" {
443        return Err(MockError::invalid_request("unsupported response_type"));
444    }
445    let mut state_guard = state.write().await;
446    let client = state_guard
447        .client(&query.client_id)
448        .cloned()
449        .ok_or_else(|| MockError::invalid_client("unknown client"))?;
450    if !client.redirect_uris.contains(&query.redirect_uri) {
451        return Err(MockError::invalid_request("redirect_uri mismatch"));
452    }
453
454    let scope_set = parse_scope(&query.scope)?;
455    let allowed: HashSet<_> = scope_set
456        .intersection(&client.allowed_scopes)
457        .cloned()
458        .collect();
459    if allowed.is_empty() {
460        return Err(MockError::invalid_scope("no allowed scopes requested"));
461    }
462
463    #[cfg(feature = "pkce")]
464    {
465        if let Some(method) = &query.code_challenge_method {
466            if method != "S256" {
467                return Err(MockError::invalid_request("only S256 accepted"));
468            }
469        } else {
470            return Err(MockError::invalid_request("missing code_challenge_method"));
471        }
472        if query.code_challenge.is_none() {
473            return Err(MockError::invalid_request("missing code_challenge"));
474        }
475    }
476
477    let code = state_guard.generate_code();
478    state_guard.authorization_codes.insert(
479        code.clone(),
480        AuthorizationCode {
481            client_id: client.client_id.clone(),
482            redirect_uri: query.redirect_uri.clone(),
483            scope: allowed,
484            code_challenge: query.code_challenge.clone(),
485            nonce: query.nonce.clone(),
486            _created_at: OffsetDateTime::now_utc(),
487        },
488    );
489
490    let mut redirect = Url::parse(&query.redirect_uri)
491        .map_err(|_| MockError::invalid_request("invalid redirect_uri"))?;
492    {
493        let mut pairs = redirect.query_pairs_mut();
494        pairs.append_pair("code", &code);
495        if let Some(state) = &query.state {
496            pairs.append_pair("state", state);
497        }
498    }
499
500    let response = (
501        StatusCode::SEE_OTHER,
502        [(header::LOCATION, redirect.to_string())],
503    );
504    Ok(response.into_response())
505}
506
507#[derive(Debug, Deserialize)]
508struct TokenRequest {
509    grant_type: String,
510    code: Option<String>,
511    redirect_uri: Option<String>,
512    code_verifier: Option<String>,
513    refresh_token: Option<String>,
514    client_id: Option<String>,
515    client_secret: Option<String>,
516    device_code: Option<String>,
517    scope: Option<String>,
518}
519
520async fn token(
521    State(state): State<SharedState>,
522    headers: HeaderMap,
523    Form(request): Form<TokenRequest>,
524) -> Result<Json<Value>, MockError> {
525    let credentials = extract_client_credentials(&headers, &request)?;
526
527    match request.grant_type.as_str() {
528        "authorization_code" => handle_authorization_code(state, credentials, request).await,
529        "client_credentials" => handle_client_credentials(state, credentials, request).await,
530        "refresh_token" => handle_refresh_token(state, credentials, request).await,
531        "urn:ietf:params:oauth:grant-type:device_code" | "device_code" => {
532            handle_device_code(state, credentials, request).await
533        }
534        other => Err(MockError::invalid_request(format!(
535            "unsupported grant_type {other}"
536        ))),
537    }
538}
539
540async fn handle_authorization_code(
541    state: SharedState,
542    credentials: ClientCredentials,
543    request: TokenRequest,
544) -> Result<Json<Value>, MockError> {
545    let code = request
546        .code
547        .as_ref()
548        .ok_or_else(|| MockError::invalid_request("missing code"))?;
549    let redirect_uri = request
550        .redirect_uri
551        .as_ref()
552        .ok_or_else(|| MockError::invalid_request("missing redirect_uri"))?;
553
554    #[cfg(feature = "pkce")]
555    let code_verifier = request
556        .code_verifier
557        .clone()
558        .ok_or_else(|| MockError::invalid_request("PKCE enabled; code_verifier is required"))?;
559
560    let mut state_guard = state.write().await;
561    let entry = state_guard
562        .authorization_codes
563        .remove(code)
564        .ok_or_else(|| MockError::invalid_grant("invalid authorization code"))?;
565
566    if entry.client_id != credentials.client_id {
567        return Err(MockError::invalid_grant(
568            "authorization code client mismatch",
569        ));
570    }
571    if entry.redirect_uri != *redirect_uri {
572        return Err(MockError::invalid_grant("redirect_uri mismatch"));
573    }
574
575    #[cfg(feature = "pkce")]
576    {
577        let expected = entry
578            .code_challenge
579            .ok_or_else(|| MockError::invalid_grant("missing code challenge"))?;
580        let verified = verify_code_challenge(&code_verifier, &expected)?;
581        if !verified {
582            return Err(MockError::invalid_grant("code_verifier mismatch"));
583        }
584    }
585
586    let client = state_guard
587        .client(&credentials.client_id)
588        .cloned()
589        .ok_or_else(|| MockError::invalid_client("unknown client"))?;
590
591    let scope = entry.scope.clone();
592    let issued_at = OffsetDateTime::now_utc();
593    let access_token = issue_access_token(&state_guard, &client, &scope, issued_at)?;
594    let id_token = issue_id_token(&state_guard, &client, &scope, issued_at, entry.nonce)?;
595    let refresh_token = issue_refresh_token(&mut state_guard, &client, &scope, issued_at)?;
596    state_guard.access_tokens.insert(access_token.clone());
597
598    Ok(Json(json!({
599        "token_type": "Bearer",
600        "expires_in": 3600,
601        "access_token": access_token,
602        "id_token": id_token,
603        "scope": scope_to_string(&scope),
604        "refresh_token": refresh_token,
605    })))
606}
607
608async fn handle_client_credentials(
609    state: SharedState,
610    credentials: ClientCredentials,
611    request: TokenRequest,
612) -> Result<Json<Value>, MockError> {
613    let mut state_guard = state.write().await;
614    let client = state_guard
615        .client(&credentials.client_id)
616        .cloned()
617        .ok_or_else(|| MockError::invalid_client("unknown client"))?;
618
619    if client.client_secret != credentials.client_secret {
620        return Err(MockError::invalid_client("invalid client_secret"));
621    }
622
623    let requested_scope = parse_scope(&request.scope)?;
624    let scope = if requested_scope.is_empty() {
625        client.allowed_scopes.clone()
626    } else {
627        requested_scope
628            .intersection(&client.allowed_scopes)
629            .cloned()
630            .collect()
631    };
632
633    let issued_at = OffsetDateTime::now_utc();
634    let access_token = issue_access_token(&state_guard, &client, &scope, issued_at)?;
635    state_guard.access_tokens.insert(access_token.clone());
636
637    Ok(Json(json!({
638        "token_type": "Bearer",
639        "expires_in": 3600,
640        "access_token": access_token,
641        "scope": scope_to_string(&scope),
642    })))
643}
644
645async fn handle_refresh_token(
646    state: SharedState,
647    credentials: ClientCredentials,
648    request: TokenRequest,
649) -> Result<Json<Value>, MockError> {
650    let refresh_token = request
651        .refresh_token
652        .as_ref()
653        .ok_or_else(|| MockError::invalid_request("missing refresh_token"))?;
654
655    let mut state_guard = state.write().await;
656    let entry = state_guard
657        .refresh_tokens
658        .remove(refresh_token)
659        .ok_or_else(|| MockError::invalid_grant("invalid refresh token"))?;
660
661    if entry.client_id != credentials.client_id {
662        return Err(MockError::invalid_grant(
663            "client mismatch for refresh token",
664        ));
665    }
666
667    let client = state_guard
668        .client(&credentials.client_id)
669        .cloned()
670        .ok_or_else(|| MockError::invalid_client("unknown client"))?;
671
672    let scope = entry.scope.clone();
673    let issued_at = OffsetDateTime::now_utc();
674    let access_token = issue_access_token(&state_guard, &client, &entry.scope, issued_at)?;
675    let id_token = issue_id_token(&state_guard, &client, &scope, issued_at, None)?;
676    let new_refresh_token = issue_refresh_token(&mut state_guard, &client, &scope, issued_at)?;
677    state_guard.access_tokens.insert(access_token.clone());
678
679    Ok(Json(json!({
680        "token_type": "Bearer",
681        "expires_in": 3600,
682        "access_token": access_token,
683        "id_token": id_token,
684        "scope": scope_to_string(&scope),
685        "refresh_token": new_refresh_token,
686    })))
687}
688
689async fn handle_device_code(
690    state: SharedState,
691    credentials: ClientCredentials,
692    request: TokenRequest,
693) -> Result<Json<Value>, MockError> {
694    let device_code = request
695        .device_code
696        .as_ref()
697        .ok_or_else(|| MockError::invalid_request("missing device_code"))?;
698
699    let mut state_guard = state.write().await;
700    let mut entry = state_guard
701        .device_codes
702        .remove(device_code)
703        .ok_or_else(|| MockError::invalid_grant("invalid device_code"))?;
704
705    if entry.client_id != credentials.client_id {
706        state_guard.device_codes.insert(device_code.clone(), entry);
707        return Err(MockError::invalid_client("client mismatch for device_code"));
708    }
709
710    if OffsetDateTime::now_utc() > entry.expires_at {
711        entry.status = DeviceCodeStatus::Expired;
712    }
713
714    let result = match &mut entry.status {
715        DeviceCodeStatus::Pending { poll_count } => {
716            *poll_count += 1;
717            if *poll_count % 3 == 0 {
718                Err(MockError::slow_down())
719            } else {
720                Err(MockError::authorization_pending())
721            }
722        }
723        DeviceCodeStatus::Approved => {
724            let client = state_guard
725                .client(&credentials.client_id)
726                .cloned()
727                .ok_or_else(|| MockError::invalid_client("unknown client"))?;
728            let issued_at = OffsetDateTime::now_utc();
729            let scope = entry.scope.clone();
730            let access_token = issue_access_token(&state_guard, &client, &scope, issued_at)?;
731            let id_token = issue_id_token(&state_guard, &client, &scope, issued_at, None)?;
732            let refresh_token =
733                issue_refresh_token(&mut state_guard, &client, &entry.scope, issued_at)?;
734            state_guard.access_tokens.insert(access_token.clone());
735            entry.status = DeviceCodeStatus::Completed;
736            Ok(Json(json!({
737                "token_type": "Bearer",
738                "expires_in": 3600,
739                "access_token": access_token,
740                "id_token": id_token,
741                "scope": scope_to_string(&scope),
742                "refresh_token": refresh_token,
743            })))
744        }
745        DeviceCodeStatus::Denied => Err(MockError::access_denied()),
746        DeviceCodeStatus::Expired => Err(MockError::expired_token()),
747        DeviceCodeStatus::Completed => Err(MockError::invalid_grant("device_code already used")),
748    };
749
750    state_guard.device_codes.insert(device_code.clone(), entry);
751    result
752}
753
754#[derive(Debug, Deserialize)]
755struct DeviceAuthorizationRequest {
756    client_id: String,
757    scope: Option<String>,
758}
759
760async fn device_authorization(
761    State(state): State<SharedState>,
762    Form(request): Form<DeviceAuthorizationRequest>,
763) -> Result<Json<Value>, MockError> {
764    #[cfg(not(feature = "device_code"))]
765    {
766        let _ = state;
767        let _ = request;
768        return Err(MockError::invalid_request("device_code feature disabled"));
769    }
770
771    #[cfg(feature = "device_code")]
772    {
773        let mut state_guard = state.write().await;
774        let client = state_guard
775            .client(&request.client_id)
776            .cloned()
777            .ok_or_else(|| MockError::invalid_client("unknown client"))?;
778
779        let requested_scope = parse_scope(&request.scope)?;
780        let scope = if requested_scope.is_empty() {
781            client.allowed_scopes.clone()
782        } else {
783            requested_scope
784                .intersection(&client.allowed_scopes)
785                .cloned()
786                .collect()
787        };
788
789        let device_code: String = state_guard.generate_code();
790        let mut rng = rng();
791        let user_code: String = Alphanumeric
792            .sample_string(&mut rng, 8)
793            .chars()
794            .map(|ch| ch.to_ascii_uppercase())
795            .collect();
796
797        let entry = DeviceCodeEntry {
798            client_id: client.client_id.clone(),
799            scope: scope.clone(),
800            _device_code: device_code.clone(),
801            user_code: user_code.clone(),
802            expires_at: OffsetDateTime::now_utc() + TimeDuration::minutes(10),
803            _interval: 5,
804            status: DeviceCodeStatus::Pending { poll_count: 0 },
805        };
806        state_guard.device_codes.insert(device_code.clone(), entry);
807
808        Ok(Json(json!({
809            "device_code": device_code,
810            "user_code": user_code,
811            "verification_uri": format!("{}/device", state_guard.issuer),
812            "verification_uri_complete": format!("{}/device?user_code={}", state_guard.issuer, user_code),
813            "expires_in": 600,
814            "interval": 5,
815        })))
816    }
817}
818
819async fn userinfo(
820    State(state): State<SharedState>,
821    headers: HeaderMap,
822) -> Result<Json<Value>, MockError> {
823    let token = extract_bearer_token(&headers)?;
824
825    let state_guard = state.read().await;
826    if !state_guard.access_tokens.contains(token) {
827        return Err(MockError::invalid_token("unknown access token"));
828    }
829
830    let claims = json!({
831        "sub": state_guard.user.sub,
832        "email": state_guard.user.email,
833        "preferred_username": state_guard.user.preferred_username,
834        "groups": state_guard.user.groups,
835    });
836    Ok(Json(claims))
837}
838
839async fn introspect(
840    State(state): State<SharedState>,
841    headers: HeaderMap,
842    Form(body): Form<HashMap<String, String>>,
843) -> Result<Json<Value>, MockError> {
844    let _ = extract_client_credentials(
845        &headers,
846        &TokenRequest {
847            grant_type: "".into(),
848            code: None,
849            redirect_uri: None,
850            code_verifier: None,
851            refresh_token: None,
852            client_id: None,
853            client_secret: None,
854            device_code: None,
855            scope: None,
856        },
857    )?;
858
859    let token = body
860        .get("token")
861        .cloned()
862        .ok_or_else(|| MockError::invalid_request("missing token"))?;
863    let state_guard = state.read().await;
864    let active = state_guard.access_tokens.contains(&token)
865        || state_guard.refresh_tokens.contains_key(&token);
866
867    Ok(Json(json!({
868        "active": active,
869        "iss": state_guard.issuer,
870        "client_id": "mock-client",
871        "scope": scope_to_string(&DEFAULT_SCOPE),
872        "token_type": "Bearer"
873    })))
874}
875
876async fn revoke(
877    State(state): State<SharedState>,
878    headers: HeaderMap,
879    Form(body): Form<HashMap<String, String>>,
880) -> Result<StatusCode, MockError> {
881    let _ = extract_client_credentials(
882        &headers,
883        &TokenRequest {
884            grant_type: "".into(),
885            code: None,
886            redirect_uri: None,
887            code_verifier: None,
888            refresh_token: None,
889            client_id: None,
890            client_secret: None,
891            device_code: None,
892            scope: None,
893        },
894    )?;
895    let token = body
896        .get("token")
897        .cloned()
898        .ok_or_else(|| MockError::invalid_request("missing token"))?;
899    let mut state_guard = state.write().await;
900    state_guard.access_tokens.remove(&token);
901    state_guard.refresh_tokens.remove(&token);
902    Ok(StatusCode::OK)
903}
904
905#[derive(Debug, Clone)]
906struct ClientCredentials {
907    client_id: String,
908    client_secret: String,
909}
910
911fn extract_client_credentials(
912    headers: &HeaderMap,
913    request: &TokenRequest,
914) -> Result<ClientCredentials, MockError> {
915    if let Some(header_value) = headers.get(header::AUTHORIZATION) {
916        let auth = header_value
917            .to_str()
918            .map_err(|_| MockError::invalid_client("invalid Authorization header"))?;
919        if let Some(encoded) = auth.strip_prefix("Basic ") {
920            let decoded = STANDARD
921                .decode(encoded)
922                .map_err(|_| MockError::invalid_client("invalid basic auth"))?;
923            let decoded = String::from_utf8(decoded)
924                .map_err(|_| MockError::invalid_client("invalid utf8 basic auth"))?;
925            if let Some((id, secret)) = decoded.split_once(':') {
926                return Ok(ClientCredentials {
927                    client_id: id.to_string(),
928                    client_secret: secret.to_string(),
929                });
930            }
931        }
932        return Err(MockError::invalid_client("invalid Authorization header"));
933    }
934
935    let client_id = request
936        .client_id
937        .clone()
938        .ok_or_else(|| MockError::invalid_client("missing client_id"))?;
939    let client_secret = request
940        .client_secret
941        .clone()
942        .ok_or_else(|| MockError::invalid_client("missing client_secret"))?;
943
944    Ok(ClientCredentials {
945        client_id,
946        client_secret,
947    })
948}
949
950fn extract_bearer_token(headers: &HeaderMap) -> Result<&str, MockError> {
951    let auth = headers
952        .get(header::AUTHORIZATION)
953        .and_then(|value| value.to_str().ok())
954        .ok_or_else(|| MockError::invalid_token("missing Authorization header"))?;
955    auth.strip_prefix("Bearer ")
956        .ok_or_else(|| MockError::invalid_token("invalid bearer token header"))
957}
958
959fn issue_access_token(
960    state: &InnerState,
961    client: &ClientConfig,
962    scope: &HashSet<String>,
963    issued_at: OffsetDateTime,
964) -> Result<String, MockError> {
965    #[derive(Debug, Serialize)]
966    struct AccessClaims<'a> {
967        iss: &'a str,
968        sub: &'a str,
969        aud: &'a str,
970        exp: i64,
971        iat: i64,
972        scope: String,
973        client_id: &'a str,
974        jti: String,
975    }
976
977    let claims = AccessClaims {
978        iss: &state.issuer,
979        sub: &state.user.sub,
980        aud: &client.client_id,
981        exp: (issued_at + TimeDuration::hours(1)).unix_timestamp(),
982        iat: issued_at.unix_timestamp(),
983        scope: scope_to_string(scope),
984        client_id: &client.client_id,
985        jti: Uuid::new_v4().to_string(),
986    };
987
988    jsonwebtoken::encode(
989        &jsonwebtoken::Header {
990            alg: jsonwebtoken::Algorithm::RS256,
991            kid: Some(state.signing.jwk.kid.clone()),
992            ..jsonwebtoken::Header::default()
993        },
994        &claims,
995        &state.signing.encoding,
996    )
997    .map_err(|err| MockError::server_error(format!("encode access token: {err}")))
998}
999
1000fn issue_id_token(
1001    state: &InnerState,
1002    client: &ClientConfig,
1003    scope: &HashSet<String>,
1004    issued_at: OffsetDateTime,
1005    nonce: Option<String>,
1006) -> Result<String, MockError> {
1007    #[derive(Debug, Serialize)]
1008    struct IdClaims<'a> {
1009        iss: &'a str,
1010        sub: &'a str,
1011        aud: &'a str,
1012        exp: i64,
1013        iat: i64,
1014        email: &'a str,
1015        preferred_username: &'a str,
1016        groups: &'a [String],
1017        scope: String,
1018        #[serde(skip_serializing_if = "Option::is_none")]
1019        nonce: Option<String>,
1020    }
1021
1022    let claims = IdClaims {
1023        iss: &state.issuer,
1024        sub: &state.user.sub,
1025        aud: &client.client_id,
1026        exp: (issued_at + TimeDuration::hours(1)).unix_timestamp(),
1027        iat: issued_at.unix_timestamp(),
1028        email: &state.user.email,
1029        preferred_username: &state.user.preferred_username,
1030        groups: &state.user.groups,
1031        scope: scope_to_string(scope),
1032        nonce,
1033    };
1034
1035    jsonwebtoken::encode(
1036        &jsonwebtoken::Header {
1037            alg: jsonwebtoken::Algorithm::RS256,
1038            kid: Some(state.signing.jwk.kid.clone()),
1039            ..jsonwebtoken::Header::default()
1040        },
1041        &claims,
1042        &state.signing.encoding,
1043    )
1044    .map_err(|err| MockError::server_error(format!("encode id token: {err}")))
1045}
1046
1047fn issue_refresh_token(
1048    state: &mut InnerState,
1049    client: &ClientConfig,
1050    scope: &HashSet<String>,
1051    issued_at: OffsetDateTime,
1052) -> Result<String, MockError> {
1053    let refresh_token = state.generate_code();
1054    state.refresh_tokens.insert(
1055        refresh_token.clone(),
1056        RefreshTokenEntry {
1057            client_id: client.client_id.clone(),
1058            scope: scope.clone(),
1059            _subject: state.user.clone(),
1060            _issued_at: issued_at,
1061        },
1062    );
1063    Ok(refresh_token)
1064}
1065
1066fn parse_scope(scope: &Option<String>) -> Result<HashSet<String>, MockError> {
1067    Ok(scope
1068        .as_ref()
1069        .map(|value| {
1070            value
1071                .split_whitespace()
1072                .filter(|part| !part.is_empty())
1073                .map(|part| part.to_string())
1074                .collect()
1075        })
1076        .unwrap_or_default())
1077}
1078
1079fn scope_to_string(scope: &HashSet<String>) -> String {
1080    let mut parts: Vec<_> = scope.iter().cloned().collect();
1081    parts.sort();
1082    parts.join(" ")
1083}
1084
1085fn verify_code_challenge(verifier: &str, expected_challenge: &str) -> Result<bool, MockError> {
1086    use sha2::{Digest, Sha256};
1087    let hashed = Sha256::digest(verifier.as_bytes());
1088    let encoded = URL_SAFE_NO_PAD.encode(hashed);
1089    Ok(encoded == expected_challenge)
1090}
1091
1092fn generate_signing_keys() -> Result<SigningKeys> {
1093    use rsa::rand_core::OsRng;
1094    use rsa::traits::PublicKeyParts;
1095    use rsa::{RsaPrivateKey, pkcs1::EncodeRsaPrivateKey};
1096
1097    let mut rng = OsRng;
1098    let private_key = RsaPrivateKey::new(&mut rng, 2048).context("generate RSA key")?;
1099    let public_key = private_key.to_public_key();
1100
1101    let pem = private_key
1102        .to_pkcs1_pem(Default::default())
1103        .context("encode RSA key to PEM")?;
1104    let encoding =
1105        jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).context("create encoding key")?;
1106    let jwk = Jwk {
1107        kty: "RSA".into(),
1108        use_: "sig".into(),
1109        kid: Uuid::new_v4().to_string(),
1110        alg: "RS256".into(),
1111        n: URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be()),
1112        e: URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be()),
1113    };
1114
1115    Ok(SigningKeys { encoding, jwk })
1116}
1117
1118#[derive(Debug, Error)]
1119enum MockError {
1120    #[error("invalid_request: {0}")]
1121    InvalidRequest(String),
1122    #[error("invalid_client: {0}")]
1123    InvalidClient(String),
1124    #[error("invalid_grant: {0}")]
1125    InvalidGrant(String),
1126    #[error("invalid_scope: {0}")]
1127    InvalidScope(String),
1128    #[error("invalid_token: {0}")]
1129    InvalidToken(String),
1130    #[error("access_denied")]
1131    AccessDenied,
1132    #[error("authorization_pending")]
1133    AuthorizationPending,
1134    #[error("slow_down")]
1135    SlowDown,
1136    #[error("expired_token")]
1137    ExpiredToken,
1138    #[error("server_error: {0}")]
1139    ServerError(String),
1140}
1141
1142impl MockError {
1143    fn invalid_request<T: Into<String>>(msg: T) -> Self {
1144        Self::InvalidRequest(msg.into())
1145    }
1146    fn invalid_client<T: Into<String>>(msg: T) -> Self {
1147        Self::InvalidClient(msg.into())
1148    }
1149    fn invalid_grant<T: Into<String>>(msg: T) -> Self {
1150        Self::InvalidGrant(msg.into())
1151    }
1152    fn invalid_scope<T: Into<String>>(msg: T) -> Self {
1153        Self::InvalidScope(msg.into())
1154    }
1155    fn invalid_token<T: Into<String>>(msg: T) -> Self {
1156        Self::InvalidToken(msg.into())
1157    }
1158    fn server_error<T: Into<String>>(msg: T) -> Self {
1159        Self::ServerError(msg.into())
1160    }
1161    fn authorization_pending() -> Self {
1162        Self::AuthorizationPending
1163    }
1164    fn slow_down() -> Self {
1165        Self::SlowDown
1166    }
1167    fn access_denied() -> Self {
1168        Self::AccessDenied
1169    }
1170    fn expired_token() -> Self {
1171        Self::ExpiredToken
1172    }
1173}
1174
1175impl IntoResponse for MockError {
1176    fn into_response(self) -> Response {
1177        let (status, body) = match self {
1178            MockError::InvalidRequest(msg) => {
1179                (StatusCode::BAD_REQUEST, json_error("invalid_request", msg))
1180            }
1181            MockError::InvalidClient(msg) => {
1182                (StatusCode::UNAUTHORIZED, json_error("invalid_client", msg))
1183            }
1184            MockError::InvalidGrant(msg) => {
1185                (StatusCode::BAD_REQUEST, json_error("invalid_grant", msg))
1186            }
1187            MockError::InvalidScope(msg) => {
1188                (StatusCode::BAD_REQUEST, json_error("invalid_scope", msg))
1189            }
1190            MockError::InvalidToken(msg) => {
1191                (StatusCode::UNAUTHORIZED, json_error("invalid_token", msg))
1192            }
1193            MockError::AccessDenied => (
1194                StatusCode::BAD_REQUEST,
1195                json_error("access_denied", "user denied the request"),
1196            ),
1197            MockError::AuthorizationPending => (
1198                StatusCode::BAD_REQUEST,
1199                json_error("authorization_pending", "authorization pending"),
1200            ),
1201            MockError::SlowDown => (
1202                StatusCode::BAD_REQUEST,
1203                json_error("slow_down", "slow down"),
1204            ),
1205            MockError::ExpiredToken => (
1206                StatusCode::BAD_REQUEST,
1207                json_error("expired_token", "device code expired"),
1208            ),
1209            MockError::ServerError(msg) => (
1210                StatusCode::INTERNAL_SERVER_ERROR,
1211                json_error("server_error", msg),
1212            ),
1213        };
1214        (status, Json(body)).into_response()
1215    }
1216}
1217
1218fn json_error(code: impl Into<String>, description: impl Into<String>) -> Value {
1219    json!({
1220        "error": code.into(),
1221        "error_description": description.into(),
1222    })
1223}