1use crate::{
7 cache,
8 error::{CommonResponse, SdkError, SdkResult},
9};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::{Arc, RwLock};
13use std::time::Duration;
14
15#[async_trait]
18pub trait AccessTokenProvider: Sync + Send + Sized + Clone {
19 async fn get_access_token(&self) -> SdkResult<AccessToken>;
21}
22
23#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
25pub struct AccessToken {
26 pub access_token: String,
27 pub expires_in: i32,
28}
29
30impl From<cache::Item<String>> for AccessToken {
31 fn from(c: cache::Item<String>) -> Self {
32 AccessToken {
33 access_token: c.object,
34 expires_in: 0,
35 }
36 }
37}
38
39impl From<AccessToken> for cache::Item<String> {
40 fn from(t: AccessToken) -> Self {
41 let duration = Duration::from_secs((t.expires_in - 5) as u64);
42 cache::Item::new(t.access_token, Some(duration))
43 }
44}
45
46#[derive(Clone)]
48pub struct TokenClient {
49 app_id: String,
50 app_secret: String,
51 cache_token: Arc<RwLock<Option<cache::Item<String>>>>,
52}
53
54impl TokenClient {
55 pub fn new(app_id: String, app_secret: String) -> Self {
56 TokenClient {
57 app_id,
58 app_secret,
59 cache_token: Arc::new(RwLock::new(None)),
60 }
61 }
62
63 fn get_cache_token(&self) -> Option<AccessToken> {
64 let locked = self.cache_token.read().unwrap();
65 match &*locked {
66 Some(i) if !i.expired() => Some(i.clone().into()),
67 _ => None,
68 }
69 }
70
71 fn set_cache_token(&self, token: AccessToken) {
72 let mut locked = self.cache_token.write().unwrap();
73 *locked = Some(token.into())
74 }
75}
76
77#[async_trait]
78impl AccessTokenProvider for TokenClient {
79 async fn get_access_token(&self) -> SdkResult<AccessToken> {
80 let url = format!(
81 "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid={}&secret={}",
82 self.app_id.clone(),
83 self.app_secret.clone()
84 );
85 let cache_token = self.get_cache_token();
86 match cache_token {
87 Some(token) => Ok(token),
88 None => {
89 let msg = reqwest::get(&url)
90 .await?
91 .json::<CommonResponse<AccessToken>>()
92 .await?;
93
94 match msg {
95 CommonResponse::Ok(at) => {
96 self.set_cache_token(at.clone());
97 Ok(at)
98 }
99 CommonResponse::Err(e) => Err(SdkError::AccessTokenError(e)),
100 }
101 }
102 }
103 }
104}
105#[cfg(test)]
106mod tests {
107 use std::time::SystemTime;
108
109 use tokio::time::sleep;
110
111 use crate::{
112 access_token::AccessTokenProvider, cache, error::CommonResponse, AccessToken, TokenClient,
113 };
114
115 #[test]
116 fn test() {
117 let input = r#"{"access_token":"ACCESS_TOKEN","expires_in":7200}"#;
118 let expected = CommonResponse::Ok(AccessToken {
119 access_token: "ACCESS_TOKEN".to_string(),
120 expires_in: 7200,
121 });
122 assert_eq!(expected, serde_json::from_str(input).unwrap());
123
124 let input = r#"{"errcode":40013,"errmsg":"invalid appid"}"#;
125 let expected = CommonResponse::<AccessToken>::Err(crate::error::CommonError {
126 errcode: 40013,
127 errmsg: "invalid appid".to_string(),
128 });
129 assert_eq!(expected, serde_json::from_str(input).unwrap());
130 }
131
132 #[tokio::test]
133 async fn test_get_from_cache() {
134 use std::time::Duration;
135
136 let token_client = TokenClient {
137 app_id: "app_id".to_owned(),
138 app_secret: "secret".to_owned(),
139 cache_token: std::sync::Arc::new(std::sync::RwLock::new(Some(cache::Item::new(
140 "ACCESS_TOKEN".to_owned(),
141 Some(Duration::from_secs(2)),
142 )))),
143 };
144 sleep(Duration::new(1, 0)).await;
145 let res = token_client.get_access_token().await.unwrap();
146 let token = res.access_token;
147 let new_t = token_client.get_access_token().await.unwrap();
148 assert_eq!(
149 new_t,
150 AccessToken {
151 access_token: token,
152 expires_in: 0
153 }
154 );
155 }
156}