atproto_oauth_aip/
workflow.rs

1//! OAuth 2.0 workflow for AT Protocol identity providers.
2//!
3//! Complete authorization code flow implementation with PAR initialization,
4//! code exchange, and AT Protocol session establishment.
5//! 1. **Initialization (`oauth_init`)**: Creates a Pushed Authorization Request (PAR)
6//!    and returns the authorization URL for user consent
7//! 2. **Completion (`oauth_complete`)**: Exchanges the authorization code for access tokens
8//! 3. **Session Exchange (`session_exchange`)**: Converts OAuth tokens to AT Protocol sessions
9//!
10//! ## Security Features
11//!
12//! - **Pushed Authorization Requests (PAR)**: Enhanced security by storing authorization
13//!   parameters server-side rather than in redirect URLs
14//! - **PKCE (Proof Key for Code Exchange)**: Protection against authorization code
15//!   interception attacks
16//! - **DPoP (Demonstration of Proof-of-Possession)**: Cryptographic binding of tokens
17//!   to specific keys for enhanced security
18//!
19//! ## Usage Example
20//!
21//! ```rust,no_run
22//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! use atproto_oauth_aip::workflow::{oauth_init, oauth_complete, session_exchange, OAuthClient};
24//! use atproto_oauth::resources::{AuthorizationServer, OAuthProtectedResource};
25//! use atproto_oauth::workflow::{OAuthRequestState, OAuthRequest};
26//!
27//! let http_client = reqwest::Client::new();
28//!
29//! // 1. Initialize OAuth flow
30//! let oauth_client = OAuthClient {
31//!     redirect_uri: "https://myapp.com/callback".to_string(),
32//!     client_id: "my_client_id".to_string(),
33//!     client_secret: "my_client_secret".to_string(),
34//! };
35//!
36//! # let authorization_server = AuthorizationServer {
37//! #     issuer: "https://auth.example.com".to_string(),
38//! #     authorization_endpoint: "https://auth.example.com/authorize".to_string(),
39//! #     token_endpoint: "https://auth.example.com/token".to_string(),
40//! #     pushed_authorization_request_endpoint: "https://auth.example.com/par".to_string(),
41//! #     introspection_endpoint: "".to_string(),
42//! #     scopes_supported: vec!["atproto".to_string(), "transition:generic".to_string()],
43//! #     response_types_supported: vec!["code".to_string()],
44//! #     grant_types_supported: vec!["authorization_code".to_string(), "refresh_token".to_string()],
45//! #     token_endpoint_auth_methods_supported: vec!["none".to_string(), "private_key_jwt".to_string()],
46//! #     token_endpoint_auth_signing_alg_values_supported: vec!["ES256".to_string()],
47//! #     require_pushed_authorization_requests: true,
48//! #     request_parameter_supported: false,
49//! #     code_challenge_methods_supported: vec!["S256".to_string()],
50//! #     authorization_response_iss_parameter_supported: true,
51//! #     dpop_signing_alg_values_supported: vec!["ES256".to_string()],
52//! #     client_id_metadata_document_supported: true,
53//! # };
54//!
55//! let oauth_request_state = OAuthRequestState {
56//!     state: "random-state".to_string(),
57//!     nonce: "random-nonce".to_string(),
58//!     code_challenge: "code-challenge".to_string(),
59//!     scope: "atproto transition:generic".to_string(),
60//! };
61//!
62//! let par_response = oauth_init(
63//!     &http_client,
64//!     &oauth_client,
65//!     Some("user.bsky.social"),
66//!     &authorization_server.pushed_authorization_request_endpoint,
67//!     &oauth_request_state
68//! ).await?;
69//!
70//! // User visits auth_url and grants consent, returns with authorization code
71//!
72//! // 2. Complete OAuth flow
73//! # let oauth_request = OAuthRequest {
74//! #     oauth_state: "state".to_string(),
75//! #     issuer: "https://auth.example.com".to_string(),
76//! #     authorization_server: "https://auth.example.com".to_string(),
77//! #     nonce: "nonce".to_string(),
78//! #     signing_public_key: "public_key".to_string(),
79//! #     pkce_verifier: "verifier".to_string(),
80//! #     dpop_private_key: "private_key".to_string(),
81//! #     created_at: chrono::Utc::now(),
82//! #     expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
83//! # };
84//! let token_response = oauth_complete(
85//!     &http_client,
86//!     &oauth_client,
87//!     &authorization_server.token_endpoint,
88//!     "received_auth_code",
89//!     &oauth_request
90//! ).await?;
91//!
92//! // 3. Exchange for AT Protocol session
93//! # let protected_resource = OAuthProtectedResource {
94//! #     resource: "https://pds.example.com".to_string(),
95//! #     scopes_supported: vec!["atproto".to_string()],
96//! #     bearer_methods_supported: vec!["header".to_string()],
97//! #     authorization_servers: vec!["https://auth.example.com".to_string()],
98//! # };
99//! let session = session_exchange(
100//!     &http_client,
101//!     &protected_resource.resource,
102//!     &token_response.access_token
103//! ).await?;
104//! # Ok(())
105//! # }
106//! ```
107//!
108//! ## Error Handling
109//!
110//! All functions return `Result<T, OAuthWorkflowError>` with detailed error information
111//! for each phase of the OAuth flow including network failures, parsing errors,
112//! and protocol violations.
113
114use anyhow::Result;
115use atproto_oauth::workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse};
116use serde::Deserialize;
117
118use crate::errors::OAuthWorkflowError;
119
120#[cfg(feature = "zeroize")]
121use zeroize::{Zeroize, ZeroizeOnDrop};
122
123/// OAuth client configuration containing essential client credentials.
124#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
125pub struct OAuthClient {
126    /// The redirect URI where the authorization server will send the user after authorization.
127    #[cfg_attr(feature = "zeroize", zeroize(skip))]
128    pub redirect_uri: String,
129
130    /// The unique client identifier for this OAuth client.
131    #[cfg_attr(feature = "zeroize", zeroize(skip))]
132    pub client_id: String,
133
134    /// The client secret used for authenticating with the authorization server.
135    pub client_secret: String,
136}
137
138#[derive(Clone, Deserialize)]
139#[serde(untagged)]
140enum WrappedParResponse {
141    ParResponse(ParResponse),
142    Error {
143        error: String,
144        error_description: Option<String>,
145    },
146}
147
148/// Represents an authenticated AT Protocol session.
149///
150/// This structure contains all the information needed to make authenticated
151/// requests to AT Protocol services after a successful OAuth flow.
152#[derive(Clone, Deserialize)]
153#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
154pub struct ATProtocolSession {
155    /// The Decentralized Identifier (DID) of the authenticated user.
156    #[cfg_attr(feature = "zeroize", zeroize(skip))]
157    pub did: String,
158
159    /// The handle (username) of the authenticated user.
160    #[cfg_attr(feature = "zeroize", zeroize(skip))]
161    pub handle: String,
162
163    /// The OAuth access token for making authenticated requests.
164    pub access_token: String,
165
166    /// The type of token (typically "Bearer").
167    pub token_type: String,
168
169    /// The list of OAuth scopes granted to this session.
170    #[cfg_attr(feature = "zeroize", zeroize(skip))]
171    pub scopes: Vec<String>,
172
173    /// The Personal Data Server (PDS) endpoint URL for this user.
174    #[cfg_attr(feature = "zeroize", zeroize(skip))]
175    pub pds_endpoint: String,
176
177    /// The DPoP (Demonstration of Proof-of-Possession) key in JWK format.
178    pub dpop_key: String,
179
180    /// Unix timestamp indicating when this session expires.
181    #[cfg_attr(feature = "zeroize", zeroize(skip))]
182    pub expires_at: i64,
183}
184
185#[derive(Deserialize, Clone)]
186#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
187#[serde(untagged)]
188enum WrappedATProtocolSession {
189    ATProtocolSession(ATProtocolSession),
190
191    #[cfg_attr(feature = "zeroize", zeroize(skip))]
192    Error {
193        error: String,
194        error_description: Option<String>,
195    },
196}
197
198/// Initiates an OAuth authorization flow using Pushed Authorization Request (PAR).
199///
200/// This function starts the OAuth flow by sending a PAR request to the authorization
201/// server. PAR allows the client to push the authorization request parameters to the
202/// authorization server before redirecting the user, providing enhanced security.
203///
204/// # Arguments
205///
206/// * `http_client` - The HTTP client to use for making requests
207/// * `oauth_client` - OAuth client configuration with credentials
208/// * `handle` - Optional user handle to pre-fill in the login form
209/// * `authorization_server` - Authorization server metadata
210/// * `oauth_request_state` - OAuth request state including PKCE challenge and state
211///
212/// # Returns
213///
214/// Returns a `ParResponse` containing the request URI to redirect the user to,
215/// or an error if the PAR request fails.
216///
217/// # Example
218///
219/// ```no_run
220/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
221/// use atproto_oauth_aip::workflow::{oauth_init, OAuthClient};
222/// use atproto_oauth::workflow::OAuthRequestState;
223/// # let http_client = reqwest::Client::new();
224/// let oauth_client = OAuthClient {
225///     redirect_uri: "https://example.com/callback".to_string(),
226///     client_id: "client123".to_string(),
227///     client_secret: "secret456".to_string(),
228/// };
229/// # let authorization_server = "https://auth.example.com/par";
230/// let oauth_request_state = OAuthRequestState {
231///     state: "random-state".to_string(),
232///     nonce: "random-nonce".to_string(),
233///     code_challenge: "code-challenge".to_string(),
234///     scope: "atproto transition:generic".to_string(),
235/// };
236/// let par_response = oauth_init(
237///     &http_client,
238///     &oauth_client,
239///     Some("alice.bsky.social"),
240///     authorization_server,
241///     &oauth_request_state,
242/// ).await?;
243/// # Ok(())
244/// # }
245/// ```
246pub async fn oauth_init(
247    http_client: &reqwest::Client,
248    oauth_client: &OAuthClient,
249    login_hint: Option<&str>,
250    par_url: &str,
251    oauth_request_state: &OAuthRequestState,
252) -> Result<ParResponse> {
253    let scope = &oauth_request_state.scope;
254
255    let mut params = vec![
256        ("client_id", oauth_client.client_id.as_str()),
257        ("code_challenge_method", "S256"),
258        ("code_challenge", &oauth_request_state.code_challenge),
259        ("redirect_uri", oauth_client.redirect_uri.as_str()),
260        ("response_type", "code"),
261        ("scope", scope),
262        ("state", oauth_request_state.state.as_str()),
263    ];
264    if let Some(value) = login_hint {
265        params.push(("login_hint", value));
266    }
267
268    let response: WrappedParResponse = http_client
269        .post(par_url)
270        .form(&params)
271        .basic_auth(
272            oauth_client.client_id.as_str(),
273            Some(oauth_client.client_secret.as_str()),
274        )
275        .send()
276        .await
277        .map_err(OAuthWorkflowError::ParRequestFailed)?
278        .json()
279        .await
280        .map_err(OAuthWorkflowError::ParResponseParseFailed)?;
281
282    match response {
283        WrappedParResponse::ParResponse(value) => Ok(value),
284        WrappedParResponse::Error {
285            error,
286            error_description,
287        } => {
288            let error_message = if let Some(value) = error_description {
289                format!("{error}: {value}")
290            } else {
291                error.to_string()
292            };
293            Err(OAuthWorkflowError::ParResponseInvalid {
294                message: error_message,
295            }
296            .into())
297        }
298    }
299}
300
301/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
302///
303/// After the user has authorized the application and been redirected back with an
304/// authorization code, this function exchanges that code for access tokens using
305/// the token endpoint.
306///
307/// # Arguments
308///
309/// * `http_client` - The HTTP client to use for making requests
310/// * `oauth_client` - OAuth client configuration with credentials
311/// * `authorization_server` - Authorization server metadata
312/// * `callback_code` - The authorization code received in the callback
313/// * `oauth_request` - The original OAuth request containing the PKCE verifier
314///
315/// # Returns
316///
317/// Returns a `TokenResponse` containing the access token and other token information,
318/// or an error if the token exchange fails.
319///
320/// # Example
321///
322/// ```no_run
323/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
324/// use atproto_oauth_aip::workflow::oauth_complete;
325/// # let http_client = reqwest::Client::new();
326/// # let oauth_client = todo!();
327/// # let token_endpoint = "https://auth.example.com/token";
328/// # let oauth_request = todo!();
329/// let token_response = oauth_complete(
330///     &http_client,
331///     &oauth_client,
332///     token_endpoint,
333///     "auth_code_from_callback",
334///     &oauth_request,
335/// ).await?;
336/// println!("Access token: {}", token_response.access_token);
337/// # Ok(())
338/// # }
339/// ```
340pub async fn oauth_complete(
341    http_client: &reqwest::Client,
342    oauth_client: &OAuthClient,
343    token_endpoint: &str,
344    callback_code: &str,
345    oauth_request: &OAuthRequest,
346) -> Result<TokenResponse> {
347    let params = [
348        ("client_id", oauth_client.client_id.as_str()),
349        ("redirect_uri", oauth_client.redirect_uri.as_str()),
350        ("grant_type", "authorization_code"),
351        ("code", callback_code),
352        ("code_verifier", &oauth_request.pkce_verifier),
353    ];
354
355    http_client
356        .post(token_endpoint)
357        .basic_auth(
358            oauth_client.client_id.as_str(),
359            Some(oauth_client.client_secret.as_str()),
360        )
361        .form(&params)
362        .send()
363        .await
364        .inspect(|value| {
365            println!("{value:?}");
366        })
367        .map_err(OAuthWorkflowError::TokenRequestFailed)?
368        .json()
369        .await
370        .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
371}
372
373/// Exchanges an OAuth access token for an AT Protocol session.
374///
375/// This function takes an OAuth access token and exchanges it for a full
376/// AT Protocol session, which includes additional information like the user's
377/// DID, handle, and PDS endpoint. This is specific to AT Protocol's OAuth
378/// implementation.
379///
380/// # Arguments
381///
382/// * `http_client` - The HTTP client to use for making requests
383/// * `protected_resource` - The protected resource metadata
384/// * `access_token` - The OAuth access token to exchange
385///
386/// # Returns
387///
388/// Returns an `ATProtocolSession` with full session information,
389/// or an error if the session exchange fails.
390///
391/// # Example
392///
393/// ```no_run
394/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
395/// use atproto_oauth_aip::workflow::session_exchange;
396/// # let http_client = reqwest::Client::new();
397/// # let protected_resource = "https://pds.example.com";
398/// # let access_token = "example_token";
399/// let session = session_exchange(
400///     &http_client,
401///     protected_resource,
402///     access_token,
403/// ).await?;
404/// println!("Authenticated as {} ({})", session.handle, session.did);
405/// println!("PDS endpoint: {}", session.pds_endpoint);
406/// # Ok(())
407/// # }
408/// ```
409pub async fn session_exchange(
410    http_client: &reqwest::Client,
411    protected_resource_base: &str,
412    access_token: &str,
413) -> Result<ATProtocolSession> {
414    let response = http_client
415        .get(format!(
416            "{}/api/atprotocol/session",
417            protected_resource_base
418        ))
419        .bearer_auth(access_token)
420        .send()
421        .await
422        .map_err(OAuthWorkflowError::SessionRequestFailed)?
423        .json()
424        .await
425        .map_err(OAuthWorkflowError::SessionResponseParseFailed)?;
426
427    match response {
428        WrappedATProtocolSession::ATProtocolSession(ref value) => Ok(value.clone()),
429        WrappedATProtocolSession::Error {
430            ref error,
431            ref error_description,
432        } => {
433            let error_message = if let Some(value) = error_description {
434                format!("{error}: {value}")
435            } else {
436                error.to_string()
437            };
438            Err(OAuthWorkflowError::SessionResponseInvalid {
439                message: error_message,
440            }
441            .into())
442        }
443    }
444}