atproto_oauth/
workflow.rs

1//! OAuth workflow implementation for AT Protocol authorization flows.
2//!
3//! This module provides a complete OAuth 2.0 authorization code flow implementation specifically
4//! designed for AT Protocol, including Pushed Authorization Requests (PAR), DPoP security, PKCE
5//! protection, and client assertion handling.
6//!
7//! ## OAuth Flow Components
8//!
9//! - **`oauth_init()`**: Initiates OAuth flow with PAR request to authorization server
10//! - **`oauth_complete()`**: Completes OAuth flow by exchanging authorization code for tokens
11//! - **`OAuthClient`**: Client configuration with credentials and signing keys
12//! - **`OAuthRequest`**: Tracking structure for ongoing authorization requests
13//! - **`OAuthRequestState`**: Security parameters including state, nonce, and PKCE challenge
14//!
15//! ## Security Features
16//!
17//! - **PKCE**: Proof Key for Code Exchange protection against authorization code interception
18//! - **DPoP**: Demonstration of Proof-of-Possession for token binding
19//! - **Client Assertions**: JWT-based client authentication using private key signatures
20//! - **State Parameters**: CSRF protection using random state values
21//! - **Nonce Values**: Additional replay protection
22//!
23//! ## Example Usage
24//!
25//! ```rust,ignore
26//! use atproto_oauth::workflow::{oauth_init, oauth_complete, OAuthClient, OAuthRequestState};
27//! use atproto_oauth::pkce::generate;
28//! use atproto_identity::key::generate_key;
29//!
30//! // Generate security parameters
31//! let signing_key = generate_key(KeyType::P256Private)?;
32//! let dpop_key = generate_key(KeyType::P256Private)?;
33//! let (pkce_verifier, code_challenge) = generate();
34//!
35//! // Configure OAuth client
36//! let oauth_client = OAuthClient {
37//!     redirect_uri: "https://app.example.com/callback".to_string(),
38//!     client_id: "https://app.example.com/client-metadata.json".to_string(),
39//!     private_signing_key_data: signing_key,
40//! };
41//!
42//! // Create request state
43//! let oauth_state = OAuthRequestState {
44//!     state: "random-state-value".to_string(),
45//!     nonce: "random-nonce-value".to_string(),
46//!     code_challenge,
47//!     scope: "atproto transition:generic".to_string(),
48//! };
49//!
50//! // Initiate OAuth flow
51//! let par_response = oauth_init(
52//!     &http_client,
53//!     &oauth_client,
54//!     &dpop_key,
55//!     "user.bsky.social",
56//!     &authorization_server,
57//!     &oauth_state,
58//! ).await?;
59//!
60//! // Build authorization URL
61//! let auth_url = format!(
62//!     "{}?client_id={}&request_uri={}",
63//!     authorization_server.authorization_endpoint,
64//!     oauth_client.client_id,
65//!     par_response.request_uri
66//! );
67//!
68//! // After user authorization and callback...
69//! let token_response = oauth_complete(
70//!     &http_client,
71//!     &oauth_client,
72//!     &dpop_key,
73//!     "authorization_code_from_callback",
74//!     &oauth_request,
75//!     &did_document,
76//! ).await?;
77//! ```
78
79use crate::{
80    dpop::{DpopRetry, auth_dpop},
81    errors::OAuthClientError,
82    jwt::{Claims, Header, JoseClaims, mint},
83    resources::{AuthorizationServer, pds_resources},
84};
85use atproto_identity::key::KeyData;
86use chrono::{DateTime, Utc};
87use rand::distributions::{Alphanumeric, DistString};
88use reqwest_chain::ChainMiddleware;
89use reqwest_middleware::ClientBuilder;
90
91use std::collections::HashMap;
92
93use serde::Deserialize;
94
95/// Response from a Pushed Authorization Request (PAR) endpoint.
96///
97/// Contains the request URI and expiration time returned by the authorization
98/// server after successfully processing a pushed authorization request.
99#[derive(Clone, Deserialize)]
100pub struct ParResponse {
101    /// The request URI to use in the authorization request.
102    pub request_uri: String,
103    /// The lifetime of the request URI in seconds.
104    pub expires_in: u64,
105
106    /// Additional fields returned by the authorization server.
107    #[serde(flatten)]
108    pub extra: HashMap<String, serde_json::Value>,
109}
110
111/// OAuth request state containing security parameters for the authorization flow.
112///
113/// This struct holds the security parameters needed to maintain state
114/// and prevent attacks during the OAuth authorization code flow.
115pub struct OAuthRequestState {
116    /// Random state parameter to prevent CSRF attacks.
117    pub state: String,
118    /// Random nonce value for additional security.
119    pub nonce: String,
120    /// PKCE code challenge derived from the code verifier.
121    pub code_challenge: String,
122    /// The scope of access requested for the authorization.
123    pub scope: String,
124}
125
126/// OAuth client configuration containing essential client credentials.
127///
128/// This struct holds the client configuration needed for OAuth authorization flows,
129/// including the redirect URI, client identifier, and signing key.
130pub struct OAuthClient {
131    /// The redirect URI where the authorization server will send the user after authorization.
132    pub redirect_uri: String,
133    /// The unique client identifier for this OAuth client.
134    pub client_id: String,
135    /// The private key data used for signing client assertions.
136    pub private_signing_key_data: KeyData,
137}
138
139/// OAuth request tracking information for ongoing authorization flows.
140///
141/// This struct contains all the necessary information to track and complete
142/// an OAuth authorization request, including security parameters and timing.
143#[derive(Clone, PartialEq)]
144pub struct OAuthRequest {
145    /// The OAuth state parameter used to prevent CSRF attacks.
146    pub oauth_state: String,
147    /// The authorization server issuer identifier.
148    pub issuer: String,
149    /// The DID (Decentralized Identifier) of the user.
150    pub did: String,
151    /// The nonce value for additional security.
152    pub nonce: String,
153    /// The PKCE code verifier for this authorization request.
154    pub pkce_verifier: String,
155    /// The public key used for signing (serialized).
156    pub signing_public_key: String,
157    /// The DPoP private key (serialized).
158    pub dpop_private_key: String,
159    /// When this OAuth request was created.
160    pub created_at: DateTime<Utc>,
161    /// When this OAuth request expires.
162    pub expires_at: DateTime<Utc>,
163}
164
165impl std::fmt::Debug for OAuthRequest {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("OAuthRequest")
168            .field("oauth_state", &self.oauth_state)
169            .field("issuer", &self.issuer)
170            .field("did", &self.did)
171            .field("nonce", &self.nonce)
172            .field("pkce_verifier", &"[REDACTED]")
173            .field("signing_public_key", &self.signing_public_key)
174            .field("dpop_private_key", &"[REDACTED]")
175            .field("created_at", &self.created_at)
176            .field("expires_at", &self.expires_at)
177            .finish()
178    }
179}
180
181/// Response from the OAuth token endpoint containing access credentials.
182///
183/// This struct represents the successful response from an OAuth token exchange,
184/// containing the access token and related metadata.
185#[derive(Clone, Deserialize)]
186pub struct TokenResponse {
187    /// The access token that can be used to access protected resources.
188    pub access_token: String,
189    /// The type of token, typically "Bearer" or "DPoP".
190    pub token_type: String,
191    /// The refresh token that can be used to obtain new access tokens.
192    pub refresh_token: String,
193    /// The scope of access granted by the access token.
194    pub scope: String,
195    /// The lifetime of the access token in seconds.
196    pub expires_in: u32,
197    /// The subject identifier (usually the user's DID).
198    pub sub: Option<String>,
199
200    /// Additional fields returned by the authorization server.
201    #[serde(flatten)]
202    pub extra: HashMap<String, serde_json::Value>,
203}
204
205/// Initiates the OAuth authorization flow by making a Pushed Authorization Request (PAR).
206///
207/// This function creates a PAR request to the authorization server with the necessary
208/// OAuth parameters, DPoP proof, and client assertion. It handles the complete setup
209/// for the AT Protocol OAuth flow including PKCE and DPoP security mechanisms.
210///
211/// # Arguments
212/// * `http_client` - The HTTP client to use for making requests
213/// * `private_signing_key_data` - The private key for signing client assertions
214/// * `dpop_key_data` - The key data for creating DPoP proofs
215/// * `handle` - The user's handle for the login hint
216/// * `authorization_server` - The authorization server configuration
217/// * `oauth_request_state` - The OAuth state parameters for this request
218///
219/// # Returns
220/// A `ParResponse` containing the request URI and expiration time on success.
221///
222/// # Errors
223/// Returns `OAuthClientError` if the PAR request fails or response parsing fails.
224pub async fn oauth_init(
225    http_client: &reqwest::Client,
226    oauth_client: &OAuthClient,
227    dpop_key_data: &KeyData,
228    handle: &str,
229    authorization_server: &AuthorizationServer,
230    oauth_request_state: &OAuthRequestState,
231) -> Result<ParResponse, OAuthClientError> {
232    let par_url = authorization_server
233        .pushed_authorization_request_endpoint
234        .clone();
235
236    let scope = &oauth_request_state.scope;
237
238    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
239        .try_into()
240        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
241
242    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
243    let client_assertion_claims = Claims::new(JoseClaims {
244        issuer: Some(oauth_client.client_id.clone()),
245        subject: Some(oauth_client.client_id.clone()),
246        audience: Some(authorization_server.issuer.clone()),
247        json_web_token_id: Some(client_assertion_jti),
248        issued_at: Some(chrono::Utc::now().timestamp() as u64),
249        ..Default::default()
250    });
251
252    let client_assertion_token = mint(
253        &oauth_client.private_signing_key_data,
254        &client_assertion_header,
255        &client_assertion_claims,
256    )
257    .map_err(OAuthClientError::MintTokenFailed)?;
258
259    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
260        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
261
262    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
263
264    let dpop_retry_client = ClientBuilder::new(http_client.clone())
265        .with(ChainMiddleware::new(dpop_retry.clone()))
266        .build();
267
268    let params = [
269        ("response_type", "code"),
270        ("code_challenge", &oauth_request_state.code_challenge),
271        ("code_challenge_method", "S256"),
272        ("client_id", oauth_client.client_id.as_str()),
273        ("state", oauth_request_state.state.as_str()),
274        ("redirect_uri", oauth_client.redirect_uri.as_str()),
275        ("scope", scope),
276        ("login_hint", handle),
277        (
278            "client_assertion_type",
279            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
280        ),
281        ("client_assertion", client_assertion_token.as_str()),
282    ];
283
284    let response = dpop_retry_client
285        .post(par_url)
286        .header("DPoP", dpop_token.as_str())
287        .form(&params)
288        .send()
289        .await
290        .map_err(OAuthClientError::PARHttpRequestFailed)?
291        .json()
292        .await
293        .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
294
295    Ok(response)
296}
297
298/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
299///
300/// This function performs the final step of the OAuth authorization code flow by
301/// exchanging the authorization code received from the callback for access and refresh tokens.
302/// It handles DPoP proof generation and client assertion creation for secure token exchange.
303///
304/// # Arguments
305/// * `http_client` - The HTTP client to use for making requests
306/// * `oauth_client` - The OAuth client configuration
307/// * `dpop_key_data` - The key data for creating DPoP proofs
308/// * `callback_code` - The authorization code received from the callback
309/// * `oauth_request` - The original OAuth request state
310/// * `document` - The identity document containing PDS endpoints
311///
312/// # Returns
313/// A `TokenResponse` containing the access token, refresh token, and metadata on success.
314///
315/// # Errors
316/// Returns `OAuthClientError` if the token exchange fails or response parsing fails.
317pub async fn oauth_complete(
318    http_client: &reqwest::Client,
319    oauth_client: &OAuthClient,
320    dpop_key_data: &KeyData,
321    callback_code: &str,
322    oauth_request: &OAuthRequest,
323    document: &atproto_identity::model::Document,
324) -> Result<TokenResponse, OAuthClientError> {
325    let pds_endpoints = document.pds_endpoints();
326    let pds_endpoint = pds_endpoints
327        .first()
328        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
329    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
330
331    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
332        .try_into()
333        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
334
335    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
336    let client_assertion_claims = Claims::new(JoseClaims {
337        issuer: Some(oauth_client.client_id.clone()),
338        subject: Some(oauth_client.client_id.clone()),
339        audience: Some(authorization_server.issuer.clone()),
340        json_web_token_id: Some(client_assertion_jti),
341        issued_at: Some(chrono::Utc::now().timestamp() as u64),
342        ..Default::default()
343    });
344
345    let client_assertion_token = mint(
346        &oauth_client.private_signing_key_data,
347        &client_assertion_header,
348        &client_assertion_claims,
349    )
350    .map_err(OAuthClientError::MintTokenFailed)?;
351
352    let params = [
353        ("client_id", oauth_client.client_id.as_str()),
354        ("redirect_uri", oauth_client.redirect_uri.as_str()),
355        ("grant_type", "authorization_code"),
356        ("code", callback_code),
357        ("code_verifier", &oauth_request.pkce_verifier),
358        (
359            "client_assertion_type",
360            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
361        ),
362        ("client_assertion", client_assertion_token.as_str()),
363    ];
364
365    let token_endpoint = authorization_server.token_endpoint.clone();
366
367    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
368        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
369
370    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
371
372    let dpop_retry_client = ClientBuilder::new(http_client.clone())
373        .with(ChainMiddleware::new(dpop_retry.clone()))
374        .build();
375
376    dpop_retry_client
377        .post(token_endpoint)
378        .header("DPoP", dpop_token.as_str())
379        .form(&params)
380        .send()
381        .await
382        .map_err(OAuthClientError::TokenHttpRequestFailed)?
383        .json()
384        .await
385        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
386}
387
388/// Refreshes OAuth access tokens using a refresh token.
389///
390/// This function exchanges a refresh token for new access and refresh tokens.
391/// It handles DPoP proof generation and client assertion creation for secure
392/// token refresh operations according to AT Protocol OAuth requirements.
393///
394/// # Arguments
395/// * `http_client` - The HTTP client to use for making requests
396/// * `oauth_client` - The OAuth client configuration
397/// * `dpop_key_data` - The key data for creating DPoP proofs
398/// * `refresh_token` - The refresh token to exchange for new tokens
399/// * `document` - The identity document containing PDS endpoints
400///
401/// # Returns
402/// A `TokenResponse` containing the new access token, refresh token, and metadata on success.
403///
404/// # Errors
405/// Returns `OAuthClientError` if the token refresh fails or response parsing fails.
406pub async fn oauth_refresh(
407    http_client: &reqwest::Client,
408    oauth_client: &OAuthClient,
409    dpop_key_data: &KeyData,
410    refresh_token: &str,
411    document: &atproto_identity::model::Document,
412) -> Result<TokenResponse, OAuthClientError> {
413    let pds_endpoints = document.pds_endpoints();
414    let pds_endpoint = pds_endpoints
415        .first()
416        .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
417    let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
418
419    let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
420        .try_into()
421        .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
422
423    let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
424    let client_assertion_claims = Claims::new(JoseClaims {
425        issuer: Some(oauth_client.client_id.clone()),
426        subject: Some(oauth_client.client_id.clone()),
427        audience: Some(authorization_server.issuer.clone()),
428        json_web_token_id: Some(client_assertion_jti),
429        issued_at: Some(chrono::Utc::now().timestamp() as u64),
430        ..Default::default()
431    });
432
433    let client_assertion_token = mint(
434        &oauth_client.private_signing_key_data,
435        &client_assertion_header,
436        &client_assertion_claims,
437    )
438    .map_err(OAuthClientError::MintTokenFailed)?;
439
440    let params = [
441        ("client_id", oauth_client.client_id.as_str()),
442        ("redirect_uri", oauth_client.redirect_uri.as_str()),
443        ("grant_type", "refresh_token"),
444        ("refresh_token", refresh_token),
445        (
446            "client_assertion_type",
447            "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
448        ),
449        ("client_assertion", client_assertion_token.as_str()),
450    ];
451
452    let token_endpoint = authorization_server.token_endpoint.clone();
453
454    let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
455        .map_err(OAuthClientError::DpopTokenCreationFailed)?;
456
457    let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
458
459    let dpop_retry_client = ClientBuilder::new(http_client.clone())
460        .with(ChainMiddleware::new(dpop_retry.clone()))
461        .build();
462
463    dpop_retry_client
464        .post(token_endpoint)
465        .header("DPoP", dpop_token.as_str())
466        .form(&params)
467        .send()
468        .await
469        .map_err(OAuthClientError::TokenHttpRequestFailed)?
470        .json()
471        .await
472        .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
473}