#[macro_use]
extern crate tracing;
pub mod builders;
pub mod error;
pub mod gateway;
pub mod model;
use builders::*;
use gateway::LavalinkEventHandler;
use model::*;
use std::{
cmp::{max, min},
sync::Arc,
time::Duration,
};
#[cfg(feature = "tokio-02-marker")]
use async_tungstenite_compat as async_tungstenite;
#[cfg(feature = "tokio-02-marker")]
use reqwest_compat as reqwest;
#[cfg(feature = "tokio-02-marker")]
use tokio_compat as tokio;
use songbird::ConnectionInfo;
use http::Request;
use reqwest::{header::*, Client as ReqwestClient, Url};
#[cfg(all(feature = "native-marker", not(feature = "tokio-02-marker")))]
use tokio_native_tls::TlsStream;
#[cfg(all(feature = "rustls-marker", not(feature = "tokio-02-marker")))]
use tokio_rustls::client::TlsStream;
#[cfg(all(feature = "native-marker", feature = "tokio-02-marker"))]
use tokio_native_tls_compat::TlsStream;
#[cfg(all(feature = "rustls-marker", feature = "tokio-02-marker"))]
use tokio_rustls_compat::client::TlsStream;
use tokio::{net::TcpStream, sync::Mutex};
use regex::Regex;
use futures::stream::{SplitSink, SplitStream, StreamExt};
use async_tungstenite::{
stream::Stream,
tokio::{connect_async, TokioAdapter},
tungstenite::Message as TungsteniteMessage,
WebSocketStream,
};
use dashmap::{DashMap, DashSet};
pub const EQ_BASE: [f64; 15] = [
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
pub const EQ_BOOST: [f64; 15] = [
-0.075, 0.125, 0.125, 0.1, 0.1, 0.05, 0.075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.15, 0.05,
];
pub const EQ_METAL: [f64; 15] = [
0.0, 0.1, 0.1, 0.15, 0.13, 0.1, 0.0, 0.125, 0.175, 0.175, 0.125, 0.125, 0.1, 0.075, 0.0,
];
pub const EQ_PIANO: [f64; 15] = [
-0.25, -0.25, -0.125, 0.0, 0.25, 0.25, 0.0, -0.25, -0.25, 0.0, 0.0, 0.5, 0.25, -0.025, 0.0,
];
pub type WsStream =
WebSocketStream<Stream<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>;
pub type WebsocketConnection = Arc<Mutex<WsStream>>;
pub struct LavalinkClientInner {
pub rest_uri: String,
pub headers: HeaderMap,
pub socket_write: SplitSink<WsStream, TungsteniteMessage>,
pub nodes: Arc<DashMap<u64, Node>>,
pub loops: Arc<DashSet<u64>>,
}
#[derive(Clone)]
pub struct LavalinkClient {
pub inner: Arc<Mutex<LavalinkClientInner>>,
}
async fn event_loop(
mut read: SplitStream<WsStream>,
handler: impl LavalinkEventHandler + Send + Sync + 'static,
client: LavalinkClient,
) {
while let Some(Ok(resp)) = read.next().await {
if let TungsteniteMessage::Text(x) = &resp {
if let Ok(base_event) = serde_json::from_str::<GatewayEvent>(&x) {
match base_event.op.as_str() {
"stats" => {
if let Ok(stats) = serde_json::from_str::<Stats>(&x) {
handler.stats(client.clone(), stats).await;
}
}
"playerUpdate" => {
if let Ok(player_update) = serde_json::from_str::<PlayerUpdate>(&x) {
{
let client_clone = client.clone();
let client_lock = client_clone.inner.lock().await;
if let Some(mut node) =
client_lock.nodes.get_mut(&player_update.guild_id)
{
if let Some(mut current_track) = node.now_playing.as_mut() {
let mut info =
current_track.track.info.as_mut().unwrap().clone();
info.position = player_update.state.position as u64;
current_track.track.info = Some(info);
trace!("Updated track {:?} with position {}", current_track.track.info.as_ref().unwrap(), player_update.state.position);
}
};
}
handler.player_update(client.clone(), player_update).await;
}
}
"event" => match base_event.event_type.unwrap().as_str() {
"TrackStartEvent" => {
if let Ok(track_start) = serde_json::from_str::<TrackStart>(&x) {
handler.track_start(client.clone(), track_start).await;
}
}
"TrackEndEvent" => {
if let Ok(track_finish) = serde_json::from_str::<TrackFinish>(&x) {
if track_finish.reason == "FINISHED" {
let client_lock = client.inner.lock().await;
if let Some(mut node) =
client_lock.nodes.get_mut(&track_finish.guild_id)
{
node.queue.remove(0);
node.now_playing = None;
};
}
handler.track_finish(client.clone(), track_finish).await;
}
}
_ => warn!("Unknown event: {}", &x),
},
_ => warn!("Unknown socket response: {}", &x),
}
}
}
}
}
impl LavalinkClient {
pub async fn new(
builder: &LavalinkClientBuilder,
handler: impl LavalinkEventHandler + Send + Sync + 'static,
) -> LavalinkResult<Self> {
let socket_uri;
let rest_uri;
if builder.is_ssl {
socket_uri = format!("wss://{}:{}", &builder.host, builder.port);
rest_uri = format!("https://{}:{}", &builder.host, builder.port);
} else {
socket_uri = format!("ws://{}:{}", &builder.host, builder.port);
rest_uri = format!("http://{}:{}", &builder.host, builder.port);
}
let mut headers = HeaderMap::new();
headers.insert("Authorization", builder.password.parse()?);
headers.insert("Num-Shards", builder.shard_count.to_string().parse()?);
headers.insert("User-Id", builder.bot_id.to_string().parse()?);
let mut url_builder = Request::builder();
{
let ref_headers = url_builder.headers_mut().unwrap();
*ref_headers = headers.clone();
}
let url = url_builder.uri(&socket_uri).body(()).unwrap();
let (ws_stream, _) = connect_async(url).await?;
let (socket_write, socket_read) = ws_stream.split();
let client_inner = LavalinkClientInner {
headers,
socket_write,
rest_uri,
nodes: Arc::new(DashMap::new()),
loops: Arc::new(DashSet::new()),
};
let client = Self {
inner: Arc::new(Mutex::new(client_inner)),
};
let client_clone = client.clone();
tokio::spawn(async move {
debug!("Starting event loop.");
event_loop(socket_read, handler, client_clone).await;
error!("Event loop ended unexpectedly.");
});
Ok(client)
}
pub fn builder(user_id: impl Into<UserId>) -> LavalinkClientBuilder {
LavalinkClientBuilder::new(user_id)
}
pub async fn get_tracks(&self, query: impl ToString) -> LavalinkResult<Tracks> {
let client = self.inner.lock().await;
let reqwest = ReqwestClient::new();
let url = Url::parse_with_params(
&format!("{}/loadtracks", &client.rest_uri),
&[("identifier", &query.to_string())],
)
.expect("The query cannot be formated to a url.");
let resp = reqwest
.get(url)
.headers(client.headers.clone())
.send()
.await?
.json::<Tracks>()
.await?;
Ok(resp)
}
pub async fn auto_search_tracks(&self, query: impl ToString) -> LavalinkResult<Tracks> {
let r = Regex::new(r"https?://(?:www\.)?.+").unwrap();
if r.is_match(&query.to_string()) {
self.get_tracks(query.to_string()).await
} else {
self.get_tracks(format!("ytsearch:{}", query.to_string()))
.await
}
}
pub async fn search_tracks(&self, query: impl ToString) -> LavalinkResult<Tracks> {
self.get_tracks(format!("ytsearch:{}", query.to_string()))
.await
}
pub async fn create_session(&self, connection_info: &ConnectionInfo) -> LavalinkResult<()> {
let event = crate::model::Event {
token: connection_info.token.to_string(),
endpoint: connection_info.endpoint.to_string(),
guild_id: connection_info.guild_id.0.to_string(),
};
let payload = crate::model::VoiceUpdate {
session_id: connection_info.session_id.to_string(),
event,
};
let mut client = self.inner.lock().await;
crate::model::SendOpcode::VoiceUpdate(payload)
.send(connection_info.guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub fn play(&self, guild_id: impl Into<GuildId>, track: Track) -> PlayParameters {
PlayParameters {
track,
guild_id: guild_id.into().0,
client: self.clone(),
replace: false,
start: 0,
finish: 0,
requester: None,
}
}
pub async fn destroy(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<()> {
let guild_id = guild_id.into();
let mut client = self.inner.lock().await;
if let Some(mut node) = client.nodes.get_mut(&guild_id.0) {
node.now_playing = None;
if !node.queue.is_empty() {
node.queue.remove(0);
}
}
crate::model::SendOpcode::Destroy
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn stop(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<()> {
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Stop
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn skip(&self, guild_id: impl Into<GuildId>) -> Option<TrackQueue> {
let client = self.inner.lock().await;
let mut node = client.nodes.get_mut(&guild_id.into().0)?;
node.now_playing = None;
if node.queue.is_empty() {
None
} else {
Some(node.queue.remove(0))
}
}
pub async fn set_pause(&self, guild_id: impl Into<GuildId>, pause: bool) -> LavalinkResult<()> {
let payload = crate::model::Pause { pause };
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Pause(payload)
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn pause(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<()> {
self.set_pause(guild_id, true).await
}
pub async fn resume(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<()> {
self.set_pause(guild_id, false).await
}
pub async fn seek(&self, guild_id: impl Into<GuildId>, time: Duration) -> LavalinkResult<()> {
let payload = crate::model::Seek {
position: time.as_millis() as u64,
};
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Seek(payload)
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn jump_to_time(
&self,
guild_id: impl Into<GuildId>,
time: Duration,
) -> LavalinkResult<()> {
self.seek(guild_id, time).await
}
pub async fn scrub(&self, guild_id: impl Into<GuildId>, time: Duration) -> LavalinkResult<()> {
self.seek(guild_id, time).await
}
pub async fn volume(&self, guild_id: impl Into<GuildId>, volume: u16) -> LavalinkResult<()> {
let good_volume = max(min(volume, 1000), 0);
let payload = crate::model::Volume {
volume: good_volume,
};
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Volume(payload)
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn equalize_all(
&self,
guild_id: impl Into<GuildId>,
bands: [f64; 15],
) -> LavalinkResult<()> {
let bands = bands
.iter()
.enumerate()
.map(|(index, i)| crate::model::Band {
band: index as u8,
gain: *i,
})
.collect::<Vec<_>>();
let payload = crate::model::Equalizer { bands };
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Equalizer(payload)
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn equalize_band(
&self,
guild_id: impl Into<GuildId>,
band: crate::model::Band,
) -> LavalinkResult<()> {
let payload = crate::model::Equalizer { bands: vec![band] };
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Equalizer(payload)
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn equalize_reset(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<()> {
let bands = (0..=14)
.map(|i| crate::model::Band {
band: i as u8,
gain: 0.,
})
.collect::<Vec<_>>();
let payload = crate::model::Equalizer { bands };
let mut client = self.inner.lock().await;
crate::model::SendOpcode::Equalizer(payload)
.send(guild_id, &mut client.socket_write)
.await?;
Ok(())
}
pub async fn nodes(&self) -> Arc<DashMap<u64, Node>> {
let client = self.inner.lock().await;
client.nodes.clone()
}
pub async fn loops(&self) -> Arc<DashSet<u64>> {
let client = self.inner.lock().await;
client.loops.clone()
}
}