use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[cfg(feature = "framework")]
use std::sync::OnceLock;
use std::time::Duration;
use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
use futures::{SinkExt, StreamExt};
use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout;
use tracing::{info, instrument, warn};
use typemap_rev::TypeMap;
#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
#[cfg(feature = "cache")]
use crate::cache::Cache;
use crate::client::{EventHandler, RawEventHandler};
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::gateway::{ConnectionStage, GatewayError, PresenceData};
use crate::http::Http;
use crate::internal::prelude::*;
use crate::internal::tokio::spawn_named;
use crate::model::gateway::GatewayIntents;
#[derive(Debug)]
pub struct ShardManager {
return_value_tx: Mutex<Sender<Result<(), GatewayError>>>,
pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
shard_index: AtomicU32,
shard_init: AtomicU32,
shard_total: AtomicU32,
shard_queuer: Sender<ShardQueuerMessage>,
shard_shutdown: Mutex<Receiver<ShardId>>,
shard_shutdown_send: Sender<ShardId>,
gateway_intents: GatewayIntents,
}
impl ShardManager {
#[must_use]
pub fn new(opt: ShardManagerOptions) -> (Arc<Self>, Receiver<Result<(), GatewayError>>) {
let (return_value_tx, return_value_rx) = mpsc::unbounded();
let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded();
let runners = Arc::new(Mutex::new(HashMap::new()));
let (shutdown_send, shutdown_recv) = mpsc::unbounded();
let manager = Arc::new(Self {
return_value_tx: Mutex::new(return_value_tx),
shard_index: AtomicU32::new(opt.shard_index),
shard_init: AtomicU32::new(opt.shard_init),
shard_queuer: shard_queue_tx,
shard_total: AtomicU32::new(opt.shard_total),
shard_shutdown: Mutex::new(shutdown_recv),
shard_shutdown_send: shutdown_send,
runners: Arc::clone(&runners),
gateway_intents: opt.intents,
});
let mut shard_queuer = ShardQueuer {
data: opt.data,
event_handlers: opt.event_handlers,
raw_event_handlers: opt.raw_event_handlers,
#[cfg(feature = "framework")]
framework: opt.framework,
last_start: None,
manager: Arc::clone(&manager),
queue: VecDeque::new(),
runners,
rx: shard_queue_rx,
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager,
ws_url: opt.ws_url,
#[cfg(feature = "cache")]
cache: opt.cache,
http: opt.http,
intents: opt.intents,
presence: opt.presence,
};
spawn_named("shard_queuer::run", async move {
shard_queuer.run().await;
});
(Arc::clone(&manager), return_value_rx)
}
pub async fn has(&self, shard_id: ShardId) -> bool {
self.runners.lock().await.contains_key(&shard_id)
}
#[instrument(skip(self))]
#[allow(clippy::missing_errors_doc)] pub fn initialize(&self) -> Result<()> {
let shard_index = self.shard_index.load(Ordering::Relaxed);
let shard_init = self.shard_init.load(Ordering::Relaxed);
let shard_total = self.shard_total.load(Ordering::Relaxed);
let shard_to = shard_index + shard_init;
for shard_id in shard_index..shard_to {
self.boot([ShardId(shard_id), ShardId(shard_total)]);
}
Ok(())
}
#[instrument(skip(self))]
pub async fn set_shards(&self, index: u32, init: u32, total: u32) {
self.shutdown_all().await;
self.shard_index.store(index, Ordering::Relaxed);
self.shard_init.store(init, Ordering::Relaxed);
self.shard_total.store(total, Ordering::Relaxed);
}
#[instrument(skip(self))]
pub async fn restart(&self, shard_id: ShardId) {
info!("Restarting shard {}", shard_id);
self.shutdown(shard_id, 4000).await;
let shard_total = self.shard_total.load(Ordering::Relaxed);
self.boot([shard_id, ShardId(shard_total)]);
}
#[instrument(skip(self))]
pub async fn shards_instantiated(&self) -> Vec<ShardId> {
self.runners.lock().await.keys().copied().collect()
}
#[instrument(skip(self))]
pub async fn shutdown(&self, shard_id: ShardId, code: u16) {
const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(5);
info!("Shutting down shard {}", shard_id);
{
let mut shard_shutdown = self.shard_shutdown.lock().await;
drop(
self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)),
);
match timeout(TIMEOUT, shard_shutdown.next()).await {
Ok(Some(shutdown_shard_id)) => {
if shutdown_shard_id != shard_id {
warn!(
"Failed to cleanly shutdown shard {}: Shutdown channel sent incorrect ID",
shard_id,
);
}
},
Ok(None) => (),
Err(why) => {
warn!(
"Failed to cleanly shutdown shard {}, reached timeout: {:?}",
shard_id, why
);
},
}
}
self.runners.lock().await.remove(&shard_id);
}
#[instrument(skip(self))]
pub async fn shutdown_all(&self) {
let keys = {
let runners = self.runners.lock().await;
if runners.is_empty() {
return;
}
runners.keys().copied().collect::<Vec<_>>()
};
info!("Shutting down all shards");
for shard_id in keys {
self.shutdown(shard_id, 1000).await;
}
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
drop(self.return_value_tx.lock().await.unbounded_send(Ok(())));
}
#[instrument(skip(self))]
fn boot(&self, shard_info: [ShardId; 2]) {
info!("Telling shard queuer to start shard {}", shard_info[0]);
let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]);
drop(self.shard_queuer.unbounded_send(msg));
}
#[must_use]
pub fn intents(&self) -> GatewayIntents {
self.gateway_intents
}
pub async fn return_with_value(&self, ret: Result<(), GatewayError>) {
if let Err(e) = self.return_value_tx.lock().await.send(ret).await {
tracing::warn!("failed to send return value: {}", e);
}
}
pub fn shutdown_finished(&self, id: ShardId) {
if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
tracing::warn!("failed to notify about finished shutdown: {}", e);
}
}
pub async fn restart_shard(&self, id: ShardId) {
self.restart(id).await;
if let Err(e) = self.shard_shutdown_send.unbounded_send(id) {
tracing::warn!("failed to notify about finished shutdown: {}", e);
}
}
pub async fn update_shard_latency_and_stage(
&self,
id: ShardId,
latency: Option<Duration>,
stage: ConnectionStage,
) {
if let Some(runner) = self.runners.lock().await.get_mut(&id) {
runner.latency = latency;
runner.stage = stage;
}
}
}
impl Drop for ShardManager {
fn drop(&mut self) {
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
}
}
pub struct ShardManagerOptions {
pub data: Arc<RwLock<TypeMap>>,
pub event_handlers: Vec<Arc<dyn EventHandler>>,
pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
pub shard_index: u32,
pub shard_init: u32,
pub shard_total: u32,
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
pub ws_url: Arc<Mutex<String>>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
pub http: Arc<Http>,
pub intents: GatewayIntents,
pub presence: Option<PresenceData>,
}