use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
#[cfg(feature = "framework")]
use std::sync::OnceLock;
use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::StreamExt;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{sleep, timeout, Duration, Instant};
use tracing::{debug, info, instrument, warn};
use typemap_rev::TypeMap;
#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
use super::{
ShardId,
ShardManager,
ShardMessenger,
ShardQueuerMessage,
ShardRunner,
ShardRunnerInfo,
ShardRunnerOptions,
};
#[cfg(feature = "cache")]
use crate::cache::Cache;
use crate::client::{EventHandler, RawEventHandler};
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage};
use crate::http::Http;
use crate::internal::prelude::*;
use crate::internal::tokio::spawn_named;
use crate::model::gateway::{GatewayIntents, ShardInfo};
const WAIT_BETWEEN_BOOTS_IN_SECONDS: u64 = 5;
pub struct ShardQueuer {
pub data: Arc<RwLock<TypeMap>>,
pub event_handlers: Vec<Arc<dyn EventHandler>>,
pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
pub last_start: Option<Instant>,
pub manager: Arc<ShardManager>,
pub queue: VecDeque<ShardInfo>,
pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
pub rx: Receiver<ShardQueuerMessage>,
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
pub ws_url: Arc<Mutex<String>>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
pub http: Arc<Http>,
pub intents: GatewayIntents,
pub presence: Option<PresenceData>,
}
impl ShardQueuer {
#[instrument(skip(self))]
pub async fn run(&mut self) {
const TIMEOUT: Duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
loop {
match timeout(TIMEOUT, self.rx.next()).await {
Ok(Some(ShardQueuerMessage::Shutdown)) => {
debug!("[Shard Queuer] Received to shutdown.");
self.shutdown_runners().await;
break;
},
Ok(Some(ShardQueuerMessage::ShutdownShard(shard, code))) => {
debug!("[Shard Queuer] Received to shutdown shard {} with {}.", shard.0, code);
self.shutdown(shard, code).await;
},
Ok(Some(ShardQueuerMessage::Start(id, total))) => {
debug!("[Shard Queuer] Received to start shard {} of {}.", id.0, total.0);
self.checked_start(id, total.0).await;
},
Ok(None) => break,
Err(_) => {
if let Some(shard) = self.queue.pop_front() {
self.checked_start(shard.id, shard.total).await;
}
},
}
}
}
#[instrument(skip(self))]
async fn check_last_start(&mut self) {
let Some(instant) = self.last_start else { return };
let duration = Duration::from_secs(WAIT_BETWEEN_BOOTS_IN_SECONDS);
let elapsed = instant.elapsed();
if let Some(to_sleep) = duration.checked_sub(elapsed) {
sleep(to_sleep).await;
}
}
#[instrument(skip(self))]
async fn checked_start(&mut self, id: ShardId, total: u32) {
debug!("[Shard Queuer] Checked start for shard {} out of {}", id, total);
self.check_last_start().await;
if let Err(why) = self.start(id, total).await {
warn!("[Shard Queuer] Err starting shard {}: {:?}", id, why);
info!("[Shard Queuer] Re-queueing start of shard {}", id);
self.queue.push_back(ShardInfo::new(id, total));
}
self.last_start = Some(Instant::now());
}
#[instrument(skip(self))]
async fn start(&mut self, id: ShardId, total: u32) -> Result<()> {
let shard_info = ShardInfo::new(id, total);
let mut shard = Shard::new(
Arc::clone(&self.ws_url),
self.http.token(),
shard_info,
self.intents,
self.presence.clone(),
)
.await?;
let cloned_http = Arc::clone(&self.http);
shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));
let mut runner = ShardRunner::new(ShardRunnerOptions {
data: Arc::clone(&self.data),
event_handlers: self.event_handlers.clone(),
raw_event_handlers: self.raw_event_handlers.clone(),
#[cfg(feature = "framework")]
framework: self.framework.get().cloned(),
manager: Arc::clone(&self.manager),
#[cfg(feature = "voice")]
voice_manager: self.voice_manager.clone(),
shard,
#[cfg(feature = "cache")]
cache: Arc::clone(&self.cache),
http: Arc::clone(&self.http),
});
let runner_info = ShardRunnerInfo {
latency: None,
runner_tx: ShardMessenger::new(&runner),
stage: ConnectionStage::Disconnected,
};
spawn_named("shard_queuer::stop", async move {
drop(Box::pin(runner.run()).await);
debug!("[ShardRunner {:?}] Stopping", runner.shard.shard_info());
});
self.runners.lock().await.insert(id, runner_info);
Ok(())
}
#[instrument(skip(self))]
async fn shutdown_runners(&mut self) {
let keys = {
let runners = self.runners.lock().await;
if runners.is_empty() {
return;
}
runners.keys().copied().collect::<Vec<_>>()
};
info!("Shutting down all shards");
for shard_id in keys {
self.shutdown(shard_id, 1000).await;
}
}
#[instrument(skip(self))]
pub async fn shutdown(&mut self, shard_id: ShardId, code: u16) {
info!("Shutting down shard {}", shard_id);
if let Some(runner) = self.runners.lock().await.get(&shard_id) {
let msg = ShardRunnerMessage::Shutdown(shard_id, code);
if let Err(why) = runner.runner_tx.tx.unbounded_send(msg) {
warn!(
"Failed to cleanly shutdown shard {} when sending message to shard runner: {:?}",
shard_id,
why,
);
}
}
}
}