1use crate::{
2 errors::OAuthError, utils, AuthInfo, ExpirationInfo, OAuthConfig, OAuthProvider, RedirectInfo,
3 Tokens, ValidationResult,
4};
5use async_trait::async_trait;
6use oauth2::{
7 basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
8 ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
9};
10use reqwest::Client;
11use serde::Deserialize;
12use std::time::Duration;
13
14const PROVIDER_ID: &str = "github";
15
16#[derive(Clone)]
17pub struct GithubConfig {
18 base: OAuthConfig,
19 redirect_uri: Option<String>,
20}
21
22impl GithubConfig {
23 pub fn new(client_id: String, client_secret: String, scope: Vec<String>) -> Self {
24 let base = OAuthConfig {
25 client_id,
26 client_secret,
27 scope,
28 };
29 Self {
30 base,
31 redirect_uri: None,
32 }
33 }
34
35 pub fn set_redirect_uri(self, redirect_uri: String) -> Self {
36 Self {
37 redirect_uri: Some(redirect_uri),
38 ..self
39 }
40 }
41}
42
43#[derive(Clone)]
44pub struct GitHubProvider {
45 client: BasicClient,
46 scope: Vec<String>,
47 web_client: Client,
48}
49
50#[async_trait]
51impl OAuthProvider for GitHubProvider {
52 type Config = GithubConfig;
53 type UserInfo = GitHubUser;
54
55 fn get_authorization_url(&self) -> RedirectInfo {
56 let mut req = self.client.authorize_url(CsrfToken::new_random);
57 for scope in &self.scope {
58 req = req.add_scope(Scope::new(scope.to_string()));
59 }
60 let info = req.url();
61 RedirectInfo {
62 url: info.0,
63 csrf_token: info.1.secret().to_string(),
64 }
65 }
66
67 fn new(config: Self::Config) -> Self {
68 let mut client = BasicClient::new(
69 ClientId::new(config.base.client_id),
70 Some(ClientSecret::new(config.base.client_secret)),
71 AuthUrl::new("https://github.com/login/oauth/authorize".to_string()).unwrap(),
72 Some(TokenUrl::new("https://github.com/login/oauth/access_token".to_string()).unwrap()),
73 );
74 if let Some(redirect_uri) = config.redirect_uri {
75 client = client.set_redirect_uri(RedirectUrl::new(redirect_uri).unwrap());
76 }
77 let web_client = Client::builder()
78 .timeout(Duration::from_secs(15))
79 .user_agent("conrad")
80 .build()
81 .unwrap();
82 Self {
83 client,
84 scope: config.base.scope,
85 web_client,
86 }
87 }
88
89 async fn validate_callback(
90 &self,
91 code: String,
92 ) -> Result<ValidationResult<Self::UserInfo>, OAuthError> {
93 let tokens = self.get_tokens(code).await?;
94 let provider_user = utils::get_provider_user::<GitHubUser>(
95 &self.web_client,
96 &tokens.access_token,
97 "https://api.github.com/user",
98 )
99 .await?;
100 let provider_user_id = provider_user.id.to_string();
101 Ok(ValidationResult {
102 tokens,
103 provider_user,
104 auth_info: AuthInfo {
105 provider_id: PROVIDER_ID,
106 provider_user_id,
107 },
108 })
109 }
110}
111
112impl GitHubProvider {
113 async fn get_tokens(&self, code: String) -> Result<Tokens, OAuthError> {
114 let token_result = self
115 .client
116 .exchange_code(AuthorizationCode::new(code))
117 .request_async(async_http_client)
118 .await
119 .map_err(|err| OAuthError::RequestError(Box::new(err)))?;
120 let access_token = token_result.access_token().secret().to_string();
121 Ok(if let Some(expires_in) = token_result.expires_in() {
122 Tokens {
123 access_token,
124 expiration_info: Some(ExpirationInfo {
125 refresh_token: token_result.refresh_token().unwrap().secret().to_string(),
126 expires_in: expires_in.as_millis() as i64,
127 }),
128 scope: None,
129 }
130 } else {
131 Tokens {
132 access_token,
133 expiration_info: None,
134 scope: None,
135 }
136 })
137 }
138}
139
140#[derive(Deserialize, Debug, Clone)]
141pub struct GitHubUser {
142 pub login: String,
143 pub id: i64,
144 pub node_id: String,
145 pub avatar_url: String,
146 pub gravatar_id: String,
147 pub url: String,
148 pub html_url: String,
149 pub followers_url: String,
150 pub following_url: String,
151 pub gists_url: String,
152 pub starred_url: String,
153 pub subscriptions_url: String,
154 pub organizations_url: String,
155 pub repos_url: String,
156 pub events_url: String,
157 pub received_events_url: String,
158 #[serde(rename = "type")]
159 pub account_type: String,
160 pub site_admin: String,
161 pub name: String,
162 pub company: String,
163 pub blog: String,
164 pub location: String,
165 pub email: String,
166 pub hireable: bool,
167 pub bio: String,
168 pub twitter_username: String,
169 pub public_repos: i64,
170 pub public_gists: i64,
171 pub followers: i64,
172 pub following: i64,
173 pub created_at: String,
174 pub updated_at: String,
175 pub private_gists: Option<i64>,
176 pub total_private_repos: Option<i64>,
177 pub owned_private_repos: Option<i64>,
178 pub disk_usage: Option<i64>,
179 pub collaborators: Option<i64>,
180 pub two_factor_authentication: Option<bool>,
181 pub plan: Option<Plan>,
182}
183
184#[derive(Deserialize, Debug, Clone)]
185pub struct Plan {
186 pub name: String,
187 pub space: i64,
188 pub private_repos: i64,
189 pub collaborators: i64,
190}