1use async_trait::async_trait;
2use authly_core::{AuthError, Identity, OAuthProvider, OAuthToken};
3use serde::Deserialize;
4use std::collections::HashMap;
5
6pub struct GithubProvider {
7 client_id: String,
8 client_secret: String,
9 redirect_uri: String,
10 http_client: reqwest::Client,
11 authorization_url: String,
12 token_url: String,
13 user_url: String,
14}
15
16impl GithubProvider {
17 pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
18 Self {
19 client_id,
20 client_secret,
21 redirect_uri,
22 http_client: reqwest::Client::new(),
23 authorization_url: "https://github.com/login/oauth/authorize".to_string(),
24 token_url: "https://github.com/login/oauth/access_token".to_string(),
25 user_url: "https://api.github.com/user".to_string(),
26 }
27 }
28
29 pub fn with_test_urls(
30 mut self,
31 authorization_url: String,
32 token_url: String,
33 user_url: String,
34 ) -> Self {
35 self.authorization_url = authorization_url;
36 self.token_url = token_url;
37 self.user_url = user_url;
38 self
39 }
40
41 pub fn with_authorization_url(mut self, authorization_url: String) -> Self {
42 self.authorization_url = authorization_url;
43 self
44 }
45}
46
47#[derive(Deserialize)]
48struct GithubAccessTokenResponse {
49 access_token: String,
50 #[serde(default = "default_token_type")]
51 token_type: String,
52 expires_in: Option<u64>,
53 refresh_token: Option<String>,
54 scope: Option<String>,
55 id_token: Option<String>,
56}
57
58fn default_token_type() -> String {
59 "Bearer".to_string()
60}
61
62#[derive(Deserialize)]
63struct GithubUserResponse {
64 id: u64,
65 login: String,
66 email: Option<String>,
67}
68
69#[async_trait]
70impl OAuthProvider for GithubProvider {
71 fn get_authorization_url(
72 &self,
73 state: &str,
74 scopes: &[&str],
75 _code_challenge: Option<&str>,
76 ) -> String {
77 let scope_param = if scopes.is_empty() {
78 "user:email".to_string()
79 } else {
80 scopes.join(" ")
81 };
82
83 format!(
84 "{}?client_id={}&redirect_uri={}&state={}&scope={}",
85 self.authorization_url, self.client_id, self.redirect_uri, state, scope_param
86 )
87 }
88
89 async fn exchange_code_for_identity(
90 &self,
91 code: &str,
92 _code_verifier: Option<&str>,
93 ) -> Result<(Identity, OAuthToken), AuthError> {
94 let token_response = self
96 .http_client
97 .post(&self.token_url)
98 .header("Accept", "application/json")
99 .form(&[
100 ("client_id", &self.client_id),
101 ("client_secret", &self.client_secret),
102 ("code", &code.to_string()),
103 ("redirect_uri", &self.redirect_uri),
104 ])
105 .send()
106 .await
107 .map_err(|_| AuthError::Network)?
108 .json::<GithubAccessTokenResponse>()
109 .await
110 .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
111
112 let user_response = self
114 .http_client
115 .get(&self.user_url)
116 .header(
117 "Authorization",
118 format!("Bearer {}", token_response.access_token),
119 )
120 .header("User-Agent", "authly-rs")
121 .send()
122 .await
123 .map_err(|_| AuthError::Network)?
124 .json::<GithubUserResponse>()
125 .await
126 .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
127
128 let identity = Identity {
130 provider_id: "github".to_string(),
131 external_id: user_response.id.to_string(),
132 email: user_response.email,
133 username: Some(user_response.login),
134 attributes: HashMap::new(),
135 };
136
137 let token = OAuthToken {
138 access_token: token_response.access_token,
139 token_type: token_response.token_type,
140 expires_in: token_response.expires_in,
141 refresh_token: token_response.refresh_token,
142 scope: token_response.scope,
143 id_token: token_response.id_token,
144 };
145
146 Ok((identity, token))
147 }
148
149 async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
150 let token_response = self
151 .http_client
152 .post(&self.token_url)
153 .header("Accept", "application/json")
154 .form(&[
155 ("client_id", &self.client_id),
156 ("client_secret", &self.client_secret),
157 ("grant_type", &"refresh_token".to_string()),
158 ("refresh_token", &refresh_token.to_string()),
159 ])
160 .send()
161 .await
162 .map_err(|_| AuthError::Network)?
163 .json::<GithubAccessTokenResponse>()
164 .await
165 .map_err(|e| {
166 AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
167 })?;
168
169 Ok(OAuthToken {
170 access_token: token_response.access_token,
171 token_type: token_response.token_type,
172 expires_in: token_response.expires_in,
173 refresh_token: token_response
174 .refresh_token
175 .or_else(|| Some(refresh_token.to_string())),
176 scope: token_response.scope,
177 id_token: token_response.id_token,
178 })
179 }
180
181 async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
182 let response = self
183 .http_client
184 .delete(format!(
185 "https://api.github.com/applications/{}/token",
186 self.client_id
187 ))
188 .basic_auth(&self.client_id, Some(&self.client_secret))
189 .header("User-Agent", "authly-rs")
190 .json(&serde_json::json!({
191 "access_token": token
192 }))
193 .send()
194 .await
195 .map_err(|_| AuthError::Network)?;
196
197 if response.status().is_success() || response.status() == reqwest::StatusCode::NO_CONTENT {
198 Ok(())
199 } else {
200 let error_text = response
201 .text()
202 .await
203 .unwrap_or_else(|_| "Unknown error".to_string());
204 Err(AuthError::Provider(format!(
205 "Failed to revoke token: {}",
206 error_text
207 )))
208 }
209 }
210}