#![allow(non_upper_case_globals)]
use crate::error::{BotError, Result};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::Mutex;
pub const TypeBearer: &str = "Bearer";
pub const TypeQQBot: &str = "QQBot";
const DEFAULT_EXPIRY_DELTA_MILLIS: u64 = 9_000;
const RAND_TIME_UPPER_LIMIT_MILLIS: u64 = 500;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct QQBotCredentials {
#[serde(alias = "appid", alias = "appId")]
pub app_id: String,
#[serde(alias = "secret", alias = "appSecret")]
pub app_secret: String,
}
pub type QQBotTokenSource = Token;
#[allow(non_snake_case)]
pub fn NewQQBotTokenSource(credentials: &QQBotCredentials) -> QQBotTokenSource {
Token::new(&credentials.app_id, &credentials.app_secret)
}
#[derive(Debug, Default)]
struct TokenState {
access_token: Option<String>,
expires_at: Option<u64>,
expires_in: Option<u64>,
}
fn default_state() -> Arc<Mutex<TokenState>> {
Arc::new(Mutex::new(TokenState::default()))
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Token {
app_id: String,
secret: String,
#[serde(skip, default = "default_state")]
state: Arc<Mutex<TokenState>>,
}
impl Token {
pub fn new(app_id: impl Into<String>, secret: impl Into<String>) -> Self {
Self {
app_id: app_id.into(),
secret: secret.into(),
state: default_state(),
}
}
pub fn app_id(&self) -> &str {
&self.app_id
}
pub fn secret(&self) -> &str {
&self.secret
}
#[allow(non_snake_case)]
pub fn GetAppID(&self) -> &str {
self.app_id()
}
pub async fn authorization_header(&self) -> Result<String> {
let access_token = self.access_token().await?;
Ok(format!("QQBot {access_token}"))
}
pub async fn bot_token(&self) -> Result<String> {
self.authorization_header().await
}
async fn ensure_valid_token(&self) -> Result<()> {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| BotError::internal("Failed to get current time"))?
.as_secs();
let is_valid = {
let state = self.state.lock().await;
state.access_token.is_some() && state.expires_at.is_some_and(|exp| current_time < exp)
};
if !is_valid {
self.refresh_access_token(current_time, false).await?;
}
Ok(())
}
async fn refresh_access_token(&self, current_time: u64, force: bool) -> Result<()> {
let mut state = self.state.lock().await;
if !force
&& state.access_token.is_some()
&& state.expires_at.is_some_and(|exp| current_time < exp)
{
return Ok(());
}
let client = reqwest::Client::new();
let request_body = serde_json::json!({
"appId": self.app_id,
"clientSecret": self.secret
});
let response = client
.post("https://bots.qq.com/app/getAppAccessToken")
.json(&request_body)
.timeout(std::time::Duration::from_secs(20))
.send()
.await
.map_err(|e| BotError::connection(format!("Failed to request access token: {e}")))?;
if !response.status().is_success() {
return Err(BotError::api(
response.status().as_u16() as u32,
format!(
"Token request failed: {}",
response.text().await.unwrap_or_default()
),
));
}
let token_response: serde_json::Value = response.json().await.map_err(BotError::Http)?;
let access_token = token_response
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| BotError::auth("No access_token in response"))?;
let expires_in = token_response
.get("expires_in")
.and_then(parse_expires_in)
.ok_or_else(|| BotError::auth("No expires_in in response"))?;
state.access_token = Some(access_token.to_string());
state.expires_at = Some(current_time + expires_in);
state.expires_in = Some(expires_in);
Ok(())
}
async fn force_refresh_access_token(&self) -> Result<()> {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| BotError::internal("Failed to get current time"))?
.as_secs();
self.refresh_access_token(current_time, true).await
}
async fn access_token(&self) -> Result<String> {
self.ensure_valid_token().await?;
self.state
.lock()
.await
.access_token
.clone()
.ok_or_else(|| BotError::auth("No valid access token available"))
}
async fn cached_expires_in(&self) -> Option<u64> {
self.state.lock().await.expires_in
}
#[cfg(test)]
pub(crate) async fn set_cached_access_token_for_test(&self, access_token: impl Into<String>) {
let mut state = self.state.lock().await;
state.access_token = Some(access_token.into());
state.expires_at = Some(u64::MAX);
}
pub fn validate(&self) -> Result<()> {
if self.app_id.is_empty() {
return Err(BotError::auth("App ID cannot be empty"));
}
if self.secret.is_empty() {
return Err(BotError::auth("Secret cannot be empty"));
}
Ok(())
}
pub fn from_env() -> Result<Self> {
let app_id = std::env::var("QQ_BOT_APP_ID")
.map_err(|_| BotError::config("QQ_BOT_APP_ID environment variable not found"))?;
let secret = std::env::var("QQ_BOT_SECRET")
.map_err(|_| BotError::config("QQ_BOT_SECRET environment variable not found"))?;
let token = Self::new(app_id, secret);
token.validate()?;
Ok(token)
}
pub fn safe_display(&self) -> String {
let masked_secret = if self.secret.len() > 8 {
format!(
"{}****{}",
&self.secret[..4],
&self.secret[self.secret.len() - 4..]
)
} else {
"****".to_string()
};
format!(
"Token {{ app_id: {}, secret: {} }}",
self.app_id, masked_secret
)
}
}
#[allow(non_snake_case)]
pub async fn StartRefreshAccessToken(
token_source: QQBotTokenSource,
) -> Result<tokio::task::JoinHandle<()>> {
token_source.ensure_valid_token().await?;
Ok(tokio::spawn(async move {
let mut consecutive_failures = 0;
loop {
let refresh_millis = if consecutive_failures > 0 {
if consecutive_failures > 10 {
panic!("get token failed continuously for more than ten times");
}
1_000
} else {
token_source
.cached_expires_in()
.await
.map(get_refresh_millis)
.unwrap_or(1_000)
};
tracing::debug!("refresh after {} milli sec", refresh_millis);
tokio::time::sleep(Duration::from_millis(refresh_millis)).await;
match token_source.force_refresh_access_token().await {
Ok(()) => consecutive_failures = 0,
Err(err) => {
consecutive_failures += 1;
tracing::error!("refresh access token failed: {}", err);
}
}
}
}))
}
fn parse_expires_in(value: &serde_json::Value) -> Option<u64> {
value
.as_u64()
.or_else(|| value.as_str().and_then(|value| value.parse().ok()))
}
fn get_refresh_millis(token_ttl_secs: u64) -> u64 {
let refresh_millis = token_ttl_secs.saturating_mul(1_000);
if refresh_millis < DEFAULT_EXPIRY_DELTA_MILLIS {
return refresh_millis;
}
let refresh_millis = refresh_millis - DEFAULT_EXPIRY_DELTA_MILLIS;
if refresh_millis > RAND_TIME_UPPER_LIMIT_MILLIS {
refresh_millis - jitter_millis(RAND_TIME_UPPER_LIMIT_MILLIS)
} else {
refresh_millis
}
}
fn jitter_millis(upper_bound: u64) -> u64 {
if upper_bound == 0 {
return 0;
}
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| u64::from(duration.subsec_nanos()) % upper_bound)
.unwrap_or_default()
}
impl fmt::Display for Token {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.safe_display())
}
}
impl PartialEq for Token {
fn eq(&self, other: &Self) -> bool {
self.app_id == other.app_id && self.secret == other.secret
}
}
impl Eq for Token {}
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Token")
.field("app_id", &self.app_id)
.field("secret", &"[REDACTED]")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_creation() {
let token = Token::new("123456", "secret123");
assert_eq!(token.app_id(), "123456");
assert_eq!(token.secret(), "secret123");
}
#[tokio::test]
async fn test_authorization_header() {
let token = Token::new("test", "secret");
let result = token.authorization_header().await;
assert!(
result.is_err(),
"Expected authorization_header to fail with invalid credentials"
);
}
#[tokio::test]
async fn test_bot_token() {
let token = Token::new("test", "secret");
let bot_token_result = token.bot_token().await;
let auth_header_result = token.authorization_header().await;
assert!(bot_token_result.is_err());
assert!(auth_header_result.is_err());
}
#[tokio::test]
async fn cloned_tokens_share_cached_access_token() {
let token = Token::new("123", "secret");
{
let mut state = token.state.lock().await;
state.access_token = Some("cached-token".to_string());
state.expires_at = Some(u64::MAX);
state.expires_in = Some(7200);
}
let cloned = token.clone();
assert_eq!(
cloned.authorization_header().await.unwrap(),
"QQBot cached-token"
);
assert_eq!(cloned.cached_expires_in().await, Some(7200));
}
#[test]
fn refresh_millis_matches_botgo_bounds() {
assert_eq!(get_refresh_millis(8), 8_000);
assert_eq!(get_refresh_millis(9), 0);
let refresh_millis = get_refresh_millis(10);
assert!((501..=1_000).contains(&refresh_millis));
let refresh_millis = get_refresh_millis(7200);
assert!((7_190_501..=7_191_000).contains(&refresh_millis));
}
#[test]
fn parse_expires_in_accepts_number_or_string() {
assert_eq!(parse_expires_in(&serde_json::json!("7200")), Some(7200));
assert_eq!(parse_expires_in(&serde_json::json!(7200)), Some(7200));
assert_eq!(parse_expires_in(&serde_json::json!("bad")), None);
}
#[test]
fn test_validation() {
let valid_token = Token::new("123", "secret");
assert!(valid_token.validate().is_ok());
let empty_app_id = Token::new("", "secret");
assert!(empty_app_id.validate().is_err());
let empty_secret = Token::new("123", "");
assert!(empty_secret.validate().is_err());
}
#[test]
fn test_safe_display() {
let token = Token::new("123456", "verylongsecret123");
let display = token.safe_display();
assert!(display.contains("123456"));
assert!(display.contains("very"));
assert!(display.contains("123"));
assert!(display.contains("****"));
assert!(!display.contains("longsecret"));
let short_token = Token::new("123", "short");
let short_display = short_token.safe_display();
assert!(short_display.contains("****"));
assert!(!short_display.contains("short"));
}
#[test]
fn test_debug_format() {
let token = Token::new("123456", "secret123");
let debug_str = format!("{:?}", token);
assert!(debug_str.contains("123456"));
assert!(debug_str.contains("[REDACTED]"));
assert!(!debug_str.contains("secret123"));
}
#[test]
fn test_botgo_token_source_alias() {
let credentials = QQBotCredentials {
app_id: "123456".to_string(),
app_secret: "secret123".to_string(),
};
let token = NewQQBotTokenSource(&credentials);
assert_eq!(token.app_id(), "123456");
assert_eq!(token.GetAppID(), "123456");
assert_eq!(token.secret(), "secret123");
assert_eq!(TypeQQBot, "QQBot");
assert_eq!(TypeBearer, "Bearer");
}
}