atproto_oauth_aip/
workflow.rs

1//! # OAuth 2.0 Workflow Implementation for AT Protocol Identity Providers
2//!
3//! This module provides a complete OAuth 2.0 authorization code flow implementation
4//! specifically designed for AT Protocol Identity Providers (AIPs). It handles the
5//! three main phases of OAuth authentication: initialization, completion, and session exchange.
6//!
7//! ## Workflow Overview
8//!
9//! The OAuth workflow consists of three main functions that handle different phases:
10//!
11//! 1. **Initialization (`oauth_init`)**: Creates a Pushed Authorization Request (PAR)
12//!    and returns the authorization URL for user consent
13//! 2. **Completion (`oauth_complete`)**: Exchanges the authorization code for access tokens
14//! 3. **Session Exchange (`session_exchange`)**: Converts OAuth tokens to AT Protocol sessions
15//!
16//! ## Security Features
17//!
18//! - **Pushed Authorization Requests (PAR)**: Enhanced security by storing authorization
19//!   parameters server-side rather than in redirect URLs
20//! - **PKCE (Proof Key for Code Exchange)**: Protection against authorization code
21//!   interception attacks
22//! - **DPoP (Demonstration of Proof-of-Possession)**: Cryptographic binding of tokens
23//!   to specific keys for enhanced security
24//!
25//! ## Usage Example
26//!
27//! ```rust,no_run
28//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
29//! use atproto_oauth_aip::workflow::{oauth_init, oauth_complete, session_exchange, OAuthClient};
30//! use atproto_oauth::resources::{AuthorizationServer, OAuthProtectedResource};
31//! use atproto_oauth::workflow::{OAuthRequestState, OAuthRequest};
32//!
33//! let http_client = reqwest::Client::new();
34//!
35//! // 1. Initialize OAuth flow
36//! let oauth_client = OAuthClient {
37//!     redirect_uri: "https://myapp.com/callback".to_string(),
38//!     client_id: "my_client_id".to_string(),
39//!     client_secret: "my_client_secret".to_string(),
40//! };
41//!
42//! # let authorization_server = AuthorizationServer {
43//! #     issuer: "https://auth.example.com".to_string(),
44//! #     authorization_endpoint: "https://auth.example.com/authorize".to_string(),
45//! #     token_endpoint: "https://auth.example.com/token".to_string(),
46//! #     pushed_authorization_request_endpoint: "https://auth.example.com/par".to_string(),
47//! #     introspection_endpoint: "".to_string(),
48//! #     scopes_supported: vec!["atproto".to_string(), "transition:generic".to_string()],
49//! #     response_types_supported: vec!["code".to_string()],
50//! #     grant_types_supported: vec!["authorization_code".to_string(), "refresh_token".to_string()],
51//! #     token_endpoint_auth_methods_supported: vec!["none".to_string(), "private_key_jwt".to_string()],
52//! #     token_endpoint_auth_signing_alg_values_supported: vec!["ES256".to_string()],
53//! #     require_pushed_authorization_requests: true,
54//! #     request_parameter_supported: false,
55//! #     code_challenge_methods_supported: vec!["S256".to_string()],
56//! #     authorization_response_iss_parameter_supported: true,
57//! #     dpop_signing_alg_values_supported: vec!["ES256".to_string()],
58//! #     client_id_metadata_document_supported: true,
59//! # };
60//!
61//! let oauth_request_state = OAuthRequestState {
62//!     state: "random-state".to_string(),
63//!     nonce: "random-nonce".to_string(),
64//!     code_challenge: "code-challenge".to_string(),
65//!     scope: "atproto transition:generic".to_string(),
66//! };
67//!
68//! let par_response = oauth_init(
69//!     &http_client,
70//!     &oauth_client,
71//!     Some("user.bsky.social"),
72//!     &authorization_server.pushed_authorization_request_endpoint,
73//!     &oauth_request_state
74//! ).await?;
75//!
76//! // User visits auth_url and grants consent, returns with authorization code
77//!
78//! // 2. Complete OAuth flow
79//! # let oauth_request = OAuthRequest {
80//! #     oauth_state: "state".to_string(),
81//! #     issuer: "https://auth.example.com".to_string(),
82//! #     authorization_server: "https://auth.example.com".to_string(),
83//! #     nonce: "nonce".to_string(),
84//! #     signing_public_key: "public_key".to_string(),
85//! #     pkce_verifier: "verifier".to_string(),
86//! #     dpop_private_key: "private_key".to_string(),
87//! #     created_at: chrono::Utc::now(),
88//! #     expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
89//! # };
90//! let token_response = oauth_complete(
91//!     &http_client,
92//!     &oauth_client,
93//!     &authorization_server.token_endpoint,
94//!     "received_auth_code",
95//!     &oauth_request
96//! ).await?;
97//!
98//! // 3. Exchange for AT Protocol session
99//! # let protected_resource = OAuthProtectedResource {
100//! #     resource: "https://pds.example.com".to_string(),
101//! #     scopes_supported: vec!["atproto".to_string()],
102//! #     bearer_methods_supported: vec!["header".to_string()],
103//! #     authorization_servers: vec!["https://auth.example.com".to_string()],
104//! # };
105//! let session = session_exchange(
106//!     &http_client,
107//!     &protected_resource.resource,
108//!     &token_response.access_token
109//! ).await?;
110//! # Ok(())
111//! # }
112//! ```
113//!
114//! ## Error Handling
115//!
116//! All functions return `Result<T, OAuthWorkflowError>` with detailed error information
117//! for each phase of the OAuth flow including network failures, parsing errors,
118//! and protocol violations.
119
120use anyhow::Result;
121use atproto_oauth::workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse};
122use serde::Deserialize;
123
124use crate::errors::OAuthWorkflowError;
125
126#[cfg(feature = "zeroize")]
127use zeroize::{Zeroize, ZeroizeOnDrop};
128
129/// OAuth client configuration containing essential client credentials.
130#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
131pub struct OAuthClient {
132    /// The redirect URI where the authorization server will send the user after authorization.
133    #[cfg_attr(feature = "zeroize", zeroize(skip))]
134    pub redirect_uri: String,
135
136    /// The unique client identifier for this OAuth client.
137    #[cfg_attr(feature = "zeroize", zeroize(skip))]
138    pub client_id: String,
139
140    /// The client secret used for authenticating with the authorization server.
141    pub client_secret: String,
142}
143
144#[derive(Clone, Deserialize)]
145#[serde(untagged)]
146enum WrappedParResponse {
147    ParResponse(ParResponse),
148    Error {
149        error: String,
150        error_description: Option<String>,
151    },
152}
153
154/// Represents an authenticated AT Protocol session.
155///
156/// This structure contains all the information needed to make authenticated
157/// requests to AT Protocol services after a successful OAuth flow.
158#[derive(Clone, Deserialize)]
159#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
160pub struct ATProtocolSession {
161    /// The Decentralized Identifier (DID) of the authenticated user.
162    #[cfg_attr(feature = "zeroize", zeroize(skip))]
163    pub did: String,
164
165    /// The handle (username) of the authenticated user.
166    #[cfg_attr(feature = "zeroize", zeroize(skip))]
167    pub handle: String,
168
169    /// The OAuth access token for making authenticated requests.
170    pub access_token: String,
171
172    /// The type of token (typically "Bearer").
173    pub token_type: String,
174
175    /// The list of OAuth scopes granted to this session.
176    #[cfg_attr(feature = "zeroize", zeroize(skip))]
177    pub scopes: Vec<String>,
178
179    /// The Personal Data Server (PDS) endpoint URL for this user.
180    #[cfg_attr(feature = "zeroize", zeroize(skip))]
181    pub pds_endpoint: String,
182
183    /// The DPoP (Demonstration of Proof-of-Possession) key in JWK format.
184    pub dpop_key: String,
185
186    /// Unix timestamp indicating when this session expires.
187    #[cfg_attr(feature = "zeroize", zeroize(skip))]
188    pub expires_at: i64,
189}
190
191#[derive(Deserialize, Clone)]
192#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
193#[serde(untagged)]
194enum WrappedATProtocolSession {
195    ATProtocolSession(ATProtocolSession),
196
197    #[cfg_attr(feature = "zeroize", zeroize(skip))]
198    Error {
199        error: String,
200        error_description: Option<String>,
201    },
202}
203
204/// Initiates an OAuth authorization flow using Pushed Authorization Request (PAR).
205///
206/// This function starts the OAuth flow by sending a PAR request to the authorization
207/// server. PAR allows the client to push the authorization request parameters to the
208/// authorization server before redirecting the user, providing enhanced security.
209///
210/// # Arguments
211///
212/// * `http_client` - The HTTP client to use for making requests
213/// * `oauth_client` - OAuth client configuration with credentials
214/// * `handle` - Optional user handle to pre-fill in the login form
215/// * `authorization_server` - Authorization server metadata
216/// * `oauth_request_state` - OAuth request state including PKCE challenge and state
217///
218/// # Returns
219///
220/// Returns a `ParResponse` containing the request URI to redirect the user to,
221/// or an error if the PAR request fails.
222///
223/// # Example
224///
225/// ```no_run
226/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
227/// use atproto_oauth_aip::workflow::{oauth_init, OAuthClient};
228/// use atproto_oauth::workflow::OAuthRequestState;
229/// # let http_client = reqwest::Client::new();
230/// let oauth_client = OAuthClient {
231///     redirect_uri: "https://example.com/callback".to_string(),
232///     client_id: "client123".to_string(),
233///     client_secret: "secret456".to_string(),
234/// };
235/// # let authorization_server = "https://auth.example.com/par";
236/// let oauth_request_state = OAuthRequestState {
237///     state: "random-state".to_string(),
238///     nonce: "random-nonce".to_string(),
239///     code_challenge: "code-challenge".to_string(),
240///     scope: "atproto transition:generic".to_string(),
241/// };
242/// let par_response = oauth_init(
243///     &http_client,
244///     &oauth_client,
245///     Some("alice.bsky.social"),
246///     authorization_server,
247///     &oauth_request_state,
248/// ).await?;
249/// # Ok(())
250/// # }
251/// ```
252pub async fn oauth_init(
253    http_client: &reqwest::Client,
254    oauth_client: &OAuthClient,
255    login_hint: Option<&str>,
256    par_url: &str,
257    oauth_request_state: &OAuthRequestState,
258) -> Result<ParResponse> {
259    let scope = &oauth_request_state.scope;
260
261    let mut params = vec![
262        ("client_id", oauth_client.client_id.as_str()),
263        ("code_challenge_method", "S256"),
264        ("code_challenge", &oauth_request_state.code_challenge),
265        ("redirect_uri", oauth_client.redirect_uri.as_str()),
266        ("response_type", "code"),
267        ("scope", scope),
268        ("state", oauth_request_state.state.as_str()),
269    ];
270    if let Some(value) = login_hint {
271        params.push(("login_hint", value));
272    }
273
274    let response: WrappedParResponse = http_client
275        .post(par_url)
276        .form(&params)
277        .basic_auth(
278            oauth_client.client_id.as_str(),
279            Some(oauth_client.client_secret.as_str()),
280        )
281        .send()
282        .await
283        .map_err(OAuthWorkflowError::ParRequestFailed)?
284        .json()
285        .await
286        .map_err(OAuthWorkflowError::ParResponseParseFailed)?;
287
288    match response {
289        WrappedParResponse::ParResponse(value) => Ok(value),
290        WrappedParResponse::Error {
291            error,
292            error_description,
293        } => {
294            let error_message = if let Some(value) = error_description {
295                format!("{error}: {value}")
296            } else {
297                error.to_string()
298            };
299            Err(OAuthWorkflowError::ParResponseInvalid {
300                message: error_message,
301            }
302            .into())
303        }
304    }
305}
306
307/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
308///
309/// After the user has authorized the application and been redirected back with an
310/// authorization code, this function exchanges that code for access tokens using
311/// the token endpoint.
312///
313/// # Arguments
314///
315/// * `http_client` - The HTTP client to use for making requests
316/// * `oauth_client` - OAuth client configuration with credentials
317/// * `authorization_server` - Authorization server metadata
318/// * `callback_code` - The authorization code received in the callback
319/// * `oauth_request` - The original OAuth request containing the PKCE verifier
320///
321/// # Returns
322///
323/// Returns a `TokenResponse` containing the access token and other token information,
324/// or an error if the token exchange fails.
325///
326/// # Example
327///
328/// ```no_run
329/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
330/// use atproto_oauth_aip::workflow::oauth_complete;
331/// # let http_client = reqwest::Client::new();
332/// # let oauth_client = todo!();
333/// # let token_endpoint = "https://auth.example.com/token";
334/// # let oauth_request = todo!();
335/// let token_response = oauth_complete(
336///     &http_client,
337///     &oauth_client,
338///     token_endpoint,
339///     "auth_code_from_callback",
340///     &oauth_request,
341/// ).await?;
342/// println!("Access token: {}", token_response.access_token);
343/// # Ok(())
344/// # }
345/// ```
346pub async fn oauth_complete(
347    http_client: &reqwest::Client,
348    oauth_client: &OAuthClient,
349    token_endpoint: &str,
350    callback_code: &str,
351    oauth_request: &OAuthRequest,
352) -> Result<TokenResponse> {
353    let params = [
354        ("client_id", oauth_client.client_id.as_str()),
355        ("redirect_uri", oauth_client.redirect_uri.as_str()),
356        ("grant_type", "authorization_code"),
357        ("code", callback_code),
358        ("code_verifier", &oauth_request.pkce_verifier),
359    ];
360
361    http_client
362        .post(token_endpoint)
363        .basic_auth(
364            oauth_client.client_id.as_str(),
365            Some(oauth_client.client_secret.as_str()),
366        )
367        .form(&params)
368        .send()
369        .await
370        .inspect(|value| {
371            println!("{value:?}");
372        })
373        .map_err(OAuthWorkflowError::TokenRequestFailed)?
374        .json()
375        .await
376        .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
377}
378
379/// Exchanges an OAuth access token for an AT Protocol session.
380///
381/// This function takes an OAuth access token and exchanges it for a full
382/// AT Protocol session, which includes additional information like the user's
383/// DID, handle, and PDS endpoint. This is specific to AT Protocol's OAuth
384/// implementation.
385///
386/// # Arguments
387///
388/// * `http_client` - The HTTP client to use for making requests
389/// * `protected_resource` - The protected resource metadata
390/// * `access_token` - The OAuth access token to exchange
391///
392/// # Returns
393///
394/// Returns an `ATProtocolSession` with full session information,
395/// or an error if the session exchange fails.
396///
397/// # Example
398///
399/// ```no_run
400/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
401/// use atproto_oauth_aip::workflow::session_exchange;
402/// # let http_client = reqwest::Client::new();
403/// # let protected_resource = "https://pds.example.com";
404/// # let access_token = "example_token";
405/// let session = session_exchange(
406///     &http_client,
407///     protected_resource,
408///     access_token,
409/// ).await?;
410/// println!("Authenticated as {} ({})", session.handle, session.did);
411/// println!("PDS endpoint: {}", session.pds_endpoint);
412/// # Ok(())
413/// # }
414/// ```
415pub async fn session_exchange(
416    http_client: &reqwest::Client,
417    protected_resource_base: &str,
418    access_token: &str,
419) -> Result<ATProtocolSession> {
420    let response = http_client
421        .get(format!(
422            "{}/api/atprotocol/session",
423            protected_resource_base
424        ))
425        .bearer_auth(access_token)
426        .send()
427        .await
428        .map_err(OAuthWorkflowError::SessionRequestFailed)?
429        .json()
430        .await
431        .map_err(OAuthWorkflowError::SessionResponseParseFailed)?;
432
433    match response {
434        WrappedATProtocolSession::ATProtocolSession(ref value) => Ok(value.clone()),
435        WrappedATProtocolSession::Error {
436            ref error,
437            ref error_description,
438        } => {
439            let error_message = if let Some(value) = error_description {
440                format!("{error}: {value}")
441            } else {
442                error.to_string()
443            };
444            Err(OAuthWorkflowError::SessionResponseInvalid {
445                message: error_message,
446            }
447            .into())
448        }
449    }
450}