Skip to main content

better_auth_api/plugins/oauth/
mod.rs

1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8#[cfg(feature = "axum")]
9use better_auth_core::plugin::{AuthState, AxumPlugin};
10
11pub mod encryption;
12mod handlers;
13mod providers;
14mod types;
15
16pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
17
18pub struct OAuthPlugin {
19    config: OAuthConfig,
20}
21
22impl OAuthPlugin {
23    pub fn new() -> Self {
24        Self {
25            config: OAuthConfig::default(),
26        }
27    }
28
29    pub fn with_config(config: OAuthConfig) -> Self {
30        Self { config }
31    }
32
33    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
34        self.config.providers.insert(name.to_string(), provider);
35        self
36    }
37}
38
39impl Default for OAuthPlugin {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45#[async_trait]
46impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
47    fn name(&self) -> &'static str {
48        "oauth"
49    }
50
51    fn routes(&self) -> Vec<AuthRoute> {
52        vec![
53            AuthRoute::post("/sign-in/social", "social_sign_in"),
54            AuthRoute::get("/callback/{provider}", "oauth_callback"),
55            AuthRoute::post("/link-social", "link_social"),
56            AuthRoute::post("/get-access-token", "get_access_token"),
57            AuthRoute::post("/refresh-token", "refresh_token"),
58        ]
59    }
60
61    async fn on_request(
62        &self,
63        req: &AuthRequest,
64        ctx: &AuthContext<DB>,
65    ) -> AuthResult<Option<AuthResponse>> {
66        match (req.method(), req.path()) {
67            (HttpMethod::Post, "/sign-in/social") => Ok(Some(
68                handlers::handle_social_sign_in(&self.config, req, ctx).await?,
69            )),
70            (HttpMethod::Get, path) if path_matches_callback(path) => {
71                let provider = extract_provider_from_callback(path);
72                Ok(Some(
73                    handlers::handle_callback(&self.config, &provider, req, ctx).await?,
74                ))
75            }
76            (HttpMethod::Post, "/link-social") => Ok(Some(
77                handlers::handle_link_social(&self.config, req, ctx).await?,
78            )),
79            (HttpMethod::Post, "/get-access-token") => Ok(Some(
80                handlers::handle_get_access_token(&self.config, req, ctx).await?,
81            )),
82            (HttpMethod::Post, "/refresh-token") => Ok(Some(
83                handlers::handle_refresh_token(&self.config, req, ctx).await?,
84            )),
85            _ => Ok(None),
86        }
87    }
88}
89
90/// Check if the path matches `/callback/{provider}` (with optional query string).
91fn path_matches_callback(path: &str) -> bool {
92    let path_without_query = path.split('?').next().unwrap_or(path);
93    path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
94}
95
96/// Extract the provider name from `/callback/{provider}?...`.
97fn extract_provider_from_callback(path: &str) -> String {
98    let path_without_query = path.split('?').next().unwrap_or(path);
99    path_without_query["/callback/".len()..].to_string()
100}
101
102// ---------------------------------------------------------------------------
103// Axum-native routing (feature-gated)
104// ---------------------------------------------------------------------------
105
106#[cfg(feature = "axum")]
107mod axum_impl {
108    use super::*;
109    use std::sync::Arc;
110
111    use axum::Json;
112    use axum::extract::{Extension, Path, Query, State};
113    use axum::http::header;
114    use axum::response::IntoResponse;
115    use better_auth_core::error::AuthError;
116    use better_auth_core::extractors::{CurrentSession, ValidatedJson};
117
118    use super::handlers::{
119        callback_core, get_access_token_core, link_social_core, refresh_token_core,
120        social_sign_in_core,
121    };
122    use super::types::{
123        AccessTokenResponse, GetAccessTokenRequest, LinkSocialRequest, RefreshTokenRequest,
124        RefreshTokenResponse, SocialSignInRequest, SocialSignInResponse,
125    };
126
127    #[derive(serde::Deserialize)]
128    struct CallbackQuery {
129        code: String,
130        state: String,
131    }
132
133    #[derive(Clone)]
134    struct PluginState {
135        config: OAuthConfig,
136    }
137
138    async fn handle_social_sign_in<DB: DatabaseAdapter>(
139        State(state): State<AuthState<DB>>,
140        Extension(ps): Extension<Arc<PluginState>>,
141        ValidatedJson(body): ValidatedJson<SocialSignInRequest>,
142    ) -> Result<Json<SocialSignInResponse>, AuthError> {
143        let ctx = state.to_context();
144        let result = social_sign_in_core(&body, &ps.config, &ctx).await?;
145        Ok(Json(result))
146    }
147
148    async fn handle_callback<DB: DatabaseAdapter>(
149        State(state): State<AuthState<DB>>,
150        Extension(ps): Extension<Arc<PluginState>>,
151        Path(provider): Path<String>,
152        Query(params): Query<CallbackQuery>,
153    ) -> Result<impl IntoResponse, AuthError> {
154        let ctx = state.to_context();
155        let (response, token) =
156            callback_core(&params.code, &params.state, &provider, &ps.config, &ctx).await?;
157        let cookie = state.session_cookie(&token);
158        Ok(([(header::SET_COOKIE, cookie)], Json(response)))
159    }
160
161    async fn handle_link_social<DB: DatabaseAdapter>(
162        State(state): State<AuthState<DB>>,
163        Extension(ps): Extension<Arc<PluginState>>,
164        CurrentSession { session, .. }: CurrentSession<DB>,
165        ValidatedJson(body): ValidatedJson<LinkSocialRequest>,
166    ) -> Result<Json<SocialSignInResponse>, AuthError> {
167        let ctx = state.to_context();
168        let result = link_social_core(&body, &session, &ps.config, &ctx).await?;
169        Ok(Json(result))
170    }
171
172    async fn handle_get_access_token<DB: DatabaseAdapter>(
173        State(state): State<AuthState<DB>>,
174        CurrentSession { session, .. }: CurrentSession<DB>,
175        ValidatedJson(body): ValidatedJson<GetAccessTokenRequest>,
176    ) -> Result<Json<AccessTokenResponse>, AuthError> {
177        let ctx = state.to_context();
178        let result = get_access_token_core(&body, &session, &ctx).await?;
179        Ok(Json(result))
180    }
181
182    async fn handle_refresh_token<DB: DatabaseAdapter>(
183        State(state): State<AuthState<DB>>,
184        Extension(ps): Extension<Arc<PluginState>>,
185        CurrentSession { session, .. }: CurrentSession<DB>,
186        ValidatedJson(body): ValidatedJson<RefreshTokenRequest>,
187    ) -> Result<Json<RefreshTokenResponse>, AuthError> {
188        let ctx = state.to_context();
189        let result = refresh_token_core(&body, &session, &ps.config, &ctx).await?;
190        Ok(Json(result))
191    }
192
193    #[async_trait]
194    impl<DB: DatabaseAdapter> AxumPlugin<DB> for OAuthPlugin {
195        fn name(&self) -> &'static str {
196            "oauth"
197        }
198
199        fn router(&self) -> axum::Router<AuthState<DB>> {
200            use axum::routing::{get, post};
201
202            let plugin_state = Arc::new(PluginState {
203                config: self.config.clone(),
204            });
205
206            axum::Router::new()
207                .route("/sign-in/social", post(handle_social_sign_in::<DB>))
208                .route("/callback/:provider", get(handle_callback::<DB>))
209                .route("/link-social", post(handle_link_social::<DB>))
210                .route("/get-access-token", post(handle_get_access_token::<DB>))
211                .route("/refresh-token", post(handle_refresh_token::<DB>))
212                .layer(Extension(plugin_state))
213        }
214    }
215}