1use atproto_identity::key::KeyData;
74use chrono::{DateTime, Utc};
75use rand::distributions::{Alphanumeric, DistString};
76use reqwest_chain::ChainMiddleware;
77use reqwest_middleware::ClientBuilder;
78use serde::Deserialize;
79use std::collections::HashMap;
80
81#[cfg(feature = "zeroize")]
82use zeroize::{Zeroize, ZeroizeOnDrop};
83
84use crate::{
85 dpop::{DpopRetry, auth_dpop},
86 errors::OAuthClientError,
87 jwt::{Claims, Header, JoseClaims, mint},
88 resources::{AuthorizationServer, pds_resources},
89};
90
91#[derive(Clone, Deserialize)]
96pub struct ParResponse {
97 pub request_uri: String,
99 pub expires_in: u64,
101
102 #[serde(flatten)]
104 pub extra: HashMap<String, serde_json::Value>,
105}
106
107pub struct OAuthRequestState {
112 pub state: String,
114 pub nonce: String,
116 pub code_challenge: String,
118 pub scope: String,
120}
121
122#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
127pub struct OAuthClient {
128 #[cfg_attr(feature = "zeroize", zeroize(skip))]
130 pub redirect_uri: String,
131
132 #[cfg_attr(feature = "zeroize", zeroize(skip))]
134 pub client_id: String,
135
136 pub private_signing_key_data: KeyData,
138}
139
140#[derive(Clone, PartialEq)]
145#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
146pub struct OAuthRequest {
147 #[cfg_attr(feature = "zeroize", zeroize(skip))]
149 pub oauth_state: String,
150
151 #[cfg_attr(feature = "zeroize", zeroize(skip))]
153 pub issuer: String,
154
155 #[cfg_attr(feature = "zeroize", zeroize(skip))]
157 pub authorization_server: String,
158
159 #[cfg_attr(feature = "zeroize", zeroize(skip))]
161 pub nonce: String,
162
163 #[cfg_attr(feature = "zeroize", zeroize(skip))]
165 pub pkce_verifier: String,
166
167 pub signing_public_key: String,
169
170 #[cfg_attr(feature = "zeroize", zeroize(skip))]
172 pub dpop_private_key: String,
173
174 #[cfg_attr(feature = "zeroize", zeroize(skip))]
176 pub created_at: DateTime<Utc>,
177
178 #[cfg_attr(feature = "zeroize", zeroize(skip))]
180 pub expires_at: DateTime<Utc>,
181}
182
183impl std::fmt::Debug for OAuthRequest {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("OAuthRequest")
186 .field("oauth_state", &self.oauth_state)
187 .field("issuer", &self.issuer)
188 .field("authorization_server", &self.authorization_server)
189 .field("nonce", &self.nonce)
190 .field("pkce_verifier", &"[REDACTED]")
191 .field("signing_public_key", &self.signing_public_key)
192 .field("dpop_private_key", &"[REDACTED]")
193 .field("created_at", &self.created_at)
194 .field("expires_at", &self.expires_at)
195 .finish()
196 }
197}
198
199#[derive(Clone, Deserialize)]
204#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
205pub struct TokenResponse {
206 pub access_token: String,
208
209 pub token_type: String,
211
212 pub refresh_token: String,
214
215 #[cfg_attr(feature = "zeroize", zeroize(skip))]
217 pub scope: String,
218
219 #[cfg_attr(feature = "zeroize", zeroize(skip))]
221 pub expires_in: u32,
222
223 #[cfg_attr(feature = "zeroize", zeroize(skip))]
225 pub sub: Option<String>,
226
227 #[serde(flatten)]
229 #[cfg_attr(feature = "zeroize", zeroize(skip))]
230 pub extra: HashMap<String, serde_json::Value>,
231}
232
233pub async fn oauth_init(
253 http_client: &reqwest::Client,
254 oauth_client: &OAuthClient,
255 dpop_key_data: &KeyData,
256 login_hint: Option<&str>,
257 authorization_server: &AuthorizationServer,
258 oauth_request_state: &OAuthRequestState,
259) -> Result<ParResponse, OAuthClientError> {
260 let par_url = authorization_server
261 .pushed_authorization_request_endpoint
262 .clone();
263
264 let scope = &oauth_request_state.scope;
265
266 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
267 .try_into()
268 .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
269
270 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
271 let client_assertion_claims = Claims::new(JoseClaims {
272 issuer: Some(oauth_client.client_id.clone()),
273 subject: Some(oauth_client.client_id.clone()),
274 audience: Some(authorization_server.issuer.clone()),
275 json_web_token_id: Some(client_assertion_jti),
276 issued_at: Some(chrono::Utc::now().timestamp() as u64),
277 ..Default::default()
278 });
279
280 let client_assertion_token = mint(
281 &oauth_client.private_signing_key_data,
282 &client_assertion_header,
283 &client_assertion_claims,
284 )
285 .map_err(OAuthClientError::MintTokenFailed)?;
286
287 let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
288 .map_err(OAuthClientError::DpopTokenCreationFailed)?;
289
290 let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
291
292 let dpop_retry_client = ClientBuilder::new(http_client.clone())
293 .with(ChainMiddleware::new(dpop_retry.clone()))
294 .build();
295
296 let mut params = vec![
297 ("response_type", "code"),
298 ("code_challenge", &oauth_request_state.code_challenge),
299 ("code_challenge_method", "S256"),
300 ("client_id", oauth_client.client_id.as_str()),
301 ("state", oauth_request_state.state.as_str()),
302 ("redirect_uri", oauth_client.redirect_uri.as_str()),
303 ("scope", scope),
304 (
305 "client_assertion_type",
306 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
307 ),
308 ("client_assertion", client_assertion_token.as_str()),
309 ];
310 if let Some(value) = login_hint {
311 params.push(("login_hint", value));
312 }
313
314 let response = dpop_retry_client
315 .post(par_url)
316 .header("DPoP", dpop_token.as_str())
317 .form(¶ms)
318 .send()
319 .await
320 .map_err(OAuthClientError::PARHttpRequestFailed)?
321 .json()
322 .await
323 .map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
324
325 Ok(response)
326}
327
328pub async fn oauth_complete(
348 http_client: &reqwest::Client,
349 oauth_client: &OAuthClient,
350 dpop_key_data: &KeyData,
351 callback_code: &str,
352 oauth_request: &OAuthRequest,
353 authorization_server: &AuthorizationServer,
354) -> Result<TokenResponse, OAuthClientError> {
355 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
362 .try_into()
363 .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
364
365 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
366 let client_assertion_claims = Claims::new(JoseClaims {
367 issuer: Some(oauth_client.client_id.clone()),
368 subject: Some(oauth_client.client_id.clone()),
369 audience: Some(authorization_server.issuer.clone()),
370 json_web_token_id: Some(client_assertion_jti),
371 issued_at: Some(chrono::Utc::now().timestamp() as u64),
372 ..Default::default()
373 });
374
375 let client_assertion_token = mint(
376 &oauth_client.private_signing_key_data,
377 &client_assertion_header,
378 &client_assertion_claims,
379 )
380 .map_err(OAuthClientError::MintTokenFailed)?;
381
382 let params = [
383 ("client_id", oauth_client.client_id.as_str()),
384 ("redirect_uri", oauth_client.redirect_uri.as_str()),
385 ("grant_type", "authorization_code"),
386 ("code", callback_code),
387 ("code_verifier", &oauth_request.pkce_verifier),
388 (
389 "client_assertion_type",
390 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
391 ),
392 ("client_assertion", client_assertion_token.as_str()),
393 ];
394
395 let token_endpoint = authorization_server.token_endpoint.clone();
396
397 let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
398 .map_err(OAuthClientError::DpopTokenCreationFailed)?;
399
400 let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
401
402 let dpop_retry_client = ClientBuilder::new(http_client.clone())
403 .with(ChainMiddleware::new(dpop_retry.clone()))
404 .build();
405
406 dpop_retry_client
407 .post(token_endpoint)
408 .header("DPoP", dpop_token.as_str())
409 .form(¶ms)
410 .send()
411 .await
412 .map_err(OAuthClientError::TokenHttpRequestFailed)?
413 .json()
414 .await
415 .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
416}
417
418pub async fn oauth_refresh(
437 http_client: &reqwest::Client,
438 oauth_client: &OAuthClient,
439 dpop_key_data: &KeyData,
440 refresh_token: &str,
441 document: &atproto_identity::model::Document,
442) -> Result<TokenResponse, OAuthClientError> {
443 let pds_endpoints = document.pds_endpoints();
444 let pds_endpoint = pds_endpoints
445 .first()
446 .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
447 let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
448
449 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
450 .try_into()
451 .map_err(OAuthClientError::JWTHeaderCreationFailed)?;
452
453 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
454 let client_assertion_claims = Claims::new(JoseClaims {
455 issuer: Some(oauth_client.client_id.clone()),
456 subject: Some(oauth_client.client_id.clone()),
457 audience: Some(authorization_server.issuer.clone()),
458 json_web_token_id: Some(client_assertion_jti),
459 issued_at: Some(chrono::Utc::now().timestamp() as u64),
460 ..Default::default()
461 });
462
463 let client_assertion_token = mint(
464 &oauth_client.private_signing_key_data,
465 &client_assertion_header,
466 &client_assertion_claims,
467 )
468 .map_err(OAuthClientError::MintTokenFailed)?;
469
470 let params = [
471 ("client_id", oauth_client.client_id.as_str()),
472 ("redirect_uri", oauth_client.redirect_uri.as_str()),
473 ("grant_type", "refresh_token"),
474 ("refresh_token", refresh_token),
475 (
476 "client_assertion_type",
477 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
478 ),
479 ("client_assertion", client_assertion_token.as_str()),
480 ];
481
482 let token_endpoint = authorization_server.token_endpoint.clone();
483
484 let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
485 .map_err(OAuthClientError::DpopTokenCreationFailed)?;
486
487 let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
488
489 let dpop_retry_client = ClientBuilder::new(http_client.clone())
490 .with(ChainMiddleware::new(dpop_retry.clone()))
491 .build();
492
493 dpop_retry_client
494 .post(token_endpoint)
495 .header("DPoP", dpop_token.as_str())
496 .form(¶ms)
497 .send()
498 .await
499 .map_err(OAuthClientError::TokenHttpRequestFailed)?
500 .json()
501 .await
502 .map_err(OAuthClientError::TokenResponseJsonParsingFailed)
503}