atproto_oauth_aip/workflow.rs
1//! OAuth 2.0 workflow for AT Protocol identity providers.
2//!
3//! Complete authorization code flow implementation with PAR initialization,
4//! code exchange, and AT Protocol session establishment.
5//! 1. **Initialization (`oauth_init`)**: Creates a Pushed Authorization Request (PAR)
6//! and returns the authorization URL for user consent
7//! 2. **Completion (`oauth_complete`)**: Exchanges the authorization code for access tokens
8//! 3. **Session Exchange (`session_exchange`)**: Converts OAuth tokens to AT Protocol sessions
9//!
10//! ## Security Features
11//!
12//! - **Pushed Authorization Requests (PAR)**: Enhanced security by storing authorization
13//! parameters server-side rather than in redirect URLs
14//! - **PKCE (Proof Key for Code Exchange)**: Protection against authorization code
15//! interception attacks
16//! - **DPoP (Demonstration of Proof-of-Possession)**: Cryptographic binding of tokens
17//! to specific keys for enhanced security
18//!
19//! ## Usage Example
20//!
21//! ```rust,no_run
22//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! use atproto_oauth_aip::workflow::{oauth_init, oauth_complete, session_exchange, OAuthClient};
24//! use atproto_oauth::resources::{AuthorizationServer, OAuthProtectedResource};
25//! use atproto_oauth::workflow::{OAuthRequestState, OAuthRequest};
26//!
27//! let http_client = reqwest::Client::new();
28//!
29//! // 1. Initialize OAuth flow
30//! let oauth_client = OAuthClient {
31//! redirect_uri: "https://myapp.com/callback".to_string(),
32//! client_id: "my_client_id".to_string(),
33//! client_secret: "my_client_secret".to_string(),
34//! };
35//!
36//! # let authorization_server = AuthorizationServer {
37//! # issuer: "https://auth.example.com".to_string(),
38//! # authorization_endpoint: "https://auth.example.com/authorize".to_string(),
39//! # token_endpoint: "https://auth.example.com/token".to_string(),
40//! # pushed_authorization_request_endpoint: "https://auth.example.com/par".to_string(),
41//! # introspection_endpoint: "".to_string(),
42//! # scopes_supported: vec!["atproto".to_string(), "transition:generic".to_string()],
43//! # response_types_supported: vec!["code".to_string()],
44//! # grant_types_supported: vec!["authorization_code".to_string(), "refresh_token".to_string()],
45//! # token_endpoint_auth_methods_supported: vec!["none".to_string(), "private_key_jwt".to_string()],
46//! # token_endpoint_auth_signing_alg_values_supported: vec!["ES256".to_string()],
47//! # require_pushed_authorization_requests: true,
48//! # request_parameter_supported: false,
49//! # code_challenge_methods_supported: vec!["S256".to_string()],
50//! # authorization_response_iss_parameter_supported: true,
51//! # dpop_signing_alg_values_supported: vec!["ES256".to_string()],
52//! # client_id_metadata_document_supported: true,
53//! # };
54//!
55//! let oauth_request_state = OAuthRequestState {
56//! state: "random-state".to_string(),
57//! nonce: "random-nonce".to_string(),
58//! code_challenge: "code-challenge".to_string(),
59//! scope: "atproto transition:generic".to_string(),
60//! };
61//!
62//! let par_response = oauth_init(
63//! &http_client,
64//! &oauth_client,
65//! Some("user.bsky.social"),
66//! &authorization_server.pushed_authorization_request_endpoint,
67//! &oauth_request_state
68//! ).await?;
69//!
70//! // User visits auth_url and grants consent, returns with authorization code
71//!
72//! // 2. Complete OAuth flow
73//! # let oauth_request = OAuthRequest {
74//! # oauth_state: "state".to_string(),
75//! # issuer: "https://auth.example.com".to_string(),
76//! # authorization_server: "https://auth.example.com".to_string(),
77//! # nonce: "nonce".to_string(),
78//! # signing_public_key: "public_key".to_string(),
79//! # pkce_verifier: "verifier".to_string(),
80//! # dpop_private_key: "private_key".to_string(),
81//! # created_at: chrono::Utc::now(),
82//! # expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
83//! # };
84//! let token_response = oauth_complete(
85//! &http_client,
86//! &oauth_client,
87//! &authorization_server.token_endpoint,
88//! "received_auth_code",
89//! &oauth_request
90//! ).await?;
91//!
92//! // 3. Exchange for AT Protocol session
93//! # let protected_resource = OAuthProtectedResource {
94//! # resource: "https://pds.example.com".to_string(),
95//! # scopes_supported: vec!["atproto".to_string()],
96//! # bearer_methods_supported: vec!["header".to_string()],
97//! # authorization_servers: vec!["https://auth.example.com".to_string()],
98//! # };
99//! let session = session_exchange(
100//! &http_client,
101//! &protected_resource.resource,
102//! &token_response.access_token
103//! ).await?;
104//! # Ok(())
105//! # }
106//! ```
107//!
108//! ## Error Handling
109//!
110//! All functions return `Result<T, OAuthWorkflowError>` with detailed error information
111//! for each phase of the OAuth flow including network failures, parsing errors,
112//! and protocol violations.
113
114use anyhow::Result;
115use atproto_identity::url::URLBuilder;
116use atproto_oauth::{
117 jwk::WrappedJsonWebKey,
118 workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse},
119};
120use serde::Deserialize;
121
122use crate::errors::OAuthWorkflowError;
123
124#[cfg(feature = "zeroize")]
125use zeroize::{Zeroize, ZeroizeOnDrop};
126
127/// OAuth client configuration containing essential client credentials.
128#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
129pub struct OAuthClient {
130 /// The redirect URI where the authorization server will send the user after authorization.
131 #[cfg_attr(feature = "zeroize", zeroize(skip))]
132 pub redirect_uri: String,
133
134 /// The unique client identifier for this OAuth client.
135 #[cfg_attr(feature = "zeroize", zeroize(skip))]
136 pub client_id: String,
137
138 /// The client secret used for authenticating with the authorization server.
139 pub client_secret: String,
140}
141
142#[derive(Clone, Deserialize)]
143#[serde(untagged)]
144enum WrappedParResponse {
145 ParResponse(ParResponse),
146 Error {
147 error: String,
148 error_description: Option<String>,
149 },
150}
151
152/// Represents an authenticated AT Protocol session.
153///
154/// This structure contains all the information needed to make authenticated
155/// requests to AT Protocol services after a successful OAuth flow.
156#[derive(Clone, Deserialize)]
157#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
158pub struct ATProtocolSession {
159 /// The Decentralized Identifier (DID) of the authenticated user.
160 #[cfg_attr(feature = "zeroize", zeroize(skip))]
161 pub did: String,
162
163 /// The handle (username) of the authenticated user.
164 #[cfg_attr(feature = "zeroize", zeroize(skip))]
165 pub handle: String,
166
167 /// The OAuth access token for making authenticated requests.
168 #[cfg_attr(feature = "zeroize", zeroize(skip))]
169 pub access_token: String,
170
171 /// The type of token (typically "Bearer").
172 #[cfg_attr(feature = "zeroize", zeroize(skip))]
173 pub token_type: String,
174
175 /// The list of OAuth scopes granted to this session.
176 #[cfg_attr(feature = "zeroize", zeroize(skip))]
177 pub scopes: Vec<String>,
178
179 /// The Personal Data Server (PDS) endpoint URL for this user.
180 #[cfg_attr(feature = "zeroize", zeroize(skip))]
181 pub pds_endpoint: String,
182
183 /// The DPoP (Demonstration of Proof-of-Possession) key in string serialized format.
184 #[cfg_attr(feature = "zeroize", zeroize(skip))]
185 pub dpop_key: Option<String>,
186
187 /// The DPoP (Demonstration of Proof-of-Possession) key in JWK format.
188 #[cfg_attr(feature = "zeroize", zeroize(skip))]
189 pub dpop_jwk: Option<WrappedJsonWebKey>,
190
191 /// Unix timestamp indicating when this session expires.
192 #[cfg_attr(feature = "zeroize", zeroize(skip))]
193 pub expires_at: i64,
194}
195
196#[derive(Deserialize, Clone)]
197#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
198#[serde(untagged)]
199enum WrappedATProtocolSession {
200 ATProtocolSession(Box<ATProtocolSession>),
201
202 #[cfg_attr(feature = "zeroize", zeroize(skip))]
203 Error {
204 error: String,
205 error_description: Option<String>,
206 },
207}
208
209/// Initiates an OAuth authorization flow using Pushed Authorization Request (PAR).
210///
211/// This function starts the OAuth flow by sending a PAR request to the authorization
212/// server. PAR allows the client to push the authorization request parameters to the
213/// authorization server before redirecting the user, providing enhanced security.
214///
215/// # Arguments
216///
217/// * `http_client` - The HTTP client to use for making requests
218/// * `oauth_client` - OAuth client configuration with credentials
219/// * `handle` - Optional user handle to pre-fill in the login form
220/// * `authorization_server` - Authorization server metadata
221/// * `oauth_request_state` - OAuth request state including PKCE challenge and state
222///
223/// # Returns
224///
225/// Returns a `ParResponse` containing the request URI to redirect the user to,
226/// or an error if the PAR request fails.
227///
228/// # Example
229///
230/// ```no_run
231/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
232/// use atproto_oauth_aip::workflow::{oauth_init, OAuthClient};
233/// use atproto_oauth::workflow::OAuthRequestState;
234/// # let http_client = reqwest::Client::new();
235/// let oauth_client = OAuthClient {
236/// redirect_uri: "https://example.com/callback".to_string(),
237/// client_id: "client123".to_string(),
238/// client_secret: "secret456".to_string(),
239/// };
240/// # let authorization_server = "https://auth.example.com/par";
241/// let oauth_request_state = OAuthRequestState {
242/// state: "random-state".to_string(),
243/// nonce: "random-nonce".to_string(),
244/// code_challenge: "code-challenge".to_string(),
245/// scope: "atproto transition:generic".to_string(),
246/// };
247/// let par_response = oauth_init(
248/// &http_client,
249/// &oauth_client,
250/// Some("alice.bsky.social"),
251/// authorization_server,
252/// &oauth_request_state,
253/// ).await?;
254/// # Ok(())
255/// # }
256/// ```
257pub async fn oauth_init(
258 http_client: &reqwest::Client,
259 oauth_client: &OAuthClient,
260 login_hint: Option<&str>,
261 par_url: &str,
262 oauth_request_state: &OAuthRequestState,
263) -> Result<ParResponse> {
264 let scope = &oauth_request_state.scope;
265
266 let mut params = vec![
267 ("client_id", oauth_client.client_id.as_str()),
268 ("code_challenge_method", "S256"),
269 ("code_challenge", &oauth_request_state.code_challenge),
270 ("redirect_uri", oauth_client.redirect_uri.as_str()),
271 ("response_type", "code"),
272 ("scope", scope),
273 ("state", oauth_request_state.state.as_str()),
274 ];
275 if let Some(value) = login_hint {
276 params.push(("login_hint", value));
277 }
278
279 let response: WrappedParResponse = http_client
280 .post(par_url)
281 .form(¶ms)
282 .basic_auth(
283 oauth_client.client_id.as_str(),
284 Some(oauth_client.client_secret.as_str()),
285 )
286 .send()
287 .await
288 .map_err(OAuthWorkflowError::ParRequestFailed)?
289 .json()
290 .await
291 .map_err(OAuthWorkflowError::ParResponseParseFailed)?;
292
293 match response {
294 WrappedParResponse::ParResponse(value) => Ok(value),
295 WrappedParResponse::Error {
296 error,
297 error_description,
298 } => {
299 let error_message = if let Some(value) = error_description {
300 format!("{error}: {value}")
301 } else {
302 error.to_string()
303 };
304 Err(OAuthWorkflowError::ParResponseInvalid {
305 message: error_message,
306 }
307 .into())
308 }
309 }
310}
311
312/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
313///
314/// After the user has authorized the application and been redirected back with an
315/// authorization code, this function exchanges that code for access tokens using
316/// the token endpoint.
317///
318/// # Arguments
319///
320/// * `http_client` - The HTTP client to use for making requests
321/// * `oauth_client` - OAuth client configuration with credentials
322/// * `authorization_server` - Authorization server metadata
323/// * `callback_code` - The authorization code received in the callback
324/// * `oauth_request` - The original OAuth request containing the PKCE verifier
325///
326/// # Returns
327///
328/// Returns a `TokenResponse` containing the access token and other token information,
329/// or an error if the token exchange fails.
330///
331/// # Example
332///
333/// ```no_run
334/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
335/// use atproto_oauth_aip::workflow::oauth_complete;
336/// # let http_client = reqwest::Client::new();
337/// # let oauth_client = todo!();
338/// # let token_endpoint = "https://auth.example.com/token";
339/// # let oauth_request = todo!();
340/// let token_response = oauth_complete(
341/// &http_client,
342/// &oauth_client,
343/// token_endpoint,
344/// "auth_code_from_callback",
345/// &oauth_request,
346/// ).await?;
347/// println!("Access token: {}", token_response.access_token);
348/// # Ok(())
349/// # }
350/// ```
351pub async fn oauth_complete(
352 http_client: &reqwest::Client,
353 oauth_client: &OAuthClient,
354 token_endpoint: &str,
355 callback_code: &str,
356 oauth_request: &OAuthRequest,
357) -> Result<TokenResponse> {
358 let params = [
359 ("client_id", oauth_client.client_id.as_str()),
360 ("redirect_uri", oauth_client.redirect_uri.as_str()),
361 ("grant_type", "authorization_code"),
362 ("code", callback_code),
363 ("code_verifier", &oauth_request.pkce_verifier),
364 ];
365
366 http_client
367 .post(token_endpoint)
368 .basic_auth(
369 oauth_client.client_id.as_str(),
370 Some(oauth_client.client_secret.as_str()),
371 )
372 .form(¶ms)
373 .send()
374 .await
375 .inspect(|value| {
376 println!("{value:?}");
377 })
378 .map_err(OAuthWorkflowError::TokenRequestFailed)?
379 .json()
380 .await
381 .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
382}
383
384/// Exchanges an OAuth access token for an AT Protocol session.
385///
386/// This function takes an OAuth access token and exchanges it for a full
387/// AT Protocol session, which includes additional information like the user's
388/// DID, handle, and PDS endpoint. This is specific to AT Protocol's OAuth
389/// implementation.
390///
391/// This is a convenience function that calls `session_exchange_with_options`
392/// with no additional options.
393///
394/// # Arguments
395///
396/// * `http_client` - The HTTP client to use for making requests
397/// * `protected_resource_base` - The base URL of the protected resource (PDS)
398/// * `access_token` - The OAuth access token to exchange
399///
400/// # Returns
401///
402/// Returns an `ATProtocolSession` with full session information,
403/// or an error if the session exchange fails.
404///
405/// # Example
406///
407/// ```no_run
408/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
409/// use atproto_oauth_aip::workflow::session_exchange;
410/// # let http_client = reqwest::Client::new();
411/// # let protected_resource = "https://pds.example.com";
412/// # let access_token = "example_token";
413/// let session = session_exchange(
414/// &http_client,
415/// protected_resource,
416/// access_token,
417/// ).await?;
418/// println!("Authenticated as {} ({})", session.handle, session.did);
419/// println!("PDS endpoint: {}", session.pds_endpoint);
420/// # Ok(())
421/// # }
422/// ```
423pub async fn session_exchange(
424 http_client: &reqwest::Client,
425 protected_resource_base: &str,
426 access_token: &str,
427) -> Result<ATProtocolSession> {
428 session_exchange_with_options(
429 http_client,
430 protected_resource_base,
431 access_token,
432 &None,
433 &None,
434 )
435 .await
436}
437
438/// Exchanges an OAuth access token for an AT Protocol session with additional options.
439///
440/// This function takes an OAuth access token and exchanges it for a full
441/// AT Protocol session, which includes additional information like the user's
442/// DID, handle, and PDS endpoint. This version allows specifying additional
443/// options for the session exchange.
444///
445/// # Arguments
446///
447/// * `http_client` - The HTTP client to use for making requests
448/// * `protected_resource_base` - The base URL of the protected resource (PDS)
449/// * `access_token` - The OAuth access token to exchange
450/// * `access_token_type` - Optional token type ("oauth_session", "app_password_session", or "best")
451/// * `subject` - Optional subject (DID) to specify which user's session to retrieve
452///
453/// # Returns
454///
455/// Returns an `ATProtocolSession` with full session information,
456/// or an error if the session exchange fails.
457///
458/// # Example
459///
460/// ```no_run
461/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
462/// use atproto_oauth_aip::workflow::session_exchange_with_options;
463/// # let http_client = reqwest::Client::new();
464/// # let protected_resource = "https://pds.example.com";
465/// # let access_token = "example_token";
466/// // Basic usage without options
467/// let session = session_exchange_with_options(
468/// &http_client,
469/// protected_resource,
470/// access_token,
471/// &None,
472/// &None,
473/// ).await?;
474/// # Ok(())
475/// # }
476/// ```
477///
478/// # Example with access_token_type
479///
480/// ```no_run
481/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
482/// use atproto_oauth_aip::workflow::session_exchange_with_options;
483/// # let http_client = reqwest::Client::new();
484/// # let protected_resource = "https://pds.example.com";
485/// # let access_token = "example_token";
486/// // Specify the token type
487/// let session = session_exchange_with_options(
488/// &http_client,
489/// protected_resource,
490/// access_token,
491/// &Some("oauth_session"),
492/// &None,
493/// ).await?;
494/// # Ok(())
495/// # }
496/// ```
497///
498/// # Example with subject
499///
500/// ```no_run
501/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
502/// use atproto_oauth_aip::workflow::session_exchange_with_options;
503/// # let http_client = reqwest::Client::new();
504/// # let protected_resource = "https://pds.example.com";
505/// # let access_token = "example_token";
506/// # let user_did = "did:plc:example123";
507/// // Specify both token type and subject
508/// let session = session_exchange_with_options(
509/// &http_client,
510/// protected_resource,
511/// access_token,
512/// &Some("app_password_session"),
513/// &Some(user_did),
514/// ).await?;
515/// # Ok(())
516/// # }
517/// ```
518pub async fn session_exchange_with_options(
519 http_client: &reqwest::Client,
520 protected_resource_base: &str,
521 access_token: &str,
522 access_token_type: &Option<&str>,
523 subject: &Option<&str>,
524) -> Result<ATProtocolSession> {
525 let mut url_builder = URLBuilder::new(protected_resource_base);
526 url_builder.path("/api/atprotocol/session");
527
528 if let Some(value) = access_token_type {
529 url_builder.param("access_token_type", value);
530 }
531
532 if let Some(value) = subject {
533 url_builder.param("sub", value);
534 }
535
536 let url = url_builder.build();
537
538 let response = http_client
539 .get(url)
540 .bearer_auth(access_token)
541 .send()
542 .await
543 .map_err(OAuthWorkflowError::SessionRequestFailed)?
544 .json()
545 .await
546 .map_err(OAuthWorkflowError::SessionResponseParseFailed)?;
547
548 match response {
549 WrappedATProtocolSession::ATProtocolSession(ref value) => Ok(*value.clone()),
550 WrappedATProtocolSession::Error {
551 ref error,
552 ref error_description,
553 } => {
554 let error_message = if let Some(value) = error_description {
555 format!("{error}: {value}")
556 } else {
557 error.to_string()
558 };
559 Err(OAuthWorkflowError::SessionResponseInvalid {
560 message: error_message,
561 }
562 .into())
563 }
564 }
565}
566
567/// Obtains an access token using OAuth client credentials grant.
568///
569/// This function implements the OAuth 2.0 client credentials flow for obtaining
570/// service-to-service access tokens. This is typically used when a service needs
571/// to authenticate itself rather than acting on behalf of a user.
572///
573/// # Arguments
574///
575/// * `http_client` - The HTTP client to use for making requests
576/// * `aip_hostname` - The hostname of the AT Protocol Identity Provider (AIP)
577/// * `aip_client_id` - The client ID for authenticating with the AIP
578/// * `aip_client_secret` - The client secret for authenticating with the AIP
579///
580/// # Returns
581///
582/// Returns a `TokenResponse` containing the access token and metadata,
583/// or an error if the token request fails.
584///
585/// # Example
586///
587/// ```no_run
588/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
589/// use atproto_oauth_aip::workflow::client_credentials_token;
590/// use atproto_oauth::workflow::TokenResponse;
591/// # let http_client = reqwest::Client::new();
592/// # let aip_hostname = "auth.example.com";
593/// # let client_id = "service-client-id";
594/// # let client_secret = "service-client-secret";
595///
596/// let token = client_credentials_token(
597/// &http_client,
598/// aip_hostname,
599/// client_id,
600/// client_secret,
601/// ).await?;
602///
603/// println!("Access token: {}", token.access_token);
604/// println!("Token type: {}", token.token_type);
605/// println!("Expires in: {} seconds", token.expires_in);
606/// # Ok(())
607/// # }
608/// ```
609pub async fn client_credentials_token(
610 http_client: &reqwest::Client,
611 aip_hostname: &str,
612 aip_client_id: &str,
613 aip_client_secret: &str,
614) -> Result<atproto_oauth::workflow::TokenResponse> {
615 // Construct the token endpoint URL
616 let token_url = format!("https://{}/oauth/token", aip_hostname);
617
618 // Prepare the form data for client credentials grant
619 let params = [("grant_type", "client_credentials")];
620
621 // Send the request with Basic authentication
622 let response = http_client
623 .post(&token_url)
624 .basic_auth(aip_client_id, Some(aip_client_secret))
625 .form(¶ms)
626 .send()
627 .await
628 .map_err(OAuthWorkflowError::TokenRequestFailed)?;
629
630 // Check if the request was successful
631 if !response.status().is_success() {
632 let status = response.status();
633 let error_text = response
634 .text()
635 .await
636 .unwrap_or_else(|_| "Unknown error".to_string());
637 return Err(OAuthWorkflowError::TokenResponseInvalid {
638 message: format!("Token request failed with status {}: {}", status, error_text),
639 }
640 .into());
641 }
642
643 // Parse the response
644 response
645 .json::<atproto_oauth::workflow::TokenResponse>()
646 .await
647 .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
648}