#[cfg(feature = "axum")]
use axum::{
Router,
extract::{FromRequestParts, Request, State},
http::StatusCode,
http::request::Parts,
response::{IntoResponse, Response},
routing::{delete, get, post},
};
#[cfg(feature = "axum")]
use std::sync::Arc;
#[cfg(feature = "axum")]
use crate::BetterAuth;
#[cfg(feature = "axum")]
use better_auth_core::SessionManager;
#[cfg(feature = "axum")]
use better_auth_core::entity::AuthSession as AuthSessionTrait;
use better_auth_core::{
AuthError, AuthRequest, AuthResponse, DatabaseAdapter, ErrorMessageResponse,
HealthCheckResponse, HttpMethod, OkResponse, core_paths,
};
#[cfg(feature = "axum")]
pub trait AxumIntegration<DB: DatabaseAdapter> {
fn axum_router(self) -> Router<Arc<BetterAuth<DB>>>;
}
#[cfg(feature = "axum")]
impl<DB: DatabaseAdapter> AxumIntegration<DB> for Arc<BetterAuth<DB>> {
fn axum_router(self) -> Router<Arc<BetterAuth<DB>>> {
let disabled_paths = self.config().disabled_paths.clone();
let mut router = Router::new();
if !disabled_paths.contains(&core_paths::OK.to_string()) {
router = router.route(core_paths::OK, get(ok_check));
}
if !disabled_paths.contains(&core_paths::ERROR.to_string()) {
router = router.route(core_paths::ERROR, get(error_check));
}
if !disabled_paths.contains(&core_paths::HEALTH.to_string()) {
router = router.route(core_paths::HEALTH, get(health_check));
}
if !disabled_paths.contains(&core_paths::OPENAPI_SPEC.to_string()) {
router = router.route(core_paths::OPENAPI_SPEC, get(create_plugin_handler::<DB>()));
}
if !disabled_paths.contains(&core_paths::UPDATE_USER.to_string()) {
router = router.route(core_paths::UPDATE_USER, post(create_plugin_handler::<DB>()));
}
if !disabled_paths.contains(&core_paths::DELETE_USER.to_string()) {
router = router.route(core_paths::DELETE_USER, post(create_plugin_handler::<DB>()));
router = router.route(
core_paths::DELETE_USER,
delete(create_plugin_handler::<DB>()),
);
}
if !disabled_paths.contains(&core_paths::CHANGE_EMAIL.to_string()) {
router = router.route(
core_paths::CHANGE_EMAIL,
post(create_plugin_handler::<DB>()),
);
}
if !disabled_paths.contains(&core_paths::DELETE_USER_CALLBACK.to_string()) {
router = router.route(
core_paths::DELETE_USER_CALLBACK,
get(create_plugin_handler::<DB>()),
);
}
for plugin in self.plugins() {
for route in plugin.routes() {
if disabled_paths.contains(&route.path) {
continue;
}
let handler_fn = create_plugin_handler::<DB>();
match route.method {
HttpMethod::Get => {
router = router.route(&route.path, get(handler_fn.clone()));
}
HttpMethod::Post => {
router = router.route(&route.path, post(handler_fn.clone()));
}
HttpMethod::Put => {
router = router.route(&route.path, axum::routing::put(handler_fn.clone()));
}
HttpMethod::Delete => {
router =
router.route(&route.path, axum::routing::delete(handler_fn.clone()));
}
HttpMethod::Patch => {
router =
router.route(&route.path, axum::routing::patch(handler_fn.clone()));
}
_ => {} }
}
}
router.with_state(self)
}
}
#[cfg(feature = "axum")]
async fn ok_check() -> impl IntoResponse {
axum::Json(OkResponse { ok: true })
}
#[cfg(feature = "axum")]
async fn error_check() -> impl IntoResponse {
axum::Json(OkResponse { ok: false })
}
#[cfg(feature = "axum")]
async fn health_check() -> impl IntoResponse {
axum::Json(HealthCheckResponse {
status: "ok",
service: "better-auth",
})
}
#[cfg(feature = "axum")]
#[allow(clippy::type_complexity)]
fn create_plugin_handler<DB: DatabaseAdapter>() -> impl Fn(
State<Arc<BetterAuth<DB>>>,
Request,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Response> + Send>,
> + Clone {
|State(auth): State<Arc<BetterAuth<DB>>>, req: Request| {
Box::pin(async move {
let limit_cfg = auth.body_limit_config();
let max_bytes = if limit_cfg.enabled {
limit_cfg.max_bytes
} else {
usize::MAX
};
match convert_axum_request(req, max_bytes).await {
Ok(auth_req) => match auth.handle_request(auth_req).await {
Ok(auth_response) => convert_auth_response(auth_response),
Err(err) => convert_auth_error(err),
},
Err(err) => convert_auth_error(err),
}
})
}
}
#[cfg(feature = "axum")]
async fn convert_axum_request(
req: Request,
max_body_bytes: usize,
) -> Result<AuthRequest, AuthError> {
use std::collections::HashMap;
let (parts, body) = req.into_parts();
let method = match parts.method {
axum::http::Method::GET => HttpMethod::Get,
axum::http::Method::POST => HttpMethod::Post,
axum::http::Method::PUT => HttpMethod::Put,
axum::http::Method::DELETE => HttpMethod::Delete,
axum::http::Method::PATCH => HttpMethod::Patch,
axum::http::Method::OPTIONS => HttpMethod::Options,
axum::http::Method::HEAD => HttpMethod::Head,
_ => {
return Err(AuthError::InvalidRequest(
"Unsupported HTTP method".to_string(),
));
}
};
let mut headers = HashMap::new();
for (name, value) in parts.headers.iter() {
if let Ok(value_str) = value.to_str() {
headers.insert(name.to_string(), value_str.to_string());
}
}
let path = parts.uri.path().to_string();
let mut query = HashMap::new();
if let Some(query_str) = parts.uri.query() {
for (key, value) in url::form_urlencoded::parse(query_str.as_bytes()) {
query.insert(key.to_string(), value.to_string());
}
}
if let Some(len) = parts
.headers
.get(axum::http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
&& len > max_body_bytes
{
return Err(AuthError::payload_too_large(format!(
"Request body exceeds the {}-byte limit",
max_body_bytes
)));
}
let body_bytes = match axum::body::to_bytes(body, max_body_bytes).await {
Ok(bytes) => {
if bytes.is_empty() {
None
} else {
Some(bytes.to_vec())
}
}
Err(err) => {
if better_auth_core::extractors::is_body_length_limit_error(&err) {
return Err(AuthError::payload_too_large(format!(
"Request body exceeds the {}-byte limit",
max_body_bytes
)));
}
tracing::warn!(error = %err, "Failed to read request body");
return Err(AuthError::bad_request("Failed to read request body"));
}
};
Ok(AuthRequest::from_parts(
method, path, headers, body_bytes, query,
))
}
#[cfg(feature = "axum")]
fn convert_auth_response(auth_response: AuthResponse) -> Response {
let mut response = Response::builder().status(
StatusCode::from_u16(auth_response.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
);
for (name, value) in auth_response.headers {
if let (Ok(header_name), Ok(header_value)) = (
axum::http::HeaderName::from_bytes(name.as_bytes()),
axum::http::HeaderValue::from_str(&value),
) {
response = response.header(header_name, header_value);
}
}
response
.body(axum::body::Body::from(auth_response.body))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(axum::body::Body::from("Internal server error"))
.unwrap()
})
}
#[cfg(feature = "axum")]
fn convert_auth_error(err: AuthError) -> Response {
let status_code =
StatusCode::from_u16(err.status_code()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let message = match err.status_code() {
500 => "Internal server error".to_string(),
_ => err.to_string(),
};
let body = ErrorMessageResponse { message };
(status_code, axum::Json(body)).into_response()
}
#[cfg(feature = "axum")]
pub struct CurrentSession<DB: DatabaseAdapter> {
pub user: DB::User,
pub session: DB::Session,
}
#[cfg(feature = "axum")]
pub struct OptionalSession<DB: DatabaseAdapter>(pub Option<CurrentSession<DB>>);
#[cfg(feature = "axum")]
fn extract_token_from_parts(parts: &Parts, cookie_name: &str) -> Option<String> {
if let Some(auth_header) = parts.headers.get("authorization")
&& let Ok(auth_str) = auth_header.to_str()
&& let Some(token) = auth_str.strip_prefix("Bearer ")
{
return Some(token.to_string());
}
if let Some(cookie_header) = parts.headers.get("cookie")
&& let Ok(cookie_str) = cookie_header.to_str()
{
for part in cookie_str.split(';') {
let part = part.trim();
if let Some(value) = part.strip_prefix(&format!("{}=", cookie_name))
&& !value.is_empty()
{
return Some(value.to_string());
}
}
}
None
}
#[cfg(feature = "axum")]
impl<DB: DatabaseAdapter> FromRequestParts<Arc<BetterAuth<DB>>> for CurrentSession<DB> {
type Rejection = Response;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<BetterAuth<DB>>,
) -> Result<Self, Self::Rejection> {
let cookie_name = &state.config().session.cookie_name;
let token = extract_token_from_parts(parts, cookie_name)
.ok_or_else(|| convert_auth_error(AuthError::Unauthenticated))?;
let session_manager =
SessionManager::new(Arc::new(state.config().clone()), state.database().clone());
let session = session_manager
.get_session(&token)
.await
.map_err(convert_auth_error)?
.ok_or_else(|| convert_auth_error(AuthError::SessionNotFound))?;
let user = state
.database()
.get_user_by_id(session.user_id())
.await
.map_err(convert_auth_error)?
.ok_or_else(|| convert_auth_error(AuthError::UserNotFound))?;
Ok(CurrentSession { user, session })
}
}
#[cfg(feature = "axum")]
impl<DB: DatabaseAdapter> FromRequestParts<Arc<BetterAuth<DB>>> for OptionalSession<DB> {
type Rejection = Response;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<BetterAuth<DB>>,
) -> Result<Self, Self::Rejection> {
match CurrentSession::from_request_parts(parts, state).await {
Ok(session) => Ok(OptionalSession(Some(session))),
Err(_) => Ok(OptionalSession(None)),
}
}
}