Skip to main content

better_auth_api/plugins/oauth/
mod.rs

1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8#[cfg(feature = "axum")]
9use better_auth_core::plugin::{AuthState, AxumPlugin};
10
11pub mod encryption;
12mod handlers;
13mod providers;
14#[cfg(test)]
15mod tests;
16mod types;
17
18pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
19
20pub struct OAuthPlugin {
21    config: OAuthConfig,
22}
23
24impl OAuthPlugin {
25    pub fn new() -> Self {
26        Self {
27            config: OAuthConfig::default(),
28        }
29    }
30
31    pub fn with_config(config: OAuthConfig) -> Self {
32        Self { config }
33    }
34
35    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
36        self.config.providers.insert(name.to_string(), provider);
37        self
38    }
39}
40
41impl Default for OAuthPlugin {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47#[async_trait]
48impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
49    fn name(&self) -> &'static str {
50        "oauth"
51    }
52
53    fn routes(&self) -> Vec<AuthRoute> {
54        vec![
55            AuthRoute::post("/sign-in/social", "social_sign_in"),
56            AuthRoute::get("/callback/{provider}", "oauth_callback"),
57            AuthRoute::post("/link-social", "link_social"),
58            AuthRoute::post("/get-access-token", "get_access_token"),
59            AuthRoute::post("/refresh-token", "refresh_token"),
60        ]
61    }
62
63    async fn on_request(
64        &self,
65        req: &AuthRequest,
66        ctx: &AuthContext<DB>,
67    ) -> AuthResult<Option<AuthResponse>> {
68        match (req.method(), req.path()) {
69            (HttpMethod::Post, "/sign-in/social") => Ok(Some(
70                handlers::handle_social_sign_in(&self.config, req, ctx).await?,
71            )),
72            (HttpMethod::Get, path) if path_matches_callback(path) => {
73                let provider = extract_provider_from_callback(path);
74                Ok(Some(
75                    handlers::handle_callback(&self.config, &provider, req, ctx).await?,
76                ))
77            }
78            (HttpMethod::Post, "/link-social") => Ok(Some(
79                handlers::handle_link_social(&self.config, req, ctx).await?,
80            )),
81            (HttpMethod::Post, "/get-access-token") => Ok(Some(
82                handlers::handle_get_access_token(&self.config, req, ctx).await?,
83            )),
84            (HttpMethod::Post, "/refresh-token") => Ok(Some(
85                handlers::handle_refresh_token(&self.config, req, ctx).await?,
86            )),
87            _ => Ok(None),
88        }
89    }
90}
91
92/// Check if the path matches `/callback/{provider}` (with optional query string).
93fn path_matches_callback(path: &str) -> bool {
94    let path_without_query = path.split('?').next().unwrap_or(path);
95    path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
96}
97
98/// Extract the provider name from `/callback/{provider}?...`.
99fn extract_provider_from_callback(path: &str) -> String {
100    let path_without_query = path.split('?').next().unwrap_or(path);
101    path_without_query["/callback/".len()..].to_string()
102}
103
104// ---------------------------------------------------------------------------
105// Axum-native routing (feature-gated)
106// ---------------------------------------------------------------------------
107
108#[cfg(feature = "axum")]
109mod axum_impl {
110    use super::*;
111    use std::sync::Arc;
112
113    use axum::Json;
114    use axum::extract::{Extension, Path, Query, State};
115    use axum::http::header;
116    use axum::response::IntoResponse;
117    use better_auth_core::error::AuthError;
118    use better_auth_core::extractors::{CurrentSession, ValidatedJson};
119
120    use super::handlers::{
121        callback_core, get_access_token_core, link_social_core, refresh_token_core,
122        social_sign_in_core,
123    };
124    use super::types::{
125        AccessTokenResponse, GetAccessTokenRequest, LinkSocialRequest, RefreshTokenRequest,
126        RefreshTokenResponse, SocialSignInRequest, SocialSignInResponse,
127    };
128
129    #[derive(serde::Deserialize)]
130    struct CallbackQuery {
131        code: String,
132        state: String,
133    }
134
135    #[derive(Clone)]
136    struct PluginState {
137        config: OAuthConfig,
138    }
139
140    async fn handle_social_sign_in<DB: DatabaseAdapter>(
141        State(state): State<AuthState<DB>>,
142        Extension(ps): Extension<Arc<PluginState>>,
143        ValidatedJson(body): ValidatedJson<SocialSignInRequest>,
144    ) -> Result<Json<SocialSignInResponse>, AuthError> {
145        let ctx = state.to_context();
146        let result = social_sign_in_core(&body, &ps.config, &ctx).await?;
147        Ok(Json(result))
148    }
149
150    async fn handle_callback<DB: DatabaseAdapter>(
151        State(state): State<AuthState<DB>>,
152        Extension(ps): Extension<Arc<PluginState>>,
153        Path(provider): Path<String>,
154        Query(params): Query<CallbackQuery>,
155    ) -> Result<impl IntoResponse, AuthError> {
156        let ctx = state.to_context();
157        let (response, token) =
158            callback_core(&params.code, &params.state, &provider, &ps.config, &ctx).await?;
159        let cookie = state.session_cookie(&token);
160        Ok(([(header::SET_COOKIE, cookie)], Json(response)))
161    }
162
163    async fn handle_link_social<DB: DatabaseAdapter>(
164        State(state): State<AuthState<DB>>,
165        Extension(ps): Extension<Arc<PluginState>>,
166        CurrentSession { session, .. }: CurrentSession<DB>,
167        ValidatedJson(body): ValidatedJson<LinkSocialRequest>,
168    ) -> Result<Json<SocialSignInResponse>, AuthError> {
169        let ctx = state.to_context();
170        let result = link_social_core(&body, &session, &ps.config, &ctx).await?;
171        Ok(Json(result))
172    }
173
174    async fn handle_get_access_token<DB: DatabaseAdapter>(
175        State(state): State<AuthState<DB>>,
176        CurrentSession { session, .. }: CurrentSession<DB>,
177        ValidatedJson(body): ValidatedJson<GetAccessTokenRequest>,
178    ) -> Result<Json<AccessTokenResponse>, AuthError> {
179        let ctx = state.to_context();
180        let result = get_access_token_core(&body, &session, &ctx).await?;
181        Ok(Json(result))
182    }
183
184    async fn handle_refresh_token<DB: DatabaseAdapter>(
185        State(state): State<AuthState<DB>>,
186        Extension(ps): Extension<Arc<PluginState>>,
187        CurrentSession { session, .. }: CurrentSession<DB>,
188        ValidatedJson(body): ValidatedJson<RefreshTokenRequest>,
189    ) -> Result<Json<RefreshTokenResponse>, AuthError> {
190        let ctx = state.to_context();
191        let result = refresh_token_core(&body, &session, &ps.config, &ctx).await?;
192        Ok(Json(result))
193    }
194
195    #[async_trait]
196    impl<DB: DatabaseAdapter> AxumPlugin<DB> for OAuthPlugin {
197        fn name(&self) -> &'static str {
198            "oauth"
199        }
200
201        fn router(&self) -> axum::Router<AuthState<DB>> {
202            use axum::routing::{get, post};
203
204            let plugin_state = Arc::new(PluginState {
205                config: self.config.clone(),
206            });
207
208            axum::Router::new()
209                .route("/sign-in/social", post(handle_social_sign_in::<DB>))
210                .route("/callback/:provider", get(handle_callback::<DB>))
211                .route("/link-social", post(handle_link_social::<DB>))
212                .route("/get-access-token", post(handle_get_access_token::<DB>))
213                .route("/refresh-token", post(handle_refresh_token::<DB>))
214                .layer(Extension(plugin_state))
215        }
216    }
217}