open_wechat/
credential.rs

1use std::{
2    collections::HashMap,
3    sync::{
4        atomic::{AtomicBool, Ordering},
5        Arc,
6    },
7};
8
9use aes::{
10    cipher::{block_padding::Pkcs7, generic_array::GenericArray, BlockDecryptMut, KeyIvInit},
11    Aes128,
12};
13use async_trait::async_trait;
14use base64::{engine::general_purpose::STANDARD, Engine};
15use cbc::Decryptor;
16use chrono::{DateTime, Duration, Utc};
17use hex::encode;
18use hmac::{Hmac, Mac};
19use serde::{Deserialize, Deserializer, Serialize};
20use serde_json::from_slice;
21use sha2::Sha256;
22use tokio::sync::{Notify, RwLock};
23use tracing::{event, instrument, Level};
24
25use crate::{
26    client::Client,
27    error::Error::InternalServer,
28    response::Response,
29    user::{User, UserBuilder},
30    Result,
31};
32
33type Aes128CbcDec = Decryptor<Aes128>;
34
35#[derive(Serialize, Deserialize, Clone)]
36pub struct Credential {
37    open_id: String,
38    session_key: String,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    union_id: Option<String>,
41}
42
43impl Credential {
44    pub fn open_id(&self) -> &str {
45        &self.open_id
46    }
47
48    pub fn session_key(&self) -> &str {
49        &self.session_key
50    }
51
52    pub fn union_id(&self) -> Option<&str> {
53        self.union_id.as_deref()
54    }
55
56    /// 解密用户数据,使用的是 AES-128-CBC 算法,数据采用PKCS#7填充。
57    /// https://developers.weixin.qq.com/miniprogram/dev/framework/open-ability/signature.html
58    /// ```rust
59    /// use axum::{extract::State, response::IntoResponse, Json};
60    /// use open_wechat::{client::Client, Result};
61    /// use serde::Deserialize;
62    ///
63    /// #[derive(Deserialize, Default)]
64    /// pub(crate) struct EncryptedPayload {
65    ///     code: String,
66    ///     encrypted_data: String,
67    ///     iv: String,
68    /// }
69    ///
70    /// pub(crate) async fn decrypt(
71    ///     State(client): State<Client>,
72    ///     Json(payload): Json<EncryptedPayload>,
73    /// ) -> Result<impl IntoResponse> {
74    ///     let credential = client.login(&payload.code).await?;
75    ///
76    ///     let user = credential.decrypt(&payload.encrypted_data, &payload.iv)?;
77    ///
78    ///     Ok(())
79    /// }
80    /// ```
81    #[instrument(skip(self, encrypted_data, iv))]
82    pub fn decrypt(&self, encrypted_data: &str, iv: &str) -> Result<User> {
83        event!(Level::DEBUG, "encrypted_data: {}", encrypted_data);
84        event!(Level::DEBUG, "iv: {}", iv);
85
86        let key = STANDARD.decode(self.session_key.as_bytes())?;
87        let iv = STANDARD.decode(iv.as_bytes())?;
88
89        let decryptor = Aes128CbcDec::new(
90            &GenericArray::clone_from_slice(&key),
91            &GenericArray::clone_from_slice(&iv),
92        );
93
94        let encrypted_data = STANDARD.decode(encrypted_data.as_bytes())?;
95
96        let buffer = decryptor.decrypt_padded_vec_mut::<Pkcs7>(&encrypted_data)?;
97
98        let builder = from_slice::<UserBuilder>(&buffer)?;
99
100        event!(Level::DEBUG, "user builder: {:#?}", builder);
101
102        Ok(builder.build())
103    }
104}
105
106impl std::fmt::Debug for Credential {
107    // 为了安全,不打印 session_key
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        f.debug_struct("Credential")
110            .field("open_id", &self.open_id)
111            .field("session_key", &"********")
112            .field("union_id", &self.union_id)
113            .finish()
114    }
115}
116
117#[derive(Deserialize)]
118pub(crate) struct CredentialBuilder {
119    #[serde(rename = "openid")]
120    open_id: String,
121    session_key: String,
122    #[serde(rename = "unionid")]
123    union_id: Option<String>,
124}
125
126impl CredentialBuilder {
127    pub(crate) fn build(self) -> Credential {
128        Credential {
129            open_id: self.open_id,
130            session_key: self.session_key,
131            union_id: self.union_id,
132        }
133    }
134}
135
136impl std::fmt::Debug for CredentialBuilder {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("CredentialBuilder")
139            .field("open_id", &self.open_id)
140            .field("session_key", &"********")
141            .field("union_id", &self.union_id)
142            .finish()
143    }
144}
145
146#[derive(Clone)]
147pub struct AccessToken {
148    access_token: String,
149    expired_at: DateTime<Utc>,
150}
151
152impl std::fmt::Debug for AccessToken {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("AccessToken")
155            .field("access_token", &"********")
156            .field("expired_at", &self.expired_at)
157            .finish()
158    }
159}
160
161#[derive(Clone)]
162pub struct StableAccessToken {
163    access_token: String,
164    expired_at: DateTime<Utc>,
165    force_refresh: Option<bool>,
166}
167
168impl std::fmt::Debug for StableAccessToken {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        f.debug_struct("StableAccessToken")
171            .field("access_token", &"********")
172            .field("expired_at", &self.expired_at)
173            .field("force_refresh", &self.force_refresh)
174            .finish()
175    }
176}
177
178#[derive(Debug, Clone)]
179pub struct GenericAccessToken<T = AccessToken> {
180    inner: Arc<RwLock<T>>,
181    refreshing: Arc<AtomicBool>,
182    notify: Arc<Notify>,
183    client: Client,
184}
185
186#[async_trait]
187pub trait GetAccessToken {
188    async fn new(client: Client) -> Result<Self>
189    where
190        Self: Sized;
191
192    async fn access_token(&self) -> Result<String>;
193}
194
195#[async_trait]
196impl GetAccessToken for GenericAccessToken<AccessToken> {
197    /// ```ignore
198    /// use open_wechat::{
199    ///     client::Client,
200    ///     credential::{GenericAccessToken, GetAccessToken}
201    /// };
202    ///
203    /// #[tokio::main]
204    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
205    ///     let app_id = "your app id";
206    ///     let app_secret = "your app secret";
207    ///    
208    ///     let client = Client::new(app_id, app_secret);
209    ///     
210    ///     let access_token = GenericAccessToken::new(client.clone()).await?;
211    ///     
212    ///     Ok(())
213    /// }
214    /// ```
215    async fn new(client: Client) -> Result<Self> {
216        let builder = client.get_access_token().await?;
217
218        Ok(Self {
219            inner: Arc::new(RwLock::new(AccessToken {
220                access_token: builder.access_token,
221                expired_at: builder.expired_at,
222            })),
223            refreshing: Arc::new(AtomicBool::new(false)),
224            notify: Arc::new(Notify::new()),
225            client,
226        })
227    }
228
229    async fn access_token(&self) -> Result<String> {
230        event!(Level::DEBUG, "read access token guard");
231
232        let guard = self.inner.read().await;
233
234        if guard.expired_at <= Utc::now() {
235            event!(Level::DEBUG, "expired at: {}", guard.expired_at);
236
237            if self.refreshing.load(Ordering::Acquire) {
238                event!(Level::DEBUG, "refreshing");
239
240                self.notify.notified().await;
241            } else {
242                event!(Level::DEBUG, "prepare to fresh");
243
244                self.refreshing.store(true, Ordering::Release);
245
246                drop(guard);
247
248                event!(Level::DEBUG, "write access token guard");
249
250                let mut guard = self.inner.write().await;
251
252                let builder = self.client.get_access_token().await?;
253
254                guard.access_token = builder.access_token;
255                guard.expired_at = builder.expired_at;
256
257                self.refreshing.store(false, Ordering::Release);
258
259                self.notify.notify_waiters();
260
261                event!(Level::DEBUG, "fresh access token: {:#?}", guard);
262
263                return Ok(guard.access_token.clone());
264            }
265        }
266
267        event!(Level::DEBUG, "access token not expired");
268
269        Ok(guard.access_token.clone())
270    }
271}
272
273#[async_trait]
274pub trait GetStableAccessToken {
275    async fn new(
276        client: Client,
277        force_refresh: impl Into<Option<bool>> + Clone + Send,
278    ) -> Result<Self>
279    where
280        Self: Sized;
281
282    async fn access_token(&self) -> Result<String>;
283
284    async fn set_force_refresh(&self, force_refresh: bool) -> Result<()>;
285}
286
287#[async_trait]
288impl GetStableAccessToken for GenericAccessToken<StableAccessToken> {
289    /// ```ignore
290    /// use open_wechat::{
291    ///     client::Client,
292    ///     credential::{GenericAccessToken, GetStableAccessToken}
293    /// };
294    ///
295    /// #[tokio::main]
296    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
297    ///     let app_id = "your app id";
298    ///     let app_secret = "your app secret";
299    ///
300    ///     let client = Client::new(app_id, app_secret);
301    ///
302    ///     let stable_access_token = GenericAccessToken::new(client.clone(), None).await?;
303    ///
304    ///     Ok(())
305    /// }
306    /// ```
307    async fn new(
308        client: Client,
309        force_refresh: impl Into<Option<bool>> + Clone + Send,
310    ) -> Result<Self> {
311        let builder = client
312            .get_stable_access_token(force_refresh.clone())
313            .await?;
314
315        Ok(Self {
316            inner: Arc::new(RwLock::new(StableAccessToken {
317                access_token: builder.access_token,
318                expired_at: builder.expired_at,
319                force_refresh: force_refresh.into(),
320            })),
321            refreshing: Arc::new(AtomicBool::new(false)),
322            notify: Arc::new(Notify::new()),
323            client,
324        })
325    }
326
327    async fn access_token(&self) -> Result<String> {
328        event!(Level::DEBUG, "read stable access token guard");
329
330        let guard = self.inner.read().await;
331
332        if guard.expired_at <= Utc::now() {
333            event!(Level::DEBUG, "expired at: {}", guard.expired_at);
334
335            if self.refreshing.load(Ordering::Acquire) {
336                event!(Level::DEBUG, "refreshing");
337
338                self.notify.notified().await;
339            } else {
340                event!(Level::DEBUG, "prepare to fresh");
341
342                self.refreshing.store(true, Ordering::Release);
343
344                drop(guard);
345
346                event!(Level::DEBUG, "write stable access token guard");
347
348                let mut guard = self.inner.write().await;
349
350                let builder = self
351                    .client
352                    .get_stable_access_token(guard.force_refresh)
353                    .await?;
354
355                guard.access_token = builder.access_token;
356                guard.expired_at = builder.expired_at;
357
358                self.refreshing.store(false, Ordering::Release);
359
360                self.notify.notify_waiters();
361
362                event!(Level::DEBUG, "fresh stable access token: {:#?}", guard);
363
364                return Ok(guard.access_token.clone());
365            }
366        }
367
368        event!(Level::DEBUG, "stable access token not expired");
369
370        Ok(guard.access_token.clone())
371    }
372
373    async fn set_force_refresh(&self, force_refresh: bool) -> Result<()> {
374        let mut guard = self.inner.write().await;
375
376        guard.force_refresh = Some(force_refresh);
377
378        Ok(())
379    }
380}
381
382#[derive(Deserialize)]
383pub(crate) struct AccessTokenBuilder {
384    access_token: String,
385    #[serde(
386        deserialize_with = "AccessTokenBuilder::deserialize_expired_at",
387        rename = "expires_in"
388    )]
389    expired_at: DateTime<Utc>,
390}
391
392impl AccessTokenBuilder {
393    fn deserialize_expired_at<'de, D>(
394        deserializer: D,
395    ) -> std::result::Result<DateTime<Utc>, D::Error>
396    where
397        D: Deserializer<'de>,
398    {
399        let seconds = Duration::seconds(i64::deserialize(deserializer)?);
400
401        Ok(Utc::now() + seconds)
402    }
403}
404
405impl std::fmt::Debug for AccessTokenBuilder {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        f.debug_struct("AccessTokenBuilder")
408            .field("access_token", &"********")
409            .field("expired_at", &self.expired_at)
410            .finish()
411    }
412}
413
414#[async_trait]
415pub trait CheckSessionKey {
416    const CHECK_SESSION_KEY: &'static str = "https://api.weixin.qq.com/wxa/checksession";
417
418    /// 检查登录态是否过期
419    /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/checkSessionKey.html
420    async fn check_session_key(&self, session_key: &str, open_id: &str) -> Result<()>;
421}
422
423type HmacSha256 = Hmac<Sha256>;
424
425#[async_trait]
426impl CheckSessionKey for GenericAccessToken<AccessToken> {
427    #[instrument(skip(self, session_key, open_id))]
428    async fn check_session_key(&self, session_key: &str, open_id: &str) -> Result<()> {
429        let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
430        mac.update(b"");
431        let hasher = mac.finalize();
432        let signature = encode(hasher.into_bytes());
433
434        let mut map = HashMap::new();
435
436        map.insert("openid", open_id.to_string());
437        map.insert("signature", signature);
438        map.insert("sig_method", "hmac_sha256".into());
439
440        let response = self
441            .client
442            .request()
443            .get(Self::CHECK_SESSION_KEY)
444            .query(&map)
445            .send()
446            .await?;
447
448        event!(Level::DEBUG, "response: {:#?}", response);
449
450        if response.status().is_success() {
451            let response = response.json::<Response<()>>().await?;
452
453            response.extract()
454        } else {
455            Err(crate::error::Error::InternalServer(response.text().await?))
456        }
457    }
458}
459
460#[async_trait]
461impl CheckSessionKey for GenericAccessToken<StableAccessToken> {
462    #[instrument(skip(self, session_key, open_id))]
463    async fn check_session_key(&self, session_key: &str, open_id: &str) -> Result<()> {
464        let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
465        mac.update(b"");
466        let hasher = mac.finalize();
467        let signature = encode(hasher.into_bytes());
468
469        let mut map = HashMap::new();
470
471        map.insert("openid", open_id.to_string());
472        map.insert("signature", signature);
473        map.insert("sig_method", "hmac_sha256".into());
474
475        let response = self
476            .client
477            .request()
478            .get(Self::CHECK_SESSION_KEY)
479            .query(&map)
480            .send()
481            .await?;
482
483        event!(Level::DEBUG, "response: {:#?}", response);
484
485        if response.status().is_success() {
486            let response = response.json::<Response<()>>().await?;
487
488            response.extract()
489        } else {
490            Err(InternalServer(response.text().await?))
491        }
492    }
493}
494
495#[async_trait]
496pub trait ResetSessionKey {
497    const RESET_SESSION_KEY: &'static str = "https://api.weixin.qq.com/wxa/resetusersessionkey";
498
499    /// 重置用户的 session_key
500    /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/user-login/ResetUserSessionKey.html
501    async fn reset_session_key(&self, session_key: &str, open_id: &str) -> Result<Credential>;
502}
503
504#[async_trait]
505impl ResetSessionKey for GenericAccessToken<AccessToken> {
506    #[instrument(skip(self, open_id))]
507    async fn reset_session_key(&self, session_key: &str, open_id: &str) -> Result<Credential> {
508        let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
509        mac.update(b"");
510        let hasher = mac.finalize();
511        let signature = encode(hasher.into_bytes());
512
513        let mut map = HashMap::new();
514
515        map.insert("access_token", self.access_token().await?);
516        map.insert("openid", open_id.to_string());
517        map.insert("signature", signature);
518        map.insert("sig_method", "hmac_sha256".into());
519
520        let response = self
521            .client
522            .request()
523            .get(Self::RESET_SESSION_KEY)
524            .query(&map)
525            .send()
526            .await?;
527
528        event!(Level::DEBUG, "response: {:#?}", response);
529
530        if response.status().is_success() {
531            let response = response.json::<Response<CredentialBuilder>>().await?;
532
533            let credential = response.extract()?.build();
534
535            event!(Level::DEBUG, "credential: {:#?}", credential);
536
537            Ok(credential)
538        } else {
539            Err(InternalServer(response.text().await?))
540        }
541    }
542}
543
544#[async_trait]
545impl ResetSessionKey for GenericAccessToken<StableAccessToken> {
546    #[instrument(skip(self, open_id))]
547    async fn reset_session_key(&self, session_key: &str, open_id: &str) -> Result<Credential> {
548        let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
549        mac.update(b"");
550        let hasher = mac.finalize();
551        let signature = encode(hasher.into_bytes());
552
553        let mut map = HashMap::new();
554
555        map.insert("access_token", self.access_token().await?);
556        map.insert("openid", open_id.to_string());
557        map.insert("signature", signature);
558        map.insert("sig_method", "hmac_sha256".into());
559
560        let response = self
561            .client
562            .request()
563            .get(Self::RESET_SESSION_KEY)
564            .query(&map)
565            .send()
566            .await?;
567
568        event!(Level::DEBUG, "response: {:#?}", response);
569
570        if response.status().is_success() {
571            let response = response.json::<Response<CredentialBuilder>>().await?;
572
573            let credential = response.extract()?.build();
574
575            event!(Level::DEBUG, "credential: {:#?}", credential);
576
577            Ok(credential)
578        } else {
579            Err(InternalServer(response.text().await?))
580        }
581    }
582}