use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{MethodRouter, post};
use axum::{Json, Router};
use serde::Serialize;
use crate::OAuthConfig;
use crate::error::OAuthError;
use crate::grants::{self, TokenRequest, TokenResponse};
#[derive(Clone)]
pub struct OAuthHandler {
config: Arc<OAuthConfig>,
}
impl OAuthHandler {
pub fn from_config(config: Arc<OAuthConfig>) -> Self {
Self { config }
}
pub fn token_endpoint(&self) -> MethodRouter {
let config = self.config.clone();
post(move |Json(request): Json<TokenRequest>| async move {
grants::issue_token(&config, request).map(Json)
})
}
pub fn router(&self) -> Router {
Router::new().route("/oauth/token", self.token_endpoint())
}
}
pub fn router(config: Arc<OAuthConfig>) -> Router {
OAuthHandler::from_config(config).router()
}
#[derive(Serialize)]
struct ErrorBody {
error: String,
#[serde(skip_serializing_if = "Option::is_none")]
error_description: Option<String>,
}
impl From<&OAuthError> for ErrorBody {
fn from(err: &OAuthError) -> Self {
let error = match err {
OAuthError::InvalidClient => "invalid_client",
OAuthError::UnsupportedGrant(_) | OAuthError::NotImplemented(_) => {
"unsupported_grant_type"
}
OAuthError::InvalidGrant(_) => "invalid_grant",
OAuthError::InvalidScope(_) => "invalid_scope",
OAuthError::Config(_) | OAuthError::TokenStore(_) | OAuthError::Internal(_) => {
"server_error"
}
}
.to_string();
let error_description = match err {
OAuthError::UnsupportedGrant(message)
| OAuthError::InvalidGrant(message)
| OAuthError::InvalidScope(message)
| OAuthError::Internal(message) => Some(message.clone()),
OAuthError::NotImplemented(message) => Some((*message).into()),
_ => None,
};
Self {
error,
error_description,
}
}
}
impl IntoResponse for OAuthError {
fn into_response(self) -> Response {
let status = StatusCode::from_u16(self.status_code().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = Json(ErrorBody::from(&self));
(status, body).into_response()
}
}