use std::collections::HashMap;
use async_trait::async_trait;
use url::{Host::Domain, Url};
use crate::{
client::{Client, GrantType, ResponseType, TokenEndpointAuthMethod},
crypto::{decode_base64, random_secure_string, sha256},
error::{ErrorCode, OAuthError},
};
#[derive(Debug, Clone, Copy)]
pub enum CodeChallengeMethod {
Plain,
S256,
}
pub struct AuthorizationRequest {
pub response_type: ResponseType,
pub client_id: String,
pub code_challenge: String,
pub code_challenge_method: Option<CodeChallengeMethod>,
pub redirect_uri: Option<String>,
pub scope: Option<String>,
pub state: Option<String>,
}
pub struct TokenRequest {
pub grant_type: GrantType,
pub code: String,
pub redirect_uri: Option<String>,
pub client_id: Option<String>,
pub code_verifier: String,
}
pub struct ClientCredentialsTokenRequest {
pub grant_type: GrantType,
pub scope: Option<String>,
}
#[derive(Debug)]
pub struct AuthorizationResponse {
pub code: String,
pub state: Option<String>,
}
#[derive(Debug)]
pub struct TokenResponse {
access_token: String,
token_type: String,
expires_in: u64,
scope: String,
}
pub struct VerifiedAuthorizationRequest {
client_id: String,
code_challenge: String,
code_challenge_method: CodeChallengeMethod,
redirect_uri: String,
scope: String,
state: Option<String>,
}
pub struct VerifiedTokenRequest {
authentication_information: AuthorizationInformation,
}
pub struct VerifiedClientCredentialsTokenRequest {
pub scope: String,
}
#[derive(Debug, Clone)]
pub struct AuthorizationInformation {
client_id: String,
redirect_uri: String,
code_challenge: String,
code_challenge_method: CodeChallengeMethod,
scope: String,
state: Option<String>,
is_valid: bool,
}
#[derive(Debug)]
pub struct SigninInformation {
client_name: String,
client_uri: String,
logo_uri: String,
scopes: Vec<String>,
contacts: Vec<String>,
tos_uri: String,
policy_uri: String,
}
#[async_trait(?Send)]
pub trait Provider
where
Self: Sized,
{
async fn store_client(&self, client: Client) -> Result<(), OAuthError>;
async fn get_client(&self, client_id: &str) -> Option<Client>;
async fn save_authorization_information(
&self,
id: String,
information: AuthorizationInformation,
) -> Result<(), OAuthError>;
async fn get_authorization_information(&self, id: &str) -> Option<AuthorizationInformation>;
async fn remove_authorization_information(&self, id: &str) -> Result<(), OAuthError>;
async fn verify_authorization_request<T: AuthorizationFlow>(
&self,
request: AuthorizationRequest,
flow: T,
) -> Result<(VerifiedAuthorizationRequest, SigninInformation), OAuthError> {
let client = self
.get_client(&request.client_id)
.await
.ok_or(OAuthError::new(
ErrorCode::InvalidRequest,
request.state.clone(),
))?;
let request = flow.verify_authorization(&client, request).await?;
Ok((
request,
SigninInformation {
client_name: client.name,
client_uri: client.uri,
logo_uri: client.logo_uri,
scopes: client.scopes,
contacts: client.contacts,
tos_uri: client.tos_uri,
policy_uri: client.policy_uri,
},
))
}
async fn authorize<T: AuthorizationFlow>(
&self,
request: VerifiedAuthorizationRequest,
flow: T,
) -> Result<T::Response, OAuthError> {
flow.perform_authorization(self, request).await
}
async fn get_token<T: TokenFlow>(
&self,
authenticated_client: Option<Client>,
request: T::Request,
flow: T,
) -> Result<T::Response, OAuthError> {
let request = flow
.verify_token_request(authenticated_client, request, self)
.await?;
flow.perform_token_exchange(request).await
}
}
#[async_trait(?Send)]
pub trait AuthorizationFlow {
type Response;
async fn verify_authorization(
&self,
client: &Client,
request: AuthorizationRequest,
) -> Result<VerifiedAuthorizationRequest, OAuthError>;
async fn perform_authorization(
&self,
provider: &impl Provider,
request: VerifiedAuthorizationRequest,
) -> Result<Self::Response, OAuthError>;
}
#[async_trait(?Send)]
pub trait TokenFlow {
type Request;
type VerifiedRequest;
type Response;
async fn verify_token_request(
&self,
authenticated_client: Option<Client>,
request: Self::Request,
provider: &impl Provider,
) -> Result<Self::VerifiedRequest, OAuthError>;
async fn perform_token_exchange(
&self,
request: Self::VerifiedRequest,
) -> Result<Self::Response, OAuthError>;
}
pub struct AuthorizationCodeFlow;
pub struct ClientCredentialsFlow;
#[async_trait(?Send)]
impl AuthorizationFlow for AuthorizationCodeFlow {
type Response = AuthorizationResponse;
async fn verify_authorization(
&self,
client: &Client,
request: AuthorizationRequest,
) -> Result<VerifiedAuthorizationRequest, OAuthError> {
let state = request.state.clone();
if request.response_type != ResponseType::Code {
return Err(OAuthError::new(ErrorCode::InvalidRequest, state));
}
let code_challenge_method = if let Some(method) = request.code_challenge_method {
method
} else {
CodeChallengeMethod::Plain
};
let redirect_uri = match request.redirect_uri {
Some(ref redirect_uri) if !client.redirect_uris.contains(redirect_uri) => {
return Err(OAuthError::new(ErrorCode::InvalidRequest, state))
}
None if client.redirect_uris.len() > 1 => {
return Err(OAuthError::new(ErrorCode::InvalidRequest, state))
}
Some(redirect_uri) => redirect_uri,
None => client.redirect_uris[0].to_string(),
};
let redirect_uri_parsed = Url::parse(&redirect_uri)
.map_err(|_| OAuthError::new(ErrorCode::InvalidRequest, state.clone()))?;
if redirect_uri_parsed.scheme() != "https"
|| redirect_uri_parsed.scheme() == "http"
&& redirect_uri_parsed.host() != Some(Domain("localhost"))
{
return Err(OAuthError::new(ErrorCode::InvalidRequest, state));
}
let scope = match request.scope {
Some(scope) => {
if scope.split_ascii_whitespace().any(|scope| {
!client
.scopes
.iter()
.find(|defined_scope| *defined_scope == scope)
.is_some()
}) {
return Err(OAuthError::new(ErrorCode::InvalidScope, state));
}
scope
}
None => "profile email".to_string(),
};
Ok(VerifiedAuthorizationRequest {
client_id: request.client_id.to_string(),
code_challenge: request.code_challenge.to_string(),
code_challenge_method,
redirect_uri,
scope,
state: request.state.clone(),
})
}
async fn perform_authorization(
&self,
provider: &impl Provider,
request: VerifiedAuthorizationRequest,
) -> Result<Self::Response, OAuthError> {
let authorization_code = random_secure_string(32);
let information = AuthorizationInformation {
client_id: request.client_id,
redirect_uri: request.redirect_uri,
code_challenge: request.code_challenge,
code_challenge_method: request.code_challenge_method,
scope: request.scope,
state: request.state.clone(),
is_valid: true,
};
provider
.save_authorization_information(authorization_code.clone(), information)
.await?;
Ok(AuthorizationResponse {
code: authorization_code,
state: request.state.clone(),
})
}
}
#[async_trait(?Send)]
impl TokenFlow for AuthorizationCodeFlow {
type Request = TokenRequest;
type VerifiedRequest = VerifiedTokenRequest;
type Response = TokenResponse;
async fn verify_token_request(
&self,
authenticated_client: Option<Client>,
request: Self::Request,
provider: &impl Provider,
) -> Result<Self::VerifiedRequest, OAuthError> {
if authenticated_client.is_some() && request.client_id.is_some()
|| authenticated_client.is_none() && request.client_id.is_none()
{
return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
}
if request.grant_type != GrantType::AuthorizationCode {
return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
}
let client = if let Some(client) = authenticated_client {
client
} else {
let client = provider
.get_client(&request.client_id.unwrap())
.await
.ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
if client.secret.is_some() {
return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
}
client
};
if client.redirect_uris.len() > 1 && request.redirect_uri.is_none() {
return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
}
let authentication_information = provider
.get_authorization_information(&request.code)
.await
.ok_or(OAuthError::new(ErrorCode::AccessDenied, None))?;
if !authentication_information.is_valid {
return Err(OAuthError::new(ErrorCode::AccessDenied, None));
}
if let Some(redirect_uri) = request.redirect_uri {
if authentication_information.redirect_uri != redirect_uri {
return Err(OAuthError::new(ErrorCode::AccessDenied, None));
}
} else if authentication_information.redirect_uri != client.redirect_uris[0].to_string() {
return Err(OAuthError::new(ErrorCode::AccessDenied, None));
}
if authentication_information.client_id != client.id {
return Err(OAuthError::new(ErrorCode::AccessDenied, None));
}
match authentication_information.code_challenge_method {
CodeChallengeMethod::Plain => {
if request.code_verifier != authentication_information.code_challenge {
return Err(OAuthError::new(ErrorCode::AccessDenied, None));
}
}
CodeChallengeMethod::S256 => {
if &sha256(&request.code_verifier)
!= authentication_information.code_challenge.as_bytes()
{
return Err(OAuthError::new(ErrorCode::AccessDenied, None));
}
}
}
provider
.remove_authorization_information(&request.code)
.await?;
Ok(VerifiedTokenRequest {
authentication_information,
})
}
async fn perform_token_exchange(
&self,
request: Self::VerifiedRequest,
) -> Result<Self::Response, OAuthError> {
Ok(TokenResponse {
access_token: random_secure_string(24),
token_type: "Bearer".to_string(),
expires_in: 3600,
scope: request.authentication_information.scope,
})
}
}
#[async_trait(?Send)]
impl TokenFlow for ClientCredentialsFlow {
type Request = ClientCredentialsTokenRequest;
type VerifiedRequest = VerifiedClientCredentialsTokenRequest;
type Response = TokenResponse;
async fn verify_token_request(
&self,
authenticated_client: Option<Client>,
request: Self::Request,
_: &impl Provider,
) -> Result<Self::VerifiedRequest, OAuthError> {
if authenticated_client.is_none() {
return Err(OAuthError::new(ErrorCode::UnauthorizedClient, None));
}
if request.grant_type != GrantType::ClientCredentials {
return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
}
let scope = if let Some(scope) = request.scope {
if scope.split_ascii_whitespace().any(|scope| {
!authenticated_client
.as_ref()
.unwrap()
.scopes
.iter()
.find(|defined_scope| *defined_scope == scope)
.is_some()
}) {
return Err(OAuthError::new(ErrorCode::InvalidScope, None));
}
scope
} else {
"profile email".to_string()
};
Ok(VerifiedClientCredentialsTokenRequest { scope })
}
async fn perform_token_exchange(
&self,
request: Self::VerifiedRequest,
) -> Result<Self::Response, OAuthError> {
Ok(TokenResponse {
access_token: random_secure_string(24),
token_type: "Bearer".to_string(),
expires_in: 3600,
scope: request.scope,
})
}
}
pub trait HttpRequestDetails {
fn get_headers(&self) -> HashMap<String, String>;
fn get_form_values(&self) -> HashMap<String, String>;
}
#[async_trait(?Send)]
pub trait ClientAuthenticator {
async fn authenticate_client(
&self,
provider: &impl Provider,
details: &impl HttpRequestDetails,
) -> Result<Client, OAuthError>;
}
pub struct ClientSecretBasic;
#[async_trait(?Send)]
impl ClientAuthenticator for ClientSecretBasic {
async fn authenticate_client(
&self,
provider: &impl Provider,
details: &impl HttpRequestDetails,
) -> Result<Client, OAuthError> {
let auth_headers = details.get_headers();
if let Some(value) = auth_headers.get("Authorization") {
let value = decode_base64(value)
.map_err(|_| OAuthError::new(ErrorCode::InvalidRequest, None))?;
let mut iter = value.split(":").take(2);
let client_id = iter
.next()
.ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
let client_secret = iter
.next()
.ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
match provider.get_client(client_id).await {
Some(client) if client.secret.as_deref() == Some(client_secret) => {
if client.token_endpoint_auth_method
== TokenEndpointAuthMethod::ClientSecretBasic
{
Ok(client)
} else {
Err(OAuthError::new(ErrorCode::InvalidRequest, None))
}
}
Some(_) => return Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
None => return Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
}
} else {
Err(OAuthError::new(ErrorCode::InvalidRequest, None))
}
}
}
pub struct ClientSecretPost;
#[async_trait(?Send)]
impl ClientAuthenticator for ClientSecretPost {
async fn authenticate_client(
&self,
provider: &impl Provider,
details: &impl HttpRequestDetails,
) -> Result<Client, OAuthError> {
let auth_form_values = details.get_form_values();
let client_id = auth_form_values
.get("client_id")
.ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
let client_secret = auth_form_values
.get("client_secret")
.ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
match provider.get_client(client_id).await {
Some(client) if client.secret.as_deref() == Some(client_secret) => {
if client.token_endpoint_auth_method == TokenEndpointAuthMethod::ClientSecretPost {
Ok(client)
} else {
Err(OAuthError::new(ErrorCode::InvalidRequest, None))
}
}
Some(_) => Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
None => Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
}
}
}