#![allow(non_snake_case, non_upper_case_globals)]
use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::{LazyLock, RwLock};
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{debug, error, info};
use crate::error::{
BotError, CodeConnCloseCantIdentify, CodeConnCloseCantResume, err_session_limit,
};
use crate::gateway::Gateway;
use crate::intents::Intents;
use crate::models::api::{GatewayResponse, WebsocketAP};
use crate::models::gateway::GatewayEvent;
use crate::token::Token;
pub static CanNotResumeErrSet: LazyLock<HashSet<i32>> =
LazyLock::new(|| HashSet::from([CodeConnCloseCantResume]));
pub static CanNotIdentifyErrSet: LazyLock<HashSet<i32>> =
LazyLock::new(|| HashSet::from([CodeConnCloseCantIdentify]));
pub type SessionFuture = Pin<Box<dyn Future<Output = (Session, crate::Result<()>)> + Send>>;
pub type SessionConnectFn =
dyn Fn(Session, mpsc::UnboundedSender<GatewayEvent>) -> SessionFuture + Send + Sync;
pub type BoxedSessionManager = Box<dyn SessionManager>;
pub type SessionManagerFactory = dyn Fn() -> BoxedSessionManager + Send + Sync;
static DEFAULT_SESSION_MANAGER: LazyLock<RwLock<Box<SessionManagerFactory>>> =
LazyLock::new(|| RwLock::new(Box::new(|| Box::new(ChanManager::new()))));
#[derive(Debug, Clone)]
pub struct Session {
pub id: String,
pub url: String,
pub token: Token,
pub intent: Intents,
pub last_seq: u64,
pub shards: crate::models::api::ShardConfig,
pub app_id: Option<String>,
}
impl Session {
pub fn new(
url: impl Into<String>,
token: Token,
intent: Intents,
shard_id: u32,
shard_count: u32,
) -> Self {
Self {
id: String::new(),
url: url.into(),
token,
intent,
last_seq: 0,
shards: crate::models::api::ShardConfig {
shard_id,
shard_count,
},
app_id: None,
}
}
pub fn shard(&self) -> [u32; 2] {
[self.shards.shard_id, self.shards.shard_count]
}
pub fn from_app_id(app_id: impl Into<String>) -> Self {
Self {
id: String::new(),
url: String::new(),
token: Token::new("", ""),
intent: Intents::default(),
last_seq: 0,
shards: crate::models::api::ShardConfig {
shard_id: 0,
shard_count: 0,
},
app_id: Some(app_id.into()),
}
}
#[allow(non_snake_case)]
pub fn FromAppID(app_id: impl Into<String>) -> Self {
Self::from_app_id(app_id)
}
}
#[async_trait::async_trait]
pub trait SessionManager: Send + Sync {
async fn start(
&mut self,
ap_info: &WebsocketAP,
token: Token,
intents: Intents,
event_sender: mpsc::UnboundedSender<GatewayEvent>,
) -> crate::Result<()>;
#[allow(non_snake_case)]
async fn Start(
&mut self,
ap_info: &WebsocketAP,
token: Token,
intents: Intents,
event_sender: mpsc::UnboundedSender<GatewayEvent>,
) -> crate::Result<()> {
self.start(ap_info, token, intents, event_sender).await
}
}
#[derive(Clone)]
pub struct ChanManager {
session_sender: Option<mpsc::UnboundedSender<Session>>,
connect_fn: std::sync::Arc<SessionConnectFn>,
}
impl Default for ChanManager {
fn default() -> Self {
Self::new()
}
}
impl ChanManager {
pub fn new() -> Self {
Self::with_connect_fn(|session, event_sender| {
Box::pin(async move {
let mut next_session = session.clone();
let shard = session.shard();
let mut gateway =
Gateway::new(session.url, session.token, session.intent, Some(shard));
if !session.id.is_empty() {
gateway = gateway.with_resume_state(session.id, session.last_seq);
}
let result = gateway.connect_once(event_sender).await;
next_session.id = gateway.session_id().unwrap_or_default().to_string();
next_session.last_seq = gateway.last_sequence();
(next_session, result)
})
})
}
pub fn with_connect_fn<F>(connect_fn: F) -> Self
where
F: Fn(Session, mpsc::UnboundedSender<GatewayEvent>) -> SessionFuture
+ Send
+ Sync
+ 'static,
{
Self {
session_sender: None,
connect_fn: std::sync::Arc::new(connect_fn),
}
}
pub fn sessions(ap_info: &WebsocketAP, token: Token, intents: Intents) -> Vec<Session> {
(0..ap_info.shards)
.map(|shard_id| {
Session::new(
ap_info.url.clone(),
token.clone(),
intents,
shard_id,
ap_info.shards,
)
})
.collect()
}
async fn new_connect(
session_sender: mpsc::UnboundedSender<Session>,
connect_fn: std::sync::Arc<SessionConnectFn>,
event_sender: mpsc::UnboundedSender<GatewayEvent>,
session: Session,
) {
let (mut reconnect_session, result) = connect_fn(session, event_sender).await;
if let Err(err) = result {
error!("[ws/session/local] Listening err {}", err);
if CanNotResume(&err) {
reconnect_session.id.clear();
reconnect_session.last_seq = 0;
}
if CanNotIdentify(&err) {
error!("can not identify because server return {}", err);
return;
}
}
if let Err(err) = session_sender.send(reconnect_session) {
debug!("[ws/session/local] session queue closed: {}", err);
}
}
}
#[async_trait::async_trait]
impl SessionManager for ChanManager {
async fn start(
&mut self,
ap_info: &WebsocketAP,
token: Token,
intents: Intents,
event_sender: mpsc::UnboundedSender<GatewayEvent>,
) -> crate::Result<()> {
CheckSessionLimit(ap_info)?;
let start_interval = CalcInterval(ap_info.session_start_limit.max_concurrency);
info!(
"[ws/session/local] will start {} sessions and per session start interval is {:?}",
ap_info.shards, start_interval
);
let (tx, mut rx) = mpsc::unbounded_channel();
self.session_sender = Some(tx.clone());
for session in Self::sessions(ap_info, token, intents) {
tx.send(session)
.map_err(|err| BotError::session(format!("produce session failed: {err}")))?;
}
while let Some(session) = rx.recv().await {
tokio::time::sleep(start_interval).await;
tokio::spawn(Self::new_connect(
tx.clone(),
self.connect_fn.clone(),
event_sender.clone(),
session,
));
}
Ok(())
}
}
pub fn new_session_manager() -> BoxedSessionManager {
let factory = DEFAULT_SESSION_MANAGER
.read()
.expect("default session manager lock poisoned");
factory()
}
#[allow(non_snake_case)]
pub fn NewSessionManager() -> BoxedSessionManager {
new_session_manager()
}
pub fn set_session_manager_factory(
factory: impl Fn() -> BoxedSessionManager + Send + Sync + 'static,
) {
*DEFAULT_SESSION_MANAGER
.write()
.expect("default session manager lock poisoned") = Box::new(factory);
}
pub fn set_session_manager(manager: impl SessionManager + Clone + 'static) {
set_session_manager_factory(move || Box::new(manager.clone()));
}
#[allow(non_snake_case)]
pub fn SetSessionManager(manager: impl SessionManager + Clone + 'static) {
set_session_manager(manager);
}
pub fn calc_interval(max_concurrency: u32) -> Duration {
Gateway::session_start_interval(max_concurrency)
}
#[allow(non_snake_case)]
pub fn CalcInterval(max_concurrency: u32) -> Duration {
calc_interval(max_concurrency)
}
pub fn can_not_resume(err: &(dyn std::error::Error + 'static)) -> bool {
CanNotResumeErrSet.contains(&crate::error::Error(err).Code())
}
#[allow(non_snake_case)]
pub fn CanNotResume(err: &(dyn std::error::Error + 'static)) -> bool {
can_not_resume(err)
}
pub fn can_not_identify(err: &(dyn std::error::Error + 'static)) -> bool {
CanNotIdentifyErrSet.contains(&crate::error::Error(err).Code())
}
#[allow(non_snake_case)]
pub fn CanNotIdentify(err: &(dyn std::error::Error + 'static)) -> bool {
can_not_identify(err)
}
pub fn check_session_limit(ap_info: &GatewayResponse) -> crate::Result<()> {
if ap_info.shards > ap_info.session_start_limit.remaining {
Err(err_session_limit().into())
} else {
Ok(())
}
}
#[allow(non_snake_case)]
pub fn CheckSessionLimit(ap_info: &GatewayResponse) -> crate::Result<()> {
check_session_limit(ap_info)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{
CodeConnCloseCantIdentify, CodeConnCloseCantResume, CodeNeedReConnect, New,
};
use crate::models::api::SessionStartLimit;
fn ap_info(shards: u32, remaining: u32, max_concurrency: u32) -> GatewayResponse {
GatewayResponse {
url: "wss://example.com".to_string(),
shards,
session_start_limit: SessionStartLimit {
total: 10,
remaining,
reset_after: 1000,
max_concurrency,
},
}
}
#[test]
fn calc_interval_matches_botgo() {
assert_eq!(CalcInterval(0), Duration::from_secs(2));
assert_eq!(CalcInterval(1), Duration::from_secs(2));
assert_eq!(CalcInterval(2), Duration::from_secs(1));
assert_eq!(CalcInterval(3), Duration::from_secs(1));
assert_eq!(CalcInterval(100), Duration::from_secs(1));
}
#[test]
fn check_session_limit_matches_botgo() {
assert!(CheckSessionLimit(&ap_info(2, 2, 1)).is_ok());
let err = CheckSessionLimit(&ap_info(3, 2, 1)).unwrap_err();
assert!(CanNotIdentify(&err));
}
#[test]
fn resume_and_identify_error_sets_match_botgo() {
let resume = New(CodeConnCloseCantResume, "invalid session");
let identify = New(CodeConnCloseCantIdentify, "bot banned");
let reconnect = New(CodeNeedReConnect, "need reconnect");
assert!(CanNotResume(&resume));
assert!(!CanNotResume(&identify));
assert!(CanNotIdentify(&identify));
assert!(!CanNotIdentify(&resume));
assert!(!CanNotIdentify(&reconnect));
}
#[test]
fn sessions_are_generated_per_shard() {
let token = Token::new("app_id", "secret");
let sessions = ChanManager::sessions(&ap_info(3, 3, 1), token, Intents::default());
let shards = sessions
.into_iter()
.map(|session| session.shard())
.collect::<Vec<_>>();
assert_eq!(shards, vec![[0, 3], [1, 3], [2, 3]]);
}
#[test]
fn app_id_session_matches_botgo_webhook_shape() {
let session = Session::from_app_id("app-id-1");
assert_eq!(session.app_id.as_deref(), Some("app-id-1"));
assert!(session.id.is_empty());
assert_eq!(session.last_seq, 0);
assert_eq!(session.shard(), [0, 0]);
}
#[tokio::test]
async fn non_resumable_error_clears_session_before_requeue() {
let (session_tx, mut session_rx) = mpsc::unbounded_channel();
let (event_tx, _event_rx) = mpsc::unbounded_channel();
let connect_fn: std::sync::Arc<SessionConnectFn> =
std::sync::Arc::new(|session, _event_sender| {
Box::pin(async move {
let mut next = session;
next.id = "stale-session".to_string();
next.last_seq = 42;
(
next,
Err(New(CodeConnCloseCantResume, "invalid session").into()),
)
})
});
ChanManager::new_connect(
session_tx,
connect_fn,
event_tx,
Session::new(
"wss://example.com",
Token::new("app_id", "secret"),
Intents::default(),
0,
1,
),
)
.await;
let session = session_rx.recv().await.unwrap();
assert!(session.id.is_empty());
assert_eq!(session.last_seq, 0);
}
}