use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
use futures::StreamExt;
use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout;
use tracing::{info, instrument, warn};
use typemap_rev::TypeMap;
use super::{
ShardId,
ShardManagerMessage,
ShardManagerMonitor,
ShardQueuer,
ShardQueuerMessage,
ShardRunnerInfo,
};
#[cfg(feature = "voice")]
use crate::client::bridge::voice::VoiceGatewayManager;
use crate::client::{EventHandler, RawEventHandler};
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::internal::prelude::*;
use crate::internal::tokio::spawn_named;
use crate::model::gateway::GatewayIntents;
use crate::CacheAndHttp;
#[derive(Debug)]
pub struct ShardManager {
monitor_tx: Sender<ShardManagerMessage>,
pub runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>,
shard_index: u64,
shard_init: u64,
shard_total: u64,
shard_queuer: Sender<ShardQueuerMessage>,
shard_shutdown: Receiver<ShardId>,
}
impl ShardManager {
pub async fn new(opt: ShardManagerOptions<'_>) -> (Arc<Mutex<Self>>, ShardManagerMonitor) {
let (thread_tx, thread_rx) = mpsc::unbounded();
let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded();
let runners = Arc::new(Mutex::new(HashMap::new()));
let (shutdown_send, shutdown_recv) = mpsc::unbounded();
let mut shard_queuer = ShardQueuer {
data: Arc::clone(opt.data),
event_handler: opt.event_handler.as_ref().map(Arc::clone),
raw_event_handler: opt.raw_event_handler.as_ref().map(Arc::clone),
#[cfg(feature = "framework")]
framework: Arc::clone(opt.framework),
last_start: None,
manager_tx: thread_tx.clone(),
queue: VecDeque::new(),
runners: Arc::clone(&runners),
rx: shard_queue_rx,
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager.clone(),
ws_url: Arc::clone(opt.ws_url),
cache_and_http: Arc::clone(opt.cache_and_http),
intents: opt.intents,
};
spawn_named("shard_queuer::run", async move {
shard_queuer.run().await;
});
let manager = Arc::new(Mutex::new(Self {
monitor_tx: thread_tx,
shard_index: opt.shard_index,
shard_init: opt.shard_init,
shard_queuer: shard_queue_tx,
shard_total: opt.shard_total,
shard_shutdown: shutdown_recv,
runners,
}));
(Arc::clone(&manager), ShardManagerMonitor {
rx: thread_rx,
manager,
shutdown: shutdown_send,
})
}
pub async fn has(&self, shard_id: ShardId) -> bool {
self.runners.lock().await.contains_key(&shard_id)
}
#[instrument(skip(self))]
pub fn initialize(&mut self) -> Result<()> {
let shard_to = self.shard_index + self.shard_init;
for shard_id in self.shard_index..shard_to {
let shard_total = self.shard_total;
self.boot([ShardId(shard_id), ShardId(shard_total)]);
}
Ok(())
}
#[instrument(skip(self))]
pub async fn set_shards(&mut self, index: u64, init: u64, total: u64) {
self.shutdown_all().await;
self.shard_index = index;
self.shard_init = init;
self.shard_total = total;
}
#[instrument(skip(self))]
pub async fn restart(&mut self, shard_id: ShardId) {
info!("Restarting shard {}", shard_id);
self.shutdown(shard_id, 4000).await;
let shard_total = self.shard_total;
self.boot([shard_id, ShardId(shard_total)]);
}
#[instrument(skip(self))]
pub async fn shards_instantiated(&self) -> Vec<ShardId> {
self.runners.lock().await.keys().copied().collect()
}
#[instrument(skip(self))]
pub async fn shutdown(&mut self, shard_id: ShardId, code: u16) {
const TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(5);
info!("Shutting down shard {}", shard_id);
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::ShutdownShard(shard_id, code)));
match timeout(TIMEOUT, self.shard_shutdown.next()).await {
Ok(Some(shutdown_shard_id)) => {
if shutdown_shard_id != shard_id {
warn!(
"Failed to cleanly shutdown shard {}: Shutdown channel sent incorrect ID",
shard_id,
);
}
},
Ok(None) => (),
Err(why) => {
warn!("Failed to cleanly shutdown shard {}, reached timeout: {:?}", shard_id, why);
},
}
self.runners.lock().await.remove(&shard_id);
}
#[instrument(skip(self))]
pub async fn shutdown_all(&mut self) {
let keys = {
let runners = self.runners.lock().await;
if runners.is_empty() {
return;
}
runners.keys().copied().collect::<Vec<_>>()
};
info!("Shutting down all shards");
for shard_id in keys {
self.shutdown(shard_id, 1000).await;
}
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
drop(self.monitor_tx.unbounded_send(ShardManagerMessage::ShutdownInitiated));
}
#[instrument(skip(self))]
fn boot(&mut self, shard_info: [ShardId; 2]) {
info!("Telling shard queuer to start shard {}", shard_info[0]);
let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]);
drop(self.shard_queuer.unbounded_send(msg));
}
}
impl Drop for ShardManager {
fn drop(&mut self) {
drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown));
drop(self.monitor_tx.unbounded_send(ShardManagerMessage::ShutdownInitiated));
}
}
pub struct ShardManagerOptions<'a> {
pub data: &'a Arc<RwLock<TypeMap>>,
pub event_handler: &'a Option<Arc<dyn EventHandler>>,
pub raw_event_handler: &'a Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: &'a Arc<dyn Framework + Send + Sync>,
pub shard_index: u64,
pub shard_init: u64,
pub shard_total: u64,
#[cfg(feature = "voice")]
pub voice_manager: &'a Option<Arc<dyn VoiceGatewayManager + Send + Sync + 'static>>,
pub ws_url: &'a Arc<Mutex<String>>,
pub cache_and_http: &'a Arc<CacheAndHttp>,
pub intents: GatewayIntents,
}