use chrono::Utc;
use openidconnect::PkceCodeChallenge;
use reqwest::header::AUTHORIZATION;
use serde::Deserialize;
use url::Url;
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as B64_STANDARD;
use crate::delegated::error::DelegatedError;
use axess_factors::ZeroizedString;
use super::credential::StoredDelegation;
use super::provider::DelegatedProvider;
#[derive(Debug, Clone)]
pub struct GrantContext {
pub state: String,
pub code_verifier: ZeroizedString,
pub requested_scopes: Vec<String>,
}
pub fn begin_grant(
provider: &DelegatedProvider,
scopes_override: &[String],
) -> Result<(Url, GrantContext), DelegatedError> {
let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
let state = openidconnect::CsrfToken::new_random();
let scopes: Vec<String> = if scopes_override.is_empty() {
provider.default_scopes.clone()
} else {
scopes_override.to_vec()
};
let scope_str = scopes.join(" ");
let mut auth_url = provider.authorization_endpoint.clone();
{
let mut q = auth_url.query_pairs_mut();
q.append_pair("response_type", "code");
q.append_pair("client_id", &provider.client_id);
q.append_pair("redirect_uri", provider.redirect_uri.as_str());
q.append_pair("state", state.secret());
q.append_pair("code_challenge", challenge.as_str());
q.append_pair("code_challenge_method", "S256");
if !scope_str.is_empty() {
q.append_pair("scope", &scope_str);
}
}
let context = GrantContext {
state: state.secret().clone(),
code_verifier: ZeroizedString::from(verifier.secret().clone()),
requested_scopes: scopes,
};
Ok((auth_url, context))
}
pub async fn complete_grant(
provider: &DelegatedProvider,
context: &GrantContext,
code: &str,
state_from_callback: &str,
http: &reqwest::Client,
) -> Result<StoredDelegation, DelegatedError> {
if context.state != state_from_callback {
return Err(DelegatedError::StateMismatch);
}
if !axess_factors::pkce::is_valid_verifier(&context.code_verifier) {
return Err(DelegatedError::PkceVerifier);
}
let form: Vec<(&str, String)> = vec![
("grant_type", "authorization_code".to_string()),
("code", code.to_string()),
("redirect_uri", provider.redirect_uri.to_string()),
("code_verifier", (*context.code_verifier).to_string()),
("client_id", provider.client_id.clone()),
];
let response = post_token_endpoint(http, provider, &form).await?;
parse_token_response(provider, response).await
}
pub(super) async fn refresh_token_grant(
provider: &DelegatedProvider,
refresh_token: &str,
http: &reqwest::Client,
) -> Result<StoredDelegation, DelegatedError> {
let form: Vec<(&str, String)> = vec![
("grant_type", "refresh_token".to_string()),
("refresh_token", refresh_token.to_string()),
("client_id", provider.client_id.clone()),
];
let response = post_token_endpoint(http, provider, &form).await?;
parse_token_response(provider, response).await
}
async fn post_token_endpoint(
http: &reqwest::Client,
provider: &DelegatedProvider,
form: &[(&str, String)],
) -> Result<reqwest::Response, DelegatedError> {
let creds = format!("{}:{}", provider.client_id, &*provider.client_secret);
let encoded = B64_STANDARD.encode(creds);
http.post(provider.token_endpoint.clone())
.header(AUTHORIZATION, format!("Basic {encoded}"))
.form(form)
.send()
.await
.map_err(|e| DelegatedError::Transport(e.to_string()))
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
scope: Option<String>,
}
async fn parse_token_response(
provider: &DelegatedProvider,
response: reqwest::Response,
) -> Result<StoredDelegation, DelegatedError> {
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 400 && body.contains("invalid_grant") {
return Err(DelegatedError::RefreshRejected);
}
return Err(DelegatedError::TokenEndpoint {
status: status.as_u16(),
body,
});
}
let parsed: TokenResponse = response
.json()
.await
.map_err(|e| DelegatedError::MalformedResponse(e.to_string()))?;
if parsed.access_token.is_empty() {
return Err(DelegatedError::MalformedResponse(
"access_token field is empty".into(),
));
}
let expires_at = parsed
.expires_in
.filter(|s| *s > 0)
.map(|s| Utc::now() + chrono::Duration::seconds(s as i64));
let scopes: Vec<String> = parsed
.scope
.map(|s| s.split_whitespace().map(str::to_string).collect())
.unwrap_or_default();
Ok(StoredDelegation {
provider: provider.name.clone(),
access_token: ZeroizedString::from(parsed.access_token),
refresh_token: parsed.refresh_token.map(ZeroizedString::from),
expires_at,
scopes,
token_type: parsed.token_type.unwrap_or_else(|| "Bearer".to_string()),
})
}