use crate::{
model::{IncomingEvent, OutgoingEvent, VoiceUpdate},
node::{Node, NodeConfig, NodeError, Resume},
player::{Player, PlayerManager},
};
use dashmap::{mapref::one::Ref, DashMap};
use futures_channel::mpsc::{TrySendError, UnboundedReceiver};
use std::{
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
net::SocketAddr,
sync::Arc,
};
use twilight_model::{
gateway::{
event::Event,
payload::{VoiceServerUpdate, VoiceStateUpdate},
},
id::{GuildId, UserId},
};
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum ClientError {
NodesUnconfigured,
SendingVoiceUpdate {
source: TrySendError<OutgoingEvent>,
},
}
impl Display for ClientError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Self::NodesUnconfigured => f.write_str("no node has been configured"),
Self::SendingVoiceUpdate { .. } => f.write_str("couldn't send voice update to node"),
}
}
}
impl Error for ClientError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::NodesUnconfigured => None,
Self::SendingVoiceUpdate { source } => Some(source),
}
}
}
#[derive(Debug)]
enum VoiceStateHalf {
Server(VoiceServerUpdate),
State(Box<VoiceStateUpdate>),
}
#[derive(Debug, Default)]
struct LavalinkRef {
guilds: DashMap<GuildId, SocketAddr>,
nodes: DashMap<SocketAddr, Node>,
players: PlayerManager,
resume: Option<Resume>,
shard_count: u64,
user_id: UserId,
waiting: DashMap<GuildId, VoiceStateHalf>,
}
#[derive(Clone, Debug)]
pub struct Lavalink(Arc<LavalinkRef>);
impl Lavalink {
pub fn new(user_id: UserId, shard_count: u64) -> Self {
Self::_new_with_resume(user_id, shard_count, None)
}
pub fn new_with_resume(
user_id: UserId,
shard_count: u64,
resume: impl Into<Option<Resume>>,
) -> Self {
Self::_new_with_resume(user_id, shard_count, resume.into())
}
fn _new_with_resume(user_id: UserId, shard_count: u64, resume: Option<Resume>) -> Self {
Self(Arc::new(LavalinkRef {
guilds: DashMap::new(),
nodes: DashMap::new(),
players: PlayerManager::new(),
resume,
shard_count,
user_id,
waiting: DashMap::new(),
}))
}
pub async fn process(&self, event: &Event) -> Result<(), ClientError> {
tracing::trace!("processing event: {:?}", event);
let (guild_id, half) = match event {
Event::Ready(e) => {
let shard_id = e.shard.map_or(0, |[id, _]| id);
self.clear_shard_states(shard_id);
return Ok(());
}
Event::VoiceServerUpdate(e) => (e.guild_id, VoiceStateHalf::Server(e.clone())),
Event::VoiceStateUpdate(e) => {
if e.0.user_id != self.0.user_id {
tracing::trace!("got voice state update from another user");
return Ok(());
}
(e.0.guild_id, VoiceStateHalf::State(e.clone()))
}
_ => return Ok(()),
};
tracing::debug!(
"got voice server/state update for {:?}: {:?}",
guild_id,
half
);
let guild_id = match guild_id {
Some(guild_id) => guild_id,
None => {
tracing::trace!("event has no guild ID: {:?}", event);
return Ok(());
}
};
let update = {
let existing_half = match self.0.waiting.get(&guild_id) {
Some(existing_half) => existing_half,
None => {
tracing::debug!(
"guild {} is now waiting for other half; got: {:?}",
guild_id,
half
);
self.0.waiting.insert(guild_id, half);
return Ok(());
}
};
tracing::debug!(
"got both halves for {}: {:?}; {:?}",
guild_id,
half,
existing_half.value()
);
match (existing_half.value(), half) {
(VoiceStateHalf::Server(_), VoiceStateHalf::Server(server)) => {
tracing::debug!(
"got the same server half twice for guild {}: {:?}",
guild_id,
server
);
self.0
.waiting
.insert(guild_id, VoiceStateHalf::Server(server));
return Ok(());
}
(VoiceStateHalf::Server(ref server), VoiceStateHalf::State(ref state)) => {
VoiceUpdate::new(guild_id, &state.0.session_id, From::from(server.clone()))
}
(VoiceStateHalf::State(_), VoiceStateHalf::State(state)) => {
tracing::debug!(
"got the same state half twice for guild {}: {:?}",
guild_id,
state
);
self.0
.waiting
.insert(guild_id, VoiceStateHalf::State(state));
return Ok(());
}
(VoiceStateHalf::State(ref state), VoiceStateHalf::Server(ref server)) => {
VoiceUpdate::new(guild_id, &state.0.session_id, From::from(server.clone()))
}
}
};
tracing::debug!("removing guild {} from waiting list", guild_id);
self.0.waiting.remove(&guild_id);
tracing::debug!("getting player for guild {}", guild_id);
let player = self.player(guild_id).await?;
tracing::debug!("sending voice update for guild {}: {:?}", guild_id, update);
player
.send(update)
.map_err(|source| ClientError::SendingVoiceUpdate { source })?;
tracing::debug!("sent voice update for guild {}", guild_id);
Ok(())
}
pub async fn add(
&self,
address: SocketAddr,
authorization: impl Into<String>,
) -> Result<(Node, UnboundedReceiver<IncomingEvent>), NodeError> {
let config = NodeConfig {
address,
authorization: authorization.into(),
resume: self.0.resume.clone(),
shard_count: self.0.shard_count,
user_id: self.0.user_id,
};
let (node, rx) = Node::connect(config, self.0.players.clone()).await?;
self.0.nodes.insert(address, node.clone());
Ok((node, rx))
}
pub async fn remove(&self, address: SocketAddr) -> Option<(SocketAddr, Node)> {
self.0.nodes.remove(&address)
}
pub async fn best(&self) -> Result<Node, ClientError> {
let mut lowest = i32::MAX;
let mut best = None;
for node in self.0.nodes.iter() {
let penalty = node.value().penalty().await;
if penalty < lowest {
lowest = penalty;
best.replace(node.clone());
}
}
best.ok_or(ClientError::NodesUnconfigured)
}
pub fn players(&self) -> &PlayerManager {
&self.0.players
}
pub async fn player(&self, guild_id: GuildId) -> Result<Ref<'_, GuildId, Player>, ClientError> {
if let Some(player) = self.players().get(&guild_id) {
return Ok(player);
}
let node = self.best().await?;
Ok(self.players().get_or_insert(guild_id, node).downgrade())
}
fn clear_shard_states(&self, shard_id: u64) {
let shard_count = self.0.shard_count;
for r in self.0.waiting.iter() {
let guild_id = r.key();
if (guild_id.0 >> 22) % shard_count == shard_id {
self.0.waiting.remove(guild_id);
}
}
}
}
#[cfg(test)]
mod tests {
use super::{ClientError, Lavalink, VoiceStateHalf};
use static_assertions::{assert_fields, assert_impl_all};
use std::{error::Error, fmt::Debug};
assert_fields!(ClientError::SendingVoiceUpdate: source);
assert_impl_all!(ClientError: Clone, Debug, Error, PartialEq, Send, Sync);
assert_impl_all!(Lavalink: Clone, Debug, Send, Sync);
assert_impl_all!(VoiceStateHalf: Debug, Send, Sync);
}