wx_sdk/
access_token.rs

1//! The [access_token](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Get_access_token.html) releated module.
2//!
3//! The purpose for this module is providing a [AccessTokenProvider] trait with a method [get_access_token][AccessTokenProvider], return a struct [AccessToken].
4//!
5//! We also provide a default [TokenClient](by [reqwest](https://crates.io/crates/reqwest) crate) for the users didn't want to implement one themselves.
6use crate::{
7    cache,
8    error::{CommonResponse, SdkError, SdkResult},
9};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::{Arc, RwLock};
13use std::time::Duration;
14
15/// [WxSdk][crate::wechat::WxSdk] take a struct which impl [AccessTokenProvider].
16/// You need to use [async_trait](https://crates.io/crates/async-trait) to implement [AccessTokenProvider].
17#[async_trait]
18pub trait AccessTokenProvider: Sync + Send + Sized + Clone {
19    /// This trait derive [async_trait](https://crates.io/crates/async-trait), it return a [std::future] of [AccessToken].
20    async fn get_access_token(&self) -> SdkResult<AccessToken>;
21}
22
23/// Access token with a expires time.
24#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
25pub struct AccessToken {
26    pub access_token: String,
27    pub expires_in: i32,
28}
29
30impl From<cache::Item<String>> for AccessToken {
31    fn from(c: cache::Item<String>) -> Self {
32        AccessToken {
33            access_token: c.object,
34            expires_in: 0,
35        }
36    }
37}
38
39impl From<AccessToken> for cache::Item<String> {
40    fn from(t: AccessToken) -> Self {
41        let duration = Duration::from_secs((t.expires_in - 5) as u64);
42        cache::Item::new(t.access_token, Some(duration))
43    }
44}
45
46/// That's a default token client implement [AccessTokenProvider].
47#[derive(Clone)]
48pub struct TokenClient {
49    app_id: String,
50    app_secret: String,
51    cache_token: Arc<RwLock<Option<cache::Item<String>>>>,
52}
53
54impl TokenClient {
55    pub fn new(app_id: String, app_secret: String) -> Self {
56        TokenClient {
57            app_id,
58            app_secret,
59            cache_token: Arc::new(RwLock::new(None)),
60        }
61    }
62
63    fn get_cache_token(&self) -> Option<AccessToken> {
64        let locked = self.cache_token.read().unwrap();
65        match &*locked {
66            Some(i) if !i.expired() => Some(i.clone().into()),
67            _ => None,
68        }
69    }
70
71    fn set_cache_token(&self, token: AccessToken) {
72        let mut locked = self.cache_token.write().unwrap();
73        *locked = Some(token.into())
74    }
75}
76
77#[async_trait]
78impl AccessTokenProvider for TokenClient {
79    async fn get_access_token(&self) -> SdkResult<AccessToken> {
80        let url = format!(
81            "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid={}&secret={}",
82            self.app_id.clone(),
83            self.app_secret.clone()
84        );
85        let cache_token = self.get_cache_token();
86        match cache_token {
87            Some(token) => Ok(token),
88            None => {
89                let msg = reqwest::get(&url)
90                    .await?
91                    .json::<CommonResponse<AccessToken>>()
92                    .await?;
93
94                match msg {
95                    CommonResponse::Ok(at) => {
96                        self.set_cache_token(at.clone());
97                        Ok(at)
98                    }
99                    CommonResponse::Err(e) => Err(SdkError::AccessTokenError(e)),
100                }
101            }
102        }
103    }
104}
105#[cfg(test)]
106mod tests {
107    use std::time::SystemTime;
108
109    use tokio::time::sleep;
110
111    use crate::{
112        access_token::AccessTokenProvider, cache, error::CommonResponse, AccessToken, TokenClient,
113    };
114
115    #[test]
116    fn test() {
117        let input = r#"{"access_token":"ACCESS_TOKEN","expires_in":7200}"#;
118        let expected = CommonResponse::Ok(AccessToken {
119            access_token: "ACCESS_TOKEN".to_string(),
120            expires_in: 7200,
121        });
122        assert_eq!(expected, serde_json::from_str(input).unwrap());
123
124        let input = r#"{"errcode":40013,"errmsg":"invalid appid"}"#;
125        let expected = CommonResponse::<AccessToken>::Err(crate::error::CommonError {
126            errcode: 40013,
127            errmsg: "invalid appid".to_string(),
128        });
129        assert_eq!(expected, serde_json::from_str(input).unwrap());
130    }
131
132    #[tokio::test]
133    async fn test_get_from_cache() {
134        use std::time::Duration;
135
136        let token_client = TokenClient {
137            app_id: "app_id".to_owned(),
138            app_secret: "secret".to_owned(),
139            cache_token: std::sync::Arc::new(std::sync::RwLock::new(Some(cache::Item::new(
140                "ACCESS_TOKEN".to_owned(),
141                Some(Duration::from_secs(2)),
142            )))),
143        };
144        sleep(Duration::new(1, 0)).await;
145        let res = token_client.get_access_token().await.unwrap();
146        let token = res.access_token;
147        let new_t = token_client.get_access_token().await.unwrap();
148        assert_eq!(
149            new_t,
150            AccessToken {
151                access_token: token,
152                expires_in: 0
153            }
154        );
155    }
156}