atproto_oauth/
workflow.rs

1//! OAuth workflow implementation for AT Protocol authorization flows.
2//!
3//! This module provides a complete OAuth 2.0 authorization code flow implementation specifically
4//! designed for AT Protocol, including Pushed Authorization Requests (PAR), DPoP security, PKCE
5//! protection, and client assertion handling.
6//!
7//! ## OAuth Flow Components
8//!
9//! - **`oauth_init()`**: Initiates OAuth flow with PAR request to authorization server
10//! - **`oauth_complete()`**: Completes OAuth flow by exchanging authorization code for tokens
11//! - **`OAuthClient`**: Client configuration with credentials and signing keys
12//! - **`OAuthRequest`**: Tracking structure for ongoing authorization requests
13//! - **`OAuthRequestState`**: Security parameters including state, nonce, and PKCE challenge
14//!
15//! ## Security Features
16//!
17//! - **PKCE**: Proof Key for Code Exchange protection against authorization code interception
18//! - **DPoP**: Demonstration of Proof-of-Possession for token binding
19//! - **Client Assertions**: JWT-based client authentication using private key signatures
20//! - **State Parameters**: CSRF protection using random state values
21//! - **Nonce Values**: Additional replay protection
22//!
23//! ## Example Usage
24//!
25//! ```rust,ignore
26//! use atproto_oauth::workflow::{oauth_init, oauth_complete, OAuthClient, OAuthRequestState};
27//! use atproto_oauth::pkce::generate;
28//! use atproto_identity::key::generate_key;
29//!
30//! // Generate security parameters
31//! let signing_key = generate_key(KeyType::P256Private)?;
32//! let dpop_key = generate_key(KeyType::P256Private)?;
33//! let (pkce_verifier, code_challenge) = generate();
34//!
35//! // Configure OAuth client
36//! let oauth_client = OAuthClient {
37//!     redirect_uri: "https://app.example.com/callback".to_string(),
38//!     client_id: "https://app.example.com/client-metadata.json".to_string(),
39//!     private_signing_key_data: signing_key,
40//! };
41//!
42//! // Create request state
43//! let oauth_state = OAuthRequestState {
44//!     state: "random-state-value".to_string(),
45//!     nonce: "random-nonce-value".to_string(),
46//!     code_challenge,
47//!     scope: "atproto transition:generic".to_string(),
48//! };
49//!
50//! // Initiate OAuth flow
51//! let par_response = oauth_init(
52//!     &http_client,
53//!     &oauth_client,
54//!     &dpop_key,
55//!     "user.bsky.social",
56//!     &authorization_server,
57//!     &oauth_state,
58//! ).await?;
59//!
60//! // Build authorization URL
61//! let auth_url = format!(
62//!     "{}?client_id={}&request_uri={}",
63//!     authorization_server.authorization_endpoint,
64//!     oauth_client.client_id,
65//!     par_response.request_uri
66//! );
67//!
68//! // After user authorization and callback...
69//! let token_response = oauth_complete(
70//!     &http_client,
71//!     &oauth_client,
72//!     &dpop_key,
73//!     "authorization_code_from_callback",
74//!     &oauth_request,
75//!     &did_document,
76//! ).await?;
77//! ```
78
79use atproto_identity::key::KeyData;
80use chrono::{DateTime, Utc};
81use rand::distributions::{Alphanumeric, DistString};
82use reqwest_chain::ChainMiddleware;
83use reqwest_middleware::ClientBuilder;
84use serde::Deserialize;
85use std::collections::HashMap;
86
87#[cfg(feature = "zeroize")]
88use zeroize::{Zeroize, ZeroizeOnDrop};
89
90use crate::{
91    dpop::{DpopRetry, auth_dpop},
92    errors::OAuthClientError,
93    jwt::{Claims, Header, JoseClaims, mint},
94    resources::{AuthorizationServer, pds_resources},
95};
96
97/// Response from a Pushed Authorization Request (PAR) endpoint.
98///
99/// Contains the request URI and expiration time returned by the authorization
100/// server after successfully processing a pushed authorization request.
101#[derive(Clone, Deserialize)]
102pub struct ParResponse {
103    /// The request URI to use in the authorization request.
104    pub request_uri: String,
105    /// The lifetime of the request URI in seconds.
106    pub expires_in: u64,
107
108    /// Additional fields returned by the authorization server.
109    #[serde(flatten)]
110    pub extra: HashMap<String, serde_json::Value>,
111}
112
113/// OAuth request state containing security parameters for the authorization flow.
114///
115/// This struct holds the security parameters needed to maintain state
116/// and prevent attacks during the OAuth authorization code flow.
117pub struct OAuthRequestState {
118    /// Random state parameter to prevent CSRF attacks.
119    pub state: String,
120    /// Random nonce value for additional security.
121    pub nonce: String,
122    /// PKCE code challenge derived from the code verifier.
123    pub code_challenge: String,
124    /// The scope of access requested for the authorization.
125    pub scope: String,
126}
127
128/// OAuth client configuration containing essential client credentials.
129///
130/// This struct holds the client configuration needed for OAuth authorization flows,
131/// including the redirect URI, client identifier, and signing key.
132#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
133pub struct OAuthClient {
134    /// The redirect URI where the authorization server will send the user after authorization.
135    #[cfg_attr(feature = "zeroize", zeroize(skip))]
136    pub redirect_uri: String,
137
138    /// The unique client identifier for this OAuth client.
139    #[cfg_attr(feature = "zeroize", zeroize(skip))]
140    pub client_id: String,
141
142    /// The private key data used for signing client assertions.
143    pub private_signing_key_data: KeyData,
144}
145
146/// OAuth request tracking information for ongoing authorization flows.
147///
148/// This struct contains all the necessary information to track and complete
149/// an OAuth authorization request, including security parameters and timing.
150#[derive(Clone, PartialEq)]
151#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
152pub struct OAuthRequest {
153    /// The OAuth state parameter used to prevent CSRF attacks.
154    #[cfg_attr(feature = "zeroize", zeroize(skip))]
155    pub oauth_state: String,
156
157    /// The authorization server issuer identifier.
158    #[cfg_attr(feature = "zeroize", zeroize(skip))]
159    pub issuer: String,
160
161    /// The DID (Decentralized Identifier) of the user.
162    #[cfg_attr(feature = "zeroize", zeroize(skip))]
163    pub did: String,
164
165    /// The nonce value for additional security.
166    #[cfg_attr(feature = "zeroize", zeroize(skip))]
167    pub nonce: String,
168
169    /// The PKCE code verifier for this authorization request.
170    #[cfg_attr(feature = "zeroize", zeroize(skip))]
171    pub pkce_verifier: String,
172
173    /// The public key used for signing (serialized).
174    pub signing_public_key: String,
175
176    /// The DPoP private key (serialized).
177    #[cfg_attr(feature = "zeroize", zeroize(skip))]
178    pub dpop_private_key: String,
179
180    /// When this OAuth request was created.
181    #[cfg_attr(feature = "zeroize", zeroize(skip))]
182    pub created_at: DateTime<Utc>,
183
184    /// When this OAuth request expires.
185    #[cfg_attr(feature = "zeroize", zeroize(skip))]
186    pub expires_at: DateTime<Utc>,
187}
188
189impl std::fmt::Debug for OAuthRequest {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        f.debug_struct("OAuthRequest")
192            .field("oauth_state", &self.oauth_state)
193            .field("issuer", &self.issuer)
194            .field("did", &self.did)
195            .field("nonce", &self.nonce)
196            .field("pkce_verifier", &"[REDACTED]")
197            .field("signing_public_key", &self.signing_public_key)
198            .field("dpop_private_key", &"[REDACTED]")
199            .field("created_at", &self.created_at)
200            .field("expires_at", &self.expires_at)
201            .finish()
202    }
203}
204
205/// Response from the OAuth token endpoint containing access credentials.
206///
207/// This struct represents the successful response from an OAuth token exchange,
208/// containing the access token and related metadata.
209#[derive(Clone, Deserialize)]
210#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
211pub struct TokenResponse {
212    /// The access token that can be used to access protected resources.
213    pub access_token: String,
214
215    /// The type of token, typically "Bearer" or "DPoP".
216    pub token_type: String,
217
218    /// The refresh token that can be used to obtain new access tokens.
219    pub refresh_token: String,
220
221    /// The scope of access granted by the access token.
222    #[cfg_attr(feature = "zeroize", zeroize(skip))]
223    pub scope: String,
224
225    /// The lifetime of the access token in seconds.
226    #[cfg_attr(feature = "zeroize", zeroize(skip))]
227    pub expires_in: u32,
228
229    /// The subject identifier (usually the user's DID).
230    #[cfg_attr(feature = "zeroize", zeroize(skip))]
231    pub sub: Option<String>,
232
233    /// Additional fields returned by the authorization server.
234    #[serde(flatten)]
235    #[cfg_attr(feature = "zeroize", zeroize(skip))]
236    pub extra: HashMap<String, serde_json::Value>,
237}
238
239/// Initiates the OAuth authorization flow by making a Pushed Authorization Request (PAR).
240///
241/// This function creates a PAR request to the authorization server with the necessary
242/// OAuth parameters, DPoP proof, and client assertion. It handles the complete setup
243/// for the AT Protocol OAuth flow including PKCE and DPoP security mechanisms.
244///
245/// # Arguments
246/// * `http_client` - The HTTP client to use for making requests
247/// * `private_signing_key_data` - The private key for signing client assertions
248/// * `dpop_key_data` - The key data for creating DPoP proofs
249/// * `handle` - The user's handle for the login hint
250/// * `authorization_server` - The authorization server configuration
251/// * `oauth_request_state` - The OAuth state parameters for this request
252///
253/// # Returns
254/// A `ParResponse` containing the request URI and expiration time on success.
255///
256/// # Errors
257/// Returns `OAuthClientError` if the PAR request fails or response parsing fails.
258pub async fn oauth_init(
259    http_client: &reqwest::Client,
260    oauth_client: &OAuthClient,
261    dpop_key_data: &KeyData,
262    handle: &str,
263    authorization_server: &AuthorizationServer,
264    oauth_request_state: &OAuthRequestState,
265) -> Result<ParResponse, OAuthClientError> {
266    let par_url = authorization_server
267        .pushed_authorization_request_endpoint
268        .clone();
269
270    let scope = &oauth_request_state.scope;
271
272    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
273        .try_into()
274        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
275
276    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
277    let client_assertion_claims = Claims::new(JoseClaims {
278        issuer: Some(oauth_client.client_id.clone()),
279        subject: Some(oauth_client.client_id.clone()),
280        audience: Some(authorization_server.issuer.clone()),
281        json_web_token_id: Some(client_assertion_jti),
282        issued_at: Some(chrono::Utc::now().timestamp() as u64),
283        ..Default::default()
284    });
285
286    let client_assertion_token = mint(
287        &oauth_client.private_signing_key_data,
288        &client_assertion_header,
289        &client_assertion_claims,
290    )
291    .map_err(OAuthClientError::MintTokenFailed)?;
292
293    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
294        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
295
296    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
297
298    let dpop_retry_client = ClientBuilder::new(http_client.clone())
299        .with(ChainMiddleware::new(dpop_retry.clone()))
300        .build();
301
302    let params = [
303        ("response_type", "code"),
304        ("code_challenge", &oauth_request_state.code_challenge),
305        ("code_challenge_method", "S256"),
306        ("client_id", oauth_client.client_id.as_str()),
307        ("state", oauth_request_state.state.as_str()),
308        ("redirect_uri", oauth_client.redirect_uri.as_str()),
309        ("scope", scope),
310        ("login_hint", handle),
311        (
312            "client_assertion_type",
313            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
314        ),
315        ("client_assertion", client_assertion_token.as_str()),
316    ];
317
318    let response = dpop_retry_client
319        .post(par_url)
320        .header("DPoP", dpop_token.as_str())
321        .form(&params)
322        .send()
323        .await
324        .map_err(OAuthClientError::PARHttpRequestFailed)?
325        .json()
326        .await
327        .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
328
329    Ok(response)
330}
331
332/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
333///
334/// This function performs the final step of the OAuth authorization code flow by
335/// exchanging the authorization code received from the callback for access and refresh tokens.
336/// It handles DPoP proof generation and client assertion creation for secure token exchange.
337///
338/// # Arguments
339/// * `http_client` - The HTTP client to use for making requests
340/// * `oauth_client` - The OAuth client configuration
341/// * `dpop_key_data` - The key data for creating DPoP proofs
342/// * `callback_code` - The authorization code received from the callback
343/// * `oauth_request` - The original OAuth request state
344/// * `document` - The identity document containing PDS endpoints
345///
346/// # Returns
347/// A `TokenResponse` containing the access token, refresh token, and metadata on success.
348///
349/// # Errors
350/// Returns `OAuthClientError` if the token exchange fails or response parsing fails.
351pub async fn oauth_complete(
352    http_client: &reqwest::Client,
353    oauth_client: &OAuthClient,
354    dpop_key_data: &KeyData,
355    callback_code: &str,
356    oauth_request: &OAuthRequest,
357    document: &atproto_identity::model::Document,
358) -> Result<TokenResponse, OAuthClientError> {
359    let pds_endpoints = document.pds_endpoints();
360    let pds_endpoint = pds_endpoints
361        .first()
362        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
363    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
364
365    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
366        .try_into()
367        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
368
369    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
370    let client_assertion_claims = Claims::new(JoseClaims {
371        issuer: Some(oauth_client.client_id.clone()),
372        subject: Some(oauth_client.client_id.clone()),
373        audience: Some(authorization_server.issuer.clone()),
374        json_web_token_id: Some(client_assertion_jti),
375        issued_at: Some(chrono::Utc::now().timestamp() as u64),
376        ..Default::default()
377    });
378
379    let client_assertion_token = mint(
380        &oauth_client.private_signing_key_data,
381        &client_assertion_header,
382        &client_assertion_claims,
383    )
384    .map_err(OAuthClientError::MintTokenFailed)?;
385
386    let params = [
387        ("client_id", oauth_client.client_id.as_str()),
388        ("redirect_uri", oauth_client.redirect_uri.as_str()),
389        ("grant_type", "authorization_code"),
390        ("code", callback_code),
391        ("code_verifier", &oauth_request.pkce_verifier),
392        (
393            "client_assertion_type",
394            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
395        ),
396        ("client_assertion", client_assertion_token.as_str()),
397    ];
398
399    let token_endpoint = authorization_server.token_endpoint.clone();
400
401    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
402        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
403
404    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
405
406    let dpop_retry_client = ClientBuilder::new(http_client.clone())
407        .with(ChainMiddleware::new(dpop_retry.clone()))
408        .build();
409
410    dpop_retry_client
411        .post(token_endpoint)
412        .header("DPoP", dpop_token.as_str())
413        .form(&params)
414        .send()
415        .await
416        .map_err(OAuthClientError::TokenHttpRequestFailed)?
417        .json()
418        .await
419        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
420}
421
422/// Refreshes OAuth access tokens using a refresh token.
423///
424/// This function exchanges a refresh token for new access and refresh tokens.
425/// It handles DPoP proof generation and client assertion creation for secure
426/// token refresh operations according to AT Protocol OAuth requirements.
427///
428/// # Arguments
429/// * `http_client` - The HTTP client to use for making requests
430/// * `oauth_client` - The OAuth client configuration
431/// * `dpop_key_data` - The key data for creating DPoP proofs
432/// * `refresh_token` - The refresh token to exchange for new tokens
433/// * `document` - The identity document containing PDS endpoints
434///
435/// # Returns
436/// A `TokenResponse` containing the new access token, refresh token, and metadata on success.
437///
438/// # Errors
439/// Returns `OAuthClientError` if the token refresh fails or response parsing fails.
440pub async fn oauth_refresh(
441    http_client: &reqwest::Client,
442    oauth_client: &OAuthClient,
443    dpop_key_data: &KeyData,
444    refresh_token: &str,
445    document: &atproto_identity::model::Document,
446) -> Result<TokenResponse, OAuthClientError> {
447    let pds_endpoints = document.pds_endpoints();
448    let pds_endpoint = pds_endpoints
449        .first()
450        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
451    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
452
453    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
454        .try_into()
455        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
456
457    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
458    let client_assertion_claims = Claims::new(JoseClaims {
459        issuer: Some(oauth_client.client_id.clone()),
460        subject: Some(oauth_client.client_id.clone()),
461        audience: Some(authorization_server.issuer.clone()),
462        json_web_token_id: Some(client_assertion_jti),
463        issued_at: Some(chrono::Utc::now().timestamp() as u64),
464        ..Default::default()
465    });
466
467    let client_assertion_token = mint(
468        &oauth_client.private_signing_key_data,
469        &client_assertion_header,
470        &client_assertion_claims,
471    )
472    .map_err(OAuthClientError::MintTokenFailed)?;
473
474    let params = [
475        ("client_id", oauth_client.client_id.as_str()),
476        ("redirect_uri", oauth_client.redirect_uri.as_str()),
477        ("grant_type", "refresh_token"),
478        ("refresh_token", refresh_token),
479        (
480            "client_assertion_type",
481            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
482        ),
483        ("client_assertion", client_assertion_token.as_str()),
484    ];
485
486    let token_endpoint = authorization_server.token_endpoint.clone();
487
488    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
489        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
490
491    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
492
493    let dpop_retry_client = ClientBuilder::new(http_client.clone())
494        .with(ChainMiddleware::new(dpop_retry.clone()))
495        .build();
496
497    dpop_retry_client
498        .post(token_endpoint)
499        .header("DPoP", dpop_token.as_str())
500        .form(&params)
501        .send()
502        .await
503        .map_err(OAuthClientError::TokenHttpRequestFailed)?
504        .json()
505        .await
506        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
507}