1use crate::{
2 Result,
3 access_token::{AccessToken, get_access_token, get_stable_access_token},
4 constants,
5 credential::{Credential, CredentialBuilder},
6 error::Error::InternalServer,
7 response::Response,
8};
9use chrono::{Duration, Utc};
10use std::{
11 collections::HashMap,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, Ordering},
15 },
16};
17use tokio::sync::{Notify, RwLock};
18use tracing::{debug, instrument};
19
20#[derive(Debug, Clone)]
22pub struct Client {
23 inner: Arc<ClientInner>,
24 access_token: Arc<RwLock<AccessToken>>,
25 refreshing: Arc<AtomicBool>,
26 notify: Arc<Notify>,
27}
28
29impl Client {
30 pub fn new(app_id: &str, secret: &str) -> Self {
44 let client = reqwest::Client::new();
45
46 Self {
47 inner: Arc::new(ClientInner {
48 app_id: app_id.into(),
49 secret: secret.into(),
50 client,
51 }),
52 access_token: Arc::new(RwLock::new(AccessToken {
53 access_token: "".to_string(),
54 expired_at: Utc::now(),
55 force_refresh: None,
56 })),
57 refreshing: Arc::new(AtomicBool::new(false)),
58 notify: Arc::new(Notify::new()),
59 }
60 }
61
62 pub(crate) fn request(&self) -> &reqwest::Client {
63 &self.inner.client
64 }
65
66 #[instrument(skip(self, code))]
89 pub async fn login(&self, code: &str) -> Result<Credential> {
90 debug!("code: {}", code);
91
92 let mut map: HashMap<&str, &str> = HashMap::new();
93
94 map.insert("appid", &self.inner.app_id);
95 map.insert("secret", &self.inner.secret);
96 map.insert("js_code", code);
97 map.insert("grant_type", "authorization_code");
98
99 let response = self
100 .inner
101 .client
102 .get(constants::AUTHENTICATION_END_POINT)
103 .query(&map)
104 .send()
105 .await?;
106
107 debug!("authentication response: {:#?}", response);
108
109 if response.status().is_success() {
110 let response = response.json::<Response<CredentialBuilder>>().await?;
111
112 let credential = response.extract()?.build();
113
114 debug!("credential: {:#?}", credential);
115
116 Ok(credential)
117 } else {
118 Err(InternalServer(response.text().await?))
119 }
120 }
121
122 pub async fn access_token(&self) -> Result<String> {
138 {
140 let guard = self.access_token.read().await;
141 if !is_token_expired(&guard) {
142 return Ok(guard.access_token.clone());
143 }
144 }
145
146 if self
148 .refreshing
149 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
150 .is_ok()
151 {
152 match self.refresh_access_token().await {
154 Ok(token) => {
155 self.refreshing.store(false, Ordering::Release);
156 self.notify.notify_waiters();
157 Ok(token)
158 }
159 Err(e) => {
160 self.refreshing.store(false, Ordering::Release);
161 self.notify.notify_waiters();
162 Err(e)
163 }
164 }
165 } else {
166 self.notify.notified().await;
168 let guard = self.access_token.read().await;
170 Ok(guard.access_token.clone())
171 }
172 }
173
174 async fn refresh_access_token(&self) -> Result<String> {
175 let mut guard = self.access_token.write().await;
176
177 if !is_token_expired(&guard) {
178 debug!("token already refreshed by another thread");
179 return Ok(guard.access_token.clone());
180 }
181
182 debug!("performing network request to refresh token");
183
184 let builder = get_access_token(
185 self.inner.client.clone(),
186 &self.inner.app_id,
187 &self.inner.secret,
188 )
189 .await?;
190
191 guard.access_token = builder.access_token.clone();
192 guard.expired_at = builder.expired_at;
193
194 debug!("fresh access token: {:#?}", guard);
195
196 Ok(guard.access_token.clone())
197 }
198
199 pub async fn stable_access_token(
216 &self,
217 force_refresh: impl Into<Option<bool>> + Clone + Send,
218 ) -> Result<String> {
219 {
221 let guard = self.access_token.read().await;
222 if !is_token_expired(&guard) {
223 return Ok(guard.access_token.clone());
224 }
225 }
226
227 if self
229 .refreshing
230 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
231 .is_ok()
232 {
233 match self.refresh_stable_access_token(force_refresh).await {
235 Ok(token) => {
236 self.refreshing.store(false, Ordering::Release);
237 self.notify.notify_waiters();
238 Ok(token)
239 }
240 Err(e) => {
241 self.refreshing.store(false, Ordering::Release);
242 self.notify.notify_waiters();
243 Err(e)
244 }
245 }
246 } else {
247 self.notify.notified().await;
249 let guard = self.access_token.read().await;
251 Ok(guard.access_token.clone())
252 }
253 }
254
255 async fn refresh_stable_access_token(
256 &self,
257 force_refresh: impl Into<Option<bool>> + Clone + Send,
258 ) -> Result<String> {
259 let mut guard = self.access_token.write().await;
261
262 if !is_token_expired(&guard) {
266 debug!("token already refreshed by another thread");
268 return Ok(guard.access_token.clone());
269 }
270
271 debug!("performing network request to refresh token");
273
274 let builder = get_stable_access_token(
275 self.inner.client.clone(),
276 &self.inner.app_id,
277 &self.inner.secret,
278 force_refresh,
279 )
280 .await?;
281
282 guard.access_token = builder.access_token.clone();
284 guard.expired_at = builder.expired_at;
285
286 debug!("fresh access token: {:#?}", guard);
287
288 Ok(guard.access_token.clone())
290 }
291}
292
293#[derive(Debug)]
294struct ClientInner {
295 app_id: String,
296 secret: String,
297 client: reqwest::Client,
298}
299
300fn is_token_expired(token: &AccessToken) -> bool {
301 let now = Utc::now();
303 token.expired_at.signed_duration_since(now) < Duration::minutes(5)
304}