use super::{
builder::ShardBuilder,
config::Config,
event::Events,
json,
processor::{ConnectingError, Latency, Session, ShardProcessor},
sink::ShardSink,
stage::Stage,
};
use crate::{listener::Listeners, EventTypeFlags, Intents};
use async_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Error as TungsteniteError, Message,
};
use futures_channel::mpsc::TrySendError;
use futures_util::{
future::{self, AbortHandle},
stream::StreamExt,
};
use once_cell::sync::OnceCell;
use std::{
borrow::Cow,
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
sync::{atomic::Ordering, Arc},
};
use tokio::sync::watch::Receiver as WatchReceiver;
use twilight_http::Error as HttpError;
use twilight_model::gateway::event::Event;
use url::ParseError as UrlParseError;
#[cfg(not(feature = "simd-json"))]
use serde_json::Error as JsonError;
#[cfg(feature = "simd-json")]
use simd_json::Error as JsonError;
#[derive(Debug)]
#[non_exhaustive]
pub enum CommandError {
Sending {
source: TrySendError<Message>,
},
Serializing {
source: JsonError,
},
SessionInactive {
source: SessionInactiveError,
},
}
impl Display for CommandError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str("the shard session is inactive and has not been started")
}
}
impl Error for CommandError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Sending { source } => Some(source),
Self::Serializing { source } => Some(source),
Self::SessionInactive { source } => Some(source),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct SessionInactiveError;
impl Display for SessionInactiveError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str("the shard session is inactive and was not started")
}
}
impl Error for SessionInactiveError {}
#[derive(Debug)]
#[non_exhaustive]
pub enum ShardStartError {
Establishing {
source: TungsteniteError,
},
ParsingGatewayUrl {
source: UrlParseError,
url: String,
},
RetrievingGatewayUrl {
source: HttpError,
},
}
impl Display for ShardStartError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Self::Establishing { source } => Display::fmt(source, f),
Self::ParsingGatewayUrl { source, url } => f.write_fmt(format_args!(
"the gateway url `{}` is invalid: {}",
url, source,
)),
Self::RetrievingGatewayUrl { .. } => {
f.write_str("retrieving the gateway URL via HTTP failed")
}
}
}
}
impl Error for ShardStartError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Establishing { source } => Some(source),
Self::ParsingGatewayUrl { source, .. } => Some(source),
Self::RetrievingGatewayUrl { source } => Some(source),
}
}
}
impl From<ConnectingError> for ShardStartError {
fn from(error: ConnectingError) -> Self {
match error {
ConnectingError::Establishing { source } => Self::Establishing { source },
ConnectingError::ParsingUrl { source, url } => Self::ParsingGatewayUrl { source, url },
}
}
}
#[derive(Clone, Debug)]
pub struct Information {
id: u64,
latency: Latency,
seq: u64,
stage: Stage,
}
impl Information {
pub fn id(&self) -> u64 {
self.id
}
pub fn latency(&self) -> &Latency {
&self.latency
}
pub fn seq(&self) -> u64 {
self.seq
}
pub fn stage(&self) -> Stage {
self.stage
}
}
#[derive(Clone, Debug)]
pub struct ResumeSession {
pub session_id: String,
pub sequence: u64,
}
#[derive(Debug)]
struct ShardRef {
config: Arc<Config>,
listeners: Listeners<Event>,
processor_handle: OnceCell<AbortHandle>,
session: OnceCell<WatchReceiver<Arc<Session>>>,
}
#[derive(Clone, Debug)]
pub struct Shard(Arc<ShardRef>);
impl Shard {
pub fn new(token: impl Into<String>, intents: Intents) -> Self {
Self::builder(token, intents).build()
}
pub(crate) fn new_with_config(config: Config) -> Self {
let config = Arc::new(config);
Self(Arc::new(ShardRef {
config,
listeners: Listeners::default(),
processor_handle: OnceCell::new(),
session: OnceCell::new(),
}))
}
pub fn builder(token: impl Into<String>, intents: Intents) -> ShardBuilder {
ShardBuilder::new(token, intents)
}
pub fn config(&self) -> &Config {
&self.0.config
}
pub async fn start(&mut self) -> Result<(), ShardStartError> {
let url = if let Some(u) = self.0.config.gateway_url.clone() {
u
} else {
self.0
.config
.http_client()
.gateway()
.authed()
.await
.map_err(|source| ShardStartError::RetrievingGatewayUrl { source })?
.url
};
let config = Arc::clone(&self.0.config);
let listeners = self.0.listeners.clone();
let (processor, wrx) = ShardProcessor::new(config, url, listeners)
.await
.map_err(ShardStartError::from)?;
let (fut, handle) = future::abortable(processor.run());
tokio::spawn(async move {
let _ = fut.await;
tracing::debug!("shard processor future ended");
});
let _ = self.0.processor_handle.set(handle);
let _ = self.0.session.set(wrx);
Ok(())
}
pub fn events(&self) -> Events {
self.some_events(EventTypeFlags::default())
}
pub fn some_events(&self, event_types: EventTypeFlags) -> Events {
let rx = self.0.listeners.add(event_types);
Events::new(event_types, rx)
}
pub fn info(&self) -> Result<Information, SessionInactiveError> {
let session = self.session()?;
Ok(Information {
id: self.config().shard()[0],
latency: session.heartbeats.latency(),
seq: session.seq(),
stage: session.stage(),
})
}
pub fn sink(&self) -> Result<ShardSink, SessionInactiveError> {
let session = self.session()?;
Ok(ShardSink(session.tx.clone()))
}
pub async fn command(&self, value: &impl serde::Serialize) -> Result<(), CommandError> {
let json = json::to_vec(value).map_err(|source| CommandError::Serializing { source })?;
self.command_raw(json).await
}
pub async fn command_raw(&self, value: Vec<u8>) -> Result<(), CommandError> {
let session = self
.session()
.map_err(|source| CommandError::SessionInactive { source })?;
let message = Message::Binary(value);
session.ratelimit.lock().await.next().await;
session
.tx
.unbounded_send(message)
.map_err(|source| CommandError::Sending { source })
}
pub fn shutdown(&self) {
self.0.listeners.remove_all();
if let Some(processor_handle) = self.0.processor_handle.get() {
processor_handle.abort();
}
if let Ok(session) = self.session() {
let _ = session.close(Some(CloseFrame {
code: CloseCode::Normal,
reason: "".into(),
}));
session.stop_heartbeater();
}
}
pub fn shutdown_resumable(&self) -> (u64, Option<ResumeSession>) {
self.0.listeners.remove_all();
if let Some(processor_handle) = self.0.processor_handle.get() {
processor_handle.abort();
}
let shard_id = self.config().shard()[0];
let session = match self.session() {
Ok(session) => session,
Err(_) => return (shard_id, None),
};
let _ = session.close(Some(CloseFrame {
code: CloseCode::Restart,
reason: Cow::from("Closing in a resumable way"),
}));
let session_id = session.id();
let sequence = session.seq.load(Ordering::Relaxed);
session.stop_heartbeater();
let data = session_id.map(|id| ResumeSession {
session_id: id,
sequence,
});
(shard_id, data)
}
fn session(&self) -> Result<Arc<Session>, SessionInactiveError> {
let session = self.0.session.get().ok_or(SessionInactiveError)?;
Ok(Arc::clone(&session.borrow()))
}
}
#[cfg(test)]
mod tests {
use super::{
CommandError, ConnectingError, Information, ResumeSession, SessionInactiveError, Shard,
ShardStartError,
};
use static_assertions::{assert_fields, assert_impl_all};
use std::{error::Error, fmt::Debug};
assert_fields!(CommandError::Sending: source);
assert_fields!(CommandError::Serializing: source);
assert_fields!(CommandError::SessionInactive: source);
assert_impl_all!(CommandError: Debug, Error, Send, Sync);
assert_impl_all!(Information: Clone, Debug, Send, Sync);
assert_impl_all!(ResumeSession: Clone, Debug, Send, Sync);
assert_impl_all!(
SessionInactiveError: Clone,
Debug,
Error,
Eq,
PartialEq,
Send,
Sync
);
assert_fields!(ShardStartError::Establishing: source);
assert_fields!(ShardStartError::ParsingGatewayUrl: source, url);
assert_fields!(ShardStartError::RetrievingGatewayUrl: source);
assert_impl_all!(
ShardStartError: Debug,
Error,
From<ConnectingError>,
Send,
Sync
);
assert_impl_all!(Shard: Clone, Debug, Send, Sync);
}