atproto_oauth_aip/
workflow.rs

1use crate::errors::OAuthWorkflowError;
2use anyhow::Result;
3use atproto_oauth::{
4    resources::{AuthorizationServer, OAuthProtectedResource},
5    workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse},
6};
7use serde::Deserialize;
8
9/// OAuth client configuration containing essential client credentials.
10pub struct OAuthClient {
11    /// The redirect URI where the authorization server will send the user after authorization.
12    pub redirect_uri: String,
13    /// The unique client identifier for this OAuth client.
14    pub client_id: String,
15
16    /// The client secret used for authenticating with the authorization server.
17    pub client_secret: String,
18}
19
20#[derive(Clone, Deserialize)]
21#[serde(untagged)]
22enum WrappedParResponse {
23    ParResponse(ParResponse),
24    Error {
25        error: String,
26        error_description: Option<String>,
27    },
28}
29
30/// Represents an authenticated AT Protocol session.
31///
32/// This structure contains all the information needed to make authenticated
33/// requests to AT Protocol services after a successful OAuth flow.
34#[derive(Clone, Deserialize)]
35pub struct ATProtocolSession {
36    /// The Decentralized Identifier (DID) of the authenticated user.
37    pub did: String,
38    /// The handle (username) of the authenticated user.
39    pub handle: String,
40    /// The OAuth access token for making authenticated requests.
41    pub access_token: String,
42    /// The type of token (typically "Bearer").
43    pub token_type: String,
44    /// The list of OAuth scopes granted to this session.
45    pub scopes: Vec<String>,
46    /// The Personal Data Server (PDS) endpoint URL for this user.
47    pub pds_endpoint: String,
48    /// The DPoP (Demonstration of Proof-of-Possession) key in JWK format.
49    pub dpop_key: String,
50    /// Unix timestamp indicating when this session expires.
51    pub expires_at: i64,
52}
53
54#[derive(Deserialize, Clone)]
55#[serde(untagged)]
56enum WrappedATProtocolSession {
57    ATProtocolSession(ATProtocolSession),
58    Error {
59        error: String,
60        error_description: Option<String>,
61    },
62}
63
64/// Initiates an OAuth authorization flow using Pushed Authorization Request (PAR).
65///
66/// This function starts the OAuth flow by sending a PAR request to the authorization
67/// server. PAR allows the client to push the authorization request parameters to the
68/// authorization server before redirecting the user, providing enhanced security.
69///
70/// # Arguments
71///
72/// * `http_client` - The HTTP client to use for making requests
73/// * `oauth_client` - OAuth client configuration with credentials
74/// * `handle` - Optional user handle to pre-fill in the login form
75/// * `authorization_server` - Authorization server metadata
76/// * `oauth_request_state` - OAuth request state including PKCE challenge and state
77///
78/// # Returns
79///
80/// Returns a `ParResponse` containing the request URI to redirect the user to,
81/// or an error if the PAR request fails.
82///
83/// # Example
84///
85/// ```no_run
86/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
87/// use atproto_oauth_aip::workflow::{oauth_init, OAuthClient};
88/// use atproto_oauth::workflow::OAuthRequestState;
89/// # let http_client = reqwest::Client::new();
90/// let oauth_client = OAuthClient {
91///     redirect_uri: "https://example.com/callback".to_string(),
92///     client_id: "client123".to_string(),
93///     client_secret: "secret456".to_string(),
94/// };
95/// # let authorization_server = todo!();
96/// let oauth_request_state = OAuthRequestState {
97///     state: "random-state".to_string(),
98///     nonce: "random-nonce".to_string(),
99///     code_challenge: "code-challenge".to_string(),
100///     scope: "atproto transition:generic".to_string(),
101/// };
102/// let par_response = oauth_init(
103///     &http_client,
104///     &oauth_client,
105///     Some("alice.bsky.social"),
106///     &authorization_server,
107///     &oauth_request_state,
108/// ).await?;
109/// # Ok(())
110/// # }
111/// ```
112pub async fn oauth_init(
113    http_client: &reqwest::Client,
114    oauth_client: &OAuthClient,
115    handle: Option<&str>,
116    authorization_server: &AuthorizationServer,
117    oauth_request_state: &OAuthRequestState,
118) -> Result<ParResponse> {
119    let par_url = authorization_server
120        .pushed_authorization_request_endpoint
121        .clone();
122
123    let scope = &oauth_request_state.scope;
124
125    let mut params = vec![
126        ("client_id", oauth_client.client_id.as_str()),
127        ("code_challenge_method", "S256"),
128        ("code_challenge", &oauth_request_state.code_challenge),
129        ("redirect_uri", oauth_client.redirect_uri.as_str()),
130        ("response_type", "code"),
131        ("scope", scope),
132        ("state", oauth_request_state.state.as_str()),
133    ];
134    if let Some(value) = handle {
135        params.push(("login_hint", value));
136    }
137
138    let response: WrappedParResponse = http_client
139        .post(par_url)
140        .form(&params)
141        .basic_auth(
142            oauth_client.client_id.as_str(),
143            Some(oauth_client.client_secret.as_str()),
144        )
145        .send()
146        .await
147        .map_err(OAuthWorkflowError::ParRequestFailed)?
148        .json()
149        .await
150        .map_err(OAuthWorkflowError::ParResponseParseFailed)?;
151
152    match response {
153        WrappedParResponse::ParResponse(value) => Ok(value),
154        WrappedParResponse::Error {
155            error,
156            error_description,
157        } => {
158            let error_message = if let Some(value) = error_description {
159                format!("{error}: {value}")
160            } else {
161                error.to_string()
162            };
163            Err(OAuthWorkflowError::ParResponseInvalid {
164                message: error_message,
165            }
166            .into())
167        }
168    }
169}
170
171/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
172///
173/// After the user has authorized the application and been redirected back with an
174/// authorization code, this function exchanges that code for access tokens using
175/// the token endpoint.
176///
177/// # Arguments
178///
179/// * `http_client` - The HTTP client to use for making requests
180/// * `oauth_client` - OAuth client configuration with credentials
181/// * `authorization_server` - Authorization server metadata
182/// * `callback_code` - The authorization code received in the callback
183/// * `oauth_request` - The original OAuth request containing the PKCE verifier
184///
185/// # Returns
186///
187/// Returns a `TokenResponse` containing the access token and other token information,
188/// or an error if the token exchange fails.
189///
190/// # Example
191///
192/// ```no_run
193/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
194/// use atproto_oauth_aip::workflow::oauth_complete;
195/// # let http_client = reqwest::Client::new();
196/// # let oauth_client = todo!();
197/// # let authorization_server = todo!();
198/// # let oauth_request = todo!();
199/// let token_response = oauth_complete(
200///     &http_client,
201///     &oauth_client,
202///     &authorization_server,
203///     "auth_code_from_callback",
204///     &oauth_request,
205/// ).await?;
206/// println!("Access token: {}", token_response.access_token);
207/// # Ok(())
208/// # }
209/// ```
210pub async fn oauth_complete(
211    http_client: &reqwest::Client,
212    oauth_client: &OAuthClient,
213    authorization_server: &AuthorizationServer,
214    callback_code: &str,
215    oauth_request: &OAuthRequest,
216) -> Result<TokenResponse> {
217    let params = [
218        ("client_id", oauth_client.client_id.as_str()),
219        ("redirect_uri", oauth_client.redirect_uri.as_str()),
220        ("grant_type", "authorization_code"),
221        ("code", callback_code),
222        ("code_verifier", &oauth_request.pkce_verifier),
223    ];
224
225    http_client
226        .post(&authorization_server.token_endpoint)
227        .basic_auth(
228            oauth_client.client_id.as_str(),
229            Some(oauth_client.client_secret.as_str()),
230        )
231        .form(&params)
232        .send()
233        .await
234        .map_err(OAuthWorkflowError::TokenRequestFailed)?
235        .json()
236        .await
237        .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
238}
239
240/// Exchanges an OAuth access token for an AT Protocol session.
241///
242/// This function takes an OAuth access token and exchanges it for a full
243/// AT Protocol session, which includes additional information like the user's
244/// DID, handle, and PDS endpoint. This is specific to AT Protocol's OAuth
245/// implementation.
246///
247/// # Arguments
248///
249/// * `http_client` - The HTTP client to use for making requests
250/// * `protected_resource` - The protected resource metadata
251/// * `access_token` - The OAuth access token to exchange
252///
253/// # Returns
254///
255/// Returns an `ATProtocolSession` with full session information,
256/// or an error if the session exchange fails.
257///
258/// # Example
259///
260/// ```no_run
261/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
262/// use atproto_oauth_aip::workflow::session_exchange;
263/// # let http_client = reqwest::Client::new();
264/// # let protected_resource = todo!();
265/// # let access_token = "example_token";
266/// let session = session_exchange(
267///     &http_client,
268///     &protected_resource,
269///     access_token,
270/// ).await?;
271/// println!("Authenticated as {} ({})", session.handle, session.did);
272/// println!("PDS endpoint: {}", session.pds_endpoint);
273/// # Ok(())
274/// # }
275/// ```
276pub async fn session_exchange(
277    http_client: &reqwest::Client,
278    protected_resource: &OAuthProtectedResource,
279    access_token: &str,
280) -> Result<ATProtocolSession> {
281    let response = http_client
282        .get(format!(
283            "{}/api/atprotocol/session",
284            protected_resource.resource
285        ))
286        .bearer_auth(access_token)
287        .send()
288        .await
289        .map_err(OAuthWorkflowError::SessionRequestFailed)?
290        .json()
291        .await
292        .map_err(OAuthWorkflowError::SessionResponseParseFailed)?;
293
294    match response {
295        WrappedATProtocolSession::ATProtocolSession(value) => Ok(value),
296        WrappedATProtocolSession::Error {
297            error,
298            error_description,
299        } => {
300            let error_message = if let Some(value) = error_description {
301                format!("{error}: {value}")
302            } else {
303                error.to_string()
304            };
305            Err(OAuthWorkflowError::SessionResponseInvalid {
306                message: error_message,
307            }
308            .into())
309        }
310    }
311}