use crate::{
model::{IncomingEvent, Opcode, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
player::PlayerManager,
};
use futures_util::{
lock::BiLock,
sink::SinkExt,
stream::{Stream, StreamExt},
};
use http::{header::HeaderName, Request, Response, StatusCode};
use std::{
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::{
net::TcpStream,
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
time as tokio_time,
};
use tokio_tungstenite::{
tungstenite::{client::IntoClientRequest, Error as TungsteniteError, Message},
MaybeTlsStream, WebSocketStream,
};
use twilight_model::id::{marker::UserMarker, Id};
#[derive(Debug)]
pub struct NodeError {
kind: NodeErrorType,
source: Option<Box<dyn Error + Send + Sync>>,
}
impl NodeError {
#[must_use = "retrieving the type has no effect if left unused"]
pub const fn kind(&self) -> &NodeErrorType {
&self.kind
}
#[must_use = "consuming the error and retrieving the source has no effect if left unused"]
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
self.source
}
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, self.source)
}
}
impl Display for NodeError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match &self.kind {
NodeErrorType::BuildingConnectionRequest { .. } => {
f.write_str("failed to build connection request")
}
NodeErrorType::Connecting { .. } => f.write_str("Failed to connect to the node"),
NodeErrorType::SerializingMessage { .. } => {
f.write_str("failed to serialize outgoing message as json")
}
NodeErrorType::Unauthorized { address, .. } => {
f.write_str("the authorization used to connect to node ")?;
Display::fmt(address, f)?;
f.write_str(" is invalid")
}
}
}
}
impl Error for NodeError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum NodeErrorType {
BuildingConnectionRequest,
Connecting,
SerializingMessage {
message: OutgoingEvent,
},
Unauthorized {
address: SocketAddr,
authorization: String,
},
}
#[derive(Debug)]
pub struct NodeSenderError {
kind: NodeSenderErrorType,
source: Option<Box<dyn Error + Send + Sync>>,
}
impl NodeSenderError {
pub const fn kind(&self) -> &NodeSenderErrorType {
&self.kind
}
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
self.source
}
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, self.source)
}
}
impl Display for NodeSenderError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match &self.kind {
NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
}
}
}
impl Error for NodeSenderError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum NodeSenderErrorType {
Sending,
}
pub struct IncomingEvents {
inner: UnboundedReceiver<IncomingEvent>,
}
impl IncomingEvents {
pub fn close(&mut self) {
self.inner.close();
}
}
impl Stream for IncomingEvents {
type Item = IncomingEvent;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_recv(cx)
}
}
pub struct NodeSender {
inner: UnboundedSender<OutgoingEvent>,
}
impl NodeSender {
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
self.inner.send(msg).map_err(|source| NodeSenderError {
kind: NodeSenderErrorType::Sending,
source: Some(Box::new(source)),
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct NodeConfig {
pub address: SocketAddr,
pub authorization: String,
pub resume: Option<Resume>,
pub user_id: Id<UserMarker>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct Resume {
pub timeout: u64,
}
impl Resume {
pub const fn new(seconds: u64) -> Self {
Self { timeout: seconds }
}
}
impl Default for Resume {
fn default() -> Self {
Self { timeout: 60 }
}
}
impl NodeConfig {
pub fn new(
user_id: Id<UserMarker>,
address: impl Into<SocketAddr>,
authorization: impl Into<String>,
resume: impl Into<Option<Resume>>,
) -> Self {
Self::_new(user_id, address.into(), authorization.into(), resume.into())
}
const fn _new(
user_id: Id<UserMarker>,
address: SocketAddr,
authorization: String,
resume: Option<Resume>,
) -> Self {
Self {
address,
authorization,
resume,
user_id,
}
}
}
#[derive(Debug)]
pub struct Node {
config: NodeConfig,
lavalink_tx: UnboundedSender<OutgoingEvent>,
players: PlayerManager,
stats: BiLock<Stats>,
}
impl Node {
pub async fn connect(
config: NodeConfig,
players: PlayerManager,
) -> Result<(Self, IncomingEvents), NodeError> {
let (bilock_left, bilock_right) = BiLock::new(Stats {
cpu: StatsCpu {
cores: 0,
lavalink_load: 0f64,
system_load: 0f64,
},
frames: None,
memory: StatsMemory {
allocated: 0,
free: 0,
used: 0,
reservable: 0,
},
players: 0,
playing_players: 0,
op: Opcode::Stats,
uptime: 0,
});
tracing::debug!("starting connection to {}", config.address);
let (conn_loop, lavalink_tx, lavalink_rx) =
Connection::connect(config.clone(), players.clone(), bilock_right).await?;
tracing::debug!("started connection to {}", config.address);
tokio::spawn(conn_loop.run());
Ok((
Self {
config,
lavalink_tx,
players,
stats: bilock_left,
},
IncomingEvents { inner: lavalink_rx },
))
}
pub const fn config(&self) -> &NodeConfig {
&self.config
}
pub const fn players(&self) -> &PlayerManager {
&self.players
}
pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
self.sender().send(event)
}
pub fn sender(&self) -> NodeSender {
NodeSender {
inner: self.lavalink_tx.clone(),
}
}
pub async fn stats(&self) -> Stats {
(*self.stats.lock().await).clone()
}
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
pub async fn penalty(&self) -> i32 {
let stats = self.stats.lock().await;
let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
let (deficit_frame, null_frame) = (
1.03f64
.powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64))
* 300f64
- 300f64,
(1.03f64
.powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64))
* 300f64
- 300f64)
* 2f64,
);
stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
}
}
struct Connection {
config: NodeConfig,
connection: WebSocketStream<MaybeTlsStream<TcpStream>>,
node_from: UnboundedReceiver<OutgoingEvent>,
node_to: UnboundedSender<IncomingEvent>,
players: PlayerManager,
stats: BiLock<Stats>,
}
impl Connection {
async fn connect(
config: NodeConfig,
players: PlayerManager,
stats: BiLock<Stats>,
) -> Result<
(
Self,
UnboundedSender<OutgoingEvent>,
UnboundedReceiver<IncomingEvent>,
),
NodeError,
> {
let connection = reconnect(&config).await?;
let (to_node, from_lavalink) = mpsc::unbounded_channel();
let (to_lavalink, from_node) = mpsc::unbounded_channel();
Ok((
Self {
config,
connection,
node_from: from_node,
node_to: to_node,
players,
stats,
},
to_lavalink,
from_lavalink,
))
}
async fn run(mut self) -> Result<(), NodeError> {
loop {
tokio::select! {
incoming = self.connection.next() => {
if let Some(Ok(incoming)) = incoming {
self.incoming(incoming).await?;
} else {
tracing::debug!("connection to {} closed, reconnecting", self.config.address);
self.connection = reconnect(&self.config).await?;
}
}
outgoing = self.node_from.recv() => {
if let Some(outgoing) = outgoing {
tracing::debug!(
"forwarding event to {}: {outgoing:?}",
self.config.address,
);
let payload = serde_json::to_string(&outgoing).map_err(|source| NodeError {
kind: NodeErrorType::SerializingMessage { message: outgoing },
source: Some(Box::new(source)),
})?;
let msg = Message::Text(payload);
self.connection.send(msg).await.unwrap();
} else {
tracing::debug!("node {} closed, ending connection", self.config.address);
break;
}
}
}
}
Ok(())
}
async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
tracing::debug!(
"received message from {}: {incoming:?}",
self.config.address,
);
let text = match incoming {
Message::Close(_) => {
tracing::debug!("got close, closing connection");
let _result = self.connection.send(Message::Close(None)).await;
return Ok(false);
}
Message::Ping(data) => {
tracing::debug!("got ping, sending pong");
let msg = Message::Pong(data);
let _result = self.connection.send(msg).await;
return Ok(true);
}
Message::Text(text) => text,
other => {
tracing::debug!("got pong or bytes payload: {other:?}");
return Ok(true);
}
};
let event = if let Ok(event) = serde_json::from_str(&text) {
event
} else {
tracing::warn!("unknown message from lavalink node: {text}");
return Ok(true);
};
match &event {
IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
IncomingEvent::Stats(stats) => self.stats(stats).await?,
_ => {}
}
if !self.node_to.is_closed() {
let _result = self.node_to.send(event);
}
Ok(true)
}
fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
let player = if let Some(player) = self.players.get(&update.guild_id) {
player
} else {
tracing::warn!(
"invalid player update for guild {}: {update:?}",
update.guild_id,
);
return Ok(());
};
player.set_position(update.state.position.unwrap_or(0));
player.set_time(update.state.time);
Ok(())
}
async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
*self.stats.lock().await = stats.clone();
Ok(())
}
}
impl Drop for Connection {
fn drop(&mut self) {
self.players
.players
.retain(|_, v| v.node().config().address != self.config.address);
}
}
fn connect_request(state: &NodeConfig) -> Result<Request<()>, NodeError> {
let mut request = format!("ws://{}", state.address)
.into_client_request()
.map_err(|source| NodeError {
kind: NodeErrorType::BuildingConnectionRequest,
source: Some(Box::new(source)),
})?;
let headers = request.headers_mut();
headers.insert("Authorization", state.authorization.parse().unwrap());
headers.insert("User-Id", state.user_id.get().into());
if state.resume.is_some() {
headers.insert("Resume-Key", state.address.to_string().parse().unwrap());
}
Ok(request)
}
async fn reconnect(
config: &NodeConfig,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
let (mut stream, res) = backoff(config).await?;
let headers = res.headers();
if let Some(resume) = config.resume.as_ref() {
let header = HeaderName::from_static("session-resumed");
if let Some(value) = headers.get(header) {
if value.as_bytes() == b"false" {
tracing::debug!("session to node {} didn't resume", config.address);
let payload = serde_json::json!({
"op": "configureResuming",
"key": config.address,
"timeout": resume.timeout,
});
let msg = Message::Text(serde_json::to_string(&payload).unwrap());
stream.send(msg).await.unwrap();
} else {
tracing::debug!("session to {} resumed", config.address);
}
}
}
Ok(stream)
}
async fn backoff(
config: &NodeConfig,
) -> Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response<()>), NodeError> {
let mut seconds = 1;
loop {
let request = connect_request(config)?;
match tokio_tungstenite::connect_async(request).await {
Ok((stream, response)) => return Ok((stream, response)),
Err(source) => {
tracing::warn!("failed to connect to node {source}: {:?}", config.address);
if matches!(&source, TungsteniteError::Http(resp) if resp.status() == StatusCode::UNAUTHORIZED)
{
return Err(NodeError {
kind: NodeErrorType::Unauthorized {
address: config.address,
authorization: config.authorization.clone(),
},
source: None,
});
}
if seconds > 64 {
tracing::debug!("no longer trying to connect to node {}", config.address);
return Err(NodeError {
kind: NodeErrorType::Connecting,
source: Some(Box::new(source)),
});
}
tracing::debug!(
"waiting {seconds} seconds before attempting to connect to node {} again",
config.address,
);
tokio_time::sleep(Duration::from_secs(seconds)).await;
seconds *= 2;
continue;
}
}
}
}
#[cfg(test)]
mod tests {
use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
use static_assertions::{assert_fields, assert_impl_all};
use std::{error::Error, fmt::Debug};
assert_fields!(NodeConfig: address, authorization, resume, user_id);
assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
assert_fields!(NodeErrorType::SerializingMessage: message);
assert_fields!(NodeErrorType::Unauthorized: address, authorization);
assert_impl_all!(NodeErrorType: Debug, Send, Sync);
assert_impl_all!(NodeError: Error, Send, Sync);
assert_impl_all!(Node: Debug, Send, Sync);
assert_fields!(Resume: timeout);
assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
}