#![allow(deprecated)]
use futures::future::pending;
use futures::{Stream, StreamExt as _};
use crate::gateway::{CollectorCallback, ShardMessenger};
use crate::model::prelude::*;
pub fn collect<T: Send + 'static>(
shard: &ShardMessenger,
extractor: impl Fn(&Event) -> Option<T> + Send + Sync + 'static,
) -> impl Stream<Item = T> {
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
shard.add_collector(CollectorCallback(Box::new(move |event| match extractor(event) {
Some(item) => sender.send(item).is_ok(),
None => !sender.is_closed(),
})));
futures::stream::poll_fn(move |cx| receiver.poll_recv(cx))
}
macro_rules! make_specific_collector {
(
$( #[ $($meta:tt)* ] )*
$collector_type:ident, $item_type:ident,
$extractor:pat => $extracted_item:ident,
$( $filter_name:ident: $filter_type:ty => $filter_passes:expr, )*
) => {
#[doc = concat!("A [`", stringify!($collector_type), "`] receives [`", stringify!($item_type), "`]'s match the given filters for a set duration.")]
$( #[ $($meta)* ] )*
#[must_use]
pub struct $collector_type {
shard: ShardMessenger,
duration: Option<std::time::Duration>,
filter: Option<Box<dyn Fn(&$item_type) -> bool + Send + Sync>>,
$( $filter_name: Option<$filter_type>, )*
}
impl $collector_type {
pub fn new(shard: impl AsRef<ShardMessenger>) -> Self {
Self {
shard: shard.as_ref().clone(),
duration: None,
filter: None,
$( $filter_name: None, )*
}
}
pub fn timeout(mut self, duration: std::time::Duration) -> Self {
self.duration = Some(duration);
self
}
pub fn filter(mut self, filter: impl Fn(&$item_type) -> bool + Send + Sync + 'static) -> Self {
self.filter = Some(Box::new(filter));
self
}
$(
#[doc = concat!("Filters [`", stringify!($item_type), "`]'s by a specific [`", stringify!($filter_type), "`].")]
pub fn $filter_name(mut self, $filter_name: $filter_type) -> Self {
self.$filter_name = Some($filter_name);
self
}
)*
#[doc = concat!("Returns a [`Stream`] over all collected [`", stringify!($item_type), "`].")]
pub fn stream(self) -> impl Stream<Item = $item_type> {
let filters_pass = move |$extracted_item: &$item_type| {
$( if let Some($filter_name) = &self.$filter_name {
if !$filter_passes {
return false;
}
} )*
if let Some(custom_filter) = &self.filter {
if !custom_filter($extracted_item) {
return false;
}
}
true
};
let timeout = async move { match self.duration {
Some(d) => tokio::time::sleep(d).await,
None => pending::<()>().await,
} };
let stream = collect(&self.shard, move |event| match event {
$extractor if filters_pass($extracted_item) => Some($extracted_item.clone()),
_ => None,
});
stream.take_until(Box::pin(timeout))
}
#[deprecated = "use `.stream()` instead"]
pub fn build(self) -> impl Stream<Item = $item_type> {
self.stream()
}
#[doc = concat!("Returns the next [`", stringify!($item_type), "`] which passes the filters.")]
#[doc = concat!("You can also call `.await` on the [`", stringify!($collector_type), "`] directly.")]
pub async fn next(self) -> Option<$item_type> {
self.stream().next().await
}
}
impl std::future::IntoFuture for $collector_type {
type Output = Option<$item_type>;
type IntoFuture = futures::future::BoxFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.next())
}
}
};
}
make_specific_collector!(
ComponentInteractionCollector, ComponentInteraction,
Event::InteractionCreate(InteractionCreateEvent {
interaction: Interaction::Component(interaction),
}) => interaction,
author_id: UserId => interaction.user.id == *author_id,
channel_id: ChannelId => interaction.channel_id == *channel_id,
guild_id: GuildId => interaction.guild_id.map_or(true, |x| x == *guild_id),
message_id: MessageId => interaction.message.id == *message_id,
custom_ids: Vec<String> => custom_ids.contains(&interaction.data.custom_id),
);
make_specific_collector!(
ModalInteractionCollector, ModalInteraction,
Event::InteractionCreate(InteractionCreateEvent {
interaction: Interaction::Modal(interaction),
}) => interaction,
author_id: UserId => interaction.user.id == *author_id,
channel_id: ChannelId => interaction.channel_id == *channel_id,
guild_id: GuildId => interaction.guild_id.map_or(true, |g| g == *guild_id),
message_id: MessageId => interaction.message.as_ref().map_or(true, |m| m.id == *message_id),
custom_ids: Vec<String> => custom_ids.contains(&interaction.data.custom_id),
);
make_specific_collector!(
ReactionCollector, Reaction,
Event::ReactionAdd(ReactionAddEvent { reaction }) => reaction,
author_id: UserId => reaction.user_id.map_or(true, |a| a == *author_id),
channel_id: ChannelId => reaction.channel_id == *channel_id,
guild_id: GuildId => reaction.guild_id.map_or(true, |g| g == *guild_id),
message_id: MessageId => reaction.message_id == *message_id,
);
make_specific_collector!(
MessageCollector, Message,
Event::MessageCreate(MessageCreateEvent { message }) => message,
author_id: UserId => message.author.id == *author_id,
channel_id: ChannelId => message.channel_id == *channel_id,
guild_id: GuildId => message.guild_id.map_or(true, |g| g == *guild_id),
);
make_specific_collector!(
#[deprecated = "prefer the stand-alone collect() function to collect arbitrary events"]
EventCollector, Event,
event => event,
);