Skip to main content

better_auth_api/plugins/oauth/
providers.rs

1use serde_json::Value;
2use std::collections::HashMap;
3
4/// Strategy for storing OAuth state during the authorization flow.
5#[derive(Debug, Clone, Default, PartialEq, Eq)]
6pub enum OAuthStateStrategy {
7    /// Stateless: store state in an encrypted cookie (default).
8    #[default]
9    Cookie,
10    /// Stateful: store state in the verification table.
11    Database,
12}
13
14/// Configuration for the OAuth plugin, containing all registered providers.
15#[derive(Debug, Clone, Default)]
16pub struct OAuthConfig {
17    pub providers: HashMap<String, OAuthProvider>,
18    /// Skip state cookie verification (default: false) - SECURITY WARNING
19    pub skip_state_cookie_check: bool,
20    /// Where to store OAuth state: Cookie (stateless) or Database (default: Cookie)
21    pub store_state_strategy: OAuthStateStrategy,
22}
23
24/// User information extracted from an OAuth provider's user info endpoint.
25#[derive(Debug, Clone)]
26pub struct OAuthUserInfo {
27    pub id: String,
28    pub email: String,
29    pub name: Option<String>,
30    pub image: Option<String>,
31    pub email_verified: bool,
32}
33
34/// Configuration for a single OAuth provider.
35#[derive(Debug, Clone)]
36pub struct OAuthProvider {
37    pub client_id: String,
38    pub client_secret: String,
39    pub auth_url: String,
40    pub token_url: String,
41    pub user_info_url: String,
42    pub scopes: Vec<String>,
43    pub map_user_info: fn(Value) -> Result<OAuthUserInfo, String>,
44}
45
46impl OAuthProvider {
47    pub fn google(client_id: &str, client_secret: &str) -> Self {
48        Self {
49            client_id: client_id.to_string(),
50            client_secret: client_secret.to_string(),
51            auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
52            token_url: "https://oauth2.googleapis.com/token".to_string(),
53            user_info_url: "https://www.googleapis.com/oauth2/v3/userinfo".to_string(),
54            scopes: vec![
55                "openid".to_string(),
56                "email".to_string(),
57                "profile".to_string(),
58            ],
59            map_user_info: |v| {
60                Ok(OAuthUserInfo {
61                    id: v["sub"].as_str().ok_or("missing sub")?.to_string(),
62                    email: v["email"].as_str().ok_or("missing email")?.to_string(),
63                    name: v["name"].as_str().map(String::from),
64                    image: v["picture"].as_str().map(String::from),
65                    email_verified: v["email_verified"].as_bool().unwrap_or(false),
66                })
67            },
68        }
69    }
70
71    pub fn github(client_id: &str, client_secret: &str) -> Self {
72        Self {
73            client_id: client_id.to_string(),
74            client_secret: client_secret.to_string(),
75            auth_url: "https://github.com/login/oauth/authorize".to_string(),
76            token_url: "https://github.com/login/oauth/access_token".to_string(),
77            user_info_url: "https://api.github.com/user".to_string(),
78            scopes: vec!["user:email".to_string()],
79            map_user_info: |v| {
80                Ok(OAuthUserInfo {
81                    id: v["id"]
82                        .as_i64()
83                        .map(|i| i.to_string())
84                        .or_else(|| v["id"].as_str().map(String::from))
85                        .ok_or("missing id")?,
86                    email: v["email"].as_str().ok_or("missing email")?.to_string(),
87                    name: v["name"].as_str().map(String::from),
88                    image: v["avatar_url"].as_str().map(String::from),
89                    email_verified: true,
90                })
91            },
92        }
93    }
94
95    pub fn discord(client_id: &str, client_secret: &str) -> Self {
96        Self {
97            client_id: client_id.to_string(),
98            client_secret: client_secret.to_string(),
99            auth_url: "https://discord.com/api/oauth2/authorize".to_string(),
100            token_url: "https://discord.com/api/oauth2/token".to_string(),
101            user_info_url: "https://discord.com/api/users/@me".to_string(),
102            scopes: vec!["identify".to_string(), "email".to_string()],
103            map_user_info: |v| {
104                Ok(OAuthUserInfo {
105                    id: v["id"].as_str().ok_or("missing id")?.to_string(),
106                    email: v["email"].as_str().ok_or("missing email")?.to_string(),
107                    name: v["username"].as_str().map(String::from),
108                    image: v["avatar"].as_str().map(|a| {
109                        format!(
110                            "https://cdn.discordapp.com/avatars/{}/{}.png",
111                            v["id"].as_str().unwrap_or(""),
112                            a
113                        )
114                    }),
115                    email_verified: v["verified"].as_bool().unwrap_or(false),
116                })
117            },
118        }
119    }
120}