use reqwest::Url;
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use oauth2::{
AuthType, AuthUrl, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
basic::BasicClient, url::ParseError,
};
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PkceTokenResult {
pub access_token: String,
pub refresh_token: Option<String>,
pub expiration: Option<u64>,
}
#[derive(Deserialize)]
struct OAuthTokenResponse {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
}
#[derive(Deserialize)]
struct OAuthErrorResponse {
error: String,
error_description: Option<String>,
}
async fn format_token_error(response: reqwest::Response) -> String {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if let Ok(error_response) = serde_json::from_str::<OAuthErrorResponse>(&body) {
match error_response.error_description {
Some(desc) => format!(
"Unable to retrieve token ({}: {})",
error_response.error, desc
),
None => format!("Unable to retrieve token ({})", error_response.error),
}
} else {
format!("HTTP {}: {}", status, body)
}
}
pub fn generate_authorization(
authorize_uri: &str,
redirect_uri: &str,
client_id: &str,
send_credentials_in_body: bool,
scopes: Option<String>,
audience: Option<String>,
) -> Result<(Url, CsrfToken, String), ParseError> {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let client = BasicClient::new(ClientId::new(client_id.to_string()))
.set_auth_uri(AuthUrl::new(authorize_uri.to_string())?)
.set_redirect_uri(RedirectUrl::new(redirect_uri.to_string())?)
.set_auth_type(if send_credentials_in_body {
AuthType::RequestBody
} else {
AuthType::BasicAuth
});
let mut auth = client.authorize_url(CsrfToken::new_random);
if let Some(scope_value) = &scopes
&& !scope_value.is_empty()
{
auth = auth.add_scope(Scope::new(scope_value.clone()));
}
if let Some(audience_value) = &audience
&& !audience_value.is_empty()
{
auth = auth.add_extra_param("audience", audience_value);
}
let (url, csrf_token) = auth.set_pkce_challenge(pkce_challenge).url();
Ok((url, csrf_token, pkce_verifier.into_secret()))
}
pub async fn retrieve_access_token(
access_token_uri: &str,
redirect_uri: &str,
client_id: &str,
code: &str,
verifier: &str,
enable_trace: bool,
) -> Result<PkceTokenResult, String> {
let http_client = reqwest::Client::builder()
.connection_verbose(enable_trace)
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");
let response = http_client
.post(access_token_uri)
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", client_id),
("code_verifier", verifier),
])
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
if !response.status().is_success() {
return Err(format_token_error(response).await);
}
let body = response
.text()
.await
.map_err(|e| format!("Failed to read token response: {}", e))?;
let token: OAuthTokenResponse = serde_json::from_str(&body)
.map_err(|e| format!("Failed to parse token response: {}", e))?;
let expiration = token.expires_in.map(|e| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ e
});
Ok(PkceTokenResult {
access_token: token.access_token,
refresh_token: token.refresh_token,
expiration,
})
}
pub async fn refresh_token(
access_token_uri: &str,
refresh_token: &str,
client_id: &str,
) -> Result<PkceTokenResult, String> {
let http_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");
let response = http_client
.post(access_token_uri)
.form(&[
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", client_id),
])
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
if !response.status().is_success() {
return Err(format_token_error(response).await);
}
let body = response
.text()
.await
.map_err(|e| format!("Failed to read token response: {}", e))?;
let token: OAuthTokenResponse = serde_json::from_str(&body)
.map_err(|e| format!("Failed to parse token response: {}", e))?;
let expiration = token.expires_in.map(|e| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ e
});
Ok(PkceTokenResult {
access_token: token.access_token,
refresh_token: token.refresh_token,
expiration,
})
}
#[cfg(test)]
pub mod tests {
use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD};
use reqwest::Url;
use sha2::{Digest, Sha256};
use super::generate_authorization;
#[test]
fn test_generate_authorization_url_has_client_id() {
let (url, ..) = generate_authorization(
"https://auth.com/",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
assert!(
parsed
.query_pairs()
.any(|q| q.0 == "client_id" && q.1 == "client1")
);
}
#[test]
fn test_generate_authorization_url_has_redirct_uri() {
let (url, ..) = generate_authorization(
"https://auth.com/",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
assert!(
parsed
.query_pairs()
.any(|q| q.0 == "redirect_uri" && q.1 == "https://localhost:3000/")
);
}
#[test]
fn test_generate_authorization_url_has_response_type() {
let (url, ..) = generate_authorization(
"https://auth.com/",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
assert!(
parsed
.query_pairs()
.any(|q| q.0 == "response_type" && q.1 == "code")
);
}
#[test]
fn test_generate_authorization_url_includes_existing_query_param() {
let (url, ..) = generate_authorization(
"https://auth.com/?abc=123",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
assert!(parsed.query_pairs().any(|q| q.0 == "abc" && q.1 == "123"));
}
#[test]
fn test_generate_authorization_url_includes_scopes() {
let (url, ..) = generate_authorization(
"https://auth.com/?abc=123",
"https://localhost:3000/",
"client1",
false,
Some("scope1 scope2".to_string()),
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
assert!(
parsed
.query_pairs()
.any(|q| q.0 == "scope" && q.1 == "scope1 scope2")
);
}
#[test]
fn test_generate_authorization_url_includes_code_challenge_method() {
let (url, ..) = generate_authorization(
"https://auth.com/?abc=123",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
assert!(
parsed
.query_pairs()
.any(|q| q.0 == "code_challenge_method" && q.1 == "S256")
);
}
#[test]
fn test_generate_authorization_url_includes_valid_challenge_and_verifier() {
let (url, .., verifier) = generate_authorization(
"https://auth.com/?abc=123",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
let parsed = Url::parse(url.as_str()).unwrap();
let challenge = parsed
.query_pairs()
.find(|p| p.0 == "code_challenge")
.map(|p| p.1)
.unwrap();
let mut hasher = Sha256::new();
hasher.update(verifier);
let hashed_secret = BASE64_URL_SAFE_NO_PAD.encode(hasher.finalize());
assert_eq!(hashed_secret, challenge.to_string());
}
#[test]
fn test_generate_authorization_url_returns_csrf_token() {
let (_url, csrf_token, _verifier) = generate_authorization(
"https://auth.com/?abc=123",
"https://localhost:3000/",
"client1",
false,
None,
None,
)
.unwrap();
assert_ne!(csrf_token.secret().as_str(), "");
}
}