use crate::client::{DispatchEvent, DispatchEventType};
use crate::model::{Emoji, Message};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc};
use tokio::time::{self, Duration, Instant};
#[derive(Debug, Clone)]
pub struct CollectorOptions {
pub time: Option<Duration>,
pub max: Option<usize>,
}
impl Default for CollectorOptions {
fn default() -> Self {
Self {
time: Some(Duration::from_secs(30)),
max: None,
}
}
}
#[derive(Clone)]
pub struct CollectorHub {
tx: broadcast::Sender<DispatchEvent>,
}
impl CollectorHub {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(256);
Self { tx }
}
pub fn dispatch(&self, event: DispatchEvent) {
let _ = self.tx.send(event);
}
pub fn message_collector<F>(&self, options: CollectorOptions, filter: F) -> MessageCollector
where
F: Fn(&Message) -> bool + Send + Sync + 'static,
{
let mut rx = self.tx.subscribe();
let (out_tx, out_rx) = mpsc::unbounded_channel();
let filter = Arc::new(filter);
tokio::spawn(async move {
let deadline = options.time.map(|t| Instant::now() + t);
let mut collected = 0usize;
loop {
if let Some(max) = options.max {
if collected >= max {
break;
}
}
let event = if let Some(deadline) = deadline {
let now = Instant::now();
if now >= deadline {
break;
}
match time::timeout_at(deadline, rx.recv()).await {
Ok(Ok(evt)) => evt,
Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
Ok(Err(broadcast::error::RecvError::Closed)) => break,
Err(_) => break,
}
} else {
match rx.recv().await {
Ok(evt) => evt,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
};
if event.kind != DispatchEventType::MessageCreate {
continue;
}
let Ok(message) = serde_json::from_value::<Message>(event.data.clone()) else {
continue;
};
if !(filter)(&message) {
continue;
}
if out_tx.send(message).is_err() {
break;
}
collected += 1;
}
});
MessageCollector { rx: out_rx }
}
pub fn reaction_collector<F>(&self, options: CollectorOptions, filter: F) -> ReactionCollector
where
F: Fn(&ReactionCollectEvent) -> bool + Send + Sync + 'static,
{
let mut rx = self.tx.subscribe();
let (out_tx, out_rx) = mpsc::unbounded_channel();
let filter = Arc::new(filter);
tokio::spawn(async move {
let deadline = options.time.map(|t| Instant::now() + t);
let mut collected = 0usize;
loop {
if let Some(max) = options.max {
if collected >= max {
break;
}
}
let event = if let Some(deadline) = deadline {
let now = Instant::now();
if now >= deadline {
break;
}
match time::timeout_at(deadline, rx.recv()).await {
Ok(Ok(evt)) => evt,
Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
Ok(Err(broadcast::error::RecvError::Closed)) => break,
Err(_) => break,
}
} else {
match rx.recv().await {
Ok(evt) => evt,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
};
let Some(reaction_event) = ReactionCollectEvent::from_dispatch(&event) else {
continue;
};
if !(filter)(&reaction_event) {
continue;
}
if out_tx.send(reaction_event).is_err() {
break;
}
collected += 1;
}
});
ReactionCollector { rx: out_rx }
}
}
impl Default for CollectorHub {
fn default() -> Self {
Self::new()
}
}
pub struct MessageCollector {
rx: mpsc::UnboundedReceiver<Message>,
}
impl MessageCollector {
pub async fn next(&mut self) -> Option<Message> {
self.rx.recv().await
}
pub async fn collect(mut self) -> Vec<Message> {
let mut out = Vec::new();
while let Some(item) = self.rx.recv().await {
out.push(item);
}
out
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReactionEventType {
Add,
Remove,
}
#[derive(Debug, Clone)]
pub struct ReactionCollectEvent {
pub kind: ReactionEventType,
pub channel_id: String,
pub message_id: String,
pub user_id: String,
pub guild_id: Option<String>,
pub emoji: Emoji,
}
impl ReactionCollectEvent {
fn from_dispatch(event: &DispatchEvent) -> Option<Self> {
let kind = match event.kind {
DispatchEventType::MessageReactionAdd => ReactionEventType::Add,
DispatchEventType::MessageReactionRemove => ReactionEventType::Remove,
_ => return None,
};
let data = &event.data;
let channel_id = data.get("channel_id")?.as_str()?.to_string();
let message_id = data.get("message_id")?.as_str()?.to_string();
let user_id = data.get("user_id")?.as_str()?.to_string();
let guild_id = data
.get("guild_id")
.and_then(Value::as_str)
.map(ToOwned::to_owned);
let emoji = serde_json::from_value::<Emoji>(data.get("emoji")?.clone()).ok()?;
Some(Self {
kind,
channel_id,
message_id,
user_id,
guild_id,
emoji,
})
}
}
pub struct ReactionCollector {
rx: mpsc::UnboundedReceiver<ReactionCollectEvent>,
}
impl ReactionCollector {
pub async fn next(&mut self) -> Option<ReactionCollectEvent> {
self.rx.recv().await
}
pub async fn collect(mut self) -> Vec<ReactionCollectEvent> {
let mut out = Vec::new();
while let Some(item) = self.rx.recv().await {
out.push(item);
}
out
}
}