conrad_oauth/providers/
github.rs

1use crate::{
2    errors::OAuthError, utils, AuthInfo, ExpirationInfo, OAuthConfig, OAuthProvider, RedirectInfo,
3    Tokens, ValidationResult,
4};
5use async_trait::async_trait;
6use oauth2::{
7    basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
8    ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
9};
10use reqwest::Client;
11use serde::Deserialize;
12use std::time::Duration;
13
14const PROVIDER_ID: &str = "github";
15
16#[derive(Clone)]
17pub struct GithubConfig {
18    base: OAuthConfig,
19    redirect_uri: Option<String>,
20}
21
22impl GithubConfig {
23    pub fn new(client_id: String, client_secret: String, scope: Vec<String>) -> Self {
24        let base = OAuthConfig {
25            client_id,
26            client_secret,
27            scope,
28        };
29        Self {
30            base,
31            redirect_uri: None,
32        }
33    }
34
35    pub fn set_redirect_uri(self, redirect_uri: String) -> Self {
36        Self {
37            redirect_uri: Some(redirect_uri),
38            ..self
39        }
40    }
41}
42
43#[derive(Clone)]
44pub struct GitHubProvider {
45    client: BasicClient,
46    scope: Vec<String>,
47    web_client: Client,
48}
49
50#[async_trait]
51impl OAuthProvider for GitHubProvider {
52    type Config = GithubConfig;
53    type UserInfo = GitHubUser;
54
55    fn get_authorization_url(&self) -> RedirectInfo {
56        let mut req = self.client.authorize_url(CsrfToken::new_random);
57        for scope in &self.scope {
58            req = req.add_scope(Scope::new(scope.to_string()));
59        }
60        let info = req.url();
61        RedirectInfo {
62            url: info.0,
63            csrf_token: info.1.secret().to_string(),
64        }
65    }
66
67    fn new(config: Self::Config) -> Self {
68        let mut client = BasicClient::new(
69            ClientId::new(config.base.client_id),
70            Some(ClientSecret::new(config.base.client_secret)),
71            AuthUrl::new("https://github.com/login/oauth/authorize".to_string()).unwrap(),
72            Some(TokenUrl::new("https://github.com/login/oauth/access_token".to_string()).unwrap()),
73        );
74        if let Some(redirect_uri) = config.redirect_uri {
75            client = client.set_redirect_uri(RedirectUrl::new(redirect_uri).unwrap());
76        }
77        let web_client = Client::builder()
78            .timeout(Duration::from_secs(15))
79            .user_agent("conrad")
80            .build()
81            .unwrap();
82        Self {
83            client,
84            scope: config.base.scope,
85            web_client,
86        }
87    }
88
89    async fn validate_callback(
90        &self,
91        code: String,
92    ) -> Result<ValidationResult<Self::UserInfo>, OAuthError> {
93        let tokens = self.get_tokens(code).await?;
94        let provider_user = utils::get_provider_user::<GitHubUser>(
95            &self.web_client,
96            &tokens.access_token,
97            "https://api.github.com/user",
98        )
99        .await?;
100        let provider_user_id = provider_user.id.to_string();
101        Ok(ValidationResult {
102            tokens,
103            provider_user,
104            auth_info: AuthInfo {
105                provider_id: PROVIDER_ID,
106                provider_user_id,
107            },
108        })
109    }
110}
111
112impl GitHubProvider {
113    async fn get_tokens(&self, code: String) -> Result<Tokens, OAuthError> {
114        let token_result = self
115            .client
116            .exchange_code(AuthorizationCode::new(code))
117            .request_async(async_http_client)
118            .await
119            .map_err(|err| OAuthError::RequestError(Box::new(err)))?;
120        let access_token = token_result.access_token().secret().to_string();
121        Ok(if let Some(expires_in) = token_result.expires_in() {
122            Tokens {
123                access_token,
124                expiration_info: Some(ExpirationInfo {
125                    refresh_token: token_result.refresh_token().unwrap().secret().to_string(),
126                    expires_in: expires_in.as_millis() as i64,
127                }),
128                scope: None,
129            }
130        } else {
131            Tokens {
132                access_token,
133                expiration_info: None,
134                scope: None,
135            }
136        })
137    }
138}
139
140#[derive(Deserialize, Debug, Clone)]
141pub struct GitHubUser {
142    pub login: String,
143    pub id: i64,
144    pub node_id: String,
145    pub avatar_url: String,
146    pub gravatar_id: String,
147    pub url: String,
148    pub html_url: String,
149    pub followers_url: String,
150    pub following_url: String,
151    pub gists_url: String,
152    pub starred_url: String,
153    pub subscriptions_url: String,
154    pub organizations_url: String,
155    pub repos_url: String,
156    pub events_url: String,
157    pub received_events_url: String,
158    #[serde(rename = "type")]
159    pub account_type: String,
160    pub site_admin: String,
161    pub name: String,
162    pub company: String,
163    pub blog: String,
164    pub location: String,
165    pub email: String,
166    pub hireable: bool,
167    pub bio: String,
168    pub twitter_username: String,
169    pub public_repos: i64,
170    pub public_gists: i64,
171    pub followers: i64,
172    pub following: i64,
173    pub created_at: String,
174    pub updated_at: String,
175    pub private_gists: Option<i64>,
176    pub total_private_repos: Option<i64>,
177    pub owned_private_repos: Option<i64>,
178    pub disk_usage: Option<i64>,
179    pub collaborators: Option<i64>,
180    pub two_factor_authentication: Option<bool>,
181    pub plan: Option<Plan>,
182}
183
184#[derive(Deserialize, Debug, Clone)]
185pub struct Plan {
186    pub name: String,
187    pub space: i64,
188    pub private_repos: i64,
189    pub collaborators: i64,
190}