atproto_oauth_aip/
workflow.rs

1//! OAuth 2.0 workflow for AT Protocol identity providers.
2//!
3//! Complete authorization code flow implementation with PAR initialization,
4//! code exchange, and AT Protocol session establishment.
5//! 1. **Initialization (`oauth_init`)**: Creates a Pushed Authorization Request (PAR)
6//!    and returns the authorization URL for user consent
7//! 2. **Completion (`oauth_complete`)**: Exchanges the authorization code for access tokens
8//! 3. **Session Exchange (`session_exchange`)**: Converts OAuth tokens to AT Protocol sessions
9//!
10//! ## Security Features
11//!
12//! - **Pushed Authorization Requests (PAR)**: Enhanced security by storing authorization
13//!   parameters server-side rather than in redirect URLs
14//! - **PKCE (Proof Key for Code Exchange)**: Protection against authorization code
15//!   interception attacks
16//! - **DPoP (Demonstration of Proof-of-Possession)**: Cryptographic binding of tokens
17//!   to specific keys for enhanced security
18//!
19//! ## Usage Example
20//!
21//! ```rust,no_run
22//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! use atproto_oauth_aip::workflow::{oauth_init, oauth_complete, session_exchange, OAuthClient};
24//! use atproto_oauth::resources::{AuthorizationServer, OAuthProtectedResource};
25//! use atproto_oauth::workflow::{OAuthRequestState, OAuthRequest};
26//!
27//! let http_client = reqwest::Client::new();
28//!
29//! // 1. Initialize OAuth flow
30//! let oauth_client = OAuthClient {
31//!     redirect_uri: "https://myapp.com/callback".to_string(),
32//!     client_id: "my_client_id".to_string(),
33//!     client_secret: "my_client_secret".to_string(),
34//! };
35//!
36//! # let authorization_server = AuthorizationServer {
37//! #     issuer: "https://auth.example.com".to_string(),
38//! #     authorization_endpoint: "https://auth.example.com/authorize".to_string(),
39//! #     token_endpoint: "https://auth.example.com/token".to_string(),
40//! #     pushed_authorization_request_endpoint: "https://auth.example.com/par".to_string(),
41//! #     introspection_endpoint: "".to_string(),
42//! #     scopes_supported: vec!["atproto".to_string(), "transition:generic".to_string()],
43//! #     response_types_supported: vec!["code".to_string()],
44//! #     grant_types_supported: vec!["authorization_code".to_string(), "refresh_token".to_string()],
45//! #     token_endpoint_auth_methods_supported: vec!["none".to_string(), "private_key_jwt".to_string()],
46//! #     token_endpoint_auth_signing_alg_values_supported: vec!["ES256".to_string()],
47//! #     require_pushed_authorization_requests: true,
48//! #     request_parameter_supported: false,
49//! #     code_challenge_methods_supported: vec!["S256".to_string()],
50//! #     authorization_response_iss_parameter_supported: true,
51//! #     dpop_signing_alg_values_supported: vec!["ES256".to_string()],
52//! #     client_id_metadata_document_supported: true,
53//! # };
54//!
55//! let oauth_request_state = OAuthRequestState {
56//!     state: "random-state".to_string(),
57//!     nonce: "random-nonce".to_string(),
58//!     code_challenge: "code-challenge".to_string(),
59//!     scope: "atproto transition:generic".to_string(),
60//! };
61//!
62//! let par_response = oauth_init(
63//!     &http_client,
64//!     &oauth_client,
65//!     Some("user.bsky.social"),
66//!     &authorization_server.pushed_authorization_request_endpoint,
67//!     &oauth_request_state
68//! ).await?;
69//!
70//! // User visits auth_url and grants consent, returns with authorization code
71//!
72//! // 2. Complete OAuth flow
73//! # let oauth_request = OAuthRequest {
74//! #     oauth_state: "state".to_string(),
75//! #     issuer: "https://auth.example.com".to_string(),
76//! #     authorization_server: "https://auth.example.com".to_string(),
77//! #     nonce: "nonce".to_string(),
78//! #     signing_public_key: "public_key".to_string(),
79//! #     pkce_verifier: "verifier".to_string(),
80//! #     dpop_private_key: "private_key".to_string(),
81//! #     created_at: chrono::Utc::now(),
82//! #     expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
83//! # };
84//! let token_response = oauth_complete(
85//!     &http_client,
86//!     &oauth_client,
87//!     &authorization_server.token_endpoint,
88//!     "received_auth_code",
89//!     &oauth_request
90//! ).await?;
91//!
92//! // 3. Exchange for AT Protocol session
93//! # let protected_resource = OAuthProtectedResource {
94//! #     resource: "https://pds.example.com".to_string(),
95//! #     scopes_supported: vec!["atproto".to_string()],
96//! #     bearer_methods_supported: vec!["header".to_string()],
97//! #     authorization_servers: vec!["https://auth.example.com".to_string()],
98//! # };
99//! let session = session_exchange(
100//!     &http_client,
101//!     &protected_resource.resource,
102//!     &token_response.access_token
103//! ).await?;
104//! # Ok(())
105//! # }
106//! ```
107//!
108//! ## Error Handling
109//!
110//! All functions return `Result<T, OAuthWorkflowError>` with detailed error information
111//! for each phase of the OAuth flow including network failures, parsing errors,
112//! and protocol violations.
113
114use anyhow::Result;
115use atproto_identity::url::URLBuilder;
116use atproto_oauth::{
117    jwk::WrappedJsonWebKey,
118    workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse},
119};
120use serde::Deserialize;
121
122use crate::errors::OAuthWorkflowError;
123
124#[cfg(feature = "zeroize")]
125use zeroize::{Zeroize, ZeroizeOnDrop};
126
127/// OAuth client configuration containing essential client credentials.
128#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
129pub struct OAuthClient {
130    /// The redirect URI where the authorization server will send the user after authorization.
131    #[cfg_attr(feature = "zeroize", zeroize(skip))]
132    pub redirect_uri: String,
133
134    /// The unique client identifier for this OAuth client.
135    #[cfg_attr(feature = "zeroize", zeroize(skip))]
136    pub client_id: String,
137
138    /// The client secret used for authenticating with the authorization server.
139    pub client_secret: String,
140}
141
142#[derive(Clone, Deserialize)]
143#[serde(untagged)]
144enum WrappedParResponse {
145    ParResponse(ParResponse),
146    Error {
147        error: String,
148        error_description: Option<String>,
149    },
150}
151
152/// Represents an authenticated AT Protocol session.
153///
154/// This structure contains all the information needed to make authenticated
155/// requests to AT Protocol services after a successful OAuth flow.
156#[derive(Clone, Deserialize)]
157#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
158pub struct ATProtocolSession {
159    /// The Decentralized Identifier (DID) of the authenticated user.
160    #[cfg_attr(feature = "zeroize", zeroize(skip))]
161    pub did: String,
162
163    /// The handle (username) of the authenticated user.
164    #[cfg_attr(feature = "zeroize", zeroize(skip))]
165    pub handle: String,
166
167    /// The OAuth access token for making authenticated requests.
168    #[cfg_attr(feature = "zeroize", zeroize(skip))]
169    pub access_token: String,
170
171    /// The type of token (typically "Bearer").
172    #[cfg_attr(feature = "zeroize", zeroize(skip))]
173    pub token_type: String,
174
175    /// The list of OAuth scopes granted to this session.
176    #[cfg_attr(feature = "zeroize", zeroize(skip))]
177    pub scopes: Vec<String>,
178
179    /// The Personal Data Server (PDS) endpoint URL for this user.
180    #[cfg_attr(feature = "zeroize", zeroize(skip))]
181    pub pds_endpoint: String,
182
183    /// The DPoP (Demonstration of Proof-of-Possession) key in string serialized format.
184    #[cfg_attr(feature = "zeroize", zeroize(skip))]
185    pub dpop_key: Option<String>,
186
187    /// The DPoP (Demonstration of Proof-of-Possession) key in JWK format.
188    #[cfg_attr(feature = "zeroize", zeroize(skip))]
189    pub dpop_jwk: Option<WrappedJsonWebKey>,
190
191    /// Unix timestamp indicating when this session expires.
192    #[cfg_attr(feature = "zeroize", zeroize(skip))]
193    pub expires_at: i64,
194}
195
196#[derive(Deserialize, Clone)]
197#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
198#[serde(untagged)]
199enum WrappedATProtocolSession {
200    ATProtocolSession(Box<ATProtocolSession>),
201
202    #[cfg_attr(feature = "zeroize", zeroize(skip))]
203    Error {
204        error: String,
205        error_description: Option<String>,
206    },
207}
208
209/// Initiates an OAuth authorization flow using Pushed Authorization Request (PAR).
210///
211/// This function starts the OAuth flow by sending a PAR request to the authorization
212/// server. PAR allows the client to push the authorization request parameters to the
213/// authorization server before redirecting the user, providing enhanced security.
214///
215/// # Arguments
216///
217/// * `http_client` - The HTTP client to use for making requests
218/// * `oauth_client` - OAuth client configuration with credentials
219/// * `handle` - Optional user handle to pre-fill in the login form
220/// * `authorization_server` - Authorization server metadata
221/// * `oauth_request_state` - OAuth request state including PKCE challenge and state
222///
223/// # Returns
224///
225/// Returns a `ParResponse` containing the request URI to redirect the user to,
226/// or an error if the PAR request fails.
227///
228/// # Example
229///
230/// ```no_run
231/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
232/// use atproto_oauth_aip::workflow::{oauth_init, OAuthClient};
233/// use atproto_oauth::workflow::OAuthRequestState;
234/// # let http_client = reqwest::Client::new();
235/// let oauth_client = OAuthClient {
236///     redirect_uri: "https://example.com/callback".to_string(),
237///     client_id: "client123".to_string(),
238///     client_secret: "secret456".to_string(),
239/// };
240/// # let authorization_server = "https://auth.example.com/par";
241/// let oauth_request_state = OAuthRequestState {
242///     state: "random-state".to_string(),
243///     nonce: "random-nonce".to_string(),
244///     code_challenge: "code-challenge".to_string(),
245///     scope: "atproto transition:generic".to_string(),
246/// };
247/// let par_response = oauth_init(
248///     &http_client,
249///     &oauth_client,
250///     Some("alice.bsky.social"),
251///     authorization_server,
252///     &oauth_request_state,
253/// ).await?;
254/// # Ok(())
255/// # }
256/// ```
257pub async fn oauth_init(
258    http_client: &reqwest::Client,
259    oauth_client: &OAuthClient,
260    login_hint: Option<&str>,
261    par_url: &str,
262    oauth_request_state: &OAuthRequestState,
263) -> Result<ParResponse> {
264    let scope = &oauth_request_state.scope;
265
266    let mut params = vec![
267        ("client_id", oauth_client.client_id.as_str()),
268        ("code_challenge_method", "S256"),
269        ("code_challenge", &oauth_request_state.code_challenge),
270        ("redirect_uri", oauth_client.redirect_uri.as_str()),
271        ("response_type", "code"),
272        ("scope", scope),
273        ("state", oauth_request_state.state.as_str()),
274    ];
275    if let Some(value) = login_hint {
276        params.push(("login_hint", value));
277    }
278
279    let response: WrappedParResponse = http_client
280        .post(par_url)
281        .form(&params)
282        .basic_auth(
283            oauth_client.client_id.as_str(),
284            Some(oauth_client.client_secret.as_str()),
285        )
286        .send()
287        .await
288        .map_err(OAuthWorkflowError::ParRequestFailed)?
289        .json()
290        .await
291        .map_err(OAuthWorkflowError::ParResponseParseFailed)?;
292
293    match response {
294        WrappedParResponse::ParResponse(value) => Ok(value),
295        WrappedParResponse::Error {
296            error,
297            error_description,
298        } => {
299            let error_message = if let Some(value) = error_description {
300                format!("{error}: {value}")
301            } else {
302                error.to_string()
303            };
304            Err(OAuthWorkflowError::ParResponseInvalid {
305                message: error_message,
306            }
307            .into())
308        }
309    }
310}
311
312/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
313///
314/// After the user has authorized the application and been redirected back with an
315/// authorization code, this function exchanges that code for access tokens using
316/// the token endpoint.
317///
318/// # Arguments
319///
320/// * `http_client` - The HTTP client to use for making requests
321/// * `oauth_client` - OAuth client configuration with credentials
322/// * `authorization_server` - Authorization server metadata
323/// * `callback_code` - The authorization code received in the callback
324/// * `oauth_request` - The original OAuth request containing the PKCE verifier
325///
326/// # Returns
327///
328/// Returns a `TokenResponse` containing the access token and other token information,
329/// or an error if the token exchange fails.
330///
331/// # Example
332///
333/// ```no_run
334/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
335/// use atproto_oauth_aip::workflow::oauth_complete;
336/// # let http_client = reqwest::Client::new();
337/// # let oauth_client = todo!();
338/// # let token_endpoint = "https://auth.example.com/token";
339/// # let oauth_request = todo!();
340/// let token_response = oauth_complete(
341///     &http_client,
342///     &oauth_client,
343///     token_endpoint,
344///     "auth_code_from_callback",
345///     &oauth_request,
346/// ).await?;
347/// println!("Access token: {}", token_response.access_token);
348/// # Ok(())
349/// # }
350/// ```
351pub async fn oauth_complete(
352    http_client: &reqwest::Client,
353    oauth_client: &OAuthClient,
354    token_endpoint: &str,
355    callback_code: &str,
356    oauth_request: &OAuthRequest,
357) -> Result<TokenResponse> {
358    let params = [
359        ("client_id", oauth_client.client_id.as_str()),
360        ("redirect_uri", oauth_client.redirect_uri.as_str()),
361        ("grant_type", "authorization_code"),
362        ("code", callback_code),
363        ("code_verifier", &oauth_request.pkce_verifier),
364    ];
365
366    http_client
367        .post(token_endpoint)
368        .basic_auth(
369            oauth_client.client_id.as_str(),
370            Some(oauth_client.client_secret.as_str()),
371        )
372        .form(&params)
373        .send()
374        .await
375        .inspect(|value| {
376            println!("{value:?}");
377        })
378        .map_err(OAuthWorkflowError::TokenRequestFailed)?
379        .json()
380        .await
381        .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
382}
383
384/// Exchanges an OAuth access token for an AT Protocol session.
385///
386/// This function takes an OAuth access token and exchanges it for a full
387/// AT Protocol session, which includes additional information like the user's
388/// DID, handle, and PDS endpoint. This is specific to AT Protocol's OAuth
389/// implementation.
390///
391/// This is a convenience function that calls `session_exchange_with_options`
392/// with no additional options.
393///
394/// # Arguments
395///
396/// * `http_client` - The HTTP client to use for making requests
397/// * `protected_resource_base` - The base URL of the protected resource (PDS)
398/// * `access_token` - The OAuth access token to exchange
399///
400/// # Returns
401///
402/// Returns an `ATProtocolSession` with full session information,
403/// or an error if the session exchange fails.
404///
405/// # Example
406///
407/// ```no_run
408/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
409/// use atproto_oauth_aip::workflow::session_exchange;
410/// # let http_client = reqwest::Client::new();
411/// # let protected_resource = "https://pds.example.com";
412/// # let access_token = "example_token";
413/// let session = session_exchange(
414///     &http_client,
415///     protected_resource,
416///     access_token,
417/// ).await?;
418/// println!("Authenticated as {} ({})", session.handle, session.did);
419/// println!("PDS endpoint: {}", session.pds_endpoint);
420/// # Ok(())
421/// # }
422/// ```
423pub async fn session_exchange(
424    http_client: &reqwest::Client,
425    protected_resource_base: &str,
426    access_token: &str,
427) -> Result<ATProtocolSession> {
428    session_exchange_with_options(
429        http_client,
430        protected_resource_base,
431        access_token,
432        &None,
433        &None,
434    )
435    .await
436}
437
438/// Exchanges an OAuth access token for an AT Protocol session with additional options.
439///
440/// This function takes an OAuth access token and exchanges it for a full
441/// AT Protocol session, which includes additional information like the user's
442/// DID, handle, and PDS endpoint. This version allows specifying additional
443/// options for the session exchange.
444///
445/// # Arguments
446///
447/// * `http_client` - The HTTP client to use for making requests
448/// * `protected_resource_base` - The base URL of the protected resource (PDS)
449/// * `access_token` - The OAuth access token to exchange
450/// * `access_token_type` - Optional token type ("oauth_session", "app_password_session", or "best")
451/// * `subject` - Optional subject (DID) to specify which user's session to retrieve
452///
453/// # Returns
454///
455/// Returns an `ATProtocolSession` with full session information,
456/// or an error if the session exchange fails.
457///
458/// # Example
459///
460/// ```no_run
461/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
462/// use atproto_oauth_aip::workflow::session_exchange_with_options;
463/// # let http_client = reqwest::Client::new();
464/// # let protected_resource = "https://pds.example.com";
465/// # let access_token = "example_token";
466/// // Basic usage without options
467/// let session = session_exchange_with_options(
468///     &http_client,
469///     protected_resource,
470///     access_token,
471///     &None,
472///     &None,
473/// ).await?;
474/// # Ok(())
475/// # }
476/// ```
477///
478/// # Example with access_token_type
479///
480/// ```no_run
481/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
482/// use atproto_oauth_aip::workflow::session_exchange_with_options;
483/// # let http_client = reqwest::Client::new();
484/// # let protected_resource = "https://pds.example.com";
485/// # let access_token = "example_token";
486/// // Specify the token type
487/// let session = session_exchange_with_options(
488///     &http_client,
489///     protected_resource,
490///     access_token,
491///     &Some("oauth_session"),
492///     &None,
493/// ).await?;
494/// # Ok(())
495/// # }
496/// ```
497///
498/// # Example with subject
499///
500/// ```no_run
501/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
502/// use atproto_oauth_aip::workflow::session_exchange_with_options;
503/// # let http_client = reqwest::Client::new();
504/// # let protected_resource = "https://pds.example.com";
505/// # let access_token = "example_token";
506/// # let user_did = "did:plc:example123";
507/// // Specify both token type and subject
508/// let session = session_exchange_with_options(
509///     &http_client,
510///     protected_resource,
511///     access_token,
512///     &Some("app_password_session"),
513///     &Some(user_did),
514/// ).await?;
515/// # Ok(())
516/// # }
517/// ```
518pub async fn session_exchange_with_options(
519    http_client: &reqwest::Client,
520    protected_resource_base: &str,
521    access_token: &str,
522    access_token_type: &Option<&str>,
523    subject: &Option<&str>,
524) -> Result<ATProtocolSession> {
525    let mut url_builder = URLBuilder::new(protected_resource_base);
526    url_builder.path("/api/atprotocol/session");
527
528    if let Some(value) = access_token_type {
529        url_builder.param("access_token_type", value);
530    }
531
532    if let Some(value) = subject {
533        url_builder.param("sub", value);
534    }
535
536    let url = url_builder.build();
537
538    let response = http_client
539        .get(url)
540        .bearer_auth(access_token)
541        .send()
542        .await
543        .map_err(OAuthWorkflowError::SessionRequestFailed)?
544        .json()
545        .await
546        .map_err(OAuthWorkflowError::SessionResponseParseFailed)?;
547
548    match response {
549        WrappedATProtocolSession::ATProtocolSession(ref value) => Ok(*value.clone()),
550        WrappedATProtocolSession::Error {
551            ref error,
552            ref error_description,
553        } => {
554            let error_message = if let Some(value) = error_description {
555                format!("{error}: {value}")
556            } else {
557                error.to_string()
558            };
559            Err(OAuthWorkflowError::SessionResponseInvalid {
560                message: error_message,
561            }
562            .into())
563        }
564    }
565}
566
567/// Obtains an access token using OAuth client credentials grant.
568///
569/// This function implements the OAuth 2.0 client credentials flow for obtaining
570/// service-to-service access tokens. This is typically used when a service needs
571/// to authenticate itself rather than acting on behalf of a user.
572///
573/// # Arguments
574///
575/// * `http_client` - The HTTP client to use for making requests
576/// * `aip_hostname` - The hostname of the AT Protocol Identity Provider (AIP)
577/// * `aip_client_id` - The client ID for authenticating with the AIP
578/// * `aip_client_secret` - The client secret for authenticating with the AIP
579///
580/// # Returns
581///
582/// Returns a `TokenResponse` containing the access token and metadata,
583/// or an error if the token request fails.
584///
585/// # Example
586///
587/// ```no_run
588/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
589/// use atproto_oauth_aip::workflow::client_credentials_token;
590/// use atproto_oauth::workflow::TokenResponse;
591/// # let http_client = reqwest::Client::new();
592/// # let aip_hostname = "auth.example.com";
593/// # let client_id = "service-client-id";
594/// # let client_secret = "service-client-secret";
595///
596/// let token = client_credentials_token(
597///     &http_client,
598///     aip_hostname,
599///     client_id,
600///     client_secret,
601/// ).await?;
602///
603/// println!("Access token: {}", token.access_token);
604/// println!("Token type: {}", token.token_type);
605/// println!("Expires in: {} seconds", token.expires_in);
606/// # Ok(())
607/// # }
608/// ```
609pub async fn client_credentials_token(
610    http_client: &reqwest::Client,
611    aip_hostname: &str,
612    aip_client_id: &str,
613    aip_client_secret: &str,
614) -> Result<atproto_oauth::workflow::TokenResponse> {
615    // Construct the token endpoint URL
616    let token_url = format!("https://{}/oauth/token", aip_hostname);
617
618    // Prepare the form data for client credentials grant
619    let params = [("grant_type", "client_credentials")];
620
621    // Send the request with Basic authentication
622    let response = http_client
623        .post(&token_url)
624        .basic_auth(aip_client_id, Some(aip_client_secret))
625        .form(&params)
626        .send()
627        .await
628        .map_err(OAuthWorkflowError::TokenRequestFailed)?;
629
630    // Check if the request was successful
631    if !response.status().is_success() {
632        let status = response.status();
633        let error_text = response
634            .text()
635            .await
636            .unwrap_or_else(|_| "Unknown error".to_string());
637        return Err(OAuthWorkflowError::TokenResponseInvalid {
638            message: format!("Token request failed with status {}: {}", status, error_text),
639        }
640        .into());
641    }
642
643    // Parse the response
644    response
645        .json::<atproto_oauth::workflow::TokenResponse>()
646        .await
647        .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
648}