use super::cache::Cache;
use super::context::Context;
use super::handler::EventHandler;
use crate::error::{Error, GatewayError};
use crate::gateway::connection::{GatewayConnection, SessionState};
use crate::gateway::event::GatewayEvent;
use crate::http::client::HttpClient;
use futures_util::{FutureExt, StreamExt};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use tokio::sync::Semaphore;
const DEFAULT_GATEWAY_URL: &str = "wss://gateway.fluxer.app";
const DEFAULT_MAX_CONCURRENT_EVENTS: usize = 1024;
pub struct Client {
token: String,
gateway_url_override: Option<String>,
handler: Arc<dyn EventHandler>,
http: HttpClient,
cache: Cache,
event_semaphore: Arc<Semaphore>,
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
pub async fn start(&self) -> Result<(), Error> {
let mut backoff = 1u64;
let mut session: Option<SessionState> = None;
let gateway_url = self.resolve_gateway_url().await;
loop {
let connect_result = if let Some(ref s) = session {
tracing::info!(
url = %gateway_url,
session_id = %s.session_id,
seq = s.sequence,
"resuming gateway session",
);
GatewayConnection::resume(&gateway_url, &self.token, s).await
} else {
tracing::info!(url = %gateway_url, "connecting to gateway");
GatewayConnection::connect(&gateway_url, &self.token).await
};
match connect_result {
Ok(mut gw) => {
backoff = 1;
let ctx = Context::new(self.http.clone(), self.cache.clone());
while let Some(result) = gw.next().await {
match result {
Ok(event) => {
self.dispatch(ctx.clone(), event).await;
}
Err(GatewayError::InvalidSession { resumable }) => {
tracing::warn!(resumable, "session invalidated");
if !resumable {
session = None;
}
break;
}
Err(GatewayError::Closed { code, reason }) => {
tracing::warn!(code, %reason, "gateway closed");
session = gw.session_state();
break;
}
Err(e) => {
tracing::error!(?e, "gateway error");
session = gw.session_state();
break;
}
}
}
if session.is_some() || gw.session_state().is_some() {
session = session.or_else(|| gw.session_state());
}
}
Err(GatewayError::InvalidSession { .. }) => {
tracing::warn!("resume rejected, falling back to fresh identify");
session = None;
}
Err(e) => {
tracing::error!(?e, "failed to connect to gateway");
}
}
let wait = std::time::Duration::from_secs(backoff);
tracing::info!(?wait, "reconnecting");
tokio::time::sleep(wait).await;
backoff = (backoff * 2).min(60);
}
}
async fn resolve_gateway_url(&self) -> String {
if let Some(url) = &self.gateway_url_override {
tracing::debug!(%url, "using gateway URL override");
return url.clone();
}
match self.http.get_gateway_bot().await {
Ok(info) => {
tracing::info!(
url = %info.url,
shards = info.shards,
session_total = info.session_start_limit.total,
session_remaining = info.session_start_limit.remaining,
session_reset_after_ms = info.session_start_limit.reset_after,
max_concurrency = info.session_start_limit.max_concurrency,
"discovered gateway URL",
);
info.url
}
Err(e) => {
tracing::warn!(
error = ?e,
fallback = DEFAULT_GATEWAY_URL,
"get_gateway_bot failed, falling back to default URL",
);
DEFAULT_GATEWAY_URL.to_string()
}
}
}
async fn dispatch(&self, ctx: Context, event: GatewayEvent) {
let handler = self.handler.clone();
let cache = self.cache.clone();
let permit = match self.event_semaphore.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
tracing::error!("event dispatch semaphore closed");
return;
}
};
tokio::spawn(async move {
let _permit = permit;
let result = AssertUnwindSafe(Self::handle_event(handler, cache, ctx, event))
.catch_unwind()
.await;
if result.is_err() {
tracing::error!("event handler panicked");
}
});
}
async fn handle_event(
handler: Arc<dyn EventHandler>,
cache: Cache,
ctx: Context,
event: GatewayEvent,
) {
match event {
GatewayEvent::Ready(ready) => {
cache.set_current_user(ready.user.clone()).await;
handler.ready(ctx, ready.user).await;
}
GatewayEvent::Resumed(_) => {
tracing::info!("session resumed successfully");
handler.resumed(ctx).await;
}
GatewayEvent::MessageCreate(msg) => handler.message_create(ctx, msg).await,
GatewayEvent::MessageUpdate(msg) => handler.message_update(ctx, msg).await,
GatewayEvent::MessageDelete(p) => handler.message_delete(ctx, p.channel_id, p.id).await,
GatewayEvent::GuildCreate(guild) => {
cache.insert_guild(guild.clone()).await;
handler.guild_create(ctx, guild).await;
}
GatewayEvent::GuildUpdate(guild) => {
cache.insert_guild(guild.clone()).await;
handler.guild_update(ctx, guild).await;
}
GatewayEvent::GuildDelete(p) => {
cache.remove_guild(p.id).await;
handler.guild_delete(ctx, p.id).await;
}
GatewayEvent::GuildMemberAdd(p) => {
handler.guild_member_add(ctx, p.guild_id, p.member).await
}
GatewayEvent::GuildMemberRemove(p) => {
handler.guild_member_remove(ctx, p.guild_id, p.user).await
}
GatewayEvent::ChannelCreate(ch) => {
cache.insert_channel(ch.clone()).await;
handler.channel_create(ctx, ch).await;
}
GatewayEvent::ChannelUpdate(ch) => {
cache.insert_channel(ch.clone()).await;
handler.channel_update(ctx, ch).await;
}
GatewayEvent::ChannelDelete(ch) => {
cache.remove_channel(ch.id).await;
handler.channel_delete(ctx, ch).await;
}
GatewayEvent::TypingStart(p) => {
handler.typing_start(ctx, p.channel_id, p.user_id).await
}
GatewayEvent::UserUpdate(user) => handler.user_update(ctx, user).await,
GatewayEvent::MessageDeleteBulk(p) => handler.message_delete_bulk(ctx, p).await,
GatewayEvent::MessageReactionAdd(p) => handler.message_reaction_add(ctx, p).await,
GatewayEvent::MessageReactionRemove(p) => handler.message_reaction_remove(ctx, p).await,
GatewayEvent::MessageReactionRemoveAll(p) => {
handler.message_reaction_remove_all(ctx, p).await
}
GatewayEvent::MessageReactionRemoveEmoji(p) => {
handler.message_reaction_remove_emoji(ctx, p).await
}
GatewayEvent::GuildMemberUpdate(p) => handler.guild_member_update(ctx, p).await,
GatewayEvent::GuildRoleCreate(p) => handler.guild_role_create(ctx, p).await,
GatewayEvent::GuildRoleUpdate(p) => handler.guild_role_update(ctx, p).await,
GatewayEvent::GuildRoleDelete(p) => handler.guild_role_delete(ctx, p).await,
GatewayEvent::GuildBanAdd(p) => handler.guild_ban_add(ctx, p).await,
GatewayEvent::GuildBanRemove(p) => handler.guild_ban_remove(ctx, p).await,
GatewayEvent::ChannelPinsUpdate(p) => handler.channel_pins_update(ctx, p).await,
GatewayEvent::InviteCreate(p) => handler.invite_create(ctx, p).await,
GatewayEvent::InviteDelete(p) => handler.invite_delete(ctx, p).await,
GatewayEvent::WebhooksUpdate(p) => handler.webhooks_update(ctx, p).await,
GatewayEvent::GuildEmojisUpdate(p) => handler.guild_emojis_update(ctx, p).await,
GatewayEvent::GuildStickersUpdate(p) => handler.guild_stickers_update(ctx, p).await,
GatewayEvent::Unknown { .. } => {}
}
}
}
pub struct ClientBuilder {
token: Option<String>,
gateway_url: Option<String>,
handler: Option<Arc<dyn EventHandler>>,
base_url: Option<String>,
auto_retry: bool,
max_concurrent_events: usize,
}
impl Default for ClientBuilder {
fn default() -> Self {
Self {
token: None,
gateway_url: None,
handler: None,
base_url: None,
auto_retry: true,
max_concurrent_events: DEFAULT_MAX_CONCURRENT_EVENTS,
}
}
}
impl ClientBuilder {
pub fn token(mut self, token: &str) -> Self {
self.token = Some(token.into());
self
}
pub fn gateway_url(mut self, url: &str) -> Self {
self.gateway_url = Some(url.into());
self
}
pub fn base_url(mut self, url: &str) -> Self {
self.base_url = Some(url.into());
self
}
pub fn auto_retry(mut self, enabled: bool) -> Self {
self.auto_retry = enabled;
self
}
pub fn max_concurrent_events(mut self, limit: usize) -> Self {
self.max_concurrent_events = limit.max(1);
self
}
pub fn event_handler(mut self, handler: impl EventHandler) -> Self {
self.handler = Some(Arc::new(handler));
self
}
#[allow(clippy::result_large_err)]
pub fn build(self) -> Result<Client, Error> {
let token = self.token.ok_or(Error::MissingToken)?;
let mut http_builder = HttpClient::builder()
.token(&token)
.auto_retry(self.auto_retry);
if let Some(base_url) = &self.base_url {
http_builder = http_builder.base_url(base_url);
}
let http = http_builder.build()?;
let handler = self.handler.ok_or(Error::MissingEventHandler)?;
Ok(Client {
token,
gateway_url_override: self.gateway_url,
handler,
http,
cache: Cache::new(),
event_semaphore: Arc::new(Semaphore::new(self.max_concurrent_events)),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
struct NoopHandler;
impl EventHandler for NoopHandler {}
#[test]
fn builder_requires_token() {
let result = Client::builder().event_handler(NoopHandler).build();
assert!(matches!(result, Err(Error::MissingToken)));
}
#[test]
fn builder_requires_event_handler() {
let result = Client::builder().token("test-token").build();
assert!(matches!(result, Err(Error::MissingEventHandler)));
}
}