1use crate::error::Result;
3use crate::{
4 auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser, GenericAuthAction,
5};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use serde_with::{formats::SpaceSeparator, serde_as, StringWithSeparator};
10use std::collections::HashMap;
11
12pub struct AuthorizationServer {
13 config: AuthConfig,
14}
15
16auth_server_builder!();
17
18impl AuthUrlProvider for AuthorizationServer {
19 type AuthRequest = AuthRequest;
20 type TokenRequest = GetTokenRequest;
21 type UserInfoRequest = GetUserInfoRequest;
22
23 fn authorize_url(request: Self::AuthRequest) -> Result<String> {
24 let query = serde_urlencoded::to_string(request)?;
25 Ok(format!(
26 "https://openapi.baidu.com/oauth/2.0/authorize?response_type=CODE&{query}"
27 ))
28 }
29
30 fn access_token_url(request: Self::TokenRequest) -> Result<String> {
31 let query = serde_urlencoded::to_string(request)?;
32 Ok(format!(
33 "https://openapi.baidu.com/oauth/2.0/token?grant_type=authorization_code&{query}"
34 ))
35 }
36
37 fn user_info_url(request: Self::UserInfoRequest) -> Result<String> {
38 let query = serde_urlencoded::to_string(request)?;
39 Ok(format!(
40 "https://openapi.baidu.com/rest/2.0/passport/users/getInfo?{query}"
41 ))
42 }
43}
44
45#[async_trait]
46impl AuthAction for AuthorizationServer {
47 type AuthCallback = AuthCallback;
48 type AuthToken = TokenResponse;
49 type AuthUser = UserInfoResponse;
50
51 async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
52 let AuthConfig {
53 client_id,
54 client_secret,
55 redirect_uri,
56 ..
57 } = &self.config;
58 let access_token_url = Self::access_token_url(GetTokenRequest {
59 client_id: client_id.to_string(),
60 client_secret: client_secret.clone().expect("client_secret is empty"),
61 code: callback.code,
62 redirect_uri: redirect_uri.to_string(),
63 })?;
64 Ok(reqwest::get(access_token_url).await?.json().await?)
65 }
66
67 async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
68 let user_info_url = Self::user_info_url(GetUserInfoRequest {
69 access_token: token.access_token,
70 get_unionid: Some(1),
71 })?;
72 Ok(reqwest::get(user_info_url).await?.json().await?)
73 }
74}
75
76#[async_trait]
77impl GenericAuthAction for AuthorizationServer {
78 async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
79 let AuthConfig {
80 client_id,
81 redirect_uri,
82 scope,
83 ..
84 } = &self.config;
85 Self::authorize_url(AuthRequest {
86 client_id: client_id.to_string(),
87 redirect_uri: redirect_uri.to_string(),
88 state: Some(state.into()),
89 scope: scope.clone().unwrap_or_default(),
90 ..Default::default()
91 })
92 }
93
94 async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
95 let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
96 let token = self.get_access_token(callback).await?;
97 let user = self.get_user_info(token.clone()).await?;
98 Ok(AuthUser {
99 user_id: user.openid,
100 name: user.username.unwrap_or_default(),
101 access_token: token.access_token,
102 refresh_token: token.refresh_token,
103 expires_in: token.expires_in,
104 extra: user.extra,
105 })
106 }
107}
108
109#[serde_as]
110#[derive(Debug, Default, Serialize, Deserialize)]
111pub struct AuthRequest {
112 client_id: String,
113 redirect_uri: String,
114 #[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
115 scope: Vec<String>,
116 state: Option<String>,
117 display: Option<DisplayStyle>,
118 force_login: Option<i8>,
119 confirm_login: Option<i8>,
120 login_type: Option<String>,
121 qrext_clientid: Option<String>,
122 bgurl: Option<String>,
123 #[serde(rename = "qrcodeW")]
124 qrcode_width: Option<u32>,
125 #[serde(rename = "qrcodeH")]
126 qrcode_height: Option<u32>,
127 qrcode: Option<i8>,
128 qrloginfrom: Option<String>,
129 #[serde(rename = "userReg")]
130 user_reg: Option<i8>,
131 #[serde(rename = "appTip")]
132 app_tip: Option<String>,
133 #[serde(rename = "appName")]
134 app_name: Option<String>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
139#[serde(rename_all = "lowercase")]
140pub enum DisplayStyle {
141 Page,
142 Popup,
143 Dialog,
144 Mobile,
145 Pad,
146 Tv,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150pub struct AuthCallback {
151 code: String,
152 state: String,
153}
154
155#[derive(Debug, Serialize)]
156pub struct GetTokenRequest {
157 client_id: String,
158 client_secret: String,
159 code: String,
160 redirect_uri: String,
161}
162
163#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub struct TokenResponse {
165 pub access_token: String,
166 pub expires_in: i64,
167 pub refresh_token: String,
168 pub scope: String,
169 pub session_key: String,
170 pub session_secret: String,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174pub struct RefreshTokenRequest {
175 grant_type: String,
176 client_id: String,
177 client_secret: String,
178 refresh_token: String,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182pub struct GetUserInfoRequest {
183 access_token: String,
184 get_unionid: Option<i8>,
185}
186
187#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
188pub struct UserInfoResponse {
189 pub openid: String,
190 pub username: Option<String>,
191 #[serde(flatten)]
192 pub extra: HashMap<String, Value>,
193}