wechat_minapp/
client.rs

1use crate::{
2    Result,
3    access_token::{AccessToken, get_access_token, get_stable_access_token},
4    constants,
5    credential::{Credential, CredentialBuilder},
6    error::Error::InternalServer,
7    response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11    collections::HashMap,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, Ordering},
15    },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20///
21/// 提供与微信小程序后端 API 交互的核心功能,包括用户登录、访问令牌管理等。
22///
23/// # 功能特性
24///
25/// - 用户登录凭证校验
26/// - 访问令牌自动管理(支持普通令牌和稳定版令牌)
27/// - 线程安全的令牌刷新机制
28/// - 内置 HTTP 客户端
29///
30/// # 快速开始
31///
32/// ```no_run
33/// use wechat_minapp::Client;
34///
35/// #[tokio::main]
36/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
37///     // 初始化客户端
38///     let app_id = "your_app_id";
39///     let secret = "your_app_secret";
40///     let client = Client::new(app_id, secret);
41///
42///     // 用户登录
43///     let code = "user_login_code_from_frontend";
44///     let credential = client.login(code).await?;
45///     println!("用户OpenID: {}", credential.open_id());
46///
47///     // 获取访问令牌
48///     let access_token = client.access_token().await?;
49///     println!("访问令牌: {}", access_token);
50///
51///     Ok(())
52/// }
53/// ```
54///
55/// # 令牌管理
56///
57/// 客户端自动管理访问令牌的生命周期:
58///
59/// - 令牌过期前自动刷新
60/// - 多线程环境下的安全并发访问
61/// - 避免重复刷新(令牌锁机制)
62/// - 支持强制刷新选项
63///
64/// # 线程安全
65///
66/// `Client` 实现了 `Send` 和 `Sync`,可以在多线程环境中安全使用。
67#[derive(Debug, Clone)]
68pub struct Client {
69    inner: Arc<ClientInner>,
70    access_token: Arc<RwLock<AccessToken>>,
71    refreshing: Arc<AtomicBool>,
72    notify: Arc<Notify>,
73    use_stable_token: bool,
74}
75
76impl Client {
77    /// 创建新的微信小程序客户端
78    ///
79    /// # 参数
80    ///
81    /// - `app_id`: 小程序 AppID
82    /// - `secret`: 小程序 AppSecret
83    ///
84    /// # 返回
85    ///
86    /// 新的 `Client` 实例
87    ///
88    /// # 示例
89    ///
90    /// ```
91    /// use wechat_minapp::Client;
92    ///
93    /// let client = Client::new("wx1234567890abcdef", "your_app_secret_here");
94    /// ```
95    pub fn new(app_id: &str, secret: &str) -> Self {
96        let client = reqwest::Client::new();
97
98        Self {
99            inner: Arc::new(ClientInner {
100                app_id: app_id.into(),
101                secret: secret.into(),
102                client,
103            }),
104            access_token: Arc::new(RwLock::new(AccessToken {
105                access_token: "".to_string(),
106                expired_at: Utc::now(),
107                force_refresh: None,
108            })),
109            refreshing: Arc::new(AtomicBool::new(false)),
110            notify: Arc::new(Notify::new()),
111            use_stable_token: true,
112        }
113    }
114
115    pub fn with_non_stable(app_id: &str, secret: &str) -> Self {
116        let client = reqwest::Client::new();
117
118        Self {
119            inner: Arc::new(ClientInner {
120                app_id: app_id.into(),
121                secret: secret.into(),
122                client,
123            }),
124            access_token: Arc::new(RwLock::new(AccessToken {
125                access_token: "".to_string(),
126                expired_at: Utc::now(),
127                force_refresh: None,
128            })),
129            refreshing: Arc::new(AtomicBool::new(false)),
130            notify: Arc::new(Notify::new()),
131            use_stable_token: false,
132        }
133    }
134
135    pub(crate) fn request(&self) -> &reqwest::Client {
136        &self.inner.client
137    }
138
139    /// 用户登录凭证校验
140    ///
141    /// 通过微信前端获取的临时登录凭证 code,换取用户的唯一标识 OpenID 和会话密钥。
142    ///
143    /// # 参数
144    ///
145    /// - `code`: 微信前端通过 `wx.login()` 获取的临时登录凭证
146    ///
147    /// # 返回
148    ///
149    /// 成功返回 `Ok(Credential)`,包含用户身份信息
150    ///
151    /// # 错误
152    ///
153    /// - 网络错误
154    /// - 微信 API 返回错误
155    /// - 响应解析错误
156    ///
157    /// # 示例
158    ///
159    /// ```no_run
160    /// use wechat_minapp::Client;
161    ///
162    /// #[tokio::main]
163    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
164    ///     let client = Client::new("app_id", "secret");
165    ///     let code = "0816abc123def456";
166    ///     let credential = client.login(code).await?;
167    ///
168    ///     println!("用户OpenID: {}", credential.open_id());
169    ///     println!("会话密钥: {}", credential.session_key());
170    ///     
171    ///     Ok(())
172    /// }
173    /// ```
174    ///
175    /// # API 文档
176    ///
177    /// [微信官方文档 - code2Session](https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/code2Session.html)
178    #[instrument(skip(self, code))]
179    pub async fn login(&self, code: &str) -> Result<Credential> {
180        debug!("code: {}", code);
181
182        let mut map: HashMap<&str, &str> = HashMap::new();
183
184        map.insert("appid", &self.inner.app_id);
185        map.insert("secret", &self.inner.secret);
186        map.insert("js_code", code);
187        map.insert("grant_type", "authorization_code");
188
189        let response = self
190            .inner
191            .client
192            .get(constants::AUTHENTICATION_END_POINT)
193            .query(&map)
194            .send()
195            .await?;
196
197        debug!("authentication response: {:#?}", response);
198
199        if response.status().is_success() {
200            let response = response.json::<Response<CredentialBuilder>>().await?;
201
202            let credential = response.extract()?.build();
203
204            debug!("credential: {:#?}", credential);
205
206            Ok(credential)
207        } else {
208            Err(InternalServer(response.text().await?))
209        }
210    }
211
212    pub async fn token(&self) -> Result<String> {
213        if self.use_stable_token {
214            self.stable_access_token(None).await
215        } else {
216            self.access_token().await
217        }
218    }
219
220    /// 获取访问令牌
221    ///
222    /// 获取用于调用微信小程序接口的访问令牌。如果当前令牌已过期或即将过期,会自动刷新。
223    ///
224    /// # 返回
225    ///
226    /// 成功返回 `Ok(String)`,包含有效的访问令牌
227    ///
228    /// # 错误
229    ///
230    /// - 网络错误
231    /// - 微信 API 返回错误
232    /// - 令牌刷新失败
233    ///
234    /// # 示例
235    ///
236    /// ```no_run
237    /// use wechat_minapp::Client;
238    ///
239    /// #[tokio::main]
240    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
241    ///     let client = Client::new("app_id", "secret");
242    ///     let access_token = client.access_token().await?;
243    ///     
244    ///     println!("访问令牌: {}", access_token);
245    ///     Ok(())
246    /// }
247    /// ```
248    ///
249    /// # 注意
250    ///
251    /// - 令牌有效期为 2 小时
252    /// - 客户端会自动管理令牌刷新,无需手动处理
253    /// - 多线程环境下安全
254    pub async fn access_token(&self) -> Result<String> {
255        // 第一次检查:快速路径
256        {
257            let guard = self.access_token.read().await;
258            if !is_token_expired(&guard) {
259                return Ok(guard.access_token.clone());
260            }
261        }
262
263        // 使用CAS竞争刷新权
264        if self
265            .refreshing
266            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
267            .is_ok()
268        {
269            // 获得刷新权
270            match self.refresh_access_token().await {
271                Ok(token) => {
272                    self.refreshing.store(false, Ordering::Release);
273                    self.notify.notify_waiters();
274                    Ok(token)
275                }
276                Err(e) => {
277                    self.refreshing.store(false, Ordering::Release);
278                    self.notify.notify_waiters();
279                    Err(e)
280                }
281            }
282        } else {
283            // 等待其他线程刷新完成
284            self.notify.notified().await;
285            // 刷新完成后重新读取
286            let guard = self.access_token.read().await;
287            Ok(guard.access_token.clone())
288        }
289    }
290
291    async fn refresh_access_token(&self) -> Result<String> {
292        let mut guard = self.access_token.write().await;
293
294        if !is_token_expired(&guard) {
295            debug!("token already refreshed by another thread");
296            return Ok(guard.access_token.clone());
297        }
298
299        debug!("performing network request to refresh token");
300
301        let builder = get_access_token(
302            self.inner.client.clone(),
303            &self.inner.app_id,
304            &self.inner.secret,
305        )
306        .await?;
307
308        guard.access_token = builder.access_token.clone();
309        guard.expired_at = builder.expired_at;
310
311        debug!("fresh access token: {:#?}", guard);
312
313        Ok(guard.access_token.clone())
314    }
315
316    /// 获取稳定版访问令牌
317    ///
318    /// 获取稳定版的访问令牌,相比普通令牌有更长的有效期和更好的稳定性。
319    ///
320    /// # 参数
321    ///
322    /// - `force_refresh`: 是否强制刷新令牌
323    ///   - `Some(true)`: 强制从微信服务器获取最新令牌
324    ///   - `Some(false)` 或 `None`: 仅在令牌过期时刷新
325    ///
326    /// # 返回
327    ///
328    /// 成功返回 `Ok(String)`,包含有效的稳定版访问令牌
329    ///
330    /// # 错误
331    ///
332    /// - 网络错误
333    /// - 微信 API 返回错误
334    /// - 令牌刷新失败
335    ///
336    /// # 示例
337    ///
338    /// ```no_run
339    /// use wechat_minapp::Client;
340    ///
341    /// #[tokio::main]
342    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
343    ///     let client = Client::new("app_id", "secret");
344    ///     
345    ///     // 仅在过期时刷新
346    ///     let token1 = client.stable_access_token(None).await?;
347    ///     
348    ///     // 强制刷新
349    ///     let token2 = client.stable_access_token(true).await?;
350    ///     
351    ///     Ok(())
352    /// }
353    /// ```
354    ///
355    /// # 注意
356    ///
357    /// - 稳定版令牌有效期更长,推荐在生产环境使用
358    /// - 强制刷新会忽略本地缓存,直接请求新令牌
359    pub async fn stable_access_token(
360        &self,
361        force_refresh: impl Into<Option<bool>> + Clone + Send,
362    ) -> Result<String> {
363        // 第一次检查:快速路径
364        {
365            let guard = self.access_token.read().await;
366            if !is_token_expired(&guard) {
367                return Ok(guard.access_token.clone());
368            }
369        }
370
371        // 使用CAS竞争刷新权
372        if self
373            .refreshing
374            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
375            .is_ok()
376        {
377            // 获得刷新权
378            match self.refresh_stable_access_token(force_refresh).await {
379                Ok(token) => {
380                    self.refreshing.store(false, Ordering::Release);
381                    self.notify.notify_waiters();
382                    Ok(token)
383                }
384                Err(e) => {
385                    self.refreshing.store(false, Ordering::Release);
386                    self.notify.notify_waiters();
387                    Err(e)
388                }
389            }
390        } else {
391            // 等待其他线程刷新完成
392            self.notify.notified().await;
393            // 刷新完成后重新读取
394            let guard = self.access_token.read().await;
395            Ok(guard.access_token.clone())
396        }
397    }
398
399    async fn refresh_stable_access_token(
400        &self,
401        force_refresh: impl Into<Option<bool>> + Clone + Send,
402    ) -> Result<String> {
403        // 1. Acquire the write lock. This blocks if another thread won CAS but is refreshing.
404        let mut guard = self.access_token.write().await;
405
406        // 2. Double-check expiration under the write lock (CRITICAL)
407        // If another CAS-winner refreshed the token while we were waiting for the write lock,
408        // we return the new token without performing a new network call.
409        if !is_token_expired(&guard) {
410            // Token is now fresh, return it
411            debug!("token already refreshed by another thread");
412            return Ok(guard.access_token.clone());
413        }
414
415        // 3. Perform the network request since the token is still stale
416        debug!("performing network request to refresh token");
417
418        let builder = get_stable_access_token(
419            self.inner.client.clone(),
420            &self.inner.app_id,
421            &self.inner.secret,
422            force_refresh,
423        )
424        .await?;
425
426        // 4. Update the token
427        guard.access_token = builder.access_token.clone();
428        guard.expired_at = builder.expired_at;
429
430        debug!("fresh access token: {:#?}", guard);
431
432        // Return the newly fetched token (cloned here for consistency)
433        Ok(guard.access_token.clone())
434    }
435}
436
437#[derive(Debug)]
438struct ClientInner {
439    app_id: String,
440    secret: String,
441    client: reqwest::Client,
442}
443
444/// 检查令牌是否过期
445///
446/// 添加安全边界,在令牌过期前5分钟就认为需要刷新
447fn is_token_expired(token: &AccessToken) -> bool {
448    // 添加安全边界,提前刷新
449    let now = Utc::now();
450    token.expired_at.signed_duration_since(now) < Duration::minutes(5)
451}