use std::sync::Arc;
use crate::crypto::P256SigningKey;
use crate::identity::Directory;
use crate::syntax::Did;
use crate::oauth::OAuthError;
use crate::oauth::client_auth::{ClientAuth, ConfidentialClientAuth, PublicClientAuth};
use crate::oauth::dpop::{self, NonceStore};
use crate::oauth::metadata::{self, ClientMetadata};
use crate::oauth::pkce::{self, base64url_encode};
use crate::oauth::session::{AuthState, Session, SessionStore, StateStore};
use crate::oauth::token;
pub struct OAuthClientConfig {
pub metadata: ClientMetadata,
pub session_store: Box<dyn SessionStore>,
pub state_store: Box<dyn StateStore>,
pub signing_key: Option<(P256SigningKey, String)>,
pub skip_issuer_verification: bool,
}
pub struct OAuthClient {
metadata: ClientMetadata,
sessions: Box<dyn SessionStore>,
states: Box<dyn StateStore>,
auth: Box<dyn ClientAuth>,
http: reqwest::Client,
nonces: Arc<NonceStore>,
skip_issuer_verification: bool,
refresh_locks: tokio::sync::Mutex<std::collections::HashMap<String, ()>>,
}
#[derive(Debug, Clone)]
pub struct AuthorizeOptions {
pub input: String,
pub redirect_uri: String,
pub scope: Option<String>,
pub state: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AuthorizeResult {
pub url: String,
pub state: String,
}
#[derive(Debug, Clone)]
pub struct CallbackParams {
pub code: String,
pub state: String,
pub iss: Option<String>,
}
impl OAuthClient {
pub fn new(config: OAuthClientConfig) -> Self {
let client_id = config.metadata.client_id.clone();
let auth: Box<dyn ClientAuth> = match config.signing_key {
Some((key, key_id)) => Box::new(ConfidentialClientAuth {
client_id,
key,
key_id,
}),
None => Box::new(PublicClientAuth { client_id }),
};
OAuthClient {
metadata: config.metadata,
sessions: config.session_store,
states: config.state_store,
auth,
http: reqwest::Client::new(),
nonces: Arc::new(NonceStore::new()),
skip_issuer_verification: config.skip_issuer_verification,
refresh_locks: tokio::sync::Mutex::new(std::collections::HashMap::new()),
}
}
pub async fn authorize(&self, opts: AuthorizeOptions) -> Result<AuthorizeResult, OAuthError> {
let did = self.resolve_input_to_did(&opts.input).await?;
let directory = Directory::new();
let identity = directory.lookup_did(&did).await?;
let pds_url = identity
.pds_endpoint()
.ok_or_else(|| OAuthError::Identity("no PDS endpoint in DID document".into()))?;
let pr_meta = metadata::fetch_protected_resource_metadata(pds_url).await?;
let issuer = pr_meta
.authorization_servers
.first()
.ok_or_else(|| {
OAuthError::InvalidMetadata("no authorization servers in resource metadata".into())
})?
.clone();
let as_meta = metadata::fetch_auth_server_metadata(&issuer).await?;
metadata::validate_auth_server_metadata(&as_meta)?;
let dpop_key = P256SigningKey::generate();
let pkce = pkce::generate_pkce();
let state = opts.state.unwrap_or_else(|| {
let mut bytes = [0u8; 16];
rand::fill(&mut bytes);
base64url_encode(&bytes)
});
let scope = opts.scope.unwrap_or_else(|| self.metadata.scope.clone());
let auth_state = AuthState {
issuer: issuer.clone(),
dpop_key_bytes: base64url_encode(&dpop_key.to_bytes()),
auth_method: self.metadata.token_endpoint_auth_method.clone(),
verifier: pkce.verifier.clone(),
redirect_uri: opts.redirect_uri.clone(),
app_state: state.clone(),
token_endpoint: as_meta.token_endpoint.clone(),
revocation_endpoint: as_meta.revocation_endpoint.clone(),
};
self.states.set(&state, &auth_state).await?;
let mut params: Vec<(String, String)> = vec![
("response_type".into(), "code".into()),
("code_challenge".into(), pkce.challenge),
("code_challenge_method".into(), "S256".into()),
("state".into(), state.clone()),
("redirect_uri".into(), opts.redirect_uri),
("scope".into(), scope),
("login_hint".into(), opts.input),
];
let par_origin =
NonceStore::origin_from_url(&as_meta.pushed_authorization_request_endpoint)?;
self.auth.apply(&mut params, &par_origin)?;
let par_endpoint = &as_meta.pushed_authorization_request_endpoint;
let nonce = self.nonces.get(&par_origin);
let proof =
dpop::create_dpop_proof(&dpop_key, "POST", par_endpoint, nonce.as_deref(), None)?;
let resp = self
.http
.post(par_endpoint)
.header("DPoP", &proof)
.form(¶ms)
.send()
.await?;
if let Some(new_nonce) = resp
.headers()
.get("DPoP-Nonce")
.and_then(|v| v.to_str().ok())
{
self.nonces.set(&par_origin, new_nonce.to_string());
}
let status = resp.status();
let resp_body: serde_json::Value = resp.json().await?;
let request_uri = if status == reqwest::StatusCode::BAD_REQUEST
&& resp_body.get("error").and_then(|v| v.as_str()) == Some("use_dpop_nonce")
{
let retry_nonce = self.nonces.get(&par_origin);
let retry_proof = dpop::create_dpop_proof(
&dpop_key,
"POST",
par_endpoint,
retry_nonce.as_deref(),
None,
)?;
let retry_resp = self
.http
.post(par_endpoint)
.header("DPoP", &retry_proof)
.form(¶ms)
.send()
.await?;
if let Some(new_nonce) = retry_resp
.headers()
.get("DPoP-Nonce")
.and_then(|v| v.to_str().ok())
{
self.nonces.set(&par_origin, new_nonce.to_string());
}
let retry_status = retry_resp.status();
let retry_body: serde_json::Value = retry_resp.json().await?;
if !retry_status.is_success() {
return Err(oauth_error_from_json(&retry_body));
}
extract_request_uri(&retry_body)?
} else if !status.is_success() {
return Err(oauth_error_from_json(&resp_body));
} else {
extract_request_uri(&resp_body)?
};
let mut auth_url = url::Url::parse(&as_meta.authorization_endpoint)
.map_err(|e| OAuthError::Http(format!("invalid authorization endpoint URL: {e}")))?;
auth_url
.query_pairs_mut()
.append_pair("client_id", &self.metadata.client_id)
.append_pair("request_uri", &request_uri);
let url = auth_url.to_string();
Ok(AuthorizeResult { url, state })
}
pub async fn callback(&self, params: CallbackParams) -> Result<Session, OAuthError> {
let auth_state = self
.states
.take(¶ms.state)
.await?
.ok_or(OAuthError::InvalidState)?;
match params.iss {
Some(ref iss) if iss != &auth_state.issuer => {
return Err(OAuthError::IssuerMismatch {
expected: auth_state.issuer.clone(),
actual: iss.clone(),
});
}
None => {
return Err(OAuthError::MissingIssuer);
}
_ => {}
}
let dpop_key = auth_state.dpop_key()?;
let token_set = token::exchange_code(
&self.http,
&auth_state.token_endpoint,
&auth_state.revocation_endpoint,
¶ms.code,
&auth_state.verifier,
&auth_state.redirect_uri,
self.auth.as_ref(),
&dpop_key,
&self.nonces,
)
.await?;
if !self.skip_issuer_verification {
let sub_did = Did::try_from(token_set.sub.as_str())
.map_err(|e| OAuthError::Identity(format!("invalid sub DID: {e}")))?;
let directory = Directory::new();
let identity = directory.lookup_did(&sub_did).await?;
let pds_url = identity.pds_endpoint().ok_or_else(|| {
OAuthError::IssuerVerification("no PDS endpoint in DID document".into())
})?;
let pr_meta = metadata::fetch_protected_resource_metadata(pds_url).await?;
let actual_issuer = pr_meta.authorization_servers.first().ok_or_else(|| {
OAuthError::IssuerVerification(
"no authorization servers in resource metadata".into(),
)
})?;
if *actual_issuer != auth_state.issuer {
token::revoke_token(
&self.http,
&auth_state.revocation_endpoint,
&token_set.access_token,
self.auth.as_ref(),
&dpop_key,
&self.nonces,
)
.await;
return Err(OAuthError::IssuerVerification(format!(
"AS mismatch: expected {}, got {}",
auth_state.issuer, actual_issuer
)));
}
}
let _ = self.sessions.delete(&token_set.sub).await;
let session = Session::from_key_and_tokens(&dpop_key, token_set);
self.sessions.set(&session.token_set.sub, &session).await?;
Ok(session)
}
pub async fn sign_out(&self, did: &str) -> Result<(), OAuthError> {
let session = self.sessions.get(did).await?;
if let Some(ref session) = session {
if let Ok(dpop_key) = session.dpop_key() {
token::revoke_token(
&self.http,
&session.token_set.revocation_endpoint,
&session.token_set.access_token,
self.auth.as_ref(),
&dpop_key,
&self.nonces,
)
.await;
}
}
self.sessions.delete(did).await?;
Ok(())
}
pub async fn get_session(&self, did: &str) -> Result<Session, OAuthError> {
let session = self
.sessions
.get(did)
.await?
.ok_or_else(|| OAuthError::NoSession(did.to_string()))?;
if !session.token_set.is_stale() || session.token_set.refresh_token.is_none() {
return Ok(session);
}
let _lock = self.refresh_locks.lock().await;
let session = self
.sessions
.get(did)
.await?
.ok_or_else(|| OAuthError::NoSession(did.to_string()))?;
if !session.token_set.is_stale() {
return Ok(session);
}
let dpop_key = session.dpop_key()?;
let new_tokens = token::refresh_token(
&self.http,
&session.token_set,
self.auth.as_ref(),
&dpop_key,
&self.nonces,
)
.await?;
let new_session = Session::from_key_and_tokens(&dpop_key, new_tokens);
self.sessions.set(did, &new_session).await?;
Ok(new_session)
}
async fn resolve_input_to_did(&self, input: &str) -> Result<Did, OAuthError> {
if let Ok(did) = Did::try_from(input) {
return Ok(did);
}
let url = format!("https://{}/.well-known/atproto-did", input);
if let Ok(resp) = self.http.get(&url).send().await
&& resp.status().is_success()
&& let Ok(body) = resp.text().await
&& let Ok(did) = Did::try_from(body.trim())
{
return Ok(did);
}
let resolve_url = format!(
"https://public.api.bsky.app/xrpc/com.atproto.identity.resolveHandle?handle={}",
input
);
let resp = self.http.get(&resolve_url).send().await.map_err(|e| {
OAuthError::Identity(format!("failed to resolve handle '{input}': {e}"))
})?;
if !resp.status().is_success() {
return Err(OAuthError::Identity(format!(
"handle resolution failed for '{input}': HTTP {}",
resp.status()
)));
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| OAuthError::Identity(format!("failed to parse resolve response: {e}")))?;
let did_str = json["did"]
.as_str()
.ok_or_else(|| OAuthError::Identity("resolveHandle response missing 'did'".into()))?;
Did::try_from(did_str)
.map_err(|e| OAuthError::Identity(format!("invalid DID from resolution: {e}")))
}
}
fn extract_request_uri(body: &serde_json::Value) -> Result<String, OAuthError> {
body["request_uri"]
.as_str()
.filter(|s| !s.is_empty())
.map(String::from)
.ok_or_else(|| OAuthError::OAuthResponse {
code: "invalid_response".into(),
description: "missing or empty request_uri in PAR response".into(),
})
}
fn oauth_error_from_json(body: &serde_json::Value) -> OAuthError {
let code = body["error"].as_str().unwrap_or("unknown").to_string();
let description = body["error_description"].as_str().unwrap_or("").to_string();
OAuthError::OAuthResponse { code, description }
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::unreachable
)]
mod tests {
use super::*;
use crate::oauth::session::{MemorySessionStore, MemoryStateStore};
fn make_client_metadata() -> ClientMetadata {
ClientMetadata {
client_id: "https://example.com/client-metadata.json".into(),
redirect_uris: vec!["http://127.0.0.1:8080/callback".into()],
scope: "atproto transition:generic".into(),
token_endpoint_auth_method: "none".into(),
application_type: "web".into(),
grant_types: vec!["authorization_code".into(), "refresh_token".into()],
response_types: vec!["code".into()],
dpop_bound_access_tokens: true,
client_name: "Test App".into(),
client_uri: "https://example.com".into(),
}
}
#[test]
fn client_new_public() {
let config = OAuthClientConfig {
metadata: make_client_metadata(),
session_store: Box::new(MemorySessionStore::new()),
state_store: Box::new(MemoryStateStore::new()),
signing_key: None,
skip_issuer_verification: false,
};
let client = OAuthClient::new(config);
assert_eq!(
client.metadata.client_id,
"https://example.com/client-metadata.json"
);
}
#[test]
fn client_new_confidential() {
let key = P256SigningKey::generate();
let config = OAuthClientConfig {
metadata: make_client_metadata(),
session_store: Box::new(MemorySessionStore::new()),
state_store: Box::new(MemoryStateStore::new()),
signing_key: Some((key, "key-1".into())),
skip_issuer_verification: false,
};
let client = OAuthClient::new(config);
assert_eq!(
client.metadata.client_id,
"https://example.com/client-metadata.json"
);
}
}