use std::sync::Arc;
use std::time::{Duration as StdDuration, Instant};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
use tracing::{debug, error, info, instrument, trace, warn};
use url::Url;
use super::{
ActivityData,
ChunkGuildFilter,
ConnectionStage,
GatewayError,
PresenceData,
ReconnectType,
ShardAction,
WsClient,
};
use crate::constants::{self, close_codes};
use crate::internal::prelude::*;
use crate::model::event::{Event, GatewayEvent};
use crate::model::gateway::{GatewayIntents, ShardInfo};
use crate::model::id::{ApplicationId, GuildId};
use crate::model::user::OnlineStatus;
pub struct Shard {
pub client: WsClient,
presence: PresenceData,
last_heartbeat_sent: Option<Instant>,
last_heartbeat_ack: Option<Instant>,
heartbeat_interval: Option<std::time::Duration>,
application_id_callback: Option<Box<dyn FnOnce(ApplicationId) + Send + Sync>>,
last_heartbeat_acknowledged: bool,
seq: u64,
session_id: Option<String>,
info: ShardInfo,
stage: ConnectionStage,
pub started: Instant,
pub token: String,
ws_url: Arc<Mutex<String>>,
pub intents: GatewayIntents,
}
impl Shard {
pub async fn new(
ws_url: Arc<Mutex<String>>,
token: &str,
info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
) -> Result<Shard> {
let url = ws_url.lock().await.clone();
let client = connect(&url).await?;
let presence = presence.unwrap_or_default();
let last_heartbeat_sent = None;
let last_heartbeat_ack = None;
let heartbeat_interval = None;
let last_heartbeat_acknowledged = true;
let seq = 0;
let stage = ConnectionStage::Handshake;
let session_id = None;
Ok(Shard {
client,
presence,
last_heartbeat_sent,
last_heartbeat_ack,
heartbeat_interval,
application_id_callback: None,
last_heartbeat_acknowledged,
seq,
stage,
started: Instant::now(),
token: token.to_string(),
session_id,
info,
ws_url,
intents,
})
}
pub fn set_application_id_callback(
&mut self,
callback: impl FnOnce(ApplicationId) + Send + Sync + 'static,
) {
self.application_id_callback = Some(Box::new(callback));
}
#[inline]
pub fn presence(&self) -> &PresenceData {
&self.presence
}
#[inline]
pub fn last_heartbeat_sent(&self) -> Option<Instant> {
self.last_heartbeat_sent
}
#[inline]
pub fn last_heartbeat_ack(&self) -> Option<Instant> {
self.last_heartbeat_ack
}
#[instrument(skip(self))]
pub async fn heartbeat(&mut self) -> Result<()> {
match self.client.send_heartbeat(&self.info, Some(self.seq)).await {
Ok(()) => {
self.last_heartbeat_sent = Some(Instant::now());
self.last_heartbeat_acknowledged = false;
Ok(())
},
Err(why) => {
match why {
Error::Tungstenite(TungsteniteError::Io(err)) => {
if err.raw_os_error() != Some(32) {
debug!("[{:?}] Err heartbeating: {:?}", self.info, err);
}
},
other => {
warn!("[{:?}] Other err w/ keepalive: {:?}", self.info, other);
},
}
Err(Error::Gateway(GatewayError::HeartbeatFailed))
},
}
}
#[inline]
pub fn heartbeat_interval(&self) -> Option<std::time::Duration> {
self.heartbeat_interval
}
#[inline]
pub fn last_heartbeat_acknowledged(&self) -> bool {
self.last_heartbeat_acknowledged
}
#[inline]
pub fn seq(&self) -> u64 {
self.seq
}
#[inline]
pub fn session_id(&self) -> Option<&String> {
self.session_id.as_ref()
}
#[inline]
#[instrument(skip(self))]
pub fn set_activity(&mut self, activity: Option<ActivityData>) {
self.presence.activity = activity;
}
#[inline]
#[instrument(skip(self))]
pub fn set_presence(&mut self, activity: Option<ActivityData>, status: OnlineStatus) {
self.set_activity(activity);
self.set_status(status);
}
#[inline]
#[instrument(skip(self))]
pub fn set_status(&mut self, mut status: OnlineStatus) {
if status == OnlineStatus::Offline {
status = OnlineStatus::Invisible;
}
self.presence.status = status;
}
pub fn shard_info(&self) -> ShardInfo {
self.info
}
pub fn stage(&self) -> ConnectionStage {
self.stage
}
#[instrument(skip(self))]
fn handle_gateway_dispatch(&mut self, seq: u64, event: &Event) -> Option<ShardAction> {
if seq > self.seq + 1 {
warn!("[{:?}] Sequence off; them: {}, us: {}", self.info, seq, self.seq);
}
match &event {
Event::Ready(ready) => {
debug!("[{:?}] Received Ready", self.info);
self.session_id = Some(ready.ready.session_id.clone());
self.stage = ConnectionStage::Connected;
if let Some(callback) = self.application_id_callback.take() {
callback(ready.ready.application.id);
}
},
Event::Resumed(_) => {
info!("[{:?}] Resumed", self.info);
self.stage = ConnectionStage::Connected;
self.last_heartbeat_acknowledged = true;
self.last_heartbeat_sent = Some(Instant::now());
self.last_heartbeat_ack = None;
},
_ => {},
}
self.seq = seq;
None
}
#[instrument(skip(self))]
fn handle_gateway_closed(
&mut self,
data: Option<&CloseFrame<'static>>,
) -> Result<Option<ShardAction>> {
let num = data.map(|d| d.code.into());
let clean = num == Some(1000);
match num {
Some(close_codes::UNKNOWN_OPCODE) => {
warn!("[{:?}] Sent invalid opcode.", self.info);
},
Some(close_codes::DECODE_ERROR) => {
warn!("[{:?}] Sent invalid message.", self.info);
},
Some(close_codes::NOT_AUTHENTICATED) => {
warn!("[{:?}] Sent no authentication.", self.info);
return Err(Error::Gateway(GatewayError::NoAuthentication));
},
Some(close_codes::AUTHENTICATION_FAILED) => {
error!("[{:?}] Sent invalid authentication, please check the token.", self.info);
return Err(Error::Gateway(GatewayError::InvalidAuthentication));
},
Some(close_codes::ALREADY_AUTHENTICATED) => {
warn!("[{:?}] Already authenticated.", self.info);
},
Some(close_codes::INVALID_SEQUENCE) => {
warn!("[{:?}] Sent invalid seq: {}.", self.info, self.seq);
self.seq = 0;
},
Some(close_codes::RATE_LIMITED) => {
warn!("[{:?}] Gateway ratelimited.", self.info);
},
Some(close_codes::INVALID_SHARD) => {
warn!("[{:?}] Sent invalid shard data.", self.info);
return Err(Error::Gateway(GatewayError::InvalidShardData));
},
Some(close_codes::SHARDING_REQUIRED) => {
error!("[{:?}] Shard has too many guilds.", self.info);
return Err(Error::Gateway(GatewayError::OverloadedShard));
},
Some(4006 | close_codes::SESSION_TIMEOUT) => {
info!("[{:?}] Invalid session.", self.info);
self.session_id = None;
},
Some(close_codes::INVALID_GATEWAY_INTENTS) => {
error!("[{:?}] Invalid gateway intents have been provided.", self.info);
return Err(Error::Gateway(GatewayError::InvalidGatewayIntents));
},
Some(close_codes::DISALLOWED_GATEWAY_INTENTS) => {
error!("[{:?}] Disallowed gateway intents have been provided.", self.info);
return Err(Error::Gateway(GatewayError::DisallowedGatewayIntents));
},
Some(other) if !clean => {
warn!(
"[{:?}] Unknown unclean close {}: {:?}",
self.info,
other,
data.map(|d| &d.reason),
);
},
_ => {},
}
let resume = num
.map_or(true, |x| x != close_codes::AUTHENTICATION_FAILED && self.session_id.is_some());
Ok(Some(if resume {
ShardAction::Reconnect(ReconnectType::Resume)
} else {
ShardAction::Reconnect(ReconnectType::Reidentify)
}))
}
#[instrument(skip(self))]
pub fn handle_event(&mut self, event: &Result<GatewayEvent>) -> Result<Option<ShardAction>> {
match event {
Ok(GatewayEvent::Dispatch(seq, event)) => Ok(self.handle_gateway_dispatch(*seq, event)),
Ok(GatewayEvent::Heartbeat(..)) => {
info!("[{:?}] Received shard heartbeat", self.info);
Ok(Some(ShardAction::Heartbeat))
},
Ok(GatewayEvent::HeartbeatAck) => {
self.last_heartbeat_ack = Some(Instant::now());
self.last_heartbeat_acknowledged = true;
trace!("[{:?}] Received heartbeat ack", self.info);
Ok(None)
},
&Ok(GatewayEvent::Hello(interval)) => {
debug!("[{:?}] Received a Hello; interval: {}", self.info, interval);
if self.stage == ConnectionStage::Resuming {
return Ok(None);
}
self.heartbeat_interval = Some(std::time::Duration::from_millis(interval));
Ok(Some(if self.stage == ConnectionStage::Handshake {
ShardAction::Identify
} else {
debug!("[{:?}] Received late Hello; autoreconnecting", self.info);
ShardAction::Reconnect(self.reconnection_type())
}))
},
&Ok(GatewayEvent::InvalidateSession(resumable)) => {
info!("[{:?}] Received session invalidation", self.info);
Ok(Some(if resumable {
ShardAction::Reconnect(ReconnectType::Resume)
} else {
ShardAction::Reconnect(ReconnectType::Reidentify)
}))
},
Ok(GatewayEvent::Reconnect) => Ok(Some(ShardAction::Reconnect(ReconnectType::Resume))),
Err(Error::Gateway(GatewayError::Closed(data))) => {
self.handle_gateway_closed(data.as_ref())
},
Err(Error::Tungstenite(why)) => {
info!("[{:?}] Websocket error: {:?}", self.info, why);
info!("[{:?}] Will attempt to auto-reconnect", self.info);
Ok(Some(ShardAction::Reconnect(self.reconnection_type())))
},
Err(why) => {
warn!("[{:?}] Unhandled error: {:?}", self.info, why);
Ok(None)
},
}
}
#[instrument(skip(self))]
pub async fn do_heartbeat(&mut self) -> bool {
let Some(heartbeat_interval) = self.heartbeat_interval else {
return self.started.elapsed() < StdDuration::from_secs(15);
};
if let Some(last_sent) = self.last_heartbeat_sent {
if last_sent.elapsed() <= heartbeat_interval {
return true;
}
}
if !self.last_heartbeat_acknowledged {
debug!("[{:?}] Last heartbeat not acknowledged", self.info,);
return false;
}
if let Err(why) = self.heartbeat().await {
warn!("[{:?}] Err heartbeating: {:?}", self.info, why);
false
} else {
trace!("[{:?}] Heartbeat", self.info);
true
}
}
#[instrument(skip(self))]
pub fn latency(&self) -> Option<StdDuration> {
if let (Some(sent), Some(received)) = (self.last_heartbeat_sent, self.last_heartbeat_ack) {
if received > sent {
return Some(received - sent);
}
}
None
}
pub fn should_reconnect(&mut self) -> Option<ReconnectType> {
if self.stage == ConnectionStage::Connecting {
return None;
}
Some(self.reconnection_type())
}
pub fn reconnection_type(&self) -> ReconnectType {
if self.session_id().is_some() {
ReconnectType::Resume
} else {
ReconnectType::Reidentify
}
}
#[instrument(skip(self))]
pub async fn chunk_guild(
&mut self,
guild_id: GuildId,
limit: Option<u16>,
presences: bool,
filter: ChunkGuildFilter,
nonce: Option<&str>,
) -> Result<()> {
debug!("[{:?}] Requesting member chunks", self.info);
self.client.send_chunk_guild(guild_id, &self.info, limit, presences, filter, nonce).await
}
#[instrument(skip(self))]
pub async fn request_soundboard_sounds(&mut self, guild_ids: &[GuildId]) -> Result<()> {
debug!("[{:?}] Requesting soundboard sounds", self.info);
self.client.request_soundboard_sounds(guild_ids, &self.info).await
}
#[instrument(skip(self))]
pub async fn identify(&mut self) -> Result<()> {
self.client.send_identify(&self.info, &self.token, self.intents, &self.presence).await?;
self.last_heartbeat_sent = Some(Instant::now());
self.stage = ConnectionStage::Identifying;
Ok(())
}
#[instrument(skip(self))]
pub async fn initialize(&mut self) -> Result<WsClient> {
debug!("[{:?}] Initializing.", self.info);
self.stage = ConnectionStage::Connecting;
self.started = Instant::now();
let url = &self.ws_url.lock().await.clone();
let client = connect(url).await?;
self.stage = ConnectionStage::Handshake;
Ok(client)
}
#[instrument(skip(self))]
pub async fn reset(&mut self) {
self.last_heartbeat_sent = Some(Instant::now());
self.last_heartbeat_ack = None;
self.heartbeat_interval = None;
self.last_heartbeat_acknowledged = true;
self.session_id = None;
self.stage = ConnectionStage::Disconnected;
self.seq = 0;
}
#[instrument(skip(self))]
pub async fn resume(&mut self) -> Result<()> {
debug!("[{:?}] Attempting to resume", self.info);
self.client = self.initialize().await?;
self.stage = ConnectionStage::Resuming;
match &self.session_id {
Some(session_id) => {
self.client.send_resume(&self.info, session_id, self.seq, &self.token).await
},
None => Err(Error::Gateway(GatewayError::NoSessionId)),
}
}
#[instrument(skip(self))]
pub async fn reconnect(&mut self) -> Result<()> {
info!("[{:?}] Attempting to reconnect", self.shard_info());
self.reset().await;
self.client = self.initialize().await?;
Ok(())
}
#[instrument(skip(self))]
pub async fn update_presence(&mut self) -> Result<()> {
self.client.send_presence_update(&self.info, &self.presence).await
}
}
async fn connect(base_url: &str) -> Result<WsClient> {
let url =
Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);
Error::Gateway(GatewayError::BuildingUrl)
})?;
WsClient::connect(url).await
}