just_auth/
github.rs

1//! https://docs.github.com/zh/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps
2use crate::error::Result;
3use crate::{
4    auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser, GenericAuthAction,
5};
6use async_trait::async_trait;
7use reqwest::header::ACCEPT;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use serde_with::{formats::SpaceSeparator, serde_as, StringWithSeparator};
11use std::collections::HashMap;
12
13pub struct AuthorizationServer {
14    config: AuthConfig,
15}
16
17auth_server_builder!();
18
19impl AuthUrlProvider for AuthorizationServer {
20    type AuthRequest = AuthRequest;
21    type TokenRequest = GetTokenRequest;
22    type UserInfoRequest = GetUserInfoRequest;
23
24    fn authorize_url(request: Self::AuthRequest) -> Result<String> {
25        let query = serde_urlencoded::to_string(request)?;
26        Ok(format!("https://github.com/login/oauth/authorize?{query}"))
27    }
28
29    fn access_token_url(request: Self::TokenRequest) -> Result<String> {
30        let query = serde_urlencoded::to_string(request)?;
31        Ok(format!(
32            "https://github.com/login/oauth/access_token?token_type=bearer&{query}"
33        ))
34    }
35
36    fn user_info_url(_request: Self::UserInfoRequest) -> Result<String> {
37        Ok(format!("https://api.github.com/user"))
38    }
39}
40
41#[async_trait]
42impl AuthAction for AuthorizationServer {
43    type AuthCallback = AuthCallback;
44    type AuthToken = TokenResponse;
45    type AuthUser = UserInfoResponse;
46
47    async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
48        let AuthConfig {
49            client_id,
50            client_secret,
51            redirect_uri,
52            ..
53        } = &self.config;
54        let access_token_url = Self::access_token_url(GetTokenRequest {
55            client_id: client_id.to_string(),
56            client_secret: client_secret.clone().expect("client_secret is empty"),
57            code: callback.code,
58            redirect_uri: redirect_uri.clone(),
59        })?;
60        Ok(reqwest::Client::default()
61            .get(access_token_url)
62            .header(ACCEPT, "application/json")
63            .send()
64            .await?
65            .json()
66            .await?)
67    }
68
69    async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
70        let user_info_url = Self::user_info_url(GetUserInfoRequest {})?;
71        Ok(reqwest::Client::default()
72            .get(user_info_url)
73            .bearer_auth(token.access_token)
74            .send()
75            .await?
76            .json()
77            .await?)
78    }
79}
80
81#[async_trait]
82impl GenericAuthAction for AuthorizationServer {
83    async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
84        let AuthConfig {
85            client_id,
86            redirect_uri,
87            scope,
88            ..
89        } = &self.config;
90        Self::authorize_url(AuthRequest {
91            client_id: client_id.to_string(),
92            redirect_uri: redirect_uri.to_string(),
93            state: state.into(),
94            scope: scope
95                .clone()
96                .or_else(|| Some(vec!["read:user".into(), "user:email".into()]))
97                .expect("scope is empty"),
98            ..Default::default()
99        })
100    }
101
102    async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
103        let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
104        let token = self.get_access_token(callback).await?;
105        let user = self.get_user_info(token.clone()).await?;
106        Ok(AuthUser {
107            user_id: user.id.to_string(),
108            name: user.name,
109            access_token: token.access_token,
110            refresh_token: token.token_type,
111            expires_in: i64::MAX,
112            extra: user.extra,
113        })
114    }
115}
116
117#[serde_as]
118#[derive(Debug, Default, Serialize, Deserialize)]
119pub struct AuthRequest {
120    client_id: String,
121    redirect_uri: String,
122    login: Option<String>,
123    #[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
124    scope: Vec<String>,
125    state: String,
126    allow_signup: Option<String>,
127    prompt: Option<String>,
128}
129
130#[derive(Debug, Serialize, Deserialize)]
131pub struct AuthCallback {
132    code: String,
133    state: String,
134}
135
136#[derive(Debug, Serialize)]
137pub struct GetTokenRequest {
138    client_id: String,
139    client_secret: String,
140    code: String,
141    redirect_uri: String,
142}
143
144#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
145pub struct TokenResponse {
146    pub access_token: String,
147    pub scope: String,
148    pub token_type: String,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152pub struct GetUserInfoRequest {}
153
154/// https://docs.github.com/en/rest/users/users
155#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
156pub struct UserInfoResponse {
157    pub id: i64,
158    pub name: String,
159    #[serde(flatten)]
160    pub extra: HashMap<String, Value>,
161}