wechat_minapp/client.rs
1use crate::{
2 Result,
3 access_token::{AccessToken, get_access_token, get_stable_access_token},
4 constants,
5 credential::{Credential, CredentialBuilder},
6 error::Error::InternalServer,
7 response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11 collections::HashMap,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, Ordering},
15 },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20///
21/// 提供与微信小程序后端 API 交互的核心功能,包括用户登录、访问令牌管理等。
22///
23/// # 功能特性
24///
25/// - 用户登录凭证校验
26/// - 访问令牌自动管理(支持普通令牌和稳定版令牌)
27/// - 线程安全的令牌刷新机制
28/// - 内置 HTTP 客户端
29///
30/// # 快速开始
31///
32/// ```no_run
33/// use wechat_minapp::Client;
34///
35/// #[tokio::main]
36/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
37/// // 初始化客户端
38/// let app_id = "your_app_id";
39/// let secret = "your_app_secret";
40/// let client = Client::new(app_id, secret);
41///
42/// // 用户登录
43/// let code = "user_login_code_from_frontend";
44/// let credential = client.login(code).await?;
45/// println!("用户OpenID: {}", credential.open_id());
46///
47/// // 获取访问令牌
48/// let access_token = client.access_token().await?;
49/// println!("访问令牌: {}", access_token);
50///
51/// Ok(())
52/// }
53/// ```
54///
55/// # 令牌管理
56///
57/// 客户端自动管理访问令牌的生命周期:
58///
59/// - 令牌过期前自动刷新
60/// - 多线程环境下的安全并发访问
61/// - 避免重复刷新(令牌锁机制)
62/// - 支持强制刷新选项
63///
64/// # 线程安全
65///
66/// `Client` 实现了 `Send` 和 `Sync`,可以在多线程环境中安全使用。
67#[derive(Debug, Clone)]
68pub struct Client {
69 inner: Arc<ClientInner>,
70 access_token: Arc<RwLock<AccessToken>>,
71 refreshing: Arc<AtomicBool>,
72 notify: Arc<Notify>,
73 use_stable_token: bool,
74}
75
76impl Client {
77 /// 创建新的微信小程序客户端
78 ///
79 /// # 参数
80 ///
81 /// - `app_id`: 小程序 AppID
82 /// - `secret`: 小程序 AppSecret
83 ///
84 /// # 返回
85 ///
86 /// 新的 `Client` 实例
87 ///
88 /// # 示例
89 ///
90 /// ```
91 /// use wechat_minapp::Client;
92 ///
93 /// let client = Client::new("wx1234567890abcdef", "your_app_secret_here");
94 /// ```
95 pub fn new(app_id: &str, secret: &str) -> Self {
96 let client = reqwest::Client::new();
97
98 Self {
99 inner: Arc::new(ClientInner {
100 app_id: app_id.into(),
101 secret: secret.into(),
102 client,
103 }),
104 access_token: Arc::new(RwLock::new(AccessToken {
105 access_token: "".to_string(),
106 expired_at: Utc::now(),
107 force_refresh: None,
108 })),
109 refreshing: Arc::new(AtomicBool::new(false)),
110 notify: Arc::new(Notify::new()),
111 use_stable_token: true,
112 }
113 }
114
115 pub fn with_non_stable(app_id: &str, secret: &str) -> Self {
116 let client = reqwest::Client::new();
117
118 Self {
119 inner: Arc::new(ClientInner {
120 app_id: app_id.into(),
121 secret: secret.into(),
122 client,
123 }),
124 access_token: Arc::new(RwLock::new(AccessToken {
125 access_token: "".to_string(),
126 expired_at: Utc::now(),
127 force_refresh: None,
128 })),
129 refreshing: Arc::new(AtomicBool::new(false)),
130 notify: Arc::new(Notify::new()),
131 use_stable_token: false,
132 }
133 }
134
135 pub(crate) fn request(&self) -> &reqwest::Client {
136 &self.inner.client
137 }
138
139 /// 用户登录凭证校验
140 ///
141 /// 通过微信前端获取的临时登录凭证 code,换取用户的唯一标识 OpenID 和会话密钥。
142 ///
143 /// # 参数
144 ///
145 /// - `code`: 微信前端通过 `wx.login()` 获取的临时登录凭证
146 ///
147 /// # 返回
148 ///
149 /// 成功返回 `Ok(Credential)`,包含用户身份信息
150 ///
151 /// # 错误
152 ///
153 /// - 网络错误
154 /// - 微信 API 返回错误
155 /// - 响应解析错误
156 ///
157 /// # 示例
158 ///
159 /// ```no_run
160 /// use wechat_minapp::Client;
161 ///
162 /// #[tokio::main]
163 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
164 /// let client = Client::new("app_id", "secret");
165 /// let code = "0816abc123def456";
166 /// let credential = client.login(code).await?;
167 ///
168 /// println!("用户OpenID: {}", credential.open_id());
169 /// println!("会话密钥: {}", credential.session_key());
170 ///
171 /// Ok(())
172 /// }
173 /// ```
174 ///
175 /// # API 文档
176 ///
177 /// [微信官方文档 - code2Session](https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/code2Session.html)
178 #[instrument(skip(self, code))]
179 pub async fn login(&self, code: &str) -> Result<Credential> {
180 debug!("code: {}", code);
181
182 let mut map: HashMap<&str, &str> = HashMap::new();
183
184 map.insert("appid", &self.inner.app_id);
185 map.insert("secret", &self.inner.secret);
186 map.insert("js_code", code);
187 map.insert("grant_type", "authorization_code");
188
189 let response = self
190 .inner
191 .client
192 .get(constants::AUTHENTICATION_END_POINT)
193 .query(&map)
194 .send()
195 .await?;
196
197 debug!("authentication response: {:#?}", response);
198
199 if response.status().is_success() {
200 let response = response.json::<Response<CredentialBuilder>>().await?;
201
202 let credential = response.extract()?.build();
203
204 debug!("credential: {:#?}", credential);
205
206 Ok(credential)
207 } else {
208 Err(InternalServer(response.text().await?))
209 }
210 }
211
212 pub async fn token(&self) -> Result<String> {
213 if self.use_stable_token {
214 self.stable_access_token(None).await
215 } else {
216 self.access_token().await
217 }
218 }
219
220 /// 获取访问令牌
221 ///
222 /// 获取用于调用微信小程序接口的访问令牌。如果当前令牌已过期或即将过期,会自动刷新。
223 ///
224 /// # 返回
225 ///
226 /// 成功返回 `Ok(String)`,包含有效的访问令牌
227 ///
228 /// # 错误
229 ///
230 /// - 网络错误
231 /// - 微信 API 返回错误
232 /// - 令牌刷新失败
233 ///
234 /// # 示例
235 ///
236 /// ```no_run
237 /// use wechat_minapp::Client;
238 ///
239 /// #[tokio::main]
240 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
241 /// let client = Client::new("app_id", "secret");
242 /// let access_token = client.access_token().await?;
243 ///
244 /// println!("访问令牌: {}", access_token);
245 /// Ok(())
246 /// }
247 /// ```
248 ///
249 /// # 注意
250 ///
251 /// - 令牌有效期为 2 小时
252 /// - 客户端会自动管理令牌刷新,无需手动处理
253 /// - 多线程环境下安全
254 pub async fn access_token(&self) -> Result<String> {
255 // 第一次检查:快速路径
256 {
257 let guard = self.access_token.read().await;
258 if !is_token_expired(&guard) {
259 return Ok(guard.access_token.clone());
260 }
261 }
262
263 // 使用CAS竞争刷新权
264 if self
265 .refreshing
266 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
267 .is_ok()
268 {
269 // 获得刷新权
270 match self.refresh_access_token().await {
271 Ok(token) => {
272 self.refreshing.store(false, Ordering::Release);
273 self.notify.notify_waiters();
274 Ok(token)
275 }
276 Err(e) => {
277 self.refreshing.store(false, Ordering::Release);
278 self.notify.notify_waiters();
279 Err(e)
280 }
281 }
282 } else {
283 // 等待其他线程刷新完成
284 self.notify.notified().await;
285 // 刷新完成后重新读取
286 let guard = self.access_token.read().await;
287 Ok(guard.access_token.clone())
288 }
289 }
290
291 async fn refresh_access_token(&self) -> Result<String> {
292 let mut guard = self.access_token.write().await;
293
294 if !is_token_expired(&guard) {
295 debug!("token already refreshed by another thread");
296 return Ok(guard.access_token.clone());
297 }
298
299 debug!("performing network request to refresh token");
300
301 let builder = get_access_token(
302 self.inner.client.clone(),
303 &self.inner.app_id,
304 &self.inner.secret,
305 )
306 .await?;
307
308 guard.access_token = builder.access_token.clone();
309 guard.expired_at = builder.expired_at;
310
311 debug!("fresh access token: {:#?}", guard);
312
313 Ok(guard.access_token.clone())
314 }
315
316 /// 获取稳定版访问令牌
317 ///
318 /// 获取稳定版的访问令牌,相比普通令牌有更长的有效期和更好的稳定性。
319 ///
320 /// # 参数
321 ///
322 /// - `force_refresh`: 是否强制刷新令牌
323 /// - `Some(true)`: 强制从微信服务器获取最新令牌
324 /// - `Some(false)` 或 `None`: 仅在令牌过期时刷新
325 ///
326 /// # 返回
327 ///
328 /// 成功返回 `Ok(String)`,包含有效的稳定版访问令牌
329 ///
330 /// # 错误
331 ///
332 /// - 网络错误
333 /// - 微信 API 返回错误
334 /// - 令牌刷新失败
335 ///
336 /// # 示例
337 ///
338 /// ```no_run
339 /// use wechat_minapp::Client;
340 ///
341 /// #[tokio::main]
342 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
343 /// let client = Client::new("app_id", "secret");
344 ///
345 /// // 仅在过期时刷新
346 /// let token1 = client.stable_access_token(None).await?;
347 ///
348 /// // 强制刷新
349 /// let token2 = client.stable_access_token(true).await?;
350 ///
351 /// Ok(())
352 /// }
353 /// ```
354 ///
355 /// # 注意
356 ///
357 /// - 稳定版令牌有效期更长,推荐在生产环境使用
358 /// - 强制刷新会忽略本地缓存,直接请求新令牌
359 pub async fn stable_access_token(
360 &self,
361 force_refresh: impl Into<Option<bool>> + Clone + Send,
362 ) -> Result<String> {
363 // 第一次检查:快速路径
364 {
365 let guard = self.access_token.read().await;
366 if !is_token_expired(&guard) {
367 return Ok(guard.access_token.clone());
368 }
369 }
370
371 // 使用CAS竞争刷新权
372 if self
373 .refreshing
374 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
375 .is_ok()
376 {
377 // 获得刷新权
378 match self.refresh_stable_access_token(force_refresh).await {
379 Ok(token) => {
380 self.refreshing.store(false, Ordering::Release);
381 self.notify.notify_waiters();
382 Ok(token)
383 }
384 Err(e) => {
385 self.refreshing.store(false, Ordering::Release);
386 self.notify.notify_waiters();
387 Err(e)
388 }
389 }
390 } else {
391 // 等待其他线程刷新完成
392 self.notify.notified().await;
393 // 刷新完成后重新读取
394 let guard = self.access_token.read().await;
395 Ok(guard.access_token.clone())
396 }
397 }
398
399 async fn refresh_stable_access_token(
400 &self,
401 force_refresh: impl Into<Option<bool>> + Clone + Send,
402 ) -> Result<String> {
403 // 1. Acquire the write lock. This blocks if another thread won CAS but is refreshing.
404 let mut guard = self.access_token.write().await;
405
406 // 2. Double-check expiration under the write lock (CRITICAL)
407 // If another CAS-winner refreshed the token while we were waiting for the write lock,
408 // we return the new token without performing a new network call.
409 if !is_token_expired(&guard) {
410 // Token is now fresh, return it
411 debug!("token already refreshed by another thread");
412 return Ok(guard.access_token.clone());
413 }
414
415 // 3. Perform the network request since the token is still stale
416 debug!("performing network request to refresh token");
417
418 let builder = get_stable_access_token(
419 self.inner.client.clone(),
420 &self.inner.app_id,
421 &self.inner.secret,
422 force_refresh,
423 )
424 .await?;
425
426 // 4. Update the token
427 guard.access_token = builder.access_token.clone();
428 guard.expired_at = builder.expired_at;
429
430 debug!("fresh access token: {:#?}", guard);
431
432 // Return the newly fetched token (cloned here for consistency)
433 Ok(guard.access_token.clone())
434 }
435}
436
437#[derive(Debug)]
438struct ClientInner {
439 app_id: String,
440 secret: String,
441 client: reqwest::Client,
442}
443
444/// 检查令牌是否过期
445///
446/// 添加安全边界,在令牌过期前5分钟就认为需要刷新
447fn is_token_expired(token: &AccessToken) -> bool {
448 // 添加安全边界,提前刷新
449 let now = Utc::now();
450 token.expired_at.signed_duration_since(now) < Duration::minutes(5)
451}