use crate::{error::ReceiveMessageError, message::Message, Config, ConfigBuilder, Shard, ShardId};
use futures_util::{
future::BoxFuture,
stream::{FuturesUnordered, Stream, StreamExt},
};
#[cfg(feature = "twilight-http")]
use std::{
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
};
use std::{
ops::{Bound, Deref, DerefMut, Range, RangeBounds},
pin::Pin,
sync::mpsc,
task::{Context, Poll},
};
#[cfg(feature = "twilight-http")]
use twilight_http::Client;
use twilight_model::gateway::event::Event;
type FutureList<'a, Item> = FuturesUnordered<BoxFuture<'a, NextItemOutput<'a, Item>>>;
#[cfg(feature = "twilight-http")]
#[derive(Debug)]
pub struct StartRecommendedError {
pub(crate) kind: StartRecommendedErrorType,
pub(crate) source: Option<Box<dyn Error + Send + Sync>>,
}
#[cfg(feature = "twilight-http")]
impl Display for StartRecommendedError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self.kind {
StartRecommendedErrorType::Deserializing => {
f.write_str("payload isn't a recognized type")
}
StartRecommendedErrorType::Request => f.write_str("request failed to complete"),
}
}
}
#[cfg(feature = "twilight-http")]
impl Error for StartRecommendedError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}
#[cfg(feature = "twilight-http")]
#[derive(Debug)]
pub enum StartRecommendedErrorType {
Deserializing,
Request,
}
pub struct ShardEventStream<'a> {
futures: FutureList<'a, Event>,
sender: mpsc::Sender<&'a mut Shard>,
receiver: mpsc::Receiver<&'a mut Shard>,
}
impl<'a> ShardEventStream<'a> {
pub fn new(shards: impl Iterator<Item = &'a mut Shard>) -> Self {
let (sender, receiver) = mpsc::channel();
let mut this = Self {
futures: FuturesUnordered::new(),
sender,
receiver,
};
for shard in shards {
this.add_shard(shard);
}
this
}
fn add_shard(&mut self, shard: &'a mut Shard) {
self.futures.push(Box::pin(async {
let result = shard.next_event().await;
NextItemOutput { result, shard }
}));
}
}
impl<'a> Stream for ShardEventStream<'a> {
type Item = (ShardRef<'a>, Result<Event, ReceiveMessageError>);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
while let Some(shard) = self.receiver.try_iter().next() {
self.add_shard(shard);
}
match self.futures.poll_next_unpin(cx) {
Poll::Ready(Some(output)) => Poll::Ready(Some((
ShardRef {
channel: self.sender.clone(),
shard: Some(output.shard),
},
output.result,
))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub struct ShardMessageStream<'a> {
futures: FutureList<'a, Message>,
sender: mpsc::Sender<&'a mut Shard>,
receiver: mpsc::Receiver<&'a mut Shard>,
}
impl<'a> ShardMessageStream<'a> {
pub fn new(shards: impl Iterator<Item = &'a mut Shard>) -> Self {
let (sender, receiver) = mpsc::channel();
let mut this = Self {
futures: FuturesUnordered::new(),
sender,
receiver,
};
for shard in shards {
this.add_shard(shard);
}
this
}
fn add_shard(&mut self, shard: &'a mut Shard) {
self.futures.push(Box::pin(async {
let result = shard.next_message().await;
NextItemOutput { result, shard }
}));
}
}
impl<'a> Stream for ShardMessageStream<'a> {
type Item = (ShardRef<'a>, Result<Message, ReceiveMessageError>);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
while let Some(shard) = self.receiver.try_iter().next() {
self.add_shard(shard);
}
match self.futures.poll_next_unpin(cx) {
Poll::Ready(Some(output)) => Poll::Ready(Some((
ShardRef {
channel: self.sender.clone(),
shard: Some(output.shard),
},
output.result,
))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub struct ShardRef<'a> {
channel: mpsc::Sender<&'a mut Shard>,
shard: Option<&'a mut Shard>,
}
impl Deref for ShardRef<'_> {
type Target = Shard;
fn deref(&self) -> &Self::Target {
self.shard.as_ref().unwrap()
}
}
impl DerefMut for ShardRef<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.shard.as_mut().unwrap()
}
}
impl Drop for ShardRef<'_> {
fn drop(&mut self) {
if let Some(shard) = self.shard.take() {
_ = self.channel.send(shard);
}
}
}
struct NextItemOutput<'a, Item> {
result: Result<Item, ReceiveMessageError>,
shard: &'a mut Shard,
}
#[track_caller]
pub fn create_bucket<F: Fn(ShardId, ConfigBuilder) -> Config>(
bucket_id: u64,
concurrency: u64,
total: u64,
config: Config,
per_shard_config: F,
) -> impl Iterator<Item = Shard> {
assert!(bucket_id < total, "bucket id must be less than the total");
assert!(
concurrency < total,
"concurrency must be less than the total"
);
let concurrency = concurrency.try_into().unwrap();
(bucket_id..total).step_by(concurrency).map(move |index| {
let id = ShardId::new(index, total);
let config = per_shard_config(id, ConfigBuilder::with_config(config.clone()));
Shard::with_config(id, config)
})
}
#[track_caller]
pub fn create_range<F: Fn(ShardId, ConfigBuilder) -> Config>(
range: impl RangeBounds<u64>,
total: u64,
config: Config,
per_shard_config: F,
) -> impl Iterator<Item = Shard> {
let range = calculate_range(range, total);
range.map(move |index| {
let id = ShardId::new(index, total);
let config = per_shard_config(id, ConfigBuilder::with_config(config.clone()));
Shard::with_config(id, config)
})
}
#[cfg(feature = "twilight-http")]
pub async fn create_recommended<F: Fn(ShardId, ConfigBuilder) -> Config>(
client: &Client,
config: Config,
per_shard_config: F,
) -> Result<impl Iterator<Item = Shard>, StartRecommendedError> {
let request = client.gateway().authed();
let response = request.await.map_err(|source| StartRecommendedError {
kind: StartRecommendedErrorType::Request,
source: Some(Box::new(source)),
})?;
let info = response
.model()
.await
.map_err(|source| StartRecommendedError {
kind: StartRecommendedErrorType::Deserializing,
source: Some(Box::new(source)),
})?;
Ok(create_range(.., info.shards, config, per_shard_config))
}
fn calculate_range(range: impl RangeBounds<u64>, total: u64) -> Range<u64> {
let start = match range.start_bound() {
Bound::Excluded(from) => *from + 1,
Bound::Included(from) => *from,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Excluded(to) => *to,
Bound::Included(to) => *to + 1,
Bound::Unbounded => total,
};
assert!(start < total, "range start must be less than the total");
assert!(end <= total, "range end must be less than the total");
start..end
}
#[cfg(test)]
mod tests {
use super::{ShardEventStream, ShardMessageStream, ShardRef};
use futures_util::Stream;
use static_assertions::assert_impl_all;
use std::ops::{Deref, DerefMut};
assert_impl_all!(ShardEventStream<'_>: Send, Stream, Unpin);
assert_impl_all!(ShardMessageStream<'_>: Send, Stream, Unpin);
assert_impl_all!(ShardRef<'_>: Deref, DerefMut, Send);
}