atproto_oauth/
workflow.rs

1//! OAuth workflow for AT Protocol authorization.
2//!
3//! Complete OAuth 2.0 authorization code flow with PAR, DPoP, PKCE,
4//! and client assertion support for AT Protocol authentication.
5//! - **`OAuthClient`**: Client configuration with credentials and signing keys
6//! - **`OAuthRequest`**: Tracking structure for ongoing authorization requests
7//! - **`OAuthRequestState`**: Security parameters including state, nonce, and PKCE challenge
8//!
9//! ## Security Features
10//!
11//! - **PKCE**: Proof Key for Code Exchange protection against authorization code interception
12//! - **DPoP**: Demonstration of Proof-of-Possession for token binding
13//! - **Client Assertions**: JWT-based client authentication using private key signatures
14//! - **State Parameters**: CSRF protection using random state values
15//! - **Nonce Values**: Additional replay protection
16//!
17//! ## Example Usage
18//!
19//! ```rust,ignore
20//! use atproto_oauth::workflow::{oauth_init, oauth_complete, OAuthClient, OAuthRequestState};
21//! use atproto_oauth::pkce::generate;
22//! use atproto_identity::key::generate_key;
23//!
24//! // Generate security parameters
25//! let signing_key = generate_key(KeyType::P256Private)?;
26//! let dpop_key = generate_key(KeyType::P256Private)?;
27//! let (pkce_verifier, code_challenge) = generate();
28//!
29//! // Configure OAuth client
30//! let oauth_client = OAuthClient {
31//!     redirect_uri: "https://app.example.com/callback".to_string(),
32//!     client_id: "https://app.example.com/client-metadata.json".to_string(),
33//!     private_signing_key_data: signing_key,
34//! };
35//!
36//! // Create request state
37//! let oauth_state = OAuthRequestState {
38//!     state: "random-state-value".to_string(),
39//!     nonce: "random-nonce-value".to_string(),
40//!     code_challenge,
41//!     scope: "atproto transition:generic".to_string(),
42//! };
43//!
44//! // Initiate OAuth flow
45//! let par_response = oauth_init(
46//!     &http_client,
47//!     &oauth_client,
48//!     &dpop_key,
49//!     "user.bsky.social",
50//!     &authorization_server,
51//!     &oauth_state,
52//! ).await?;
53//!
54//! // Build authorization URL
55//! let auth_url = format!(
56//!     "{}?client_id={}&request_uri={}",
57//!     authorization_server.authorization_endpoint,
58//!     oauth_client.client_id,
59//!     par_response.request_uri
60//! );
61//!
62//! // After user authorization and callback...
63//! let token_response = oauth_complete(
64//!     &http_client,
65//!     &oauth_client,
66//!     &dpop_key,
67//!     "authorization_code_from_callback",
68//!     &oauth_request,
69//!     &did_document,
70//! ).await?;
71//! ```
72
73use atproto_identity::key::KeyData;
74use chrono::{DateTime, Utc};
75use rand::distributions::{Alphanumeric, DistString};
76use reqwest_chain::ChainMiddleware;
77use reqwest_middleware::ClientBuilder;
78use serde::Deserialize;
79use std::collections::HashMap;
80
81#[cfg(feature = "zeroize")]
82use zeroize::{Zeroize, ZeroizeOnDrop};
83
84use crate::{
85    dpop::{DpopRetry, auth_dpop},
86    errors::OAuthClientError,
87    jwt::{Claims, Header, JoseClaims, mint},
88    resources::{AuthorizationServer, pds_resources},
89};
90
91/// Response from a Pushed Authorization Request (PAR) endpoint.
92///
93/// Contains the request URI and expiration time returned by the authorization
94/// server after successfully processing a pushed authorization request.
95#[derive(Clone, Deserialize)]
96pub struct ParResponse {
97    /// The request URI to use in the authorization request.
98    pub request_uri: String,
99    /// The lifetime of the request URI in seconds.
100    pub expires_in: u64,
101
102    /// Additional fields returned by the authorization server.
103    #[serde(flatten)]
104    pub extra: HashMap<String, serde_json::Value>,
105}
106
107/// OAuth request state containing security parameters for the authorization flow.
108///
109/// This struct holds the security parameters needed to maintain state
110/// and prevent attacks during the OAuth authorization code flow.
111pub struct OAuthRequestState {
112    /// Random state parameter to prevent CSRF attacks.
113    pub state: String,
114    /// Random nonce value for additional security.
115    pub nonce: String,
116    /// PKCE code challenge derived from the code verifier.
117    pub code_challenge: String,
118    /// The scope of access requested for the authorization.
119    pub scope: String,
120}
121
122/// OAuth client configuration containing essential client credentials.
123///
124/// This struct holds the client configuration needed for OAuth authorization flows,
125/// including the redirect URI, client identifier, and signing key.
126#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
127pub struct OAuthClient {
128    /// The redirect URI where the authorization server will send the user after authorization.
129    #[cfg_attr(feature = "zeroize", zeroize(skip))]
130    pub redirect_uri: String,
131
132    /// The unique client identifier for this OAuth client.
133    #[cfg_attr(feature = "zeroize", zeroize(skip))]
134    pub client_id: String,
135
136    /// The private key data used for signing client assertions.
137    pub private_signing_key_data: KeyData,
138}
139
140/// OAuth request tracking information for ongoing authorization flows.
141///
142/// This struct contains all the necessary information to track and complete
143/// an OAuth authorization request, including security parameters and timing.
144#[derive(Clone, PartialEq)]
145#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
146pub struct OAuthRequest {
147    /// The OAuth state parameter used to prevent CSRF attacks.
148    #[cfg_attr(feature = "zeroize", zeroize(skip))]
149    pub oauth_state: String,
150
151    /// The authorization server issuer identifier.
152    #[cfg_attr(feature = "zeroize", zeroize(skip))]
153    pub issuer: String,
154
155    /// The authorization server identifier.
156    #[cfg_attr(feature = "zeroize", zeroize(skip))]
157    pub authorization_server: String,
158
159    /// The nonce value for additional security.
160    #[cfg_attr(feature = "zeroize", zeroize(skip))]
161    pub nonce: String,
162
163    /// The PKCE code verifier for this authorization request.
164    #[cfg_attr(feature = "zeroize", zeroize(skip))]
165    pub pkce_verifier: String,
166
167    /// The public key used for signing (serialized).
168    pub signing_public_key: String,
169
170    /// The DPoP private key (serialized).
171    #[cfg_attr(feature = "zeroize", zeroize(skip))]
172    pub dpop_private_key: String,
173
174    /// When this OAuth request was created.
175    #[cfg_attr(feature = "zeroize", zeroize(skip))]
176    pub created_at: DateTime<Utc>,
177
178    /// When this OAuth request expires.
179    #[cfg_attr(feature = "zeroize", zeroize(skip))]
180    pub expires_at: DateTime<Utc>,
181}
182
183impl std::fmt::Debug for OAuthRequest {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("OAuthRequest")
186            .field("oauth_state", &self.oauth_state)
187            .field("issuer", &self.issuer)
188            .field("authorization_server", &self.authorization_server)
189            .field("nonce", &self.nonce)
190            .field("pkce_verifier", &"[REDACTED]")
191            .field("signing_public_key", &self.signing_public_key)
192            .field("dpop_private_key", &"[REDACTED]")
193            .field("created_at", &self.created_at)
194            .field("expires_at", &self.expires_at)
195            .finish()
196    }
197}
198
199/// Response from the OAuth token endpoint containing access credentials.
200///
201/// This struct represents the successful response from an OAuth token exchange,
202/// containing the access token and related metadata.
203#[derive(Clone, Deserialize)]
204#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
205pub struct TokenResponse {
206    /// The access token that can be used to access protected resources.
207    pub access_token: String,
208
209    /// The type of token, typically "Bearer" or "DPoP".
210    pub token_type: String,
211
212    /// The refresh token that can be used to obtain new access tokens.
213    pub refresh_token: String,
214
215    /// The scope of access granted by the access token.
216    #[cfg_attr(feature = "zeroize", zeroize(skip))]
217    pub scope: String,
218
219    /// The lifetime of the access token in seconds.
220    #[cfg_attr(feature = "zeroize", zeroize(skip))]
221    pub expires_in: u32,
222
223    /// The subject identifier (usually the user's DID).
224    #[cfg_attr(feature = "zeroize", zeroize(skip))]
225    pub sub: Option<String>,
226
227    /// Additional fields returned by the authorization server.
228    #[serde(flatten)]
229    #[cfg_attr(feature = "zeroize", zeroize(skip))]
230    pub extra: HashMap<String, serde_json::Value>,
231}
232
233/// Initiates the OAuth authorization flow by making a Pushed Authorization Request (PAR).
234///
235/// This function creates a PAR request to the authorization server with the necessary
236/// OAuth parameters, DPoP proof, and client assertion. It handles the complete setup
237/// for the AT Protocol OAuth flow including PKCE and DPoP security mechanisms.
238///
239/// # Arguments
240/// * `http_client` - The HTTP client to use for making requests
241/// * `private_signing_key_data` - The private key for signing client assertions
242/// * `dpop_key_data` - The key data for creating DPoP proofs
243/// * `handle` - The user's handle for the login hint
244/// * `authorization_server` - The authorization server configuration
245/// * `oauth_request_state` - The OAuth state parameters for this request
246///
247/// # Returns
248/// A `ParResponse` containing the request URI and expiration time on success.
249///
250/// # Errors
251/// Returns `OAuthClientError` if the PAR request fails or response parsing fails.
252pub async fn oauth_init(
253    http_client: &reqwest::Client,
254    oauth_client: &OAuthClient,
255    dpop_key_data: &KeyData,
256    login_hint: Option<&str>,
257    authorization_server: &AuthorizationServer,
258    oauth_request_state: &OAuthRequestState,
259) -> Result<ParResponse, OAuthClientError> {
260    let par_url = authorization_server
261        .pushed_authorization_request_endpoint
262        .clone();
263
264    let scope = &oauth_request_state.scope;
265
266    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
267        .try_into()
268        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
269
270    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
271    let client_assertion_claims = Claims::new(JoseClaims {
272        issuer: Some(oauth_client.client_id.clone()),
273        subject: Some(oauth_client.client_id.clone()),
274        audience: Some(authorization_server.issuer.clone()),
275        json_web_token_id: Some(client_assertion_jti),
276        issued_at: Some(chrono::Utc::now().timestamp() as u64),
277        ..Default::default()
278    });
279
280    let client_assertion_token = mint(
281        &oauth_client.private_signing_key_data,
282        &client_assertion_header,
283        &client_assertion_claims,
284    )
285    .map_err(OAuthClientError::MintTokenFailed)?;
286
287    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
288        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
289
290    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
291
292    let dpop_retry_client = ClientBuilder::new(http_client.clone())
293        .with(ChainMiddleware::new(dpop_retry.clone()))
294        .build();
295
296    let mut params = vec![
297        ("response_type", "code"),
298        ("code_challenge", &oauth_request_state.code_challenge),
299        ("code_challenge_method", "S256"),
300        ("client_id", oauth_client.client_id.as_str()),
301        ("state", oauth_request_state.state.as_str()),
302        ("redirect_uri", oauth_client.redirect_uri.as_str()),
303        ("scope", scope),
304        (
305            "client_assertion_type",
306            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
307        ),
308        ("client_assertion", client_assertion_token.as_str()),
309    ];
310    if let Some(value) = login_hint {
311        params.push(("login_hint", value));
312    }
313
314    let response = dpop_retry_client
315        .post(par_url)
316        .header("DPoP", dpop_token.as_str())
317        .form(&params)
318        .send()
319        .await
320        .map_err(OAuthClientError::PARHttpRequestFailed)?
321        .json()
322        .await
323        .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
324
325    Ok(response)
326}
327
328/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
329///
330/// This function performs the final step of the OAuth authorization code flow by
331/// exchanging the authorization code received from the callback for access and refresh tokens.
332/// It handles DPoP proof generation and client assertion creation for secure token exchange.
333///
334/// # Arguments
335/// * `http_client` - The HTTP client to use for making requests
336/// * `oauth_client` - The OAuth client configuration
337/// * `dpop_key_data` - The key data for creating DPoP proofs
338/// * `callback_code` - The authorization code received from the callback
339/// * `oauth_request` - The original OAuth request state
340/// * `document` - The identity document containing PDS endpoints
341///
342/// # Returns
343/// A `TokenResponse` containing the access token, refresh token, and metadata on success.
344///
345/// # Errors
346/// Returns `OAuthClientError` if the token exchange fails or response parsing fails.
347pub async fn oauth_complete(
348    http_client: &reqwest::Client,
349    oauth_client: &OAuthClient,
350    dpop_key_data: &KeyData,
351    callback_code: &str,
352    oauth_request: &OAuthRequest,
353    authorization_server: &AuthorizationServer,
354) -> Result<TokenResponse, OAuthClientError> {
355    // let pds_endpoints = document.pds_endpoints();
356    // let pds_endpoint = pds_endpoints
357    //     .first()
358    //     .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
359    // let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
360
361    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
362        .try_into()
363        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
364
365    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
366    let client_assertion_claims = Claims::new(JoseClaims {
367        issuer: Some(oauth_client.client_id.clone()),
368        subject: Some(oauth_client.client_id.clone()),
369        audience: Some(authorization_server.issuer.clone()),
370        json_web_token_id: Some(client_assertion_jti),
371        issued_at: Some(chrono::Utc::now().timestamp() as u64),
372        ..Default::default()
373    });
374
375    let client_assertion_token = mint(
376        &oauth_client.private_signing_key_data,
377        &client_assertion_header,
378        &client_assertion_claims,
379    )
380    .map_err(OAuthClientError::MintTokenFailed)?;
381
382    let params = [
383        ("client_id", oauth_client.client_id.as_str()),
384        ("redirect_uri", oauth_client.redirect_uri.as_str()),
385        ("grant_type", "authorization_code"),
386        ("code", callback_code),
387        ("code_verifier", &oauth_request.pkce_verifier),
388        (
389            "client_assertion_type",
390            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
391        ),
392        ("client_assertion", client_assertion_token.as_str()),
393    ];
394
395    let token_endpoint = authorization_server.token_endpoint.clone();
396
397    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
398        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
399
400    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
401
402    let dpop_retry_client = ClientBuilder::new(http_client.clone())
403        .with(ChainMiddleware::new(dpop_retry.clone()))
404        .build();
405
406    dpop_retry_client
407        .post(token_endpoint)
408        .header("DPoP", dpop_token.as_str())
409        .form(&params)
410        .send()
411        .await
412        .map_err(OAuthClientError::TokenHttpRequestFailed)?
413        .json()
414        .await
415        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
416}
417
418/// Refreshes OAuth access tokens using a refresh token.
419///
420/// This function exchanges a refresh token for new access and refresh tokens.
421/// It handles DPoP proof generation and client assertion creation for secure
422/// token refresh operations according to AT Protocol OAuth requirements.
423///
424/// # Arguments
425/// * `http_client` - The HTTP client to use for making requests
426/// * `oauth_client` - The OAuth client configuration
427/// * `dpop_key_data` - The key data for creating DPoP proofs
428/// * `refresh_token` - The refresh token to exchange for new tokens
429/// * `document` - The identity document containing PDS endpoints
430///
431/// # Returns
432/// A `TokenResponse` containing the new access token, refresh token, and metadata on success.
433///
434/// # Errors
435/// Returns `OAuthClientError` if the token refresh fails or response parsing fails.
436pub async fn oauth_refresh(
437    http_client: &reqwest::Client,
438    oauth_client: &OAuthClient,
439    dpop_key_data: &KeyData,
440    refresh_token: &str,
441    document: &atproto_identity::model::Document,
442) -> Result<TokenResponse, OAuthClientError> {
443    let pds_endpoints = document.pds_endpoints();
444    let pds_endpoint = pds_endpoints
445        .first()
446        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
447    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
448
449    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
450        .try_into()
451        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
452
453    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
454    let client_assertion_claims = Claims::new(JoseClaims {
455        issuer: Some(oauth_client.client_id.clone()),
456        subject: Some(oauth_client.client_id.clone()),
457        audience: Some(authorization_server.issuer.clone()),
458        json_web_token_id: Some(client_assertion_jti),
459        issued_at: Some(chrono::Utc::now().timestamp() as u64),
460        ..Default::default()
461    });
462
463    let client_assertion_token = mint(
464        &oauth_client.private_signing_key_data,
465        &client_assertion_header,
466        &client_assertion_claims,
467    )
468    .map_err(OAuthClientError::MintTokenFailed)?;
469
470    let params = [
471        ("client_id", oauth_client.client_id.as_str()),
472        ("redirect_uri", oauth_client.redirect_uri.as_str()),
473        ("grant_type", "refresh_token"),
474        ("refresh_token", refresh_token),
475        (
476            "client_assertion_type",
477            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
478        ),
479        ("client_assertion", client_assertion_token.as_str()),
480    ];
481
482    let token_endpoint = authorization_server.token_endpoint.clone();
483
484    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
485        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
486
487    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
488
489    let dpop_retry_client = ClientBuilder::new(http_client.clone())
490        .with(ChainMiddleware::new(dpop_retry.clone()))
491        .build();
492
493    dpop_retry_client
494        .post(token_endpoint)
495        .header("DPoP", dpop_token.as_str())
496        .form(&params)
497        .send()
498        .await
499        .map_err(OAuthClientError::TokenHttpRequestFailed)?
500        .json()
501        .await
502        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
503}