atproto_oauth/
workflow.rs

1//! OAuth workflow for AT Protocol authorization.
2//!
3//! Complete OAuth 2.0 authorization code flow with PAR, DPoP, PKCE,
4//! and client assertion support for AT Protocol authentication.
5//! - **`OAuthClient`**: Client configuration with credentials and signing keys
6//! - **`OAuthRequest`**: Tracking structure for ongoing authorization requests
7//! - **`OAuthRequestState`**: Security parameters including state, nonce, and PKCE challenge
8//!
9//! ## Security Features
10//!
11//! - **PKCE**: Proof Key for Code Exchange protection against authorization code interception
12//! - **DPoP**: Demonstration of Proof-of-Possession for token binding
13//! - **Client Assertions**: JWT-based client authentication using private key signatures
14//! - **State Parameters**: CSRF protection using random state values
15//! - **Nonce Values**: Additional replay protection
16//!
17//! ## Example Usage
18//!
19//! ```rust,ignore
20//! use atproto_oauth::workflow::{oauth_init, oauth_complete, OAuthClient, OAuthRequestState};
21//! use atproto_oauth::pkce::generate;
22//! use atproto_identity::key::generate_key;
23//!
24//! // Generate security parameters
25//! let signing_key = generate_key(KeyType::P256Private)?;
26//! let dpop_key = generate_key(KeyType::P256Private)?;
27//! let (pkce_verifier, code_challenge) = generate();
28//!
29//! // Configure OAuth client
30//! let oauth_client = OAuthClient {
31//!     redirect_uri: "https://app.example.com/callback".to_string(),
32//!     client_id: "https://app.example.com/client-metadata.json".to_string(),
33//!     private_signing_key_data: signing_key,
34//! };
35//!
36//! // Create request state
37//! let oauth_state = OAuthRequestState {
38//!     state: "random-state-value".to_string(),
39//!     nonce: "random-nonce-value".to_string(),
40//!     code_challenge,
41//!     scope: "atproto transition:generic".to_string(),
42//! };
43//!
44//! // Initiate OAuth flow
45//! let par_response = oauth_init(
46//!     &http_client,
47//!     &oauth_client,
48//!     &dpop_key,
49//!     "user.bsky.social",
50//!     &authorization_server,
51//!     &oauth_state,
52//! ).await?;
53//!
54//! // Build authorization URL
55//! let auth_url = format!(
56//!     "{}?client_id={}&request_uri={}",
57//!     authorization_server.authorization_endpoint,
58//!     oauth_client.client_id,
59//!     par_response.request_uri
60//! );
61//!
62//! // After user authorization and callback...
63//! let token_response = oauth_complete(
64//!     &http_client,
65//!     &oauth_client,
66//!     &dpop_key,
67//!     "authorization_code_from_callback",
68//!     &oauth_request,
69//!     &did_document,
70//! ).await?;
71//! ```
72
73use atproto_identity::key::KeyData;
74use chrono::{DateTime, Utc};
75use rand::distributions::{Alphanumeric, DistString};
76use reqwest_chain::ChainMiddleware;
77use reqwest_middleware::ClientBuilder;
78use serde::Deserialize;
79use std::collections::HashMap;
80
81#[cfg(feature = "zeroize")]
82use zeroize::{Zeroize, ZeroizeOnDrop};
83
84use crate::{
85    dpop::{DpopRetry, auth_dpop},
86    errors::OAuthClientError,
87    jwt::{Claims, Header, JoseClaims, mint},
88    resources::{AuthorizationServer, pds_resources},
89};
90
91/// Response from a Pushed Authorization Request (PAR) endpoint.
92///
93/// Contains the request URI and expiration time returned by the authorization
94/// server after successfully processing a pushed authorization request.
95#[derive(Clone, Deserialize)]
96pub struct ParResponse {
97    /// The request URI to use in the authorization request.
98    pub request_uri: String,
99    /// The lifetime of the request URI in seconds.
100    pub expires_in: u64,
101
102    /// Additional fields returned by the authorization server.
103    #[serde(flatten)]
104    pub extra: HashMap<String, serde_json::Value>,
105}
106
107/// OAuth request state containing security parameters for the authorization flow.
108///
109/// This struct holds the security parameters needed to maintain state
110/// and prevent attacks during the OAuth authorization code flow.
111pub struct OAuthRequestState {
112    /// Random state parameter to prevent CSRF attacks.
113    pub state: String,
114    /// Random nonce value for additional security.
115    pub nonce: String,
116    /// PKCE code challenge derived from the code verifier.
117    pub code_challenge: String,
118    /// The scope of access requested for the authorization.
119    pub scope: String,
120}
121
122/// OAuth client configuration containing essential client credentials.
123///
124/// This struct holds the client configuration needed for OAuth authorization flows,
125/// including the redirect URI, client identifier, and signing key.
126#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
127pub struct OAuthClient {
128    /// The redirect URI where the authorization server will send the user after authorization.
129    #[cfg_attr(feature = "zeroize", zeroize(skip))]
130    pub redirect_uri: String,
131
132    /// The unique client identifier for this OAuth client.
133    #[cfg_attr(feature = "zeroize", zeroize(skip))]
134    pub client_id: String,
135
136    /// The private key data used for signing client assertions.
137    pub private_signing_key_data: KeyData,
138}
139
140/// OAuth request tracking information for ongoing authorization flows.
141///
142/// This struct contains all the necessary information to track and complete
143/// an OAuth authorization request, including security parameters and timing.
144#[derive(Clone, PartialEq)]
145#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
146pub struct OAuthRequest {
147    /// The OAuth state parameter used to prevent CSRF attacks.
148    #[cfg_attr(feature = "zeroize", zeroize(skip))]
149    pub oauth_state: String,
150
151    /// The authorization server issuer identifier.
152    #[cfg_attr(feature = "zeroize", zeroize(skip))]
153    pub issuer: String,
154
155    /// The authorization server identifier.
156    #[cfg_attr(feature = "zeroize", zeroize(skip))]
157    pub authorization_server: String,
158
159    /// The nonce value for additional security.
160    #[cfg_attr(feature = "zeroize", zeroize(skip))]
161    pub nonce: String,
162
163    /// The PKCE code verifier for this authorization request.
164    #[cfg_attr(feature = "zeroize", zeroize(skip))]
165    pub pkce_verifier: String,
166
167    /// The public key used for signing (serialized).
168    pub signing_public_key: String,
169
170    /// The DPoP private key (serialized).
171    #[cfg_attr(feature = "zeroize", zeroize(skip))]
172    pub dpop_private_key: String,
173
174    /// When this OAuth request was created.
175    #[cfg_attr(feature = "zeroize", zeroize(skip))]
176    pub created_at: DateTime<Utc>,
177
178    /// When this OAuth request expires.
179    #[cfg_attr(feature = "zeroize", zeroize(skip))]
180    pub expires_at: DateTime<Utc>,
181}
182
183impl std::fmt::Debug for OAuthRequest {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("OAuthRequest")
186            .field("oauth_state", &self.oauth_state)
187            .field("issuer", &self.issuer)
188            .field("authorization_server", &self.authorization_server)
189            .field("nonce", &self.nonce)
190            .field("pkce_verifier", &"[REDACTED]")
191            .field("signing_public_key", &self.signing_public_key)
192            .field("dpop_private_key", &"[REDACTED]")
193            .field("created_at", &self.created_at)
194            .field("expires_at", &self.expires_at)
195            .finish()
196    }
197}
198
199/// Response from the OAuth token endpoint containing access credentials.
200///
201/// This struct represents the successful response from an OAuth token exchange,
202/// containing the access token and related metadata.
203#[derive(Clone, Deserialize)]
204#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
205pub struct TokenResponse {
206    /// The access token that can be used to access protected resources.
207    pub access_token: String,
208
209    /// The type of token, typically "Bearer" or "DPoP".
210    pub token_type: String,
211
212    /// The refresh token that can be used to obtain new access tokens.
213    /// Not all token responses include a refresh token.
214    pub refresh_token: Option<String>,
215
216    /// The scope of access granted by the access token.
217    #[cfg_attr(feature = "zeroize", zeroize(skip))]
218    pub scope: String,
219
220    /// The lifetime of the access token in seconds.
221    #[cfg_attr(feature = "zeroize", zeroize(skip))]
222    pub expires_in: u32,
223
224    /// The subject identifier (usually the user's DID).
225    #[cfg_attr(feature = "zeroize", zeroize(skip))]
226    pub sub: Option<String>,
227
228    /// Additional fields returned by the authorization server.
229    #[serde(flatten)]
230    #[cfg_attr(feature = "zeroize", zeroize(skip))]
231    pub extra: HashMap<String, serde_json::Value>,
232}
233
234/// Initiates the OAuth authorization flow by making a Pushed Authorization Request (PAR).
235///
236/// This function creates a PAR request to the authorization server with the necessary
237/// OAuth parameters, DPoP proof, and client assertion. It handles the complete setup
238/// for the AT Protocol OAuth flow including PKCE and DPoP security mechanisms.
239///
240/// # Arguments
241/// * `http_client` - The HTTP client to use for making requests
242/// * `private_signing_key_data` - The private key for signing client assertions
243/// * `dpop_key_data` - The key data for creating DPoP proofs
244/// * `handle` - The user's handle for the login hint
245/// * `authorization_server` - The authorization server configuration
246/// * `oauth_request_state` - The OAuth state parameters for this request
247///
248/// # Returns
249/// A `ParResponse` containing the request URI and expiration time on success.
250///
251/// # Errors
252/// Returns `OAuthClientError` if the PAR request fails or response parsing fails.
253pub async fn oauth_init(
254    http_client: &reqwest::Client,
255    oauth_client: &OAuthClient,
256    dpop_key_data: &KeyData,
257    login_hint: Option<&str>,
258    authorization_server: &AuthorizationServer,
259    oauth_request_state: &OAuthRequestState,
260) -> Result<ParResponse, OAuthClientError> {
261    let par_url = authorization_server
262        .pushed_authorization_request_endpoint
263        .clone();
264
265    let scope = &oauth_request_state.scope;
266
267    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
268        .try_into()
269        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
270
271    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
272    let client_assertion_claims = Claims::new(JoseClaims {
273        issuer: Some(oauth_client.client_id.clone()),
274        subject: Some(oauth_client.client_id.clone()),
275        audience: Some(authorization_server.issuer.clone()),
276        json_web_token_id: Some(client_assertion_jti),
277        issued_at: Some(chrono::Utc::now().timestamp() as u64),
278        ..Default::default()
279    });
280
281    let client_assertion_token = mint(
282        &oauth_client.private_signing_key_data,
283        &client_assertion_header,
284        &client_assertion_claims,
285    )
286    .map_err(OAuthClientError::MintTokenFailed)?;
287
288    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
289        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
290
291    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
292
293    let dpop_retry_client = ClientBuilder::new(http_client.clone())
294        .with(ChainMiddleware::new(dpop_retry.clone()))
295        .build();
296
297    let mut params = vec![
298        ("response_type", "code"),
299        ("code_challenge", &oauth_request_state.code_challenge),
300        ("code_challenge_method", "S256"),
301        ("client_id", oauth_client.client_id.as_str()),
302        ("state", oauth_request_state.state.as_str()),
303        ("redirect_uri", oauth_client.redirect_uri.as_str()),
304        ("scope", scope),
305        (
306            "client_assertion_type",
307            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
308        ),
309        ("client_assertion", client_assertion_token.as_str()),
310    ];
311    if let Some(value) = login_hint {
312        params.push(("login_hint", value));
313    }
314
315    let response = dpop_retry_client
316        .post(par_url)
317        .header("DPoP", dpop_token.as_str())
318        .form(&params)
319        .send()
320        .await
321        .map_err(OAuthClientError::PARHttpRequestFailed)?
322        .json()
323        .await
324        .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
325
326    Ok(response)
327}
328
329/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
330///
331/// This function performs the final step of the OAuth authorization code flow by
332/// exchanging the authorization code received from the callback for access and refresh tokens.
333/// It handles DPoP proof generation and client assertion creation for secure token exchange.
334///
335/// # Arguments
336/// * `http_client` - The HTTP client to use for making requests
337/// * `oauth_client` - The OAuth client configuration
338/// * `dpop_key_data` - The key data for creating DPoP proofs
339/// * `callback_code` - The authorization code received from the callback
340/// * `oauth_request` - The original OAuth request state
341/// * `document` - The identity document containing PDS endpoints
342///
343/// # Returns
344/// A `TokenResponse` containing the access token, refresh token, and metadata on success.
345///
346/// # Errors
347/// Returns `OAuthClientError` if the token exchange fails or response parsing fails.
348pub async fn oauth_complete(
349    http_client: &reqwest::Client,
350    oauth_client: &OAuthClient,
351    dpop_key_data: &KeyData,
352    callback_code: &str,
353    oauth_request: &OAuthRequest,
354    authorization_server: &AuthorizationServer,
355) -> Result<TokenResponse, OAuthClientError> {
356    // let pds_endpoints = document.pds_endpoints();
357    // let pds_endpoint = pds_endpoints
358    //     .first()
359    //     .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
360    // let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
361
362    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
363        .try_into()
364        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
365
366    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
367    let client_assertion_claims = Claims::new(JoseClaims {
368        issuer: Some(oauth_client.client_id.clone()),
369        subject: Some(oauth_client.client_id.clone()),
370        audience: Some(authorization_server.issuer.clone()),
371        json_web_token_id: Some(client_assertion_jti),
372        issued_at: Some(chrono::Utc::now().timestamp() as u64),
373        ..Default::default()
374    });
375
376    let client_assertion_token = mint(
377        &oauth_client.private_signing_key_data,
378        &client_assertion_header,
379        &client_assertion_claims,
380    )
381    .map_err(OAuthClientError::MintTokenFailed)?;
382
383    let params = [
384        ("client_id", oauth_client.client_id.as_str()),
385        ("redirect_uri", oauth_client.redirect_uri.as_str()),
386        ("grant_type", "authorization_code"),
387        ("code", callback_code),
388        ("code_verifier", &oauth_request.pkce_verifier),
389        (
390            "client_assertion_type",
391            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
392        ),
393        ("client_assertion", client_assertion_token.as_str()),
394    ];
395
396    let token_endpoint = authorization_server.token_endpoint.clone();
397
398    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
399        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
400
401    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
402
403    let dpop_retry_client = ClientBuilder::new(http_client.clone())
404        .with(ChainMiddleware::new(dpop_retry.clone()))
405        .build();
406
407    dpop_retry_client
408        .post(token_endpoint)
409        .header("DPoP", dpop_token.as_str())
410        .form(&params)
411        .send()
412        .await
413        .map_err(OAuthClientError::TokenHttpRequestFailed)?
414        .json()
415        .await
416        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
417}
418
419/// Refreshes OAuth access tokens using a refresh token.
420///
421/// This function exchanges a refresh token for new access and refresh tokens.
422/// It handles DPoP proof generation and client assertion creation for secure
423/// token refresh operations according to AT Protocol OAuth requirements.
424///
425/// # Arguments
426/// * `http_client` - The HTTP client to use for making requests
427/// * `oauth_client` - The OAuth client configuration
428/// * `dpop_key_data` - The key data for creating DPoP proofs
429/// * `refresh_token` - The refresh token to exchange for new tokens
430/// * `document` - The identity document containing PDS endpoints
431///
432/// # Returns
433/// A `TokenResponse` containing the new access token, refresh token, and metadata on success.
434///
435/// # Errors
436/// Returns `OAuthClientError` if the token refresh fails or response parsing fails.
437pub async fn oauth_refresh(
438    http_client: &reqwest::Client,
439    oauth_client: &OAuthClient,
440    dpop_key_data: &KeyData,
441    refresh_token: &str,
442    document: &atproto_identity::model::Document,
443) -> Result<TokenResponse, OAuthClientError> {
444    let pds_endpoints = document.pds_endpoints();
445    let pds_endpoint = pds_endpoints
446        .first()
447        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
448    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
449
450    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
451        .try_into()
452        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
453
454    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
455    let client_assertion_claims = Claims::new(JoseClaims {
456        issuer: Some(oauth_client.client_id.clone()),
457        subject: Some(oauth_client.client_id.clone()),
458        audience: Some(authorization_server.issuer.clone()),
459        json_web_token_id: Some(client_assertion_jti),
460        issued_at: Some(chrono::Utc::now().timestamp() as u64),
461        ..Default::default()
462    });
463
464    let client_assertion_token = mint(
465        &oauth_client.private_signing_key_data,
466        &client_assertion_header,
467        &client_assertion_claims,
468    )
469    .map_err(OAuthClientError::MintTokenFailed)?;
470
471    let params = [
472        ("client_id", oauth_client.client_id.as_str()),
473        ("redirect_uri", oauth_client.redirect_uri.as_str()),
474        ("grant_type", "refresh_token"),
475        ("refresh_token", refresh_token),
476        (
477            "client_assertion_type",
478            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
479        ),
480        ("client_assertion", client_assertion_token.as_str()),
481    ];
482
483    let token_endpoint = authorization_server.token_endpoint.clone();
484
485    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
486        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
487
488    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
489
490    let dpop_retry_client = ClientBuilder::new(http_client.clone())
491        .with(ChainMiddleware::new(dpop_retry.clone()))
492        .build();
493
494    dpop_retry_client
495        .post(token_endpoint)
496        .header("DPoP", dpop_token.as_str())
497        .form(&params)
498        .send()
499        .await
500        .map_err(OAuthClientError::TokenHttpRequestFailed)?
501        .json()
502        .await
503        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
504}