1use atproto_identity::key::KeyData;
80use chrono::{DateTime, Utc};
81use rand::distributions::{Alphanumeric, DistString};
82use reqwest_chain::ChainMiddleware;
83use reqwest_middleware::ClientBuilder;
84use serde::Deserialize;
85use std::collections::HashMap;
86
87#[cfg(feature = "zeroize")]
88use zeroize::{Zeroize, ZeroizeOnDrop};
89
90use crate::{
91 dpop::{DpopRetry, auth_dpop},
92 errors::OAuthClientError,
93 jwt::{Claims, Header, JoseClaims, mint},
94 resources::{AuthorizationServer, pds_resources},
95};
96
97#[derive(Clone, Deserialize)]
102pub struct ParResponse {
103 pub request_uri: String,
105 pub expires_in: u64,
107
108 #[serde(flatten)]
110 pub extra: HashMap<String, serde_json::Value>,
111}
112
113pub struct OAuthRequestState {
118 pub state: String,
120 pub nonce: String,
122 pub code_challenge: String,
124 pub scope: String,
126}
127
128#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
133pub struct OAuthClient {
134 #[cfg_attr(feature = "zeroize", zeroize(skip))]
136 pub redirect_uri: String,
137
138 #[cfg_attr(feature = "zeroize", zeroize(skip))]
140 pub client_id: String,
141
142 pub private_signing_key_data: KeyData,
144}
145
146#[derive(Clone, PartialEq)]
151#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
152pub struct OAuthRequest {
153 #[cfg_attr(feature = "zeroize", zeroize(skip))]
155 pub oauth_state: String,
156
157 #[cfg_attr(feature = "zeroize", zeroize(skip))]
159 pub issuer: String,
160
161 #[cfg_attr(feature = "zeroize", zeroize(skip))]
163 pub authorization_server: String,
164
165 #[cfg_attr(feature = "zeroize", zeroize(skip))]
167 pub nonce: String,
168
169 #[cfg_attr(feature = "zeroize", zeroize(skip))]
171 pub pkce_verifier: String,
172
173 pub signing_public_key: String,
175
176 #[cfg_attr(feature = "zeroize", zeroize(skip))]
178 pub dpop_private_key: String,
179
180 #[cfg_attr(feature = "zeroize", zeroize(skip))]
182 pub created_at: DateTime<Utc>,
183
184 #[cfg_attr(feature = "zeroize", zeroize(skip))]
186 pub expires_at: DateTime<Utc>,
187}
188
189impl std::fmt::Debug for OAuthRequest {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.debug_struct("OAuthRequest")
192 .field("oauth_state", &self.oauth_state)
193 .field("issuer", &self.issuer)
194 .field("authorization_server", &self.authorization_server)
195 .field("nonce", &self.nonce)
196 .field("pkce_verifier", &"[REDACTED]")
197 .field("signing_public_key", &self.signing_public_key)
198 .field("dpop_private_key", &"[REDACTED]")
199 .field("created_at", &self.created_at)
200 .field("expires_at", &self.expires_at)
201 .finish()
202 }
203}
204
205#[derive(Clone, Deserialize)]
210#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
211pub struct TokenResponse {
212 pub access_token: String,
214
215 pub token_type: String,
217
218 pub refresh_token: String,
220
221 #[cfg_attr(feature = "zeroize", zeroize(skip))]
223 pub scope: String,
224
225 #[cfg_attr(feature = "zeroize", zeroize(skip))]
227 pub expires_in: u32,
228
229 #[cfg_attr(feature = "zeroize", zeroize(skip))]
231 pub sub: Option<String>,
232
233 #[serde(flatten)]
235 #[cfg_attr(feature = "zeroize", zeroize(skip))]
236 pub extra: HashMap<String, serde_json::Value>,
237}
238
239pub async fn oauth_init(
259 http_client: &reqwest::Client,
260 oauth_client: &OAuthClient,
261 dpop_key_data: &KeyData,
262 login_hint: Option<&str>,
263 authorization_server: &AuthorizationServer,
264 oauth_request_state: &OAuthRequestState,
265) -> Result<ParResponse, OAuthClientError> {
266 let par_url = authorization_server
267 .pushed_authorization_request_endpoint
268 .clone();
269
270 let scope = &oauth_request_state.scope;
271
272 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
273 .try_into()
274 .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
275
276 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
277 let client_assertion_claims = Claims::new(JoseClaims {
278 issuer: Some(oauth_client.client_id.clone()),
279 subject: Some(oauth_client.client_id.clone()),
280 audience: Some(authorization_server.issuer.clone()),
281 json_web_token_id: Some(client_assertion_jti),
282 issued_at: Some(chrono::Utc::now().timestamp() as u64),
283 ..Default::default()
284 });
285
286 let client_assertion_token = mint(
287 &oauth_client.private_signing_key_data,
288 &client_assertion_header,
289 &client_assertion_claims,
290 )
291 .map_err(OAuthClientError::MintTokenFailed)?;
292
293 let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
294 .map_err(OAuthClientError::DpopTokenCreationFailed)?;
295
296 let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
297
298 let dpop_retry_client = ClientBuilder::new(http_client.clone())
299 .with(ChainMiddleware::new(dpop_retry.clone()))
300 .build();
301
302 let mut params = vec![
303 ("response_type", "code"),
304 ("code_challenge", &oauth_request_state.code_challenge),
305 ("code_challenge_method", "S256"),
306 ("client_id", oauth_client.client_id.as_str()),
307 ("state", oauth_request_state.state.as_str()),
308 ("redirect_uri", oauth_client.redirect_uri.as_str()),
309 ("scope", scope),
310 (
311 "client_assertion_type",
312 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
313 ),
314 ("client_assertion", client_assertion_token.as_str()),
315 ];
316 if let Some(value) = login_hint {
317 params.push(("login_hint", value));
318 }
319
320 let response = dpop_retry_client
321 .post(par_url)
322 .header("DPoP", dpop_token.as_str())
323 .form(¶ms)
324 .send()
325 .await
326 .map_err(OAuthClientError::PARHttpRequestFailed)?
327 .json()
328 .await
329 .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
330
331 Ok(response)
332}
333
334pub async fn oauth_complete(
354 http_client: &reqwest::Client,
355 oauth_client: &OAuthClient,
356 dpop_key_data: &KeyData,
357 callback_code: &str,
358 oauth_request: &OAuthRequest,
359 authorization_server: &AuthorizationServer,
360) -> Result<TokenResponse, OAuthClientError> {
361 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
368 .try_into()
369 .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
370
371 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
372 let client_assertion_claims = Claims::new(JoseClaims {
373 issuer: Some(oauth_client.client_id.clone()),
374 subject: Some(oauth_client.client_id.clone()),
375 audience: Some(authorization_server.issuer.clone()),
376 json_web_token_id: Some(client_assertion_jti),
377 issued_at: Some(chrono::Utc::now().timestamp() as u64),
378 ..Default::default()
379 });
380
381 let client_assertion_token = mint(
382 &oauth_client.private_signing_key_data,
383 &client_assertion_header,
384 &client_assertion_claims,
385 )
386 .map_err(OAuthClientError::MintTokenFailed)?;
387
388 let params = [
389 ("client_id", oauth_client.client_id.as_str()),
390 ("redirect_uri", oauth_client.redirect_uri.as_str()),
391 ("grant_type", "authorization_code"),
392 ("code", callback_code),
393 ("code_verifier", &oauth_request.pkce_verifier),
394 (
395 "client_assertion_type",
396 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
397 ),
398 ("client_assertion", client_assertion_token.as_str()),
399 ];
400
401 let token_endpoint = authorization_server.token_endpoint.clone();
402
403 let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
404 .map_err(OAuthClientError::DpopTokenCreationFailed)?;
405
406 let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
407
408 let dpop_retry_client = ClientBuilder::new(http_client.clone())
409 .with(ChainMiddleware::new(dpop_retry.clone()))
410 .build();
411
412 dpop_retry_client
413 .post(token_endpoint)
414 .header("DPoP", dpop_token.as_str())
415 .form(¶ms)
416 .send()
417 .await
418 .map_err(OAuthClientError::TokenHttpRequestFailed)?
419 .json()
420 .await
421 .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
422}
423
424pub async fn oauth_refresh(
443 http_client: &reqwest::Client,
444 oauth_client: &OAuthClient,
445 dpop_key_data: &KeyData,
446 refresh_token: &str,
447 document: &atproto_identity::model::Document,
448) -> Result<TokenResponse, OAuthClientError> {
449 let pds_endpoints = document.pds_endpoints();
450 let pds_endpoint = pds_endpoints
451 .first()
452 .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
453 let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
454
455 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
456 .try_into()
457 .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
458
459 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
460 let client_assertion_claims = Claims::new(JoseClaims {
461 issuer: Some(oauth_client.client_id.clone()),
462 subject: Some(oauth_client.client_id.clone()),
463 audience: Some(authorization_server.issuer.clone()),
464 json_web_token_id: Some(client_assertion_jti),
465 issued_at: Some(chrono::Utc::now().timestamp() as u64),
466 ..Default::default()
467 });
468
469 let client_assertion_token = mint(
470 &oauth_client.private_signing_key_data,
471 &client_assertion_header,
472 &client_assertion_claims,
473 )
474 .map_err(OAuthClientError::MintTokenFailed)?;
475
476 let params = [
477 ("client_id", oauth_client.client_id.as_str()),
478 ("redirect_uri", oauth_client.redirect_uri.as_str()),
479 ("grant_type", "refresh_token"),
480 ("refresh_token", refresh_token),
481 (
482 "client_assertion_type",
483 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
484 ),
485 ("client_assertion", client_assertion_token.as_str()),
486 ];
487
488 let token_endpoint = authorization_server.token_endpoint.clone();
489
490 let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
491 .map_err(OAuthClientError::DpopTokenCreationFailed)?;
492
493 let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
494
495 let dpop_retry_client = ClientBuilder::new(http_client.clone())
496 .with(ChainMiddleware::new(dpop_retry.clone()))
497 .build();
498
499 dpop_retry_client
500 .post(token_endpoint)
501 .header("DPoP", dpop_token.as_str())
502 .form(¶ms)
503 .send()
504 .await
505 .map_err(OAuthClientError::TokenHttpRequestFailed)?
506 .json()
507 .await
508 .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
509}