mod error;
pub mod message;
use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use std::error::Error as StdError;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[derive(Debug)]
struct AccessToken {
value: Option<String>,
timestamp: SystemTime,
lifetime: Duration,
}
impl AccessToken {
pub fn value(&self) -> Option<&String> {
self.value.as_ref()
}
pub fn update(&mut self, token: &str, timestamp: SystemTime, lifetime: Duration) {
self.value = Some(token.to_owned());
self.timestamp = timestamp;
self.lifetime = lifetime;
}
pub fn expired(&self) -> bool {
match SystemTime::now().duration_since(self.timestamp) {
Ok(duration) => duration >= self.lifetime,
Err(_) => false,
}
}
pub fn expire_in(&self, n: u64) -> bool {
match SystemTime::now().duration_since(self.timestamp) {
Ok(duration_from_last_update) => {
duration_from_last_update + Duration::from_secs(n) > self.lifetime
}
Err(_) => false,
}
}
pub fn timestamp(&self) -> SystemTime {
self.timestamp
}
}
impl Default for AccessToken {
fn default() -> Self {
Self {
value: None,
timestamp: UNIX_EPOCH,
lifetime: Duration::from_secs(7200),
}
}
}
#[derive(Debug)]
pub struct WecomAgent {
corp_id: String,
secret: String,
access_token: RwLock<AccessToken>,
client: reqwest::Client,
}
impl WecomAgent {
pub fn new(corp_id: &str, secret: &str) -> Self {
Self {
corp_id: String::from(corp_id),
secret: String::from(secret),
access_token: RwLock::new(AccessToken::default()),
client: reqwest::Client::new(),
}
}
pub async fn update_token(
&self,
backoff_seconds: u64,
) -> Result<(), Box<dyn StdError + Send + Sync>> {
let mut access_token = self.access_token.write().await;
let seconds_since_last_update = SystemTime::now()
.duration_since(access_token.timestamp())?
.as_secs();
if seconds_since_last_update < backoff_seconds {
return Err(Box::new(error::Error::new(
-9,
format!("Access token更新过于频繁。上次更新于{seconds_since_last_update}秒前。"),
)));
}
let url = format!(
"https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={}&corpsecret={}",
self.corp_id, self.secret,
);
let response = reqwest::get(url)
.await?
.json::<AccessTokenResponse>()
.await?;
if response.errcode != 0 {
return Err(Box::<error::Error>::new(error::Error::new(
response.errcode,
response.errmsg,
)));
};
access_token.update(
&response.access_token,
SystemTime::now(),
Duration::from_secs(response.expires_in),
);
Ok(())
}
pub async fn send<T>(&self, msg: T) -> Result<MsgSendResponse, Box<dyn StdError + Send + Sync>>
where
T: Serialize,
{
let token_should_update: bool = {
let access_token = self.access_token.read().await;
access_token.value().is_none() || access_token.expire_in(300) || access_token.expired()
};
if token_should_update {
warn!("Token invalid. Updating...");
self.update_token(10).await?;
info!("Token updated");
}
let url = {
let access_token = self.access_token.read().await;
format!(
"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={}",
access_token
.value()
.expect("Access token should not be None.")
)
};
debug!("Sending [try 1]...");
let mut response: MsgSendResponse = self
.client
.post(&url)
.json(&msg)
.send()
.await?
.json::<MsgSendResponse>()
.await?;
if response.error_code() == 40014 {
warn!("Token invalid. Updating...");
self.update_token(10).await?;
debug!("Sending [try 2]...");
response = self
.client
.post(&url)
.json(&msg)
.send()
.await?
.json::<MsgSendResponse>()
.await?;
};
debug!("Sending [Done]");
Ok(response)
}
}
#[derive(Deserialize)]
pub struct MsgSendResponse {
errcode: i64,
errmsg: String,
#[allow(dead_code)]
invaliduser: Option<String>,
#[allow(dead_code)]
invalidparty: Option<String>,
#[allow(dead_code)]
invalidtag: Option<String>,
#[allow(dead_code)]
unlicenseduser: Option<String>,
#[allow(dead_code)]
msgid: Option<String>,
#[allow(dead_code)]
response_code: Option<String>,
}
impl MsgSendResponse {
pub fn is_error(&self) -> bool {
self.errcode != 0
}
pub fn error_code(&self) -> i64 {
self.errcode
}
pub fn error_msg(&self) -> &str {
&self.errmsg
}
}
#[derive(Deserialize)]
struct AccessTokenResponse {
errcode: i64,
errmsg: String,
access_token: String,
expires_in: u64,
}