use crate::client::HttpClientExt;
use crate::user::ConnectUser;
use async_trait::async_trait;
pub fn build_oauth_params<'a>(
base_url: &str,
client_id: &'a str,
redirect_uri: &'a str,
scopes: &'a str,
state: Option<&'a str>,
pkce_challenge: Option<&'a str>,
) -> url::form_urlencoded::Serializer<'a, String> {
let mut string = String::with_capacity(base_url.len() + 256);
string.push_str(base_url);
let separator = if base_url.contains('?') { '&' } else { '?' };
string.push(separator);
let start_position = string.len();
let mut params = url::form_urlencoded::Serializer::for_suffix(string, start_position);
params.append_pair("client_id", client_id);
params.append_pair("redirect_uri", redirect_uri);
if !scopes.is_empty() {
params.append_pair("scope", scopes);
}
if let Some(s) = state {
params.append_pair("state", s);
}
if let Some(p) = pkce_challenge {
params.append_pair("code_challenge", p);
params.append_pair("code_challenge_method", "S256");
}
params
}
#[derive(Debug, Default, Clone)]
pub struct ExchangeParams<'a> {
pub auth_code: &'a str,
pub code_verifier: Option<&'a str>,
pub expected_nonce: Option<&'a str>,
}
#[derive(serde::Serialize)]
pub struct TokenExchangeForm<'a> {
pub client_id: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<&'a str>,
pub code: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
pub grant_type: Option<&'a str>,
pub redirect_uri: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_verifier: Option<&'a str>,
}
#[async_trait]
pub trait Provider: Send + Sync {
fn redirect_url(&self) -> String;
fn redirect_url_with_state(&self, state: &str) -> String {
let mut string = self.redirect_url();
string.reserve(8 + state.len());
let separator = if string.contains('?') { '&' } else { '?' };
string.push(separator);
let start_position = string.len();
let mut serializer = url::form_urlencoded::Serializer::for_suffix(string, start_position);
serializer.append_pair("state", state);
serializer.finish()
}
fn redirect_url_with_pkce(&self, code_challenge: &str) -> String {
let mut string = self.redirect_url();
string.reserve(44 + code_challenge.len());
let separator = if string.contains('?') { '&' } else { '?' };
string.push(separator);
let start_position = string.len();
let mut serializer = url::form_urlencoded::Serializer::for_suffix(string, start_position);
serializer.append_pair("code_challenge", code_challenge);
serializer.append_pair("code_challenge_method", "S256");
serializer.finish()
}
fn redirect_url_with_pkce_and_state(&self, code_challenge: &str, state: &str) -> String {
let mut string = self.redirect_url();
string.reserve(52 + code_challenge.len() + state.len());
let separator = if string.contains('?') { '&' } else { '?' };
string.push(separator);
let start_position = string.len();
let mut serializer = url::form_urlencoded::Serializer::for_suffix(string, start_position);
serializer.append_pair("code_challenge", code_challenge);
serializer.append_pair("code_challenge_method", "S256");
serializer.append_pair("state", state);
serializer.finish()
}
async fn get_user(
&self,
params: ExchangeParams<'_>,
) -> Result<ConnectUser, crate::error::ConnectError>;
async fn get_user_from_token(
&self,
access_token: &str,
) -> Result<ConnectUser, crate::error::ConnectError>;
fn token_url(&self) -> String;
async fn refresh_token(
&self,
_refresh_token: &str,
) -> Result<ConnectUser, crate::error::ConnectError> {
Err(crate::error::ConnectError::Token(
"Refresh token is not supported by this provider".to_string(),
))
}
async fn revoke_token(&self, _token: &str) -> Result<(), crate::error::ConnectError> {
Err(crate::error::ConnectError::Token(
"Token revocation is not supported by this provider".to_string(),
))
}
async fn request_device_code(
&self,
) -> Result<crate::user::DeviceAuthorizationResponse, crate::error::ConnectError> {
Err(crate::error::ConnectError::Provider(
"Device Authorization is not supported by this provider".into(),
))
}
async fn poll_device_token(
&self,
_device_code: &str,
) -> Result<ConnectUser, crate::error::ConnectError> {
Err(crate::error::ConnectError::Provider(
"Device Authorization is not supported by this provider".into(),
))
}
}
#[derive(Debug, Clone)]
pub struct Oauth2TokenResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
}
pub async fn fetch_access_token(
client: &dyn crate::client::HttpClient,
token_url: &str,
form: &TokenExchangeForm<'_>,
) -> Result<Oauth2TokenResponse, crate::error::ConnectError> {
let token_res = client
.post(token_url)
.form(form)
.send()
.await?
.error_for_status()?
.json::<serde_json::Value>()
.await?;
if let Some(err) = token_res["error"].as_str() {
let err_desc = token_res["error_description"].as_str().unwrap_or("");
return Err(crate::error::ConnectError::Token(format!(
"Provider returned error: {} - {}",
err, err_desc
)));
}
let access_token = token_res["access_token"]
.as_str()
.ok_or_else(|| crate::error::ConnectError::Token("Failed to get access_token".to_owned()))?
.to_owned();
let refresh_token = token_res["refresh_token"].as_str().map(String::from);
let expires_in = token_res["expires_in"]
.as_u64()
.or_else(|| token_res["expires_in"].as_i64().map(|v| v as u64));
Ok(Oauth2TokenResponse {
access_token,
refresh_token,
expires_in,
})
}
pub async fn fetch_refresh_token(
client: &dyn crate::client::HttpClient,
token_url: &str,
client_id: &str,
client_secret: &str,
refresh_token: &str,
) -> Result<Oauth2TokenResponse, crate::error::ConnectError> {
let token_res = client
.post(token_url)
.form(&[
("client_id", client_id),
("client_secret", client_secret),
("refresh_token", refresh_token),
("grant_type", "refresh_token"),
])
.send()
.await?
.error_for_status()?
.json::<serde_json::Value>()
.await?;
if let Some(err) = token_res["error"].as_str() {
let err_desc = token_res["error_description"].as_str().unwrap_or("");
return Err(crate::error::ConnectError::Token(format!(
"Provider returned error: {} - {}",
err, err_desc
)));
}
let access_token = token_res["access_token"]
.as_str()
.ok_or_else(|| {
crate::error::ConnectError::Token(
"Failed to get access_token during refresh".to_owned(),
)
})?
.to_owned();
let refresh_token = token_res["refresh_token"].as_str().map(String::from);
let expires_in = token_res["expires_in"]
.as_u64()
.or_else(|| token_res["expires_in"].as_i64().map(|v| v as u64));
Ok(Oauth2TokenResponse {
access_token,
refresh_token,
expires_in,
})
}
#[allow(clippy::too_many_arguments)]
pub async fn exchange_and_get_user<P>(
provider: &P,
client: &dyn crate::client::HttpClient,
token_url: &str,
form: &TokenExchangeForm<'_>,
_expected_nonce: Option<&str>,
) -> Result<ConnectUser, crate::error::ConnectError>
where
P: Provider + ?Sized,
{
let token = fetch_access_token(client, token_url, form).await?;
let mut user = provider.get_user_from_token(&token.access_token).await?;
user.refresh_token = token.refresh_token.map(secrecy::SecretString::from);
user.expires_in = token.expires_in;
Ok(user)
}
pub async fn refresh_and_get_user<P>(
provider: &P,
client: &dyn crate::client::HttpClient,
token_url: &str,
client_id: &str,
client_secret: &secrecy::SecretString,
refresh_token: &str,
) -> Result<ConnectUser, crate::error::ConnectError>
where
P: Provider + ?Sized,
{
let token = fetch_refresh_token(
client,
token_url,
client_id,
secrecy::ExposeSecret::expose_secret(client_secret),
refresh_token,
)
.await?;
let mut user = provider.get_user_from_token(&token.access_token).await?;
user.refresh_token = token.refresh_token.map(secrecy::SecretString::from);
user.expires_in = token.expires_in;
Ok(user)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ConnectError;
use crate::user::ConnectUser;
use async_trait::async_trait;
struct DummyProvider {
base_url: String,
}
#[async_trait]
impl Provider for DummyProvider {
fn redirect_url(&self) -> String {
self.base_url.clone()
}
fn token_url(&self) -> String {
"".to_string()
}
async fn get_user(&self, _params: ExchangeParams<'_>) -> Result<ConnectUser, ConnectError> {
unimplemented!()
}
async fn get_user_from_token(
&self,
_access_token: &str,
) -> Result<ConnectUser, ConnectError> {
unimplemented!()
}
}
#[test]
fn test_redirect_url_with_state() {
let provider_no_query = DummyProvider {
base_url: "https://example.com/auth".to_string(),
};
assert_eq!(
provider_no_query.redirect_url_with_state("my_state"),
"https://example.com/auth?state=my_state"
);
let provider_with_query = DummyProvider {
base_url: "https://example.com/auth?client_id=123".to_string(),
};
assert_eq!(
provider_with_query.redirect_url_with_state("my_state"),
"https://example.com/auth?client_id=123&state=my_state"
);
}
#[test]
fn test_redirect_url_with_pkce() {
let provider_no_query = DummyProvider {
base_url: "https://example.com/auth".to_string(),
};
assert_eq!(
provider_no_query.redirect_url_with_pkce("my_challenge"),
"https://example.com/auth?code_challenge=my_challenge&code_challenge_method=S256"
);
let provider_with_query = DummyProvider {
base_url: "https://example.com/auth?client_id=123".to_string(),
};
assert_eq!(
provider_with_query.redirect_url_with_pkce("my_challenge"),
"https://example.com/auth?client_id=123&code_challenge=my_challenge&code_challenge_method=S256"
);
}
#[test]
fn test_redirect_url_with_pkce_and_state() {
let provider_no_query = DummyProvider {
base_url: "https://example.com/auth".to_string(),
};
assert_eq!(
provider_no_query.redirect_url_with_pkce_and_state("my_challenge", "my_state"),
"https://example.com/auth?code_challenge=my_challenge&code_challenge_method=S256&state=my_state"
);
let provider_with_query = DummyProvider {
base_url: "https://example.com/auth?client_id=123".to_string(),
};
assert_eq!(
provider_with_query.redirect_url_with_pkce_and_state("my_challenge", "my_state"),
"https://example.com/auth?client_id=123&code_challenge=my_challenge&code_challenge_method=S256&state=my_state"
);
}
#[tokio::test]
async fn test_default_revoke_token() {
let provider = DummyProvider {
base_url: "".to_string(),
};
let result = provider.revoke_token("some_token").await;
assert!(result.is_err());
match result.unwrap_err() {
ConnectError::Token(msg) => {
assert_eq!(msg, "Token revocation is not supported by this provider");
}
_ => panic!("Expected ConnectError::Token"),
}
}
#[tokio::test]
async fn test_default_poll_device_token() {
let provider = DummyProvider {
base_url: "".to_string(),
};
let result = provider.poll_device_token("some_code").await;
assert!(result.is_err());
match result.unwrap_err() {
ConnectError::Provider(msg) => {
assert_eq!(
msg,
"Device Authorization is not supported by this provider"
);
}
_ => panic!("Expected ConnectError::Provider"),
}
}
#[test]
fn test_redirect_url_with_pkce_and_state_multiple_query_params() {
let provider_multiple_query = DummyProvider {
base_url: "https://example.com/auth?foo=bar&baz=qux".to_string(),
};
assert_eq!(
provider_multiple_query.redirect_url_with_pkce_and_state("my_challenge", "my_state"),
"https://example.com/auth?foo=bar&baz=qux&code_challenge=my_challenge&code_challenge_method=S256&state=my_state"
);
}
#[test]
fn test_build_oauth_params_variations() {
let mut serializer = build_oauth_params("", "client", "redirect", "", None, None);
let query = serializer.finish();
assert!(query.contains("client_id=client"));
assert!(query.contains("redirect_uri=redirect"));
assert!(!query.contains("scope"));
let mut serializer = build_oauth_params("", "client", "redirect", "read", None, None);
let query = serializer.finish();
assert!(query.contains("scope=read"));
let mut serializer = build_oauth_params(
"",
"client",
"redirect",
"read write",
Some("state123"),
Some("pkce_challenge"),
);
let query = serializer.finish();
assert!(query.contains("scope=read+write"));
assert!(query.contains("state=state123"));
assert!(query.contains("code_challenge=pkce_challenge"));
assert!(query.contains("code_challenge_method=S256"));
}
struct MockFetchClient;
#[async_trait]
impl crate::client::HttpClient for MockFetchClient {
async fn execute(
&self,
req: crate::client::HttpRequest,
) -> Result<crate::client::HttpResponse, crate::error::ConnectError> {
if req.url.contains("error") {
Ok(crate::client::HttpResponse {
status: 400,
body: serde_json::json!({
"error": "invalid_request",
"error_description": "Test error"
}),
})
} else {
Ok(crate::client::HttpResponse {
status: 200,
body: serde_json::json!({
"access_token": "mock_access",
"refresh_token": "mock_refresh",
"expires_in": 3600
}),
})
}
}
}
#[tokio::test]
async fn test_fetch_access_token() {
let client = MockFetchClient;
let form = TokenExchangeForm {
client_id: "client_id",
client_secret: Some("client_secret"),
code: "auth_code",
grant_type: Some("authorization_code"),
redirect_uri: "https://redirect",
code_verifier: Some("verifier"),
};
let res = fetch_access_token(&client, "https://example.com/token", &form)
.await
.expect("Failed to fetch access token");
assert_eq!(res.access_token, "mock_access");
assert_eq!(res.refresh_token.as_deref(), Some("mock_refresh"));
assert_eq!(res.expires_in, Some(3600));
}
#[tokio::test]
async fn test_fetch_refresh_token() {
let client = MockFetchClient;
let res = fetch_refresh_token(
&client,
"https://example.com/token",
"client_id",
"client_secret",
"mock_refresh",
)
.await
.expect("Failed to fetch refresh token");
assert_eq!(res.access_token, "mock_access");
assert_eq!(res.refresh_token.as_deref(), Some("mock_refresh"));
assert_eq!(res.expires_in, Some(3600));
}
}