Skip to main content

allowthem_core/
oauth_github.rs

1use serde::Deserialize;
2
3use crate::auth_client::AuthFuture;
4use crate::error::AuthError;
5use crate::oauth::{OAuthProvider, OAuthUserInfo};
6
7/// GitHub OAuth2 provider.
8///
9/// Implements the authorization code flow with PKCE against GitHub's OAuth endpoints.
10/// Requires a GitHub OAuth App with client_id and client_secret.
11pub struct GitHubProvider {
12    client_id: String,
13    client_secret: String,
14    client: reqwest::Client,
15}
16
17impl GitHubProvider {
18    pub fn new(client_id: String, client_secret: String) -> Self {
19        let client = reqwest::Client::builder()
20            .user_agent("allowthem-oauth")
21            .build()
22            .expect("failed to build HTTP client");
23        Self {
24            client_id,
25            client_secret,
26            client,
27        }
28    }
29}
30
31#[derive(Deserialize)]
32struct TokenResponse {
33    access_token: Option<String>,
34    error: Option<String>,
35    error_description: Option<String>,
36}
37
38#[derive(Deserialize)]
39struct GitHubUser {
40    id: i64,
41    email: Option<String>,
42    name: Option<String>,
43}
44
45#[derive(Deserialize)]
46struct GitHubEmail {
47    email: String,
48    primary: bool,
49    verified: bool,
50}
51
52impl OAuthProvider for GitHubProvider {
53    fn name(&self) -> &str {
54        "github"
55    }
56
57    fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String {
58        format!(
59            "https://github.com/login/oauth/authorize\
60             ?client_id={}\
61             &redirect_uri={}\
62             &state={}\
63             &scope=user:email\
64             &code_challenge={}\
65             &code_challenge_method=S256",
66            self.client_id, redirect_uri, state, pkce_challenge,
67        )
68    }
69
70    fn exchange_code<'a>(
71        &'a self,
72        code: &'a str,
73        redirect_uri: &'a str,
74        pkce_verifier: &'a str,
75    ) -> AuthFuture<'a, String> {
76        Box::pin(async move {
77            let resp = self
78                .client
79                .post("https://github.com/login/oauth/access_token")
80                .header("Accept", "application/json")
81                .form(&[
82                    ("client_id", self.client_id.as_str()),
83                    ("client_secret", self.client_secret.as_str()),
84                    ("code", code),
85                    ("redirect_uri", redirect_uri),
86                    ("code_verifier", pkce_verifier),
87                ])
88                .send()
89                .await
90                .map_err(|e| AuthError::OAuthHttp(e.to_string()))?;
91
92            let token_resp: TokenResponse = resp
93                .json()
94                .await
95                .map_err(|e| AuthError::OAuthTokenExchange(e.to_string()))?;
96
97            if let Some(err) = token_resp.error {
98                let desc = token_resp.error_description.unwrap_or_default();
99                return Err(AuthError::OAuthTokenExchange(format!("{err}: {desc}")));
100            }
101
102            token_resp
103                .access_token
104                .ok_or_else(|| AuthError::OAuthTokenExchange("missing access_token".into()))
105        })
106    }
107
108    fn user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, OAuthUserInfo> {
109        Box::pin(async move {
110            let user: GitHubUser = self
111                .client
112                .get("https://api.github.com/user")
113                .bearer_auth(access_token)
114                .send()
115                .await
116                .map_err(|e| AuthError::OAuthHttp(e.to_string()))?
117                .json()
118                .await
119                .map_err(|e| AuthError::OAuthUserInfoFetch(e.to_string()))?;
120
121            // GitHub often returns email: null when the user's email is private.
122            // Fall back to the /user/emails endpoint.
123            let (email, email_verified) = if let Some(ref e) = user.email {
124                // Public email — GitHub only exposes it if the user chose to make it
125                // public, which means it is verified.
126                (e.clone(), true)
127            } else {
128                let emails: Vec<GitHubEmail> = self
129                    .client
130                    .get("https://api.github.com/user/emails")
131                    .bearer_auth(access_token)
132                    .send()
133                    .await
134                    .map_err(|e| AuthError::OAuthHttp(e.to_string()))?
135                    .json()
136                    .await
137                    .map_err(|e| AuthError::OAuthUserInfoFetch(e.to_string()))?;
138
139                let primary = emails
140                    .iter()
141                    .find(|e| e.primary && e.verified)
142                    .or_else(|| emails.iter().find(|e| e.verified));
143
144                match primary {
145                    Some(entry) => (entry.email.clone(), entry.verified),
146                    None => {
147                        return Err(AuthError::OAuthUserInfoFetch(
148                            "no verified email found on GitHub account".into(),
149                        ));
150                    }
151                }
152            };
153
154            Ok(OAuthUserInfo {
155                provider_user_id: user.id.to_string(),
156                email,
157                email_verified,
158                name: user.name,
159            })
160        })
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    fn test_provider() -> GitHubProvider {
169        GitHubProvider::new("test-client-id".into(), "test-secret".into())
170    }
171
172    #[test]
173    fn name_returns_github() {
174        let p = test_provider();
175        assert_eq!(p.name(), "github");
176    }
177
178    #[test]
179    fn authorize_url_contains_required_params() {
180        let p = test_provider();
181        let url = p.authorize_url(
182            "https://example.com/oauth/github/callback",
183            "test-state-value",
184            "test-challenge-value",
185        );
186
187        assert!(url.starts_with("https://github.com/login/oauth/authorize"));
188        assert!(url.contains("client_id=test-client-id"));
189        assert!(url.contains("redirect_uri=https://example.com/oauth/github/callback"));
190        assert!(url.contains("state=test-state-value"));
191        assert!(url.contains("code_challenge=test-challenge-value"));
192        assert!(url.contains("code_challenge_method=S256"));
193        assert!(url.contains("scope=user:email"));
194    }
195
196    #[test]
197    fn authorize_url_does_not_contain_secret() {
198        let p = test_provider();
199        let url = p.authorize_url("https://example.com/cb", "state", "challenge");
200        assert!(
201            !url.contains("test-secret"),
202            "authorize URL must never contain client_secret"
203        );
204    }
205}