Skip to main content

auth_framework/api/
oauth_advanced.rs

1//! OAuth 2.0 Advanced Features API Endpoints
2//!
3//! This module implements OAuth 2.0 advanced features:
4//! - RFC 7662: Token Introspection
5//! - RFC 9126: Pushed Authorization Requests  
6//! - RFC 8628: Device Authorization Grant
7//! - OpenID Connect CIBA (Client Initiated Backchannel Authentication)
8
9use crate::api::ApiState;
10use axum::{
11    Form,
12    extract::State,
13    http::{HeaderMap, StatusCode},
14    response::Json,
15};
16use base64::{Engine as _, engine::general_purpose};
17use serde::{Deserialize, Serialize};
18use serde_json::{Value as JsonValue, json};
19use std::collections::HashMap;
20use tracing::{debug, error};
21use url::Url;
22use uuid::Uuid;
23
24/// Token introspection request (RFC 7662)
25#[derive(Debug, Deserialize)]
26pub struct IntrospectRequest {
27    /// The token to introspect
28    pub token: String,
29
30    /// Optional hint about the token type
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub token_type_hint: Option<String>,
33
34    /// Client ID (if using POST body authentication)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub client_id: Option<String>,
37
38    /// Client secret (if using POST body authentication)
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub client_secret: Option<String>,
41}
42
43/// Token introspection response (RFC 7662)
44#[derive(Debug, Serialize)]
45pub struct IntrospectResponse {
46    /// Indicates if the token is currently active
47    pub active: bool,
48
49    /// The subject of the token
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub sub: Option<String>,
52
53    /// The client_id associated with the token
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub client_id: Option<String>,
56
57    /// The scopes associated with the token
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub scope: Option<String>,
60
61    /// Token expiration timestamp
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub exp: Option<i64>,
64
65    /// Token issued at timestamp
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub iat: Option<i64>,
68
69    /// Not before timestamp
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub nbf: Option<i64>,
72
73    /// Issuer of the token
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub iss: Option<String>,
76
77    /// Audience of the token
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub aud: Option<JsonValue>,
80
81    /// JWT ID
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub jti: Option<String>,
84
85    /// Token type (e.g., "Bearer")
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub token_type: Option<String>,
88
89    /// Username associated with the token
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub username: Option<String>,
92}
93
94/// Pushed Authorization Request parameters (RFC 9126)
95///
96/// Required fields are non-optional so Axum Form validation returns 422
97/// automatically when they are missing from the request body.
98#[derive(Debug, Deserialize)]
99pub struct PARRequest {
100    /// Required: OAuth 2.0 response type (e.g., "code")
101    pub response_type: String,
102
103    /// Required: The client identifier
104    pub client_id: String,
105
106    /// Required: Redirection URI for the authorization response
107    pub redirect_uri: String,
108
109    /// Optional: Requested scope(s)
110    pub scope: Option<String>,
111
112    /// Optional: Opaque state value for the client
113    pub state: Option<String>,
114
115    /// Optional: Nonce for OIDC requests
116    pub nonce: Option<String>,
117
118    /// Optional: PKCE code challenge
119    pub code_challenge: Option<String>,
120
121    /// Optional: PKCE code challenge method (e.g., "S256")
122    pub code_challenge_method: Option<String>,
123}
124
125/// Pushed Authorization Request response (RFC 9126)
126#[derive(Debug, Serialize)]
127pub struct PARResponse {
128    /// The request URI for the authorization request
129    pub request_uri: String,
130
131    /// Expiration time in seconds
132    pub expires_in: u64,
133}
134
135// ============================================================================
136// Endpoint Handlers
137// ============================================================================
138
139/// Verify OAuth2 client credentials against the registered client record in storage.
140///
141/// Returns `Ok(true)` if the credentials are valid, `Ok(false)` if the client_id
142/// is unknown or the secret does not match, and `Err(...)` on storage failures.
143/// The function uses constant-time comparison to prevent timing oracle attacks on
144/// the client_secret.
145async fn verify_client_credentials(
146    state: &State<ApiState>,
147    client_id: &str,
148    client_secret: &str,
149) -> Result<bool, (StatusCode, Json<JsonValue>)> {
150    let client_key = format!("oauth2_client:{}", client_id);
151    let client_data = match state.auth_framework.storage().get_kv(&client_key).await {
152        Ok(Some(bytes)) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
153            Ok(v) => v,
154            Err(_) => {
155                error!(
156                    "Introspect: failed to deserialize client record for {}",
157                    client_id
158                );
159                return Err((
160                    StatusCode::INTERNAL_SERVER_ERROR,
161                    Json(json!({
162                        "error": "server_error",
163                        "error_description": "Internal server error"
164                    })),
165                ));
166            }
167        },
168        Ok(None) => {
169            // Unknown client — return false (don't reveal whether the client exists)
170            return Ok(false);
171        }
172        Err(e) => {
173            error!(
174                "Introspect: storage error looking up client {}: {}",
175                client_id, e
176            );
177            return Err((
178                StatusCode::INTERNAL_SERVER_ERROR,
179                Json(json!({
180                    "error": "server_error",
181                    "error_description": "Internal server error"
182                })),
183            ));
184        }
185    };
186
187    let stored_secret = client_data["client_secret"].as_str().unwrap_or("");
188    // Constant-time comparison prevents timing oracle on the secret
189    Ok(
190        crate::security::timing_protection::constant_time_string_compare(
191            client_secret,
192            stored_secret,
193        ),
194    )
195}
196
197/// Token introspection endpoint (RFC 7662)
198///
199/// Allows authorized clients to determine the active state and meta-information
200/// about a given token.
201///
202/// Authentication is required via either:
203/// - HTTP Basic Auth header (`Authorization: Basic <base64(client_id:client_secret)>`)
204/// - POST body parameters (`client_id` + `client_secret`)
205///
206/// Bearer token authentication is explicitly rejected per RFC 7662 §2.1.
207pub async fn introspect_token(
208    state: State<ApiState>,
209    headers: HeaderMap,
210    form: Form<IntrospectRequest>,
211) -> Result<Json<IntrospectResponse>, (StatusCode, Json<JsonValue>)> {
212    debug!("Processing token introspection request");
213
214    // --- Authentication enforcement (RFC 7662 §2.1) ---
215    let auth_header = headers.get(axum::http::header::AUTHORIZATION);
216
217    let authenticated = match auth_header {
218        Some(value) => {
219            let value_str = value.to_str().unwrap_or("");
220            if value_str.starts_with("Bearer ") {
221                // Bearer tokens are not a valid authentication method for introspection
222                debug!("Introspect rejected: Bearer auth is not allowed");
223                return Err((
224                    StatusCode::UNAUTHORIZED,
225                    Json(json!({
226                        "error": "invalid_client",
227                        "error_description": "Bearer token authentication is not supported for token introspection"
228                    })),
229                ));
230            } else if let Some(encoded) = value_str.strip_prefix("Basic ") {
231                // Decode Basic auth credentials and verify against registered client.
232                match general_purpose::STANDARD.decode(encoded) {
233                    Ok(decoded_bytes) => {
234                        let decoded = String::from_utf8_lossy(&decoded_bytes);
235                        let mut parts = decoded.splitn(2, ':');
236                        let basic_client_id = parts.next().unwrap_or("").to_string();
237                        let basic_client_secret = parts.next().unwrap_or("").to_string();
238                        verify_client_credentials(&state, &basic_client_id, &basic_client_secret)
239                            .await?
240                    }
241                    Err(_) => {
242                        debug!("Introspect rejected: invalid Basic auth encoding");
243                        return Err((
244                            StatusCode::UNAUTHORIZED,
245                            Json(json!({
246                                "error": "invalid_client",
247                                "error_description": "Invalid Basic authentication encoding"
248                            })),
249                        ));
250                    }
251                }
252            } else {
253                // Unrecognised auth scheme
254                debug!("Introspect rejected: unknown auth scheme");
255                return Err((
256                    StatusCode::UNAUTHORIZED,
257                    Json(json!({
258                        "error": "invalid_client",
259                        "error_description": "Unsupported authentication scheme"
260                    })),
261                ));
262            }
263        }
264        None => {
265            // No Authorization header — verify using POST body client_id + client_secret.
266            match (&form.client_id, &form.client_secret) {
267                (Some(id), Some(secret)) => verify_client_credentials(&state, id, secret).await?,
268                _ => {
269                    debug!("Introspect rejected: missing client credentials");
270                    return Err((
271                        StatusCode::UNAUTHORIZED,
272                        Json(json!({
273                            "error": "invalid_client",
274                            "error_description": "client_id and client_secret are required"
275                        })),
276                    ));
277                }
278            }
279        }
280    };
281
282    if !authenticated {
283        debug!("Introspect rejected: invalid client credentials");
284        return Err((
285            StatusCode::UNAUTHORIZED,
286            Json(json!({
287                "error": "invalid_client",
288                "error_description": "Client authentication failed"
289            })),
290        ));
291    }
292
293    // --- Token validation ---
294    let token_manager = state.auth_framework.token_manager();
295
296    match token_manager.validate_jwt_token(&form.token) {
297        Ok(claims) => {
298            // Cross-check the revocation list before reporting the token as active.
299            // Tokens revoked via POST /oauth/revoke or POST /auth/logout are stored
300            // under revoked_token:{jti}; return active=false without revealing why.
301            let revocation_key = format!("revoked_token:{}", claims.jti);
302            if let Ok(Some(_)) = state.auth_framework.storage().get_kv(&revocation_key).await {
303                debug!(
304                    "Token introspection: token has been revoked (jti: {})",
305                    claims.jti
306                );
307                return Ok(Json(IntrospectResponse {
308                    active: false,
309                    sub: None,
310                    client_id: None,
311                    scope: None,
312                    exp: None,
313                    iat: None,
314                    nbf: None,
315                    iss: None,
316                    aud: None,
317                    jti: None,
318                    token_type: None,
319                    username: None,
320                }));
321            }
322
323            debug_assert!(!claims.sub.is_empty(), "Token subject should not be empty");
324            Ok(Json(IntrospectResponse {
325                active: true,
326                sub: Some(claims.sub.clone()),
327                client_id: claims.client_id.clone(),
328                scope: Some(claims.scope.clone()),
329                exp: Some(claims.exp),
330                iat: Some(claims.iat),
331                nbf: Some(claims.nbf),
332                iss: Some(claims.iss.clone()),
333                aud: Some(JsonValue::String(claims.aud.clone())),
334                jti: Some(claims.jti.clone()),
335                token_type: Some("Bearer".to_string()),
336                username: Some(claims.sub),
337            }))
338        }
339        Err(_e) => {
340            debug!("Token introspection: token is inactive");
341            // Return inactive token (per RFC 7662, don't reveal why it's invalid)
342            Ok(Json(IntrospectResponse {
343                active: false,
344                sub: None,
345                client_id: None,
346                scope: None,
347                exp: None,
348                iat: None,
349                nbf: None,
350                iss: None,
351                aud: None,
352                jti: None,
353                token_type: None,
354                username: None,
355            }))
356        }
357    }
358}
359
360/// Pushed Authorization Request endpoint (RFC 9126)
361///
362/// Clients push authorization request parameters to the server and receive a
363/// `request_uri` they can use at the authorization endpoint. The URI is unique
364/// per request and expires after 90 seconds (RFC 9126 §2.2).
365///
366/// Required form fields: `response_type`, `client_id`, `redirect_uri`
367/// Missing required fields are automatically rejected with 422 by Axum.
368pub async fn pushed_authorization_request(
369    State(state): State<ApiState>,
370    Form(req): Form<PARRequest>,
371) -> (StatusCode, Json<PARResponse>) {
372    debug!("Processing PAR request for client_id={}", req.client_id);
373
374    // Validate redirect_uri is a well-formed URL before accepting it
375    if Url::parse(&req.redirect_uri).is_err() {
376        return (
377            StatusCode::BAD_REQUEST,
378            Json(PARResponse {
379                request_uri: String::new(),
380                expires_in: 0,
381            }),
382        );
383    }
384
385    // Generate a unique request URI per RFC 9126 §2.2
386    let request_id = Uuid::new_v4().to_string();
387    let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", request_id);
388
389    // Persist the authorization request parameters so the /authorize endpoint
390    // can retrieve them when presented with this request_uri (RFC 9126 §4).
391    let par_data = json!({
392        "response_type": req.response_type,
393        "client_id": req.client_id,
394        "redirect_uri": req.redirect_uri,
395        "scope": req.scope,
396        "state": req.state,
397        "nonce": req.nonce,
398        "code_challenge": req.code_challenge,
399        "code_challenge_method": req.code_challenge_method,
400    });
401    let storage_key = format!("par_request:{}", request_id);
402    if let Err(e) = state
403        .auth_framework
404        .storage()
405        .store_kv(
406            &storage_key,
407            par_data.to_string().as_bytes(),
408            Some(std::time::Duration::from_secs(90)),
409        )
410        .await
411    {
412        error!("Failed to store PAR request: {}", e);
413        // The request_uri would be unresolvable; returning a 201 with a URI
414        // that will never resolve is worse than surfacing the error early.
415        return (
416            StatusCode::INTERNAL_SERVER_ERROR,
417            Json(PARResponse {
418                request_uri: String::new(),
419                expires_in: 0,
420            }),
421        );
422    }
423
424    (
425        StatusCode::CREATED,
426        Json(PARResponse {
427            request_uri,
428            expires_in: 90,
429        }),
430    )
431}
432
433/// Device Authorization endpoint (RFC 8628 §3.1)
434///
435/// Initiates a device-authorization flow for input-constrained devices.
436/// Returns a `device_code`, human-friendly `user_code`, `verification_uri`,
437/// `expires_in`, and a polling `interval`.
438pub async fn device_authorization(
439    State(state): State<ApiState>,
440    Form(form): Form<HashMap<String, String>>,
441) -> Result<Json<JsonValue>, (StatusCode, Json<JsonValue>)> {
442    // RFC 8628 §3.1 – client_id is required
443    if form.get("client_id").map(|s| s.is_empty()).unwrap_or(true) {
444        return Err((
445            StatusCode::BAD_REQUEST,
446            Json(json!({
447                "error": "invalid_request",
448                "error_description": "client_id is required"
449            })),
450        ));
451    }
452
453    // Generate a high-entropy device_code and a human-friendly user_code.
454    let device_code = format!("dc_{}", Uuid::new_v4().simple());
455    let user_code = generate_user_code();
456    let verification_uri = "/device";
457    let expires_in: u64 = 600; // RFC 8628 recommends ≥600 s
458
459    // Persist the device authorization request so the token endpoint can poll.
460    let device_data = json!({
461        "client_id": form.get("client_id").cloned().unwrap_or_default(),
462        "scope":     form.get("scope").cloned().unwrap_or_default(),
463        "user_code": user_code,
464        "authorized": false
465    });
466    state
467        .auth_framework
468        .storage()
469        .store_kv(
470            &format!("device:{}", device_code),
471            device_data.to_string().as_bytes(),
472            Some(std::time::Duration::from_secs(expires_in)),
473        )
474        .await
475        .map_err(|e| {
476            error!("Failed to store device authorization request: {}", e);
477            (
478                StatusCode::INTERNAL_SERVER_ERROR,
479                Json(json!({
480                    "error": "server_error",
481                    "error_description": "Failed to initiate device authorization flow"
482                })),
483            )
484        })?;
485
486    debug!(
487        "Device authorization initiated for client_id={}",
488        form.get("client_id")
489            .map(String::as_str)
490            .unwrap_or_default()
491    );
492
493    Ok(Json(json!({
494        "device_code":              device_code,
495        "user_code":               user_code,
496        "verification_uri":        verification_uri,
497        "verification_uri_complete": format!("{}?user_code={}", verification_uri, user_code),
498        "expires_in":              expires_in,
499        "interval":                5
500    })))
501}
502
503/// CIBA (Client Initiated Backchannel Authentication) endpoint (OpenID Connect CIBA Core §7.1)
504///
505/// Initiates a backchannel authentication request.  Exactly one of
506/// `login_hint`, `login_hint_token`, or `id_token_hint` must be present.
507/// Returns an `auth_req_id` that the client polls at the token endpoint.
508pub async fn ciba_backchannel_auth(
509    State(state): State<ApiState>,
510    Form(form): Form<HashMap<String, String>>,
511) -> Result<Json<JsonValue>, (StatusCode, Json<JsonValue>)> {
512    // CIBA Core §7.1 – exactly one user-identification hint is required.
513    let login_hint = form
514        .get("login_hint")
515        .or_else(|| form.get("login_hint_token"))
516        .or_else(|| form.get("id_token_hint"))
517        .cloned()
518        .ok_or_else(|| {
519            (
520                StatusCode::BAD_REQUEST,
521                Json(json!({
522                    "error": "invalid_request",
523                    "error_description":
524                        "One of login_hint, login_hint_token, or id_token_hint is required"
525                })),
526            )
527        })?;
528
529    let auth_req_id = Uuid::new_v4().to_string();
530    let expires_in: u64 = 120; // 2-minute window for the authenticating device
531
532    // Persist the CIBA request so the token-endpoint poll can resolve it.
533    let ciba_data = json!({
534        "login_hint":      login_hint,
535        "client_id":       form.get("client_id").cloned().unwrap_or_default(),
536        "scope":           form.get("scope").cloned().unwrap_or_default(),
537        "binding_message": form.get("binding_message").cloned(),
538        "status":          "pending"
539    });
540    state
541        .auth_framework
542        .storage()
543        .store_kv(
544            &format!("ciba:{}", auth_req_id),
545            ciba_data.to_string().as_bytes(),
546            Some(std::time::Duration::from_secs(expires_in)),
547        )
548        .await
549        .map_err(|e| {
550            error!("Failed to store CIBA request: {}", e);
551            (
552                StatusCode::INTERNAL_SERVER_ERROR,
553                Json(json!({
554                    "error": "server_error",
555                    "error_description": "Failed to initiate backchannel authentication"
556                })),
557            )
558        })?;
559
560    debug!("CIBA request created: auth_req_id={}", auth_req_id);
561
562    Ok(Json(json!({
563        "auth_req_id": auth_req_id,
564        "expires_in":  expires_in,
565        "interval":    5
566    })))
567}
568
569// ---------------------------------------------------------------------------
570// Helpers
571// ---------------------------------------------------------------------------
572
573/// Generate an 8-character, human-friendly user code (RFC 8628 §6.1 charset).
574/// Uses only unambiguous uppercase letters and digits (no O/0, I/1).
575fn generate_user_code() -> String {
576    use rand::RngExt;
577    const CHARS: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
578    let mut rng = rand::rng();
579    (0..8)
580        .map(|_| CHARS[rng.random_range(0..CHARS.len())] as char)
581        .collect()
582}