1use std::{
2 collections::HashMap,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7};
8
9use aes::{
10 cipher::{block_padding::Pkcs7, generic_array::GenericArray, BlockDecryptMut, KeyIvInit},
11 Aes128,
12};
13use async_trait::async_trait;
14use base64::{engine::general_purpose::STANDARD, Engine};
15use cbc::Decryptor;
16use chrono::{DateTime, Duration, Utc};
17use hex::encode;
18use hmac::{Hmac, Mac};
19use serde::{Deserialize, Deserializer, Serialize};
20use serde_json::from_slice;
21use sha2::Sha256;
22use tokio::sync::{Notify, RwLock};
23use tracing::{event, instrument, Level};
24
25use crate::{
26 client::Client,
27 error::Error::InternalServer,
28 response::Response,
29 user::{User, UserBuilder},
30 Result,
31};
32
33type Aes128CbcDec = Decryptor<Aes128>;
34
35#[derive(Serialize, Deserialize, Clone)]
36pub struct Credential {
37 open_id: String,
38 session_key: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 union_id: Option<String>,
41}
42
43impl Credential {
44 pub fn open_id(&self) -> &str {
45 &self.open_id
46 }
47
48 pub fn session_key(&self) -> &str {
49 &self.session_key
50 }
51
52 pub fn union_id(&self) -> Option<&str> {
53 self.union_id.as_deref()
54 }
55
56 #[instrument(skip(self, encrypted_data, iv))]
82 pub fn decrypt(&self, encrypted_data: &str, iv: &str) -> Result<User> {
83 event!(Level::DEBUG, "encrypted_data: {}", encrypted_data);
84 event!(Level::DEBUG, "iv: {}", iv);
85
86 let key = STANDARD.decode(self.session_key.as_bytes())?;
87 let iv = STANDARD.decode(iv.as_bytes())?;
88
89 let decryptor = Aes128CbcDec::new(
90 &GenericArray::clone_from_slice(&key),
91 &GenericArray::clone_from_slice(&iv),
92 );
93
94 let encrypted_data = STANDARD.decode(encrypted_data.as_bytes())?;
95
96 let buffer = decryptor.decrypt_padded_vec_mut::<Pkcs7>(&encrypted_data)?;
97
98 let builder = from_slice::<UserBuilder>(&buffer)?;
99
100 event!(Level::DEBUG, "user builder: {:#?}", builder);
101
102 Ok(builder.build())
103 }
104}
105
106impl std::fmt::Debug for Credential {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 f.debug_struct("Credential")
110 .field("open_id", &self.open_id)
111 .field("session_key", &"********")
112 .field("union_id", &self.union_id)
113 .finish()
114 }
115}
116
117#[derive(Deserialize)]
118pub(crate) struct CredentialBuilder {
119 #[serde(rename = "openid")]
120 open_id: String,
121 session_key: String,
122 #[serde(rename = "unionid")]
123 union_id: Option<String>,
124}
125
126impl CredentialBuilder {
127 pub(crate) fn build(self) -> Credential {
128 Credential {
129 open_id: self.open_id,
130 session_key: self.session_key,
131 union_id: self.union_id,
132 }
133 }
134}
135
136impl std::fmt::Debug for CredentialBuilder {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("CredentialBuilder")
139 .field("open_id", &self.open_id)
140 .field("session_key", &"********")
141 .field("union_id", &self.union_id)
142 .finish()
143 }
144}
145
146#[derive(Clone)]
147pub struct AccessToken {
148 access_token: String,
149 expired_at: DateTime<Utc>,
150}
151
152impl std::fmt::Debug for AccessToken {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("AccessToken")
155 .field("access_token", &"********")
156 .field("expired_at", &self.expired_at)
157 .finish()
158 }
159}
160
161#[derive(Clone)]
162pub struct StableAccessToken {
163 access_token: String,
164 expired_at: DateTime<Utc>,
165 force_refresh: Option<bool>,
166}
167
168impl std::fmt::Debug for StableAccessToken {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct("StableAccessToken")
171 .field("access_token", &"********")
172 .field("expired_at", &self.expired_at)
173 .field("force_refresh", &self.force_refresh)
174 .finish()
175 }
176}
177
178#[derive(Debug, Clone)]
179pub struct GenericAccessToken<T = AccessToken> {
180 inner: Arc<RwLock<T>>,
181 refreshing: Arc<AtomicBool>,
182 notify: Arc<Notify>,
183 client: Client,
184}
185
186#[async_trait]
187pub trait GetAccessToken {
188 async fn new(client: Client) -> Result<Self>
189 where
190 Self: Sized;
191
192 async fn access_token(&self) -> Result<String>;
193}
194
195#[async_trait]
196impl GetAccessToken for GenericAccessToken<AccessToken> {
197 async fn new(client: Client) -> Result<Self> {
216 let builder = client.get_access_token().await?;
217
218 Ok(Self {
219 inner: Arc::new(RwLock::new(AccessToken {
220 access_token: builder.access_token,
221 expired_at: builder.expired_at,
222 })),
223 refreshing: Arc::new(AtomicBool::new(false)),
224 notify: Arc::new(Notify::new()),
225 client,
226 })
227 }
228
229 async fn access_token(&self) -> Result<String> {
230 event!(Level::DEBUG, "read access token guard");
231
232 let guard = self.inner.read().await;
233
234 if guard.expired_at <= Utc::now() {
235 event!(Level::DEBUG, "expired at: {}", guard.expired_at);
236
237 if self.refreshing.load(Ordering::Acquire) {
238 event!(Level::DEBUG, "refreshing");
239
240 self.notify.notified().await;
241 } else {
242 event!(Level::DEBUG, "prepare to fresh");
243
244 self.refreshing.store(true, Ordering::Release);
245
246 drop(guard);
247
248 event!(Level::DEBUG, "write access token guard");
249
250 let mut guard = self.inner.write().await;
251
252 let builder = self.client.get_access_token().await?;
253
254 guard.access_token = builder.access_token;
255 guard.expired_at = builder.expired_at;
256
257 self.refreshing.store(false, Ordering::Release);
258
259 self.notify.notify_waiters();
260
261 event!(Level::DEBUG, "fresh access token: {:#?}", guard);
262
263 return Ok(guard.access_token.clone());
264 }
265 }
266
267 event!(Level::DEBUG, "access token not expired");
268
269 Ok(guard.access_token.clone())
270 }
271}
272
273#[async_trait]
274pub trait GetStableAccessToken {
275 async fn new(
276 client: Client,
277 force_refresh: impl Into<Option<bool>> + Clone + Send,
278 ) -> Result<Self>
279 where
280 Self: Sized;
281
282 async fn access_token(&self) -> Result<String>;
283
284 async fn set_force_refresh(&self, force_refresh: bool) -> Result<()>;
285}
286
287#[async_trait]
288impl GetStableAccessToken for GenericAccessToken<StableAccessToken> {
289 async fn new(
308 client: Client,
309 force_refresh: impl Into<Option<bool>> + Clone + Send,
310 ) -> Result<Self> {
311 let builder = client
312 .get_stable_access_token(force_refresh.clone())
313 .await?;
314
315 Ok(Self {
316 inner: Arc::new(RwLock::new(StableAccessToken {
317 access_token: builder.access_token,
318 expired_at: builder.expired_at,
319 force_refresh: force_refresh.into(),
320 })),
321 refreshing: Arc::new(AtomicBool::new(false)),
322 notify: Arc::new(Notify::new()),
323 client,
324 })
325 }
326
327 async fn access_token(&self) -> Result<String> {
328 event!(Level::DEBUG, "read stable access token guard");
329
330 let guard = self.inner.read().await;
331
332 if guard.expired_at <= Utc::now() {
333 event!(Level::DEBUG, "expired at: {}", guard.expired_at);
334
335 if self.refreshing.load(Ordering::Acquire) {
336 event!(Level::DEBUG, "refreshing");
337
338 self.notify.notified().await;
339 } else {
340 event!(Level::DEBUG, "prepare to fresh");
341
342 self.refreshing.store(true, Ordering::Release);
343
344 drop(guard);
345
346 event!(Level::DEBUG, "write stable access token guard");
347
348 let mut guard = self.inner.write().await;
349
350 let builder = self
351 .client
352 .get_stable_access_token(guard.force_refresh)
353 .await?;
354
355 guard.access_token = builder.access_token;
356 guard.expired_at = builder.expired_at;
357
358 self.refreshing.store(false, Ordering::Release);
359
360 self.notify.notify_waiters();
361
362 event!(Level::DEBUG, "fresh stable access token: {:#?}", guard);
363
364 return Ok(guard.access_token.clone());
365 }
366 }
367
368 event!(Level::DEBUG, "stable access token not expired");
369
370 Ok(guard.access_token.clone())
371 }
372
373 async fn set_force_refresh(&self, force_refresh: bool) -> Result<()> {
374 let mut guard = self.inner.write().await;
375
376 guard.force_refresh = Some(force_refresh);
377
378 Ok(())
379 }
380}
381
382#[derive(Deserialize)]
383pub(crate) struct AccessTokenBuilder {
384 access_token: String,
385 #[serde(
386 deserialize_with = "AccessTokenBuilder::deserialize_expired_at",
387 rename = "expires_in"
388 )]
389 expired_at: DateTime<Utc>,
390}
391
392impl AccessTokenBuilder {
393 fn deserialize_expired_at<'de, D>(
394 deserializer: D,
395 ) -> std::result::Result<DateTime<Utc>, D::Error>
396 where
397 D: Deserializer<'de>,
398 {
399 let seconds = Duration::seconds(i64::deserialize(deserializer)?);
400
401 Ok(Utc::now() + seconds)
402 }
403}
404
405impl std::fmt::Debug for AccessTokenBuilder {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 f.debug_struct("AccessTokenBuilder")
408 .field("access_token", &"********")
409 .field("expired_at", &self.expired_at)
410 .finish()
411 }
412}
413
414#[async_trait]
415pub trait CheckSessionKey {
416 const CHECK_SESSION_KEY: &'static str = "https://api.weixin.qq.com/wxa/checksession";
417
418 async fn check_session_key(&self, session_key: &str, open_id: &str) -> Result<()>;
421}
422
423type HmacSha256 = Hmac<Sha256>;
424
425#[async_trait]
426impl CheckSessionKey for GenericAccessToken<AccessToken> {
427 #[instrument(skip(self, session_key, open_id))]
428 async fn check_session_key(&self, session_key: &str, open_id: &str) -> Result<()> {
429 let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
430 mac.update(b"");
431 let hasher = mac.finalize();
432 let signature = encode(hasher.into_bytes());
433
434 let mut map = HashMap::new();
435
436 map.insert("openid", open_id.to_string());
437 map.insert("signature", signature);
438 map.insert("sig_method", "hmac_sha256".into());
439
440 let response = self
441 .client
442 .request()
443 .get(Self::CHECK_SESSION_KEY)
444 .query(&map)
445 .send()
446 .await?;
447
448 event!(Level::DEBUG, "response: {:#?}", response);
449
450 if response.status().is_success() {
451 let response = response.json::<Response<()>>().await?;
452
453 response.extract()
454 } else {
455 Err(crate::error::Error::InternalServer(response.text().await?))
456 }
457 }
458}
459
460#[async_trait]
461impl CheckSessionKey for GenericAccessToken<StableAccessToken> {
462 #[instrument(skip(self, session_key, open_id))]
463 async fn check_session_key(&self, session_key: &str, open_id: &str) -> Result<()> {
464 let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
465 mac.update(b"");
466 let hasher = mac.finalize();
467 let signature = encode(hasher.into_bytes());
468
469 let mut map = HashMap::new();
470
471 map.insert("openid", open_id.to_string());
472 map.insert("signature", signature);
473 map.insert("sig_method", "hmac_sha256".into());
474
475 let response = self
476 .client
477 .request()
478 .get(Self::CHECK_SESSION_KEY)
479 .query(&map)
480 .send()
481 .await?;
482
483 event!(Level::DEBUG, "response: {:#?}", response);
484
485 if response.status().is_success() {
486 let response = response.json::<Response<()>>().await?;
487
488 response.extract()
489 } else {
490 Err(InternalServer(response.text().await?))
491 }
492 }
493}
494
495#[async_trait]
496pub trait ResetSessionKey {
497 const RESET_SESSION_KEY: &'static str = "https://api.weixin.qq.com/wxa/resetusersessionkey";
498
499 async fn reset_session_key(&self, session_key: &str, open_id: &str) -> Result<Credential>;
502}
503
504#[async_trait]
505impl ResetSessionKey for GenericAccessToken<AccessToken> {
506 #[instrument(skip(self, open_id))]
507 async fn reset_session_key(&self, session_key: &str, open_id: &str) -> Result<Credential> {
508 let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
509 mac.update(b"");
510 let hasher = mac.finalize();
511 let signature = encode(hasher.into_bytes());
512
513 let mut map = HashMap::new();
514
515 map.insert("access_token", self.access_token().await?);
516 map.insert("openid", open_id.to_string());
517 map.insert("signature", signature);
518 map.insert("sig_method", "hmac_sha256".into());
519
520 let response = self
521 .client
522 .request()
523 .get(Self::RESET_SESSION_KEY)
524 .query(&map)
525 .send()
526 .await?;
527
528 event!(Level::DEBUG, "response: {:#?}", response);
529
530 if response.status().is_success() {
531 let response = response.json::<Response<CredentialBuilder>>().await?;
532
533 let credential = response.extract()?.build();
534
535 event!(Level::DEBUG, "credential: {:#?}", credential);
536
537 Ok(credential)
538 } else {
539 Err(InternalServer(response.text().await?))
540 }
541 }
542}
543
544#[async_trait]
545impl ResetSessionKey for GenericAccessToken<StableAccessToken> {
546 #[instrument(skip(self, open_id))]
547 async fn reset_session_key(&self, session_key: &str, open_id: &str) -> Result<Credential> {
548 let mut mac = HmacSha256::new_from_slice(session_key.as_bytes())?;
549 mac.update(b"");
550 let hasher = mac.finalize();
551 let signature = encode(hasher.into_bytes());
552
553 let mut map = HashMap::new();
554
555 map.insert("access_token", self.access_token().await?);
556 map.insert("openid", open_id.to_string());
557 map.insert("signature", signature);
558 map.insert("sig_method", "hmac_sha256".into());
559
560 let response = self
561 .client
562 .request()
563 .get(Self::RESET_SESSION_KEY)
564 .query(&map)
565 .send()
566 .await?;
567
568 event!(Level::DEBUG, "response: {:#?}", response);
569
570 if response.status().is_success() {
571 let response = response.json::<Response<CredentialBuilder>>().await?;
572
573 let credential = response.extract()?.build();
574
575 event!(Level::DEBUG, "credential: {:#?}", credential);
576
577 Ok(credential)
578 } else {
579 Err(InternalServer(response.text().await?))
580 }
581 }
582}