use crate::{
Result,
access_token::{AccessToken, get_access_token, get_stable_access_token},
constants,
credential::{Credential, CredentialBuilder},
error::Error::InternalServer,
response::Response,
};
use chrono::{Duration, Utc};
use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
};
use tokio::sync::{Notify, RwLock};
use tracing::{debug, instrument};
#[derive(Debug, Clone)]
pub struct Client {
inner: Arc<ClientInner>,
access_token: Arc<RwLock<AccessToken>>,
refreshing: Arc<AtomicBool>,
notify: Arc<Notify>,
}
impl Client {
pub fn new(app_id: &str, secret: &str) -> Self {
let client = reqwest::Client::new();
Self {
inner: Arc::new(ClientInner {
app_id: app_id.into(),
secret: secret.into(),
client,
}),
access_token: Arc::new(RwLock::new(AccessToken {
access_token: "".to_string(),
expired_at: Utc::now(),
force_refresh: None,
})),
refreshing: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
}
}
pub(crate) fn request(&self) -> &reqwest::Client {
&self.inner.client
}
#[instrument(skip(self, code))]
pub async fn login(&self, code: &str) -> Result<Credential> {
debug!("code: {}", code);
let mut map: HashMap<&str, &str> = HashMap::new();
map.insert("appid", &self.inner.app_id);
map.insert("secret", &self.inner.secret);
map.insert("js_code", code);
map.insert("grant_type", "authorization_code");
let response = self
.inner
.client
.get(constants::AUTHENTICATION_END_POINT)
.query(&map)
.send()
.await?;
debug!("authentication response: {:#?}", response);
if response.status().is_success() {
let response = response.json::<Response<CredentialBuilder>>().await?;
let credential = response.extract()?.build();
debug!("credential: {:#?}", credential);
Ok(credential)
} else {
Err(InternalServer(response.text().await?))
}
}
pub async fn access_token(&self) -> Result<String> {
{
let guard = self.access_token.read().await;
if !is_token_expired(&guard) {
return Ok(guard.access_token.clone());
}
}
if self
.refreshing
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
match self.refresh_access_token().await {
Ok(token) => {
self.refreshing.store(false, Ordering::Release);
self.notify.notify_waiters();
Ok(token)
}
Err(e) => {
self.refreshing.store(false, Ordering::Release);
self.notify.notify_waiters();
Err(e)
}
}
} else {
self.notify.notified().await;
let guard = self.access_token.read().await;
Ok(guard.access_token.clone())
}
}
async fn refresh_access_token(&self) -> Result<String> {
let mut guard = self.access_token.write().await;
if !is_token_expired(&guard) {
debug!("token already refreshed by another thread");
return Ok(guard.access_token.clone());
}
debug!("performing network request to refresh token");
let builder = get_access_token(
self.inner.client.clone(),
&self.inner.app_id,
&self.inner.secret,
)
.await?;
guard.access_token = builder.access_token.clone();
guard.expired_at = builder.expired_at;
debug!("fresh access token: {:#?}", guard);
Ok(guard.access_token.clone())
}
pub async fn stable_access_token(
&self,
force_refresh: impl Into<Option<bool>> + Clone + Send,
) -> Result<String> {
{
let guard = self.access_token.read().await;
if !is_token_expired(&guard) {
return Ok(guard.access_token.clone());
}
}
if self
.refreshing
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
match self.refresh_stable_access_token(force_refresh).await {
Ok(token) => {
self.refreshing.store(false, Ordering::Release);
self.notify.notify_waiters();
Ok(token)
}
Err(e) => {
self.refreshing.store(false, Ordering::Release);
self.notify.notify_waiters();
Err(e)
}
}
} else {
self.notify.notified().await;
let guard = self.access_token.read().await;
Ok(guard.access_token.clone())
}
}
async fn refresh_stable_access_token(
&self,
force_refresh: impl Into<Option<bool>> + Clone + Send,
) -> Result<String> {
let mut guard = self.access_token.write().await;
if !is_token_expired(&guard) {
debug!("token already refreshed by another thread");
return Ok(guard.access_token.clone());
}
debug!("performing network request to refresh token");
let builder = get_stable_access_token(
self.inner.client.clone(),
&self.inner.app_id,
&self.inner.secret,
force_refresh,
)
.await?;
guard.access_token = builder.access_token.clone();
guard.expired_at = builder.expired_at;
debug!("fresh access token: {:#?}", guard);
Ok(guard.access_token.clone())
}
}
#[derive(Debug)]
struct ClientInner {
app_id: String,
secret: String,
client: reqwest::Client,
}
fn is_token_expired(token: &AccessToken) -> bool {
let now = Utc::now();
token.expired_at.signed_duration_since(now) < Duration::minutes(5)
}