atproto_oauth/
workflow.rs

1//! OAuth workflow implementation for AT Protocol authorization flows.
2//!
3//! This module provides a complete OAuth 2.0 authorization code flow implementation specifically
4//! designed for AT Protocol, including Pushed Authorization Requests (PAR), DPoP security, PKCE
5//! protection, and client assertion handling.
6//!
7//! ## OAuth Flow Components
8//!
9//! - **`oauth_init()`**: Initiates OAuth flow with PAR request to authorization server
10//! - **`oauth_complete()`**: Completes OAuth flow by exchanging authorization code for tokens
11//! - **`OAuthClient`**: Client configuration with credentials and signing keys
12//! - **`OAuthRequest`**: Tracking structure for ongoing authorization requests
13//! - **`OAuthRequestState`**: Security parameters including state, nonce, and PKCE challenge
14//!
15//! ## Security Features
16//!
17//! - **PKCE**: Proof Key for Code Exchange protection against authorization code interception
18//! - **DPoP**: Demonstration of Proof-of-Possession for token binding
19//! - **Client Assertions**: JWT-based client authentication using private key signatures
20//! - **State Parameters**: CSRF protection using random state values
21//! - **Nonce Values**: Additional replay protection
22//!
23//! ## Example Usage
24//!
25//! ```rust,ignore
26//! use atproto_oauth::workflow::{oauth_init, oauth_complete, OAuthClient, OAuthRequestState};
27//! use atproto_oauth::pkce::generate;
28//! use atproto_identity::key::generate_key;
29//!
30//! // Generate security parameters
31//! let signing_key = generate_key(KeyType::P256Private)?;
32//! let dpop_key = generate_key(KeyType::P256Private)?;
33//! let (pkce_verifier, code_challenge) = generate();
34//!
35//! // Configure OAuth client
36//! let oauth_client = OAuthClient {
37//!     redirect_uri: "https://app.example.com/callback".to_string(),
38//!     client_id: "https://app.example.com/client-metadata.json".to_string(),
39//!     private_signing_key_data: signing_key,
40//! };
41//!
42//! // Create request state
43//! let oauth_state = OAuthRequestState {
44//!     state: "random-state-value".to_string(),
45//!     nonce: "random-nonce-value".to_string(),
46//!     code_challenge,
47//!     scope: "atproto transition:generic".to_string(),
48//! };
49//!
50//! // Initiate OAuth flow
51//! let par_response = oauth_init(
52//!     &http_client,
53//!     &oauth_client,
54//!     &dpop_key,
55//!     "user.bsky.social",
56//!     &authorization_server,
57//!     &oauth_state,
58//! ).await?;
59//!
60//! // Build authorization URL
61//! let auth_url = format!(
62//!     "{}?client_id={}&request_uri={}",
63//!     authorization_server.authorization_endpoint,
64//!     oauth_client.client_id,
65//!     par_response.request_uri
66//! );
67//!
68//! // After user authorization and callback...
69//! let token_response = oauth_complete(
70//!     &http_client,
71//!     &oauth_client,
72//!     &dpop_key,
73//!     "authorization_code_from_callback",
74//!     &oauth_request,
75//!     &did_document,
76//! ).await?;
77//! ```
78
79use atproto_identity::key::KeyData;
80use chrono::{DateTime, Utc};
81use rand::distributions::{Alphanumeric, DistString};
82use reqwest_chain::ChainMiddleware;
83use reqwest_middleware::ClientBuilder;
84use serde::Deserialize;
85use std::collections::HashMap;
86
87#[cfg(feature = "zeroize")]
88use zeroize::{Zeroize, ZeroizeOnDrop};
89
90use crate::{
91    dpop::{DpopRetry, auth_dpop},
92    errors::OAuthClientError,
93    jwt::{Claims, Header, JoseClaims, mint},
94    resources::{AuthorizationServer, pds_resources},
95};
96
97/// Response from a Pushed Authorization Request (PAR) endpoint.
98///
99/// Contains the request URI and expiration time returned by the authorization
100/// server after successfully processing a pushed authorization request.
101#[derive(Clone, Deserialize)]
102pub struct ParResponse {
103    /// The request URI to use in the authorization request.
104    pub request_uri: String,
105    /// The lifetime of the request URI in seconds.
106    pub expires_in: u64,
107
108    /// Additional fields returned by the authorization server.
109    #[serde(flatten)]
110    pub extra: HashMap<String, serde_json::Value>,
111}
112
113/// OAuth request state containing security parameters for the authorization flow.
114///
115/// This struct holds the security parameters needed to maintain state
116/// and prevent attacks during the OAuth authorization code flow.
117pub struct OAuthRequestState {
118    /// Random state parameter to prevent CSRF attacks.
119    pub state: String,
120    /// Random nonce value for additional security.
121    pub nonce: String,
122    /// PKCE code challenge derived from the code verifier.
123    pub code_challenge: String,
124    /// The scope of access requested for the authorization.
125    pub scope: String,
126}
127
128/// OAuth client configuration containing essential client credentials.
129///
130/// This struct holds the client configuration needed for OAuth authorization flows,
131/// including the redirect URI, client identifier, and signing key.
132#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
133pub struct OAuthClient {
134    /// The redirect URI where the authorization server will send the user after authorization.
135    #[cfg_attr(feature = "zeroize", zeroize(skip))]
136    pub redirect_uri: String,
137
138    /// The unique client identifier for this OAuth client.
139    #[cfg_attr(feature = "zeroize", zeroize(skip))]
140    pub client_id: String,
141
142    /// The private key data used for signing client assertions.
143    pub private_signing_key_data: KeyData,
144}
145
146/// OAuth request tracking information for ongoing authorization flows.
147///
148/// This struct contains all the necessary information to track and complete
149/// an OAuth authorization request, including security parameters and timing.
150#[derive(Clone, PartialEq)]
151#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
152pub struct OAuthRequest {
153    /// The OAuth state parameter used to prevent CSRF attacks.
154    #[cfg_attr(feature = "zeroize", zeroize(skip))]
155    pub oauth_state: String,
156
157    /// The authorization server issuer identifier.
158    #[cfg_attr(feature = "zeroize", zeroize(skip))]
159    pub issuer: String,
160
161    /// The authorization server identifier.
162    #[cfg_attr(feature = "zeroize", zeroize(skip))]
163    pub authorization_server: String,
164
165    /// The nonce value for additional security.
166    #[cfg_attr(feature = "zeroize", zeroize(skip))]
167    pub nonce: String,
168
169    /// The PKCE code verifier for this authorization request.
170    #[cfg_attr(feature = "zeroize", zeroize(skip))]
171    pub pkce_verifier: String,
172
173    /// The public key used for signing (serialized).
174    pub signing_public_key: String,
175
176    /// The DPoP private key (serialized).
177    #[cfg_attr(feature = "zeroize", zeroize(skip))]
178    pub dpop_private_key: String,
179
180    /// When this OAuth request was created.
181    #[cfg_attr(feature = "zeroize", zeroize(skip))]
182    pub created_at: DateTime<Utc>,
183
184    /// When this OAuth request expires.
185    #[cfg_attr(feature = "zeroize", zeroize(skip))]
186    pub expires_at: DateTime<Utc>,
187}
188
189impl std::fmt::Debug for OAuthRequest {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        f.debug_struct("OAuthRequest")
192            .field("oauth_state", &self.oauth_state)
193            .field("issuer", &self.issuer)
194            .field("authorization_server", &self.authorization_server)
195            .field("nonce", &self.nonce)
196            .field("pkce_verifier", &"[REDACTED]")
197            .field("signing_public_key", &self.signing_public_key)
198            .field("dpop_private_key", &"[REDACTED]")
199            .field("created_at", &self.created_at)
200            .field("expires_at", &self.expires_at)
201            .finish()
202    }
203}
204
205/// Response from the OAuth token endpoint containing access credentials.
206///
207/// This struct represents the successful response from an OAuth token exchange,
208/// containing the access token and related metadata.
209#[derive(Clone, Deserialize)]
210#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
211pub struct TokenResponse {
212    /// The access token that can be used to access protected resources.
213    pub access_token: String,
214
215    /// The type of token, typically "Bearer" or "DPoP".
216    pub token_type: String,
217
218    /// The refresh token that can be used to obtain new access tokens.
219    pub refresh_token: String,
220
221    /// The scope of access granted by the access token.
222    #[cfg_attr(feature = "zeroize", zeroize(skip))]
223    pub scope: String,
224
225    /// The lifetime of the access token in seconds.
226    #[cfg_attr(feature = "zeroize", zeroize(skip))]
227    pub expires_in: u32,
228
229    /// The subject identifier (usually the user's DID).
230    #[cfg_attr(feature = "zeroize", zeroize(skip))]
231    pub sub: Option<String>,
232
233    /// Additional fields returned by the authorization server.
234    #[serde(flatten)]
235    #[cfg_attr(feature = "zeroize", zeroize(skip))]
236    pub extra: HashMap<String, serde_json::Value>,
237}
238
239/// Initiates the OAuth authorization flow by making a Pushed Authorization Request (PAR).
240///
241/// This function creates a PAR request to the authorization server with the necessary
242/// OAuth parameters, DPoP proof, and client assertion. It handles the complete setup
243/// for the AT Protocol OAuth flow including PKCE and DPoP security mechanisms.
244///
245/// # Arguments
246/// * `http_client` - The HTTP client to use for making requests
247/// * `private_signing_key_data` - The private key for signing client assertions
248/// * `dpop_key_data` - The key data for creating DPoP proofs
249/// * `handle` - The user's handle for the login hint
250/// * `authorization_server` - The authorization server configuration
251/// * `oauth_request_state` - The OAuth state parameters for this request
252///
253/// # Returns
254/// A `ParResponse` containing the request URI and expiration time on success.
255///
256/// # Errors
257/// Returns `OAuthClientError` if the PAR request fails or response parsing fails.
258pub async fn oauth_init(
259    http_client: &reqwest::Client,
260    oauth_client: &OAuthClient,
261    dpop_key_data: &KeyData,
262    login_hint: Option<&str>,
263    authorization_server: &AuthorizationServer,
264    oauth_request_state: &OAuthRequestState,
265) -> Result<ParResponse, OAuthClientError> {
266    let par_url = authorization_server
267        .pushed_authorization_request_endpoint
268        .clone();
269
270    let scope = &oauth_request_state.scope;
271
272    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
273        .try_into()
274        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
275
276    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
277    let client_assertion_claims = Claims::new(JoseClaims {
278        issuer: Some(oauth_client.client_id.clone()),
279        subject: Some(oauth_client.client_id.clone()),
280        audience: Some(authorization_server.issuer.clone()),
281        json_web_token_id: Some(client_assertion_jti),
282        issued_at: Some(chrono::Utc::now().timestamp() as u64),
283        ..Default::default()
284    });
285
286    let client_assertion_token = mint(
287        &oauth_client.private_signing_key_data,
288        &client_assertion_header,
289        &client_assertion_claims,
290    )
291    .map_err(OAuthClientError::MintTokenFailed)?;
292
293    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
294        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
295
296    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
297
298    let dpop_retry_client = ClientBuilder::new(http_client.clone())
299        .with(ChainMiddleware::new(dpop_retry.clone()))
300        .build();
301
302    let mut params = vec![
303        ("response_type", "code"),
304        ("code_challenge", &oauth_request_state.code_challenge),
305        ("code_challenge_method", "S256"),
306        ("client_id", oauth_client.client_id.as_str()),
307        ("state", oauth_request_state.state.as_str()),
308        ("redirect_uri", oauth_client.redirect_uri.as_str()),
309        ("scope", scope),
310        (
311            "client_assertion_type",
312            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
313        ),
314        ("client_assertion", client_assertion_token.as_str()),
315    ];
316    if let Some(value) = login_hint {
317        params.push(("login_hint", value));
318    }
319
320    let response = dpop_retry_client
321        .post(par_url)
322        .header("DPoP", dpop_token.as_str())
323        .form(&params)
324        .send()
325        .await
326        .map_err(OAuthClientError::PARHttpRequestFailed)?
327        .json()
328        .await
329        .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
330
331    Ok(response)
332}
333
334/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
335///
336/// This function performs the final step of the OAuth authorization code flow by
337/// exchanging the authorization code received from the callback for access and refresh tokens.
338/// It handles DPoP proof generation and client assertion creation for secure token exchange.
339///
340/// # Arguments
341/// * `http_client` - The HTTP client to use for making requests
342/// * `oauth_client` - The OAuth client configuration
343/// * `dpop_key_data` - The key data for creating DPoP proofs
344/// * `callback_code` - The authorization code received from the callback
345/// * `oauth_request` - The original OAuth request state
346/// * `document` - The identity document containing PDS endpoints
347///
348/// # Returns
349/// A `TokenResponse` containing the access token, refresh token, and metadata on success.
350///
351/// # Errors
352/// Returns `OAuthClientError` if the token exchange fails or response parsing fails.
353pub async fn oauth_complete(
354    http_client: &reqwest::Client,
355    oauth_client: &OAuthClient,
356    dpop_key_data: &KeyData,
357    callback_code: &str,
358    oauth_request: &OAuthRequest,
359    authorization_server: &AuthorizationServer,
360) -> Result<TokenResponse, OAuthClientError> {
361    // let pds_endpoints = document.pds_endpoints();
362    // let pds_endpoint = pds_endpoints
363    //     .first()
364    //     .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
365    // let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
366
367    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
368        .try_into()
369        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
370
371    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
372    let client_assertion_claims = Claims::new(JoseClaims {
373        issuer: Some(oauth_client.client_id.clone()),
374        subject: Some(oauth_client.client_id.clone()),
375        audience: Some(authorization_server.issuer.clone()),
376        json_web_token_id: Some(client_assertion_jti),
377        issued_at: Some(chrono::Utc::now().timestamp() as u64),
378        ..Default::default()
379    });
380
381    let client_assertion_token = mint(
382        &oauth_client.private_signing_key_data,
383        &client_assertion_header,
384        &client_assertion_claims,
385    )
386    .map_err(OAuthClientError::MintTokenFailed)?;
387
388    let params = [
389        ("client_id", oauth_client.client_id.as_str()),
390        ("redirect_uri", oauth_client.redirect_uri.as_str()),
391        ("grant_type", "authorization_code"),
392        ("code", callback_code),
393        ("code_verifier", &oauth_request.pkce_verifier),
394        (
395            "client_assertion_type",
396            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
397        ),
398        ("client_assertion", client_assertion_token.as_str()),
399    ];
400
401    let token_endpoint = authorization_server.token_endpoint.clone();
402
403    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
404        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
405
406    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
407
408    let dpop_retry_client = ClientBuilder::new(http_client.clone())
409        .with(ChainMiddleware::new(dpop_retry.clone()))
410        .build();
411
412    dpop_retry_client
413        .post(token_endpoint)
414        .header("DPoP", dpop_token.as_str())
415        .form(&params)
416        .send()
417        .await
418        .map_err(OAuthClientError::TokenHttpRequestFailed)?
419        .json()
420        .await
421        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
422}
423
424/// Refreshes OAuth access tokens using a refresh token.
425///
426/// This function exchanges a refresh token for new access and refresh tokens.
427/// It handles DPoP proof generation and client assertion creation for secure
428/// token refresh operations according to AT Protocol OAuth requirements.
429///
430/// # Arguments
431/// * `http_client` - The HTTP client to use for making requests
432/// * `oauth_client` - The OAuth client configuration
433/// * `dpop_key_data` - The key data for creating DPoP proofs
434/// * `refresh_token` - The refresh token to exchange for new tokens
435/// * `document` - The identity document containing PDS endpoints
436///
437/// # Returns
438/// A `TokenResponse` containing the new access token, refresh token, and metadata on success.
439///
440/// # Errors
441/// Returns `OAuthClientError` if the token refresh fails or response parsing fails.
442pub async fn oauth_refresh(
443    http_client: &reqwest::Client,
444    oauth_client: &OAuthClient,
445    dpop_key_data: &KeyData,
446    refresh_token: &str,
447    document: &atproto_identity::model::Document,
448) -> Result<TokenResponse, OAuthClientError> {
449    let pds_endpoints = document.pds_endpoints();
450    let pds_endpoint = pds_endpoints
451        .first()
452        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
453    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
454
455    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
456        .try_into()
457        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
458
459    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
460    let client_assertion_claims = Claims::new(JoseClaims {
461        issuer: Some(oauth_client.client_id.clone()),
462        subject: Some(oauth_client.client_id.clone()),
463        audience: Some(authorization_server.issuer.clone()),
464        json_web_token_id: Some(client_assertion_jti),
465        issued_at: Some(chrono::Utc::now().timestamp() as u64),
466        ..Default::default()
467    });
468
469    let client_assertion_token = mint(
470        &oauth_client.private_signing_key_data,
471        &client_assertion_header,
472        &client_assertion_claims,
473    )
474    .map_err(OAuthClientError::MintTokenFailed)?;
475
476    let params = [
477        ("client_id", oauth_client.client_id.as_str()),
478        ("redirect_uri", oauth_client.redirect_uri.as_str()),
479        ("grant_type", "refresh_token"),
480        ("refresh_token", refresh_token),
481        (
482            "client_assertion_type",
483            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
484        ),
485        ("client_assertion", client_assertion_token.as_str()),
486    ];
487
488    let token_endpoint = authorization_server.token_endpoint.clone();
489
490    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
491        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
492
493    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
494
495    let dpop_retry_client = ClientBuilder::new(http_client.clone())
496        .with(ChainMiddleware::new(dpop_retry.clone()))
497        .build();
498
499    dpop_retry_client
500        .post(token_endpoint)
501        .header("DPoP", dpop_token.as_str())
502        .form(&params)
503        .send()
504        .await
505        .map_err(OAuthClientError::TokenHttpRequestFailed)?
506        .json()
507        .await
508        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
509}