use crate::{NoError, SaasContext, Wechat, WechatEncryptError, WechatError, WechatToken};
use async_trait::async_trait;
use log::{debug, info};
use smol::{Task, Timer};
use std::marker::{Send, Sync};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
pub struct TokenResolveGard {
finish_flag: Arc<AtomicBool>,
}
impl TokenResolveGard {
pub fn new(finish_flag: Arc<AtomicBool>) -> Self {
TokenResolveGard { finish_flag }
}
}
impl Drop for TokenResolveGard {
fn drop(&mut self) {
self.finish_flag.store(true, Ordering::SeqCst);
}
}
#[async_trait]
pub trait TokenProvider: Send + Sync {
async fn get_token(
&self,
wechat: &Wechat,
context: &SaasContext,
) -> Result<Option<WechatToken>, WechatError>;
async fn set_token(
&self,
wechat: &Wechat,
context: &SaasContext,
token: Option<WechatToken>,
) -> Result<(), WechatError>;
async fn lock_token_resolver(
&self,
wechat: &Wechat,
context: &SaasContext,
) -> Result<TokenResolveGard, WechatError>;
async fn unlock_token_resolver(
&self,
wechat: &Wechat,
context: &SaasContext,
gard: TokenResolveGard,
) -> Result<(), WechatError>;
}
pub mod memory {
use super::*;
use async_mutex::Mutex;
use std::collections::HashMap;
use std::sync::RwLock;
pub struct MemoryTokenProvider {
token_list: RwLock<HashMap<u64, WechatToken>>,
rw_mutex: Arc<Mutex<()>>,
}
impl MemoryTokenProvider {
pub fn new() -> Self {
MemoryTokenProvider {
token_list: RwLock::new(HashMap::new()),
rw_mutex: Arc::new(Mutex::new(())),
}
}
}
impl Drop for MemoryTokenProvider {
fn drop(&mut self) {
todo!()
}
}
#[allow(unused_variables)]
#[async_trait]
impl TokenProvider for MemoryTokenProvider {
async fn get_token(
&self,
wechat: &Wechat,
context: &SaasContext,
) -> Result<Option<WechatToken>, WechatError> {
use chrono::Utc;
let list = self.token_list.read().unwrap();
let token = (*list).get(&context.id);
Ok(match token {
None => None,
Some(token) => {
if token.expire_at < Utc::now() {
None
} else {
Some(token.clone())
}
}
})
}
async fn set_token(
&self,
wechat: &Wechat,
context: &SaasContext,
token: Option<WechatToken>,
) -> Result<(), WechatError> {
let mut list = self.token_list.write().unwrap();
if let Some(token) = token {
info!("set token:{:?}, token:{:?}", context, token);
(*list).insert(context.id, token);
} else {
info!("remove token:{:?}", context);
(*list).remove(&context.id);
}
Ok(())
}
async fn lock_token_resolver(
&self,
wechat: &Wechat,
context: &SaasContext,
) -> Result<TokenResolveGard, WechatError> {
use std::sync::Arc;
let context = context.clone();
let rw_mutex = self.rw_mutex.clone();
let lock = *rw_mutex.lock().await;
let flag = Arc::new(AtomicBool::new(false));
let gard = TokenResolveGard::new(flag.clone());
Task::spawn(async move {
use std::time::Duration;
while !flag.load(Ordering::SeqCst) {
Timer::after(Duration::from_secs(10)).await;
}
drop(lock);
debug!("token lock released, {:?}", context);
})
.detach();
return Ok(gard);
}
async fn unlock_token_resolver(
&self,
wechat: &Wechat,
context: &SaasContext,
gard: TokenResolveGard,
) -> Result<(), WechatError> {
debug!("begin release token lock:{:?}", context);
drop(gard);
Ok(())
}
}
}
pub mod reids {
use super::*;
use bb8_redis::{
redis::{cmd, AsyncCommands},
RedisPool,
};
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::time::Duration;
fn get_key(context: &SaasContext, key: &str) -> String {
format!("wechat::{}::{}", context.id, key)
}
fn get_token_key(context: &SaasContext) -> String {
get_key(context, "token")
}
fn get_lock_key(context: &SaasContext) -> String {
get_key(context, "lock")
}
pub struct RedisTokenProvider {
redis_pool: RedisPool,
seq: AtomicI64,
}
impl RedisTokenProvider {
pub fn new(pool: RedisPool) -> Self {
RedisTokenProvider {
redis_pool: pool,
seq: AtomicI64::new(1),
}
}
}
#[allow(unused_variables)]
#[async_trait]
impl TokenProvider for RedisTokenProvider {
async fn get_token(
&self,
wechat: &Wechat,
context: &SaasContext,
) -> Result<Option<WechatToken>, WechatError> {
let mut conn = self.redis_pool.get().await?;
let conn = conn.as_mut().unwrap();
let key = get_token_key(context);
let reply: Option<String> = cmd("GET").arg(key).query_async(conn).await?;
let reply = match reply {
Some(reply) => reply,
None => return Ok(None),
};
let token: WechatToken = serde_json::from_str(&reply)?;
Ok(Some(token))
}
async fn set_token(
&self,
wechat: &Wechat,
context: &SaasContext,
token: Option<WechatToken>,
) -> Result<(), WechatError> {
let mut conn = self.redis_pool.get().await?;
let conn = conn.as_mut().unwrap();
let key = get_token_key(context);
if let Some(token) = token {
let value = serde_json::to_string(&token)?;
let ttl = token.remain_ttl();
if ttl <= 100 {
info!("token将要过期:{}, ttl:{}, {:?}", context.id, ttl, token);
cmd("DEL").arg(key).query_async(conn).await?;
} else {
debug!("设置token:{}, ttl:{}, {:?}", context.id, ttl, token);
cmd("SET")
.arg(key)
.arg(value)
.arg("EX")
.arg(ttl - 100)
.query_async(conn)
.await?;
}
} else {
debug!("删除token:{}", context.id);
cmd("DEL").arg(key).query_async(conn).await?;
}
Ok(())
}
async fn lock_token_resolver(
&self,
wechat: &Wechat,
context: &SaasContext,
) -> Result<TokenResolveGard, WechatError> {
use std::sync::Arc;
let context = context.clone();
let redis_pool = self.redis_pool.clone();
let key = get_lock_key(&context);
let seq = self.seq.fetch_add(1, Ordering::SeqCst);
let seq = format!("{}", seq);
for i in 0..=100i32 {
let mut conn = redis_pool.get().await?;
let conn = conn.as_mut().unwrap();
let result: Option<String> = cmd("SET")
.arg(key.clone())
.arg(seq.clone())
.arg("NX")
.arg("EX")
.arg(30)
.query_async(conn)
.await?;
let result = match result {
Some(result) => result,
None => {
break;
}
};
if "OK" == result.as_str() {
debug!("token加锁成功:{:?}", context);
break;
}
Timer::after(Duration::from_secs(1)).await;
if i > 99 {
return Err(WechatError::EncryptError {
source: WechatEncryptError::ApiRequestError {
msg: format!("无法获得token, 超时. {:?}", context),
source: Box::new(NoError),
},
});
}
}
let flag = Arc::new(AtomicBool::new(false));
let gard = TokenResolveGard::new(flag.clone());
Task::spawn(async move {
let r = Task::spawn(async move {
while !flag.load(Ordering::SeqCst) {
Timer::after(Duration::from_secs(1)).await;
let mut conn = redis_pool.get().await?;
let conn = conn.as_mut().unwrap();
cmd("EXPIRE")
.arg(key.clone())
.arg(30i32)
.query_async(conn)
.await?;
}
debug!("token解锁, {:?}", context);
let mut conn = redis_pool.get().await?;
let conn = conn.as_mut().unwrap();
cmd("DEL").arg(key).query_async(conn).await?;
Ok::<(), WechatError>(())
})
.await;
if let Err(e) = r {
info!("{:?}", e);
}
})
.detach();
return Ok(gard);
}
async fn unlock_token_resolver(
&self,
wechat: &Wechat,
context: &SaasContext,
gard: TokenResolveGard,
) -> Result<(), WechatError> {
debug!("begin release token lock:{:?}", context);
drop(gard);
Ok(())
}
}
}