Skip to main content

authkestra_providers_github/
lib.rs

1use async_trait::async_trait;
2use authkestra_core::{error::AuthError, state::Identity, state::OAuthToken, OAuthProvider};
3use serde::Deserialize;
4use std::collections::HashMap;
5
6pub struct GithubProvider {
7    client_id: String,
8    client_secret: String,
9    redirect_uri: String,
10    http_client: reqwest::Client,
11    authorization_url: String,
12    token_url: String,
13    user_url: String,
14}
15
16impl GithubProvider {
17    pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
18        Self {
19            client_id,
20            client_secret,
21            redirect_uri,
22            http_client: reqwest::Client::new(),
23            authorization_url: "https://github.com/login/oauth/authorize".to_string(),
24            token_url: "https://github.com/login/oauth/access_token".to_string(),
25            user_url: "https://api.github.com/user".to_string(),
26        }
27    }
28
29    pub fn with_test_urls(
30        mut self,
31        authorization_url: String,
32        token_url: String,
33        user_url: String,
34    ) -> Self {
35        self.authorization_url = authorization_url;
36        self.token_url = token_url;
37        self.user_url = user_url;
38        self
39    }
40
41    pub fn with_authorization_url(mut self, authorization_url: String) -> Self {
42        self.authorization_url = authorization_url;
43        self
44    }
45}
46
47#[derive(Deserialize)]
48struct GithubAccessTokenResponse {
49    access_token: String,
50    token_type: String,
51    expires_in: Option<u64>,
52    refresh_token: Option<String>,
53    scope: Option<String>,
54    id_token: Option<String>,
55}
56
57#[derive(Deserialize)]
58struct GithubUserResponse {
59    id: u64,
60    login: String,
61    email: Option<String>,
62}
63
64#[async_trait]
65impl OAuthProvider for GithubProvider {
66    fn provider_id(&self) -> &str {
67        "github"
68    }
69
70    fn get_authorization_url(
71        &self,
72        state: &str,
73        scopes: &[&str],
74        _code_challenge: Option<&str>,
75    ) -> String {
76        let scope_param = if scopes.is_empty() {
77            "user:email".to_string()
78        } else {
79            scopes.join(" ")
80        };
81
82        format!(
83            "{}?client_id={}&redirect_uri={}&state={}&scope={}",
84            self.authorization_url, self.client_id, self.redirect_uri, state, scope_param
85        )
86    }
87
88    async fn exchange_code_for_identity(
89        &self,
90        code: &str,
91        _code_verifier: Option<&str>,
92    ) -> Result<(Identity, OAuthToken), AuthError> {
93        // 1. Exchange code for access token
94        let token_response = self
95            .http_client
96            .post(&self.token_url)
97            .header("Accept", "application/json")
98            .form(&[
99                ("client_id", &self.client_id),
100                ("client_secret", &self.client_secret),
101                ("code", &code.to_string()),
102                ("redirect_uri", &self.redirect_uri),
103            ])
104            .send()
105            .await
106            .map_err(|_| AuthError::Network)?
107            .json::<GithubAccessTokenResponse>()
108            .await
109            .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
110
111        // 2. Get user information
112        let user_response = self
113            .http_client
114            .get(&self.user_url)
115            .header(
116                "Authorization",
117                format!("Bearer {}", token_response.access_token),
118            )
119            .header("User-Agent", "authkestra")
120            .send()
121            .await
122            .map_err(|_| AuthError::Network)?
123            .json::<GithubUserResponse>()
124            .await
125            .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
126
127        // 3. Map to Identity
128        let identity = Identity {
129            provider_id: "github".to_string(),
130            external_id: user_response.id.to_string(),
131            email: user_response.email,
132            username: Some(user_response.login),
133            attributes: HashMap::new(),
134        };
135
136        let token = OAuthToken {
137            access_token: token_response.access_token,
138            token_type: token_response.token_type,
139            expires_in: token_response.expires_in,
140            refresh_token: token_response.refresh_token,
141            scope: token_response.scope,
142            id_token: token_response.id_token,
143        };
144
145        Ok((identity, token))
146    }
147
148    async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
149        let token_response = self
150            .http_client
151            .post(&self.token_url)
152            .header("Accept", "application/json")
153            .form(&[
154                ("client_id", &self.client_id),
155                ("client_secret", &self.client_secret),
156                ("grant_type", &"refresh_token".to_string()),
157                ("refresh_token", &refresh_token.to_string()),
158            ])
159            .send()
160            .await
161            .map_err(|_| AuthError::Network)?
162            .json::<GithubAccessTokenResponse>()
163            .await
164            .map_err(|e| {
165                AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
166            })?;
167
168        Ok(OAuthToken {
169            access_token: token_response.access_token,
170            token_type: token_response.token_type,
171            expires_in: token_response.expires_in,
172            refresh_token: token_response
173                .refresh_token
174                .or_else(|| Some(refresh_token.to_string())),
175            scope: token_response.scope,
176            id_token: token_response.id_token,
177        })
178    }
179
180    async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
181        let response = self
182            .http_client
183            .delete(format!(
184                "https://api.github.com/applications/{}/token",
185                self.client_id
186            ))
187            .basic_auth(&self.client_id, Some(&self.client_secret))
188            .header("User-Agent", "authkestra")
189            .json(&serde_json::json!({
190                "access_token": token
191            }))
192            .send()
193            .await
194            .map_err(|_| AuthError::Network)?;
195
196        if response.status().is_success() || response.status() == reqwest::StatusCode::NO_CONTENT {
197            Ok(())
198        } else {
199            let error_text = response
200                .text()
201                .await
202                .unwrap_or_else(|_| "Unknown error".to_string());
203            Err(AuthError::Provider(format!(
204                "Failed to revoke token: {}",
205                error_text
206            )))
207        }
208    }
209}