use async_trait::async_trait;
use better_auth_core::AuthResult;
use better_auth_core::adapters::DatabaseAdapter;
use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
#[cfg(feature = "axum")]
use better_auth_core::plugin::{AuthState, AxumPlugin};
pub mod encryption;
mod handlers;
mod providers;
#[cfg(test)]
mod tests;
mod types;
pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
pub struct OAuthPlugin {
config: OAuthConfig,
}
impl OAuthPlugin {
pub fn new() -> Self {
Self {
config: OAuthConfig::default(),
}
}
pub fn with_config(config: OAuthConfig) -> Self {
Self { config }
}
pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
self.config.providers.insert(name.to_string(), provider);
self
}
}
impl Default for OAuthPlugin {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
fn name(&self) -> &'static str {
"oauth"
}
fn routes(&self) -> Vec<AuthRoute> {
vec![
AuthRoute::post("/sign-in/social", "social_sign_in"),
AuthRoute::get("/callback/{provider}", "oauth_callback"),
AuthRoute::post("/link-social", "link_social"),
AuthRoute::post("/get-access-token", "get_access_token"),
AuthRoute::post("/refresh-token", "refresh_token"),
]
}
async fn on_request(
&self,
req: &AuthRequest,
ctx: &AuthContext<DB>,
) -> AuthResult<Option<AuthResponse>> {
match (req.method(), req.path()) {
(HttpMethod::Post, "/sign-in/social") => Ok(Some(
handlers::handle_social_sign_in(&self.config, req, ctx).await?,
)),
(HttpMethod::Get, path) if path_matches_callback(path) => {
let provider = extract_provider_from_callback(path);
Ok(Some(
handlers::handle_callback(&self.config, &provider, req, ctx).await?,
))
}
(HttpMethod::Post, "/link-social") => Ok(Some(
handlers::handle_link_social(&self.config, req, ctx).await?,
)),
(HttpMethod::Post, "/get-access-token") => Ok(Some(
handlers::handle_get_access_token(&self.config, req, ctx).await?,
)),
(HttpMethod::Post, "/refresh-token") => Ok(Some(
handlers::handle_refresh_token(&self.config, req, ctx).await?,
)),
_ => Ok(None),
}
}
}
fn path_matches_callback(path: &str) -> bool {
let path_without_query = path.split('?').next().unwrap_or(path);
path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
}
fn extract_provider_from_callback(path: &str) -> String {
let path_without_query = path.split('?').next().unwrap_or(path);
path_without_query["/callback/".len()..].to_string()
}
#[cfg(feature = "axum")]
mod axum_impl {
use super::*;
use std::sync::Arc;
use axum::Json;
use axum::extract::{Extension, Path, Query, State};
use axum::http::header;
use axum::response::IntoResponse;
use better_auth_core::error::AuthError;
use better_auth_core::extractors::{CurrentSession, ValidatedJson};
use super::handlers::{
callback_core, get_access_token_core, link_social_core, refresh_token_core,
social_sign_in_core,
};
use super::types::{
AccessTokenResponse, GetAccessTokenRequest, LinkSocialRequest, RefreshTokenRequest,
RefreshTokenResponse, SocialSignInRequest, SocialSignInResponse,
};
#[derive(serde::Deserialize)]
struct CallbackQuery {
code: String,
state: String,
}
#[derive(Clone)]
struct PluginState {
config: OAuthConfig,
}
async fn handle_social_sign_in<DB: DatabaseAdapter>(
State(state): State<AuthState<DB>>,
Extension(ps): Extension<Arc<PluginState>>,
ValidatedJson(body): ValidatedJson<SocialSignInRequest>,
) -> Result<Json<SocialSignInResponse>, AuthError> {
let ctx = state.to_context();
let result = social_sign_in_core(&body, &ps.config, &ctx).await?;
Ok(Json(result))
}
async fn handle_callback<DB: DatabaseAdapter>(
State(state): State<AuthState<DB>>,
Extension(ps): Extension<Arc<PluginState>>,
Path(provider): Path<String>,
Query(params): Query<CallbackQuery>,
) -> Result<impl IntoResponse, AuthError> {
let ctx = state.to_context();
let (response, token) =
callback_core(¶ms.code, ¶ms.state, &provider, &ps.config, &ctx).await?;
let cookie = state.session_cookie(&token);
Ok(([(header::SET_COOKIE, cookie)], Json(response)))
}
async fn handle_link_social<DB: DatabaseAdapter>(
State(state): State<AuthState<DB>>,
Extension(ps): Extension<Arc<PluginState>>,
CurrentSession { session, .. }: CurrentSession<DB>,
ValidatedJson(body): ValidatedJson<LinkSocialRequest>,
) -> Result<Json<SocialSignInResponse>, AuthError> {
let ctx = state.to_context();
let result = link_social_core(&body, &session, &ps.config, &ctx).await?;
Ok(Json(result))
}
async fn handle_get_access_token<DB: DatabaseAdapter>(
State(state): State<AuthState<DB>>,
CurrentSession { session, .. }: CurrentSession<DB>,
ValidatedJson(body): ValidatedJson<GetAccessTokenRequest>,
) -> Result<Json<AccessTokenResponse>, AuthError> {
let ctx = state.to_context();
let result = get_access_token_core(&body, &session, &ctx).await?;
Ok(Json(result))
}
async fn handle_refresh_token<DB: DatabaseAdapter>(
State(state): State<AuthState<DB>>,
Extension(ps): Extension<Arc<PluginState>>,
CurrentSession { session, .. }: CurrentSession<DB>,
ValidatedJson(body): ValidatedJson<RefreshTokenRequest>,
) -> Result<Json<RefreshTokenResponse>, AuthError> {
let ctx = state.to_context();
let result = refresh_token_core(&body, &session, &ps.config, &ctx).await?;
Ok(Json(result))
}
#[async_trait]
impl<DB: DatabaseAdapter> AxumPlugin<DB> for OAuthPlugin {
fn name(&self) -> &'static str {
"oauth"
}
fn router(&self) -> axum::Router<AuthState<DB>> {
use axum::routing::{get, post};
let plugin_state = Arc::new(PluginState {
config: self.config.clone(),
});
axum::Router::new()
.route("/sign-in/social", post(handle_social_sign_in::<DB>))
.route("/callback/:provider", get(handle_callback::<DB>))
.route("/link-social", post(handle_link_social::<DB>))
.route("/get-access-token", post(handle_get_access_token::<DB>))
.route("/refresh-token", post(handle_refresh_token::<DB>))
.layer(Extension(plugin_state))
}
}
}