1use crate::{
4 auth_server_builder, error::Result, AuthAction, AuthConfig, AuthUrlProvider, AuthUser,
5 GenericAuthAction,
6};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use serde_with::{formats::CommaSeparator, serde_as, StringWithSeparator};
11use std::collections::HashMap;
12
13pub struct AuthorizationServer {
14 config: AuthConfig,
15}
16
17auth_server_builder!();
18
19impl AuthUrlProvider for AuthorizationServer {
20 type AuthRequest = AuthRequest;
21
22 type TokenRequest = GetTokenRequest;
23
24 type UserInfoRequest = GetUserInfoRequest;
25
26 fn authorize_url(request: Self::AuthRequest) -> Result<String> {
27 let query = serde_urlencoded::to_string(request)?;
28 Ok(format!(
29 "https://open.weixin.qq.com/connect/qrconnect?response_type=code&{query}"
30 ))
31 }
32
33 fn access_token_url(request: Self::TokenRequest) -> Result<String> {
34 let query = serde_urlencoded::to_string(request)?;
35 Ok(format!(
36 "https://api.weixin.qq.com/sns/oauth2/access_token?grant_type=authorization_code&{query}"
37 ))
38 }
39
40 fn user_info_url(request: Self::UserInfoRequest) -> Result<String> {
41 let query = serde_urlencoded::to_string(request)?;
42 Ok(format!("https://api.weixin.qq.com/sns/userinfo?{query}"))
43 }
44}
45
46#[async_trait]
47impl AuthAction for AuthorizationServer {
48 type AuthCallback = AuthCallback;
49 type AuthToken = TokenResponse;
50 type AuthUser = UserInfoResponse;
51
52 async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
53 let AuthConfig {
54 client_id,
55 client_secret,
56 ..
57 } = &self.config;
58 let access_token_url = Self::access_token_url(GetTokenRequest {
59 appid: client_id.to_string(),
60 secret: client_secret.clone().expect("client_secret is empty"),
61 code: callback.code,
62 })?;
63 Ok(reqwest::get(access_token_url).await?.json().await?)
64 }
65
66 async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
67 let user_info_url = Self::user_info_url(GetUserInfoRequest {
68 openid: token.unionid,
69 access_token: token.access_token,
70 ..Default::default()
71 })?;
72 Ok(reqwest::get(user_info_url).await?.json().await?)
73 }
74}
75
76#[async_trait]
77impl GenericAuthAction for AuthorizationServer {
78 async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
79 let AuthConfig {
80 client_id,
81 redirect_uri,
82 scope,
83 ..
84 } = &self.config;
85 Self::authorize_url(AuthRequest {
86 appid: client_id.to_string(),
87 redirect_uri: redirect_uri.to_string(),
88 state: Some(state.into()),
89 scope: scope
90 .clone()
91 .or_else(|| {
92 Some(vec![
93 "snsapi_base".into(),
94 "snsapi_login".into(),
95 "snsapi_userinfo".into(),
96 ])
97 })
98 .expect("scope is empty"),
99 ..Default::default()
100 })
101 }
102
103 async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
104 let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
105 let token = self.get_access_token(callback).await?;
106 let user = self.get_user_info(token.clone()).await?;
107 Ok(AuthUser {
108 user_id: user.unionid,
109 name: user.nickname,
110 access_token: token.access_token,
111 refresh_token: token.refresh_token,
112 expires_in: token.expires_in,
113 extra: user.extra,
114 })
115 }
116}
117
118#[serde_as]
119#[derive(Debug, Default, Serialize, Deserialize)]
120pub struct AuthRequest {
121 appid: String,
122 redirect_uri: String,
123 #[serde_as(as = "StringWithSeparator::<CommaSeparator, String>")]
124 scope: Vec<String>,
125 state: Option<String>,
126 lang: Option<Lang>,
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130#[serde(rename_all = "lowercase")]
131pub enum Lang {
132 En,
133 Cn,
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137pub struct AuthCallback {
138 code: String,
139 state: String,
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143pub struct GetTokenRequest {
144 appid: String,
145 secret: String,
146 code: String,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150pub struct RefreshTokenRequest {
151 grant_type: String,
152 appid: String,
153 refresh_token: String,
154}
155
156#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
157pub struct TokenResponse {
158 pub access_token: String,
159 pub expires_in: i64,
160 pub refresh_token: String,
161 pub openid: String,
162 pub scope: String,
163 pub unionid: String,
164}
165
166#[derive(Debug, Default, Serialize, Deserialize)]
167pub struct GetUserInfoRequest {
168 access_token: String,
169 openid: String,
170 lang: Option<String>,
171}
172
173#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
175pub struct UserInfoResponse {
176 pub unionid: String,
177 pub nickname: String,
178 #[serde(flatten)]
179 pub extra: HashMap<String, Value>,
180}