use crate::user::ConnectUser;
use async_trait::async_trait;
pub fn build_oauth_params<'a>(
client_id: &'a str,
redirect_uri: &'a str,
scopes: &'a [String],
state: Option<&'a str>,
pkce_challenge: Option<&'a str>,
) -> url::form_urlencoded::Serializer<'a, String> {
let mut params = url::form_urlencoded::Serializer::new(String::with_capacity(256));
params.append_pair("client_id", client_id);
params.append_pair("redirect_uri", redirect_uri);
if !scopes.is_empty() {
params.append_pair("scope", &scopes.join(" "));
}
if let Some(s) = state {
params.append_pair("state", s);
}
if let Some(p) = pkce_challenge {
params.append_pair("code_challenge", p);
params.append_pair("code_challenge_method", "S256");
}
params
}
#[async_trait]
pub trait Provider: Send + Sync {
fn redirect_url(&self) -> String;
fn redirect_url_with_state(&self, state: &str) -> String {
let url = self.redirect_url();
let separator = if url.contains('?') { "&" } else { "?" };
format!("{url}{separator}state={state}")
}
fn redirect_url_with_pkce(&self, code_challenge: &str) -> String {
let url = self.redirect_url();
let separator = if url.contains('?') { "&" } else { "?" };
format!(
"{}{}code_challenge={}&code_challenge_method=S256",
url, separator, code_challenge
)
}
fn redirect_url_with_pkce_and_state(&self, code_challenge: &str, state: &str) -> String {
let url = self.redirect_url();
let separator = if url.contains('?') { "&" } else { "?" };
format!(
"{}{}code_challenge={}&code_challenge_method=S256&state={}",
url, separator, code_challenge, state
)
}
async fn get_user(&self, auth_code: &str) -> Result<ConnectUser, crate::error::ConnectError>;
async fn get_user_with_pkce(
&self,
auth_code: &str,
_code_verifier: &str,
) -> Result<ConnectUser, crate::error::ConnectError> {
self.get_user(auth_code).await
}
async fn get_user_from_token(
&self,
access_token: &str,
) -> Result<ConnectUser, crate::error::ConnectError>;
fn token_url(&self) -> String;
async fn refresh_token(
&self,
_refresh_token: &str,
) -> Result<ConnectUser, crate::error::ConnectError> {
Err(crate::error::ConnectError::Token(
"Refresh token is not supported by this provider".to_string(),
))
}
async fn revoke_token(&self, _token: &str) -> Result<(), crate::error::ConnectError> {
Err(crate::error::ConnectError::Token(
"Token revocation is not supported by this provider".to_string(),
))
}
async fn request_device_code(
&self,
) -> Result<crate::user::DeviceAuthorizationResponse, crate::error::ConnectError> {
Err(crate::error::ConnectError::Provider(
"Device Authorization is not supported by this provider".into(),
))
}
async fn poll_device_token(
&self,
_device_code: &str,
) -> Result<ConnectUser, crate::error::ConnectError> {
Err(crate::error::ConnectError::Provider(
"Device Authorization is not supported by this provider".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ConnectError;
use crate::user::ConnectUser;
use async_trait::async_trait;
struct DummyProvider {
base_url: String,
}
#[async_trait]
impl Provider for DummyProvider {
fn redirect_url(&self) -> String {
self.base_url.clone()
}
fn token_url(&self) -> String {
"".to_string()
}
async fn get_user(&self, _auth_code: &str) -> Result<ConnectUser, ConnectError> {
unimplemented!()
}
async fn get_user_from_token(
&self,
_access_token: &str,
) -> Result<ConnectUser, ConnectError> {
unimplemented!()
}
}
#[test]
fn test_redirect_url_with_state() {
let provider_no_query = DummyProvider {
base_url: "https://example.com/auth".to_string(),
};
assert_eq!(
provider_no_query.redirect_url_with_state("my_state"),
"https://example.com/auth?state=my_state"
);
let provider_with_query = DummyProvider {
base_url: "https://example.com/auth?client_id=123".to_string(),
};
assert_eq!(
provider_with_query.redirect_url_with_state("my_state"),
"https://example.com/auth?client_id=123&state=my_state"
);
}
#[test]
fn test_redirect_url_with_pkce() {
let provider_no_query = DummyProvider {
base_url: "https://example.com/auth".to_string(),
};
assert_eq!(
provider_no_query.redirect_url_with_pkce("my_challenge"),
"https://example.com/auth?code_challenge=my_challenge&code_challenge_method=S256"
);
let provider_with_query = DummyProvider {
base_url: "https://example.com/auth?client_id=123".to_string(),
};
assert_eq!(
provider_with_query.redirect_url_with_pkce("my_challenge"),
"https://example.com/auth?client_id=123&code_challenge=my_challenge&code_challenge_method=S256"
);
}
#[test]
fn test_redirect_url_with_pkce_and_state() {
let provider_no_query = DummyProvider {
base_url: "https://example.com/auth".to_string(),
};
assert_eq!(
provider_no_query.redirect_url_with_pkce_and_state("my_challenge", "my_state"),
"https://example.com/auth?code_challenge=my_challenge&code_challenge_method=S256&state=my_state"
);
let provider_with_query = DummyProvider {
base_url: "https://example.com/auth?client_id=123".to_string(),
};
assert_eq!(
provider_with_query.redirect_url_with_pkce_and_state("my_challenge", "my_state"),
"https://example.com/auth?client_id=123&code_challenge=my_challenge&code_challenge_method=S256&state=my_state"
);
}
#[tokio::test]
async fn test_default_revoke_token() {
let provider = DummyProvider {
base_url: "".to_string(),
};
let result = provider.revoke_token("some_token").await;
assert!(result.is_err());
match result.unwrap_err() {
ConnectError::Token(msg) => {
assert_eq!(msg, "Token revocation is not supported by this provider");
}
_ => panic!("Expected ConnectError::Token"),
}
}
}