atproto_oauth_aip/workflow.rs
1use crate::errors::OAuthWorkflowError;
2use anyhow::Result;
3use atproto_oauth::{
4 resources::{AuthorizationServer, OAuthProtectedResource},
5 workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse},
6};
7use serde::Deserialize;
8
9/// OAuth client configuration containing essential client credentials.
10pub struct OAuthClient {
11 /// The redirect URI where the authorization server will send the user after authorization.
12 pub redirect_uri: String,
13 /// The unique client identifier for this OAuth client.
14 pub client_id: String,
15
16 /// The client secret used for authenticating with the authorization server.
17 pub client_secret: String,
18}
19
20#[derive(Clone, Deserialize)]
21#[serde(untagged)]
22enum WrappedParResponse {
23 ParResponse(ParResponse),
24 Error {
25 error: String,
26 error_description: Option<String>,
27 },
28}
29
30/// Represents an authenticated AT Protocol session.
31///
32/// This structure contains all the information needed to make authenticated
33/// requests to AT Protocol services after a successful OAuth flow.
34#[derive(Clone, Deserialize)]
35pub struct ATProtocolSession {
36 /// The Decentralized Identifier (DID) of the authenticated user.
37 pub did: String,
38 /// The handle (username) of the authenticated user.
39 pub handle: String,
40 /// The OAuth access token for making authenticated requests.
41 pub access_token: String,
42 /// The type of token (typically "Bearer").
43 pub token_type: String,
44 /// The list of OAuth scopes granted to this session.
45 pub scopes: Vec<String>,
46 /// The Personal Data Server (PDS) endpoint URL for this user.
47 pub pds_endpoint: String,
48 /// The DPoP (Demonstration of Proof-of-Possession) key in JWK format.
49 pub dpop_key: String,
50 /// Unix timestamp indicating when this session expires.
51 pub expires_at: i64,
52}
53
54#[derive(Deserialize, Clone)]
55#[serde(untagged)]
56enum WrappedATProtocolSession {
57 ATProtocolSession(ATProtocolSession),
58 Error {
59 error: String,
60 error_description: Option<String>,
61 },
62}
63
64/// Initiates an OAuth authorization flow using Pushed Authorization Request (PAR).
65///
66/// This function starts the OAuth flow by sending a PAR request to the authorization
67/// server. PAR allows the client to push the authorization request parameters to the
68/// authorization server before redirecting the user, providing enhanced security.
69///
70/// # Arguments
71///
72/// * `http_client` - The HTTP client to use for making requests
73/// * `oauth_client` - OAuth client configuration with credentials
74/// * `handle` - Optional user handle to pre-fill in the login form
75/// * `authorization_server` - Authorization server metadata
76/// * `oauth_request_state` - OAuth request state including PKCE challenge and state
77///
78/// # Returns
79///
80/// Returns a `ParResponse` containing the request URI to redirect the user to,
81/// or an error if the PAR request fails.
82///
83/// # Example
84///
85/// ```no_run
86/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
87/// use atproto_oauth_aip::workflow::{oauth_init, OAuthClient};
88/// use atproto_oauth::workflow::OAuthRequestState;
89/// # let http_client = reqwest::Client::new();
90/// let oauth_client = OAuthClient {
91/// redirect_uri: "https://example.com/callback".to_string(),
92/// client_id: "client123".to_string(),
93/// client_secret: "secret456".to_string(),
94/// };
95/// # let authorization_server = todo!();
96/// let oauth_request_state = OAuthRequestState {
97/// state: "random-state".to_string(),
98/// nonce: "random-nonce".to_string(),
99/// code_challenge: "code-challenge".to_string(),
100/// scope: "atproto transition:generic".to_string(),
101/// };
102/// let par_response = oauth_init(
103/// &http_client,
104/// &oauth_client,
105/// Some("alice.bsky.social"),
106/// &authorization_server,
107/// &oauth_request_state,
108/// ).await?;
109/// # Ok(())
110/// # }
111/// ```
112pub async fn oauth_init(
113 http_client: &reqwest::Client,
114 oauth_client: &OAuthClient,
115 handle: Option<&str>,
116 authorization_server: &AuthorizationServer,
117 oauth_request_state: &OAuthRequestState,
118) -> Result<ParResponse> {
119 let par_url = authorization_server
120 .pushed_authorization_request_endpoint
121 .clone();
122
123 let scope = &oauth_request_state.scope;
124
125 let mut params = vec![
126 ("client_id", oauth_client.client_id.as_str()),
127 ("code_challenge_method", "S256"),
128 ("code_challenge", &oauth_request_state.code_challenge),
129 ("redirect_uri", oauth_client.redirect_uri.as_str()),
130 ("response_type", "code"),
131 ("scope", scope),
132 ("state", oauth_request_state.state.as_str()),
133 ];
134 if let Some(value) = handle {
135 params.push(("login_hint", value));
136 }
137
138 let response: WrappedParResponse = http_client
139 .post(par_url)
140 .form(¶ms)
141 .basic_auth(
142 oauth_client.client_id.as_str(),
143 Some(oauth_client.client_secret.as_str()),
144 )
145 .send()
146 .await
147 .map_err(OAuthWorkflowError::ParRequestFailed)?
148 .json()
149 .await
150 .map_err(OAuthWorkflowError::ParResponseParseFailed)?;
151
152 match response {
153 WrappedParResponse::ParResponse(value) => Ok(value),
154 WrappedParResponse::Error {
155 error,
156 error_description,
157 } => {
158 let error_message = if let Some(value) = error_description {
159 format!("{error}: {value}")
160 } else {
161 error.to_string()
162 };
163 Err(OAuthWorkflowError::ParResponseInvalid {
164 message: error_message,
165 }
166 .into())
167 }
168 }
169}
170
171/// Completes the OAuth authorization flow by exchanging the authorization code for tokens.
172///
173/// After the user has authorized the application and been redirected back with an
174/// authorization code, this function exchanges that code for access tokens using
175/// the token endpoint.
176///
177/// # Arguments
178///
179/// * `http_client` - The HTTP client to use for making requests
180/// * `oauth_client` - OAuth client configuration with credentials
181/// * `authorization_server` - Authorization server metadata
182/// * `callback_code` - The authorization code received in the callback
183/// * `oauth_request` - The original OAuth request containing the PKCE verifier
184///
185/// # Returns
186///
187/// Returns a `TokenResponse` containing the access token and other token information,
188/// or an error if the token exchange fails.
189///
190/// # Example
191///
192/// ```no_run
193/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
194/// use atproto_oauth_aip::workflow::oauth_complete;
195/// # let http_client = reqwest::Client::new();
196/// # let oauth_client = todo!();
197/// # let authorization_server = todo!();
198/// # let oauth_request = todo!();
199/// let token_response = oauth_complete(
200/// &http_client,
201/// &oauth_client,
202/// &authorization_server,
203/// "auth_code_from_callback",
204/// &oauth_request,
205/// ).await?;
206/// println!("Access token: {}", token_response.access_token);
207/// # Ok(())
208/// # }
209/// ```
210pub async fn oauth_complete(
211 http_client: &reqwest::Client,
212 oauth_client: &OAuthClient,
213 authorization_server: &AuthorizationServer,
214 callback_code: &str,
215 oauth_request: &OAuthRequest,
216) -> Result<TokenResponse> {
217 let params = [
218 ("client_id", oauth_client.client_id.as_str()),
219 ("redirect_uri", oauth_client.redirect_uri.as_str()),
220 ("grant_type", "authorization_code"),
221 ("code", callback_code),
222 ("code_verifier", &oauth_request.pkce_verifier),
223 ];
224
225 http_client
226 .post(&authorization_server.token_endpoint)
227 .basic_auth(
228 oauth_client.client_id.as_str(),
229 Some(oauth_client.client_secret.as_str()),
230 )
231 .form(¶ms)
232 .send()
233 .await
234 .map_err(OAuthWorkflowError::TokenRequestFailed)?
235 .json()
236 .await
237 .map_err(|e| OAuthWorkflowError::TokenResponseParseFailed(e).into())
238}
239
240/// Exchanges an OAuth access token for an AT Protocol session.
241///
242/// This function takes an OAuth access token and exchanges it for a full
243/// AT Protocol session, which includes additional information like the user's
244/// DID, handle, and PDS endpoint. This is specific to AT Protocol's OAuth
245/// implementation.
246///
247/// # Arguments
248///
249/// * `http_client` - The HTTP client to use for making requests
250/// * `protected_resource` - The protected resource metadata
251/// * `access_token` - The OAuth access token to exchange
252///
253/// # Returns
254///
255/// Returns an `ATProtocolSession` with full session information,
256/// or an error if the session exchange fails.
257///
258/// # Example
259///
260/// ```no_run
261/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
262/// use atproto_oauth_aip::workflow::session_exchange;
263/// # let http_client = reqwest::Client::new();
264/// # let protected_resource = todo!();
265/// # let access_token = "example_token";
266/// let session = session_exchange(
267/// &http_client,
268/// &protected_resource,
269/// access_token,
270/// ).await?;
271/// println!("Authenticated as {} ({})", session.handle, session.did);
272/// println!("PDS endpoint: {}", session.pds_endpoint);
273/// # Ok(())
274/// # }
275/// ```
276pub async fn session_exchange(
277 http_client: &reqwest::Client,
278 protected_resource: &OAuthProtectedResource,
279 access_token: &str,
280) -> Result<ATProtocolSession> {
281 let response = http_client
282 .get(format!(
283 "{}/api/atprotocol/session",
284 protected_resource.resource
285 ))
286 .bearer_auth(access_token)
287 .send()
288 .await
289 .map_err(OAuthWorkflowError::SessionRequestFailed)?
290 .json()
291 .await
292 .map_err(OAuthWorkflowError::SessionResponseParseFailed)?;
293
294 match response {
295 WrappedATProtocolSession::ATProtocolSession(value) => Ok(value),
296 WrappedATProtocolSession::Error {
297 error,
298 error_description,
299 } => {
300 let error_message = if let Some(value) = error_description {
301 format!("{error}: {value}")
302 } else {
303 error.to_string()
304 };
305 Err(OAuthWorkflowError::SessionResponseInvalid {
306 message: error_message,
307 }
308 .into())
309 }
310 }
311}