mod context;
#[cfg(feature = "gateway")]
pub(crate) mod dispatch;
mod error;
#[cfg(feature = "gateway")]
mod event_handler;
use std::future::IntoFuture;
use std::ops::Range;
use std::sync::Arc;
#[cfg(feature = "framework")]
use std::sync::OnceLock;
use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::future::BoxFuture;
use futures::StreamExt as _;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, info, instrument};
use typemap_rev::{TypeMap, TypeMapKey};
pub use self::context::Context;
pub use self::error::Error as ClientError;
#[cfg(feature = "gateway")]
pub use self::event_handler::{EventHandler, FullEvent, RawEventHandler};
#[cfg(feature = "gateway")]
use super::gateway::GatewayError;
#[cfg(feature = "cache")]
pub use crate::cache::Cache;
#[cfg(feature = "cache")]
use crate::cache::Settings as CacheSettings;
#[cfg(feature = "framework")]
use crate::framework::Framework;
#[cfg(feature = "voice")]
use crate::gateway::VoiceGatewayManager;
use crate::gateway::{ActivityData, PresenceData};
#[cfg(feature = "gateway")]
use crate::gateway::{ShardManager, ShardManagerOptions};
use crate::http::Http;
use crate::internal::prelude::*;
#[cfg(feature = "gateway")]
use crate::model::gateway::GatewayIntents;
use crate::model::id::ApplicationId;
use crate::model::user::OnlineStatus;
#[cfg(feature = "gateway")]
#[must_use = "Builders do nothing unless they are awaited"]
pub struct ClientBuilder {
data: TypeMap,
http: Http,
intents: GatewayIntents,
#[cfg(feature = "cache")]
cache_settings: CacheSettings,
#[cfg(feature = "framework")]
framework: Option<Box<dyn Framework>>,
#[cfg(feature = "voice")]
voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
event_handlers: Vec<Arc<dyn EventHandler>>,
raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
presence: PresenceData,
}
#[cfg(feature = "gateway")]
impl ClientBuilder {
fn new_(http: Http, intents: GatewayIntents) -> Self {
Self {
data: TypeMap::new(),
http,
intents,
#[cfg(feature = "cache")]
cache_settings: CacheSettings::default(),
#[cfg(feature = "framework")]
framework: None,
#[cfg(feature = "voice")]
voice_manager: None,
event_handlers: vec![],
raw_event_handlers: vec![],
presence: PresenceData::default(),
}
}
pub fn new(token: impl AsRef<str>, intents: GatewayIntents) -> Self {
Self::new_(Http::new(token.as_ref()), intents)
}
pub fn new_with_http(http: Http, intents: GatewayIntents) -> Self {
Self::new_(http, intents)
}
pub fn token(mut self, token: impl AsRef<str>) -> Self {
self.http = Http::new(token.as_ref());
self
}
pub fn get_token(&self) -> &str {
self.http.token()
}
pub fn application_id(self, application_id: ApplicationId) -> Self {
self.http.set_application_id(application_id);
self
}
pub fn get_application_id(&self) -> Option<ApplicationId> {
self.http.application_id()
}
pub fn type_map(mut self, type_map: TypeMap) -> Self {
self.data = type_map;
self
}
pub fn get_type_map(&self) -> &TypeMap {
&self.data
}
pub fn type_map_insert<T: TypeMapKey>(mut self, value: T::Value) -> Self {
self.data.insert::<T>(value);
self
}
#[cfg(feature = "cache")]
pub fn cache_settings(mut self, settings: CacheSettings) -> Self {
self.cache_settings = settings;
self
}
#[cfg(feature = "cache")]
pub fn get_cache_settings(&self) -> &CacheSettings {
&self.cache_settings
}
#[cfg(feature = "framework")]
pub fn framework<F>(mut self, framework: F) -> Self
where
F: Framework + 'static,
{
self.framework = Some(Box::new(framework));
self
}
#[cfg(feature = "framework")]
pub fn get_framework(&self) -> Option<&dyn Framework> {
self.framework.as_deref()
}
#[cfg(feature = "voice")]
pub fn voice_manager<V>(mut self, voice_manager: V) -> Self
where
V: VoiceGatewayManager + 'static,
{
self.voice_manager = Some(Arc::new(voice_manager));
self
}
#[cfg(feature = "voice")]
pub fn voice_manager_arc(
mut self,
voice_manager: Arc<dyn VoiceGatewayManager + 'static>,
) -> Self {
self.voice_manager = Some(voice_manager);
self
}
#[cfg(feature = "voice")]
pub fn get_voice_manager(&self) -> Option<Arc<dyn VoiceGatewayManager>> {
self.voice_manager.clone()
}
pub fn intents(mut self, intents: GatewayIntents) -> Self {
self.intents = intents;
self
}
pub fn get_intents(&self) -> GatewayIntents {
self.intents
}
pub fn event_handler<H: EventHandler + 'static>(mut self, event_handler: H) -> Self {
self.event_handlers.push(Arc::new(event_handler));
self
}
pub fn event_handler_arc<H: EventHandler + 'static>(
mut self,
event_handler_arc: Arc<H>,
) -> Self {
self.event_handlers.push(event_handler_arc);
self
}
pub fn get_event_handlers(&self) -> &[Arc<dyn EventHandler>] {
&self.event_handlers
}
pub fn raw_event_handler<H: RawEventHandler + 'static>(mut self, raw_event_handler: H) -> Self {
self.raw_event_handlers.push(Arc::new(raw_event_handler));
self
}
pub fn get_raw_event_handlers(&self) -> &[Arc<dyn RawEventHandler>] {
&self.raw_event_handlers
}
pub fn activity(mut self, activity: ActivityData) -> Self {
self.presence.activity = Some(activity);
self
}
pub fn status(mut self, status: OnlineStatus) -> Self {
self.presence.status = status;
self
}
pub fn get_presence(&self) -> &PresenceData {
&self.presence
}
}
#[cfg(feature = "gateway")]
impl IntoFuture for ClientBuilder {
type Output = Result<Client>;
type IntoFuture = BoxFuture<'static, Result<Client>>;
#[instrument(skip(self))]
fn into_future(self) -> Self::IntoFuture {
let data = Arc::new(RwLock::new(self.data));
#[cfg(feature = "framework")]
let framework = self.framework;
let event_handlers = self.event_handlers;
let raw_event_handlers = self.raw_event_handlers;
let intents = self.intents;
let presence = self.presence;
let mut http = self.http;
if let Some(ratelimiter) = &mut http.ratelimiter {
let event_handlers_clone = event_handlers.clone();
ratelimiter.set_ratelimit_callback(Box::new(move |info| {
for event_handler in event_handlers_clone.iter().map(Arc::clone) {
let info = info.clone();
tokio::spawn(async move { event_handler.ratelimit(info).await });
}
}));
}
let http = Arc::new(http);
#[cfg(feature = "voice")]
let voice_manager = self.voice_manager;
#[cfg(feature = "cache")]
let cache = Arc::new(Cache::new_with_settings(self.cache_settings));
Box::pin(async move {
let ws_url = Arc::new(Mutex::new(match http.get_gateway().await {
Ok(response) => response.url,
Err(err) => {
tracing::warn!("HTTP request to get gateway URL failed: {}", err);
"wss://gateway.discord.gg".to_string()
},
}));
#[cfg(feature = "framework")]
let framework_cell = Arc::new(OnceLock::new());
let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions {
data: Arc::clone(&data),
event_handlers,
raw_event_handlers,
#[cfg(feature = "framework")]
framework: Arc::clone(&framework_cell),
shard_index: 0,
shard_init: 0,
shard_total: 0,
#[cfg(feature = "voice")]
voice_manager: voice_manager.clone(),
ws_url: Arc::clone(&ws_url),
#[cfg(feature = "cache")]
cache: Arc::clone(&cache),
http: Arc::clone(&http),
intents,
presence: Some(presence),
});
let client = Client {
data,
shard_manager,
shard_manager_return_value: shard_manager_ret_value,
#[cfg(feature = "voice")]
voice_manager,
ws_url,
#[cfg(feature = "cache")]
cache,
http,
};
#[cfg(feature = "framework")]
if let Some(mut framework) = framework {
framework.init(&client).await;
if let Err(_existing) = framework_cell.set(framework.into()) {
tracing::warn!("overwrote existing contents of framework OnceLock");
}
}
Ok(client)
})
}
}
#[cfg(feature = "gateway")]
pub struct Client {
pub data: Arc<RwLock<TypeMap>>,
pub shard_manager: Arc<ShardManager>,
shard_manager_return_value: Receiver<Result<(), GatewayError>>,
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
pub ws_url: Arc<Mutex<String>>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
pub http: Arc<Http>,
}
impl Client {
pub fn builder(token: impl AsRef<str>, intents: GatewayIntents) -> ClientBuilder {
ClientBuilder::new(token, intents)
}
#[instrument(skip(self))]
pub async fn start(&mut self) -> Result<()> {
self.start_connection(0, 0, 1).await
}
#[instrument(skip(self))]
pub async fn start_autosharded(&mut self) -> Result<()> {
let (end, total) = {
let res = self.http.get_bot_gateway().await?;
(res.shards - 1, res.shards)
};
self.start_connection(0, end, total).await
}
#[instrument(skip(self))]
pub async fn start_shard(&mut self, shard: u32, shards: u32) -> Result<()> {
self.start_connection(shard, shard, shards).await
}
#[instrument(skip(self))]
pub async fn start_shards(&mut self, total_shards: u32) -> Result<()> {
self.start_connection(0, total_shards - 1, total_shards).await
}
#[instrument(skip(self))]
pub async fn start_shard_range(&mut self, range: Range<u32>, total_shards: u32) -> Result<()> {
self.start_connection(range.start, range.end, total_shards).await
}
#[instrument(skip(self))]
async fn start_connection(
&mut self,
start_shard: u32,
end_shard: u32,
total_shards: u32,
) -> Result<()> {
#[cfg(feature = "voice")]
if let Some(voice_manager) = &self.voice_manager {
let user = self.http.get_current_user().await?;
voice_manager.initialise(total_shards, user.id).await;
}
let init = end_shard - start_shard + 1;
self.shard_manager.set_shards(start_shard, init, total_shards).await;
debug!("Initializing shard info: {} - {}/{}", start_shard, init, total_shards);
if let Err(why) = self.shard_manager.initialize() {
error!("Failed to boot a shard: {:?}", why);
info!("Shutting down all shards");
self.shard_manager.shutdown_all().await;
return Err(Error::Client(ClientError::ShardBootFailure));
}
if let Some(Err(err)) = self.shard_manager_return_value.next().await {
return Err(Error::Gateway(err));
}
Ok(())
}
}