1use crate::error::Result;
3use crate::{
4 auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser, GenericAuthAction,
5};
6use async_trait::async_trait;
7use reqwest::header::ACCEPT;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use serde_with::{formats::SpaceSeparator, serde_as, StringWithSeparator};
11use std::collections::HashMap;
12
13pub struct AuthorizationServer {
14 config: AuthConfig,
15}
16
17auth_server_builder!();
18
19impl AuthUrlProvider for AuthorizationServer {
20 type AuthRequest = AuthRequest;
21 type TokenRequest = GetTokenRequest;
22 type UserInfoRequest = GetUserInfoRequest;
23
24 fn authorize_url(request: Self::AuthRequest) -> Result<String> {
25 let query = serde_urlencoded::to_string(request)?;
26 Ok(format!("https://github.com/login/oauth/authorize?{query}"))
27 }
28
29 fn access_token_url(request: Self::TokenRequest) -> Result<String> {
30 let query = serde_urlencoded::to_string(request)?;
31 Ok(format!(
32 "https://github.com/login/oauth/access_token?token_type=bearer&{query}"
33 ))
34 }
35
36 fn user_info_url(_request: Self::UserInfoRequest) -> Result<String> {
37 Ok(format!("https://api.github.com/user"))
38 }
39}
40
41#[async_trait]
42impl AuthAction for AuthorizationServer {
43 type AuthCallback = AuthCallback;
44 type AuthToken = TokenResponse;
45 type AuthUser = UserInfoResponse;
46
47 async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
48 let AuthConfig {
49 client_id,
50 client_secret,
51 redirect_uri,
52 ..
53 } = &self.config;
54 let access_token_url = Self::access_token_url(GetTokenRequest {
55 client_id: client_id.to_string(),
56 client_secret: client_secret.clone().expect("client_secret is empty"),
57 code: callback.code,
58 redirect_uri: redirect_uri.clone(),
59 })?;
60 Ok(reqwest::Client::default()
61 .get(access_token_url)
62 .header(ACCEPT, "application/json")
63 .send()
64 .await?
65 .json()
66 .await?)
67 }
68
69 async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
70 let user_info_url = Self::user_info_url(GetUserInfoRequest {})?;
71 Ok(reqwest::Client::default()
72 .get(user_info_url)
73 .bearer_auth(token.access_token)
74 .send()
75 .await?
76 .json()
77 .await?)
78 }
79}
80
81#[async_trait]
82impl GenericAuthAction for AuthorizationServer {
83 async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
84 let AuthConfig {
85 client_id,
86 redirect_uri,
87 scope,
88 ..
89 } = &self.config;
90 Self::authorize_url(AuthRequest {
91 client_id: client_id.to_string(),
92 redirect_uri: redirect_uri.to_string(),
93 state: state.into(),
94 scope: scope
95 .clone()
96 .or_else(|| Some(vec!["read:user".into(), "user:email".into()]))
97 .expect("scope is empty"),
98 ..Default::default()
99 })
100 }
101
102 async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
103 let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
104 let token = self.get_access_token(callback).await?;
105 let user = self.get_user_info(token.clone()).await?;
106 Ok(AuthUser {
107 user_id: user.id.to_string(),
108 name: user.name,
109 access_token: token.access_token,
110 refresh_token: token.token_type,
111 expires_in: i64::MAX,
112 extra: user.extra,
113 })
114 }
115}
116
117#[serde_as]
118#[derive(Debug, Default, Serialize, Deserialize)]
119pub struct AuthRequest {
120 client_id: String,
121 redirect_uri: String,
122 login: Option<String>,
123 #[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
124 scope: Vec<String>,
125 state: String,
126 allow_signup: Option<String>,
127 prompt: Option<String>,
128}
129
130#[derive(Debug, Serialize, Deserialize)]
131pub struct AuthCallback {
132 code: String,
133 state: String,
134}
135
136#[derive(Debug, Serialize)]
137pub struct GetTokenRequest {
138 client_id: String,
139 client_secret: String,
140 code: String,
141 redirect_uri: String,
142}
143
144#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
145pub struct TokenResponse {
146 pub access_token: String,
147 pub scope: String,
148 pub token_type: String,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152pub struct GetUserInfoRequest {}
153
154#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
156pub struct UserInfoResponse {
157 pub id: i64,
158 pub name: String,
159 #[serde(flatten)]
160 pub extra: HashMap<String, Value>,
161}