use atproto_identity::key::KeyData;
use chrono::{DateTime, Utc};
use rand::distributions::{Alphanumeric, DistString};
use reqwest_chain::ChainMiddleware;
use reqwest_middleware::ClientBuilder;
use serde::Deserialize;
use std::collections::HashMap;
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::{
dpop::{DpopRetry, auth_dpop},
errors::OAuthClientError,
jwt::{Claims, Header, JoseClaims, mint},
resources::{AuthorizationServer, pds_resources},
};
#[derive(Clone, Deserialize)]
pub struct ParResponse {
pub request_uri: String,
pub expires_in: u64,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
pub struct OAuthRequestState {
pub state: String,
pub nonce: String,
pub code_challenge: String,
pub scope: String,
}
#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
pub struct OAuthClient {
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub redirect_uri: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub client_id: String,
pub private_signing_key_data: KeyData,
}
#[derive(Clone, PartialEq)]
#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
pub struct OAuthRequest {
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub oauth_state: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub issuer: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub authorization_server: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub nonce: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub pkce_verifier: String,
pub signing_public_key: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub dpop_private_key: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub created_at: DateTime<Utc>,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub expires_at: DateTime<Utc>,
}
impl std::fmt::Debug for OAuthRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthRequest")
.field("oauth_state", &self.oauth_state)
.field("issuer", &self.issuer)
.field("authorization_server", &self.authorization_server)
.field("nonce", &self.nonce)
.field("pkce_verifier", &"[REDACTED]")
.field("signing_public_key", &self.signing_public_key)
.field("dpop_private_key", &"[REDACTED]")
.field("created_at", &self.created_at)
.field("expires_at", &self.expires_at)
.finish()
}
}
#[derive(Clone, Deserialize)]
#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub refresh_token: Option<String>,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub scope: String,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub expires_in: u32,
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub sub: Option<String>,
#[serde(flatten)]
#[cfg_attr(feature = "zeroize", zeroize(skip))]
pub extra: HashMap<String, serde_json::Value>,
}
pub async fn oauth_init(
http_client: &reqwest::Client,
oauth_client: &OAuthClient,
dpop_key_data: &KeyData,
login_hint: Option<&str>,
authorization_server: &AuthorizationServer,
oauth_request_state: &OAuthRequestState,
) -> Result<ParResponse, OAuthClientError> {
oauth_init_with_prompt(
http_client,
oauth_client,
dpop_key_data,
login_hint,
None,
authorization_server,
oauth_request_state,
)
.await
}
pub async fn oauth_init_with_prompt(
http_client: &reqwest::Client,
oauth_client: &OAuthClient,
dpop_key_data: &KeyData,
login_hint: Option<&str>,
prompt: Option<&str>,
authorization_server: &AuthorizationServer,
oauth_request_state: &OAuthRequestState,
) -> Result<ParResponse, OAuthClientError> {
let par_url = authorization_server
.pushed_authorization_request_endpoint
.clone();
let scope = &oauth_request_state.scope;
let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
.try_into()
.map_err(OAuthClientError::JWTHeaderCreationFailed)?;
let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
let client_assertion_claims = Claims::new(JoseClaims {
issuer: Some(oauth_client.client_id.clone()),
subject: Some(oauth_client.client_id.clone()),
audience: Some(authorization_server.issuer.clone()),
json_web_token_id: Some(client_assertion_jti),
issued_at: Some(chrono::Utc::now().timestamp().cast_unsigned()),
..Default::default()
});
let client_assertion_token = mint(
&oauth_client.private_signing_key_data,
&client_assertion_header,
&client_assertion_claims,
)
.map_err(OAuthClientError::MintTokenFailed)?;
let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &par_url)
.map_err(OAuthClientError::DpopTokenCreationFailed)?;
let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
let dpop_retry_client = ClientBuilder::new(http_client.clone())
.with(ChainMiddleware::new(dpop_retry.clone()))
.build();
let mut params = vec![
("response_type", "code"),
("code_challenge", &oauth_request_state.code_challenge),
("code_challenge_method", "S256"),
("client_id", oauth_client.client_id.as_str()),
("state", oauth_request_state.state.as_str()),
("redirect_uri", oauth_client.redirect_uri.as_str()),
("scope", scope),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", client_assertion_token.as_str()),
];
if let Some(value) = login_hint {
params.push(("login_hint", value));
}
if let Some(value) = prompt {
params.push(("prompt", value));
}
let response = dpop_retry_client
.post(par_url)
.header("DPoP", dpop_token.as_str())
.form(¶ms)
.send()
.await
.map_err(OAuthClientError::PARHttpRequestFailed)?
.json()
.await
.map_err(OAuthClientError::PARResponseJsonParsingFailed)?;
Ok(response)
}
pub async fn oauth_complete(
http_client: &reqwest::Client,
oauth_client: &OAuthClient,
dpop_key_data: &KeyData,
callback_code: &str,
oauth_request: &OAuthRequest,
authorization_server: &AuthorizationServer,
) -> Result<TokenResponse, OAuthClientError> {
let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
.try_into()
.map_err(OAuthClientError::JWTHeaderCreationFailed)?;
let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
let client_assertion_claims = Claims::new(JoseClaims {
issuer: Some(oauth_client.client_id.clone()),
subject: Some(oauth_client.client_id.clone()),
audience: Some(authorization_server.issuer.clone()),
json_web_token_id: Some(client_assertion_jti),
issued_at: Some(chrono::Utc::now().timestamp().cast_unsigned()),
..Default::default()
});
let client_assertion_token = mint(
&oauth_client.private_signing_key_data,
&client_assertion_header,
&client_assertion_claims,
)
.map_err(OAuthClientError::MintTokenFailed)?;
let params = [
("client_id", oauth_client.client_id.as_str()),
("redirect_uri", oauth_client.redirect_uri.as_str()),
("grant_type", "authorization_code"),
("code", callback_code),
("code_verifier", &oauth_request.pkce_verifier),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", client_assertion_token.as_str()),
];
let token_endpoint = authorization_server.token_endpoint.clone();
let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
.map_err(OAuthClientError::DpopTokenCreationFailed)?;
let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
let dpop_retry_client = ClientBuilder::new(http_client.clone())
.with(ChainMiddleware::new(dpop_retry.clone()))
.build();
dpop_retry_client
.post(token_endpoint)
.header("DPoP", dpop_token.as_str())
.form(¶ms)
.send()
.await
.map_err(OAuthClientError::TokenHttpRequestFailed)?
.json()
.await
.map_err(OAuthClientError::TokenResponseJsonParsingFailed)
}
pub async fn oauth_refresh(
http_client: &reqwest::Client,
oauth_client: &OAuthClient,
dpop_key_data: &KeyData,
refresh_token: &str,
document: &atproto_identity::model::Document,
) -> Result<TokenResponse, OAuthClientError> {
let pds_endpoints = document.pds_endpoints();
let pds_endpoint = pds_endpoints
.first()
.ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
.try_into()
.map_err(OAuthClientError::JWTHeaderCreationFailed)?;
let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
let client_assertion_claims = Claims::new(JoseClaims {
issuer: Some(oauth_client.client_id.clone()),
subject: Some(oauth_client.client_id.clone()),
audience: Some(authorization_server.issuer.clone()),
json_web_token_id: Some(client_assertion_jti),
issued_at: Some(chrono::Utc::now().timestamp().cast_unsigned()),
..Default::default()
});
let client_assertion_token = mint(
&oauth_client.private_signing_key_data,
&client_assertion_header,
&client_assertion_claims,
)
.map_err(OAuthClientError::MintTokenFailed)?;
let params = [
("client_id", oauth_client.client_id.as_str()),
("redirect_uri", oauth_client.redirect_uri.as_str()),
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", client_assertion_token.as_str()),
];
let token_endpoint = authorization_server.token_endpoint.clone();
let (dpop_token, dpop_header, dpop_claims) = auth_dpop(dpop_key_data, "POST", &token_endpoint)
.map_err(OAuthClientError::DpopTokenCreationFailed)?;
let dpop_retry = DpopRetry::new(dpop_header, dpop_claims, dpop_key_data.clone(), true);
let dpop_retry_client = ClientBuilder::new(http_client.clone())
.with(ChainMiddleware::new(dpop_retry.clone()))
.build();
dpop_retry_client
.post(token_endpoint)
.header("DPoP", dpop_token.as_str())
.form(¶ms)
.send()
.await
.map_err(OAuthClientError::TokenHttpRequestFailed)?
.json()
.await
.map_err(OAuthClientError::TokenResponseJsonParsingFailed)
}