Skip to main content

authkestra_providers_github/
lib.rs

1use async_trait::async_trait;
2use authkestra_core::{AuthError, Identity, OAuthProvider, OAuthToken};
3use serde::Deserialize;
4use std::collections::HashMap;
5
6pub struct GithubProvider {
7    client_id: String,
8    client_secret: String,
9    redirect_uri: String,
10    http_client: reqwest::Client,
11    authorization_url: String,
12    token_url: String,
13    user_url: String,
14}
15
16impl GithubProvider {
17    pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
18        Self {
19            client_id,
20            client_secret,
21            redirect_uri,
22            http_client: reqwest::Client::new(),
23            authorization_url: "https://github.com/login/oauth/authorize".to_string(),
24            token_url: "https://github.com/login/oauth/access_token".to_string(),
25            user_url: "https://api.github.com/user".to_string(),
26        }
27    }
28
29    pub fn with_test_urls(
30        mut self,
31        authorization_url: String,
32        token_url: String,
33        user_url: String,
34    ) -> Self {
35        self.authorization_url = authorization_url;
36        self.token_url = token_url;
37        self.user_url = user_url;
38        self
39    }
40
41    pub fn with_authorization_url(mut self, authorization_url: String) -> Self {
42        self.authorization_url = authorization_url;
43        self
44    }
45}
46
47#[derive(Deserialize)]
48struct GithubAccessTokenResponse {
49    access_token: String,
50    #[serde(default = "default_token_type")]
51    token_type: String,
52    expires_in: Option<u64>,
53    refresh_token: Option<String>,
54    scope: Option<String>,
55    id_token: Option<String>,
56}
57
58fn default_token_type() -> String {
59    "Bearer".to_string()
60}
61
62#[derive(Deserialize)]
63struct GithubUserResponse {
64    id: u64,
65    login: String,
66    email: Option<String>,
67}
68
69#[async_trait]
70impl OAuthProvider for GithubProvider {
71    fn provider_id(&self) -> &str {
72        "github"
73    }
74
75    fn get_authorization_url(
76        &self,
77        state: &str,
78        scopes: &[&str],
79        _code_challenge: Option<&str>,
80    ) -> String {
81        let scope_param = if scopes.is_empty() {
82            "user:email".to_string()
83        } else {
84            scopes.join(" ")
85        };
86
87        format!(
88            "{}?client_id={}&redirect_uri={}&state={}&scope={}",
89            self.authorization_url, self.client_id, self.redirect_uri, state, scope_param
90        )
91    }
92
93    async fn exchange_code_for_identity(
94        &self,
95        code: &str,
96        _code_verifier: Option<&str>,
97    ) -> Result<(Identity, OAuthToken), AuthError> {
98        // 1. Exchange code for access token
99        let token_response = self
100            .http_client
101            .post(&self.token_url)
102            .header("Accept", "application/json")
103            .form(&[
104                ("client_id", &self.client_id),
105                ("client_secret", &self.client_secret),
106                ("code", &code.to_string()),
107                ("redirect_uri", &self.redirect_uri),
108            ])
109            .send()
110            .await
111            .map_err(|_| AuthError::Network)?
112            .json::<GithubAccessTokenResponse>()
113            .await
114            .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
115
116        // 2. Get user information
117        let user_response = self
118            .http_client
119            .get(&self.user_url)
120            .header(
121                "Authorization",
122                format!("Bearer {}", token_response.access_token),
123            )
124            .header("User-Agent", "authkestra-rs")
125            .send()
126            .await
127            .map_err(|_| AuthError::Network)?
128            .json::<GithubUserResponse>()
129            .await
130            .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
131
132        // 3. Map to Identity
133        let identity = Identity {
134            provider_id: "github".to_string(),
135            external_id: user_response.id.to_string(),
136            email: user_response.email,
137            username: Some(user_response.login),
138            attributes: HashMap::new(),
139        };
140
141        let token = OAuthToken {
142            access_token: token_response.access_token,
143            token_type: token_response.token_type,
144            expires_in: token_response.expires_in,
145            refresh_token: token_response.refresh_token,
146            scope: token_response.scope,
147            id_token: token_response.id_token,
148        };
149
150        Ok((identity, token))
151    }
152
153    async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
154        let token_response = self
155            .http_client
156            .post(&self.token_url)
157            .header("Accept", "application/json")
158            .form(&[
159                ("client_id", &self.client_id),
160                ("client_secret", &self.client_secret),
161                ("grant_type", &"refresh_token".to_string()),
162                ("refresh_token", &refresh_token.to_string()),
163            ])
164            .send()
165            .await
166            .map_err(|_| AuthError::Network)?
167            .json::<GithubAccessTokenResponse>()
168            .await
169            .map_err(|e| {
170                AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
171            })?;
172
173        Ok(OAuthToken {
174            access_token: token_response.access_token,
175            token_type: token_response.token_type,
176            expires_in: token_response.expires_in,
177            refresh_token: token_response
178                .refresh_token
179                .or_else(|| Some(refresh_token.to_string())),
180            scope: token_response.scope,
181            id_token: token_response.id_token,
182        })
183    }
184
185    async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
186        let response = self
187            .http_client
188            .delete(format!(
189                "https://api.github.com/applications/{}/token",
190                self.client_id
191            ))
192            .basic_auth(&self.client_id, Some(&self.client_secret))
193            .header("User-Agent", "authkestra-rs")
194            .json(&serde_json::json!({
195                "access_token": token
196            }))
197            .send()
198            .await
199            .map_err(|_| AuthError::Network)?;
200
201        if response.status().is_success() || response.status() == reqwest::StatusCode::NO_CONTENT {
202            Ok(())
203        } else {
204            let error_text = response
205                .text()
206                .await
207                .unwrap_or_else(|_| "Unknown error".to_string());
208            Err(AuthError::Provider(format!(
209                "Failed to revoke token: {}",
210                error_text
211            )))
212        }
213    }
214}