use super::{
config::Config as ClusterConfig,
r#impl::{Cluster, ClusterStartError},
};
use crate::shard::{LargeThresholdError, ResumeSession, ShardBuilder};
use std::{
collections::HashMap,
convert::TryFrom,
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
ops::{Bound, RangeBounds},
sync::Arc,
};
use twilight_gateway_queue::{LocalQueue, Queue};
use twilight_http::Client;
use twilight_model::gateway::{payload::update_status::UpdateStatusInfo, Intents};
#[derive(Debug)]
pub enum ShardSchemeRangeError {
IdTooLarge {
end: u64,
start: u64,
total: u64,
},
}
impl Display for ShardSchemeRangeError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Self::IdTooLarge { end, start, total } => f.write_fmt(format_args!(
"The shard ID range {}-{}/{} is larger than the total",
start, end, total
)),
}
}
}
impl Error for ShardSchemeRangeError {}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum ShardScheme {
Auto,
Range {
from: u64,
to: u64,
total: u64,
},
}
impl Default for ShardScheme {
fn default() -> Self {
Self::Auto
}
}
impl<T: RangeBounds<u64>> TryFrom<(T, u64)> for ShardScheme {
type Error = ShardSchemeRangeError;
fn try_from((range, total): (T, u64)) -> Result<Self, Self::Error> {
let start = match range.start_bound() {
Bound::Excluded(num) => *num - 1,
Bound::Included(num) => *num,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Excluded(num) => *num - 1,
Bound::Included(num) => *num,
Bound::Unbounded => total - 1,
};
if start > end {
return Err(ShardSchemeRangeError::IdTooLarge { end, start, total });
}
Ok(Self::Range {
from: start,
to: end,
total,
})
}
}
#[derive(Debug)]
pub struct ClusterBuilder(ClusterConfig, ShardBuilder);
impl ClusterBuilder {
pub fn new(token: impl Into<String>, intents: Intents) -> Self {
Self::_new(token.into(), intents)
}
fn _new(mut token: String, intents: Intents) -> Self {
if !token.starts_with("Bot ") {
token.insert_str(0, "Bot ");
}
let http_client = Client::new(token.clone());
let shard_config =
ShardBuilder::new(token.clone(), intents).http_client(http_client.clone());
Self(
ClusterConfig {
http_client,
shard_config: shard_config.0,
shard_scheme: ShardScheme::Auto,
queue: Arc::new(Box::new(LocalQueue::new())),
resume_sessions: HashMap::new(),
},
ShardBuilder::new(token, intents),
)
}
pub async fn build(mut self) -> Result<Cluster, ClusterStartError> {
if self.0.shard_config.gateway_url.is_none() {
let gateway_url = (self.1)
.0
.http_client
.gateway()
.authed()
.await
.ok()
.map(|s| s.url);
self = self.gateway_url(gateway_url);
}
self.0.shard_config = (self.1).0;
Cluster::new_with_config(self.0).await
}
pub fn gateway_url(mut self, gateway_url: Option<String>) -> Self {
self.1 = self.1.gateway_url(gateway_url);
self
}
pub fn http_client(mut self, http_client: Client) -> Self {
self.1 = self.1.http_client(http_client);
self
}
pub fn large_threshold(mut self, large_threshold: u64) -> Result<Self, LargeThresholdError> {
self.1 = self.1.large_threshold(large_threshold)?;
Ok(self)
}
pub fn presence(mut self, presence: UpdateStatusInfo) -> Self {
self.1 = self.1.presence(presence);
self
}
pub fn shard_scheme(mut self, scheme: ShardScheme) -> Self {
self.0.shard_scheme = scheme;
self
}
pub fn queue(mut self, queue: Arc<Box<dyn Queue>>) -> Self {
self.0.queue = Arc::clone(&queue);
self.1 = self.1.queue(queue);
self
}
pub fn resume_sessions(mut self, resume_sessions: HashMap<u64, ResumeSession>) -> Self {
self.0.resume_sessions = resume_sessions;
self
}
}
impl<T: Into<String>> From<(T, Intents)> for ClusterBuilder {
fn from((token, intents): (T, Intents)) -> Self {
Self::new(token, intents)
}
}
#[cfg(test)]
mod tests {
use super::{ClusterBuilder, ShardScheme, ShardSchemeRangeError};
use crate::Intents;
use static_assertions::{assert_fields, assert_impl_all};
use std::{
convert::TryFrom,
error::Error,
fmt::{Debug, Display},
hash::Hash,
};
assert_fields!(ShardSchemeRangeError::IdTooLarge: end, start, total);
assert_fields!(ShardScheme::Range: from, to, total);
assert_impl_all!(ClusterBuilder: Debug, From<(String, Intents)>, Send, Sync);
assert_impl_all!(ShardSchemeRangeError: Debug, Display, Error, Send, Sync);
assert_impl_all!(
ShardScheme: Clone,
Debug,
Default,
Eq,
Hash,
PartialEq,
Send,
Sync
);
#[test]
fn test_shard_scheme() -> Result<(), Box<dyn Error>> {
assert_eq!(
ShardScheme::Range {
from: 0,
to: 9,
total: 10,
},
ShardScheme::try_from((0..=9, 10))?
);
Ok(())
}
}