1use crate::error::{LavalinkError, LavalinkResult};
2use crate::model::*;
3use crate::node;
4use crate::player_context::*;
5
6use std::collections::VecDeque;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9
10use ::http::header::HeaderMap;
11use arc_swap::{ArcSwap, ArcSwapOption};
12use dashmap::DashMap;
13use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
14use tokio::sync::Mutex;
15
16#[derive(Clone)]
17#[cfg_attr(feature = "python", pyo3::pyclass)]
18pub struct LavalinkClient {
20 pub nodes: Vec<Arc<node::Node>>,
21 pub players: Arc<DashMap<GuildId, (ArcSwapOption<PlayerContext>, Arc<node::Node>)>>,
22 pub events: Arc<events::Events>,
23 tx: UnboundedSender<client::ClientMessage>,
24 user_id: UserId,
25 user_data: Arc<dyn std::any::Any + Send + Sync>,
26 strategy: client::NodeDistributionStrategy,
27}
28
29impl LavalinkClient {
30 pub async fn new(
38 events: events::Events,
39 nodes: Vec<node::NodeBuilder>,
40 strategy: client::NodeDistributionStrategy,
41 ) -> LavalinkClient {
42 Self::new_with_data(events, nodes, strategy, Arc::new(())).await
43 }
44
45 pub async fn new_with_data<Data: std::any::Any + Send + Sync>(
54 events: events::Events,
55 nodes: Vec<node::NodeBuilder>,
56 strategy: client::NodeDistributionStrategy,
57 user_data: Arc<Data>,
58 ) -> LavalinkClient {
59 if nodes.is_empty() {
60 panic!("At least one node must be provided.");
61 }
62
63 let mut built_nodes = Vec::new();
64
65 for (idx, i) in nodes.into_iter().enumerate() {
66 let mut headers = HeaderMap::new();
67 headers.insert("Authorization", i.password.parse().unwrap());
68 headers.insert("User-Id", i.user_id.0.to_string().parse().unwrap());
69 headers.insert("Connection", "keep-alive".parse().unwrap());
70
71 if let Some(session_id) = &i.session_id {
72 headers.insert("Session-Id", session_id.parse().unwrap());
73 }
74
75 headers.insert(
76 "Client-Name",
77 format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))
78 .to_string()
79 .parse()
80 .unwrap(),
81 );
82
83 #[cfg(feature = "_rustls-webpki-roots")]
84 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
85 .with_webpki_roots()
86 .https_or_http()
87 .enable_all_versions()
88 .build();
89 #[cfg(feature = "_rustls-native-roots")]
90 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
91 .with_native_roots()
92 .expect("no native root CA certificates found")
93 .https_or_http()
94 .enable_all_versions()
95 .build();
96 #[cfg(feature = "_native-tls")]
97 let https_connector = hyper_tls::HttpsConnector::new();
98
99 let request_client =
100 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
101 .pool_idle_timeout(std::time::Duration::from_secs(60))
102 .pool_timer(hyper_util::rt::TokioTimer::new())
103 .build(https_connector);
104
105 let node = if i.is_ssl {
106 let http = crate::http::Http {
107 authority: i.hostname.clone(),
108 rest_address: format!("https://{}/v4", i.hostname),
109 rest_address_versionless: format!("https://{}", i.hostname),
110 headers,
111 request_client: request_client.into(),
112 };
113
114 node::Node {
115 id: idx,
116 websocket_address: format!("wss://{}/v4/websocket", i.hostname),
117 http,
118 events: i.events,
119 password: Secret(i.password.into()),
120 user_id: i.user_id,
121 is_running: AtomicBool::new(false),
122 session_id: ArcSwap::new(if let Some(id) = i.session_id {
123 id.into()
124 } else {
125 idx.to_string().into()
126 }),
127 cpu: ArcSwap::new(Default::default()),
128 memory: ArcSwap::new(Default::default()),
129 }
130 } else {
131 let http = crate::http::Http {
132 authority: i.hostname.clone(),
133 rest_address: format!("http://{}/v4", i.hostname),
134 rest_address_versionless: format!("http://{}", i.hostname),
135 headers,
136 request_client: request_client.into(),
137 };
138
139 node::Node {
140 id: idx,
141 websocket_address: format!("ws://{}/v4/websocket", i.hostname),
142 http,
143 events: i.events,
144 password: Secret(i.password.into()),
145 user_id: i.user_id,
146 is_running: AtomicBool::new(false),
147 session_id: ArcSwap::new(if let Some(id) = i.session_id {
148 id.into()
149 } else {
150 idx.to_string().into()
151 }),
152 cpu: ArcSwap::new(Default::default()),
153 memory: ArcSwap::new(Default::default()),
154 }
155 };
156
157 let node_arc = Arc::new(node);
158
159 built_nodes.push(node_arc);
160 }
161
162 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
163
164 let client = LavalinkClient {
165 user_id: built_nodes[0].user_id,
166 nodes: built_nodes,
167 players: Arc::new(DashMap::new()),
168 events: Arc::new(events),
169 tx,
170 user_data,
171 strategy,
172 };
173
174 for node in &*client.nodes {
175 if let Err(why) = node.connect(client.clone()).await {
176 error!("Failed to connect to the lavalink websocket: {}", why);
177 }
178 }
179
180 tokio::spawn(LavalinkClient::handle_connection_info(client.clone(), rx));
181
182 let lavalink_client = client.clone();
183
184 tokio::spawn(async move {
185 loop {
186 tokio::time::sleep(std::time::Duration::from_secs(15)).await;
187
188 for node in &*lavalink_client.nodes {
189 if !node.is_running.load(Ordering::SeqCst) {
190 if let Err(why) = node.connect(lavalink_client.clone()).await {
191 error!("Failed to connect to the lavalink websocket: {}", why);
192 }
193 }
194 }
195 }
196 });
197
198 client
199 }
200
201 pub fn get_node_by_index(&self, idx: usize) -> Option<Arc<node::Node>> {
203 self.nodes.get(idx).cloned()
204 }
205
206 pub async fn get_node_for_guild(&self, guild_id: impl Into<GuildId>) -> Arc<node::Node> {
208 let guild_id = guild_id.into();
209
210 if let Some(node) = self.players.get(&guild_id) {
211 trace!("Node already selected for guild {:?}", guild_id);
212 return node.1.clone();
213 }
214
215 debug!("First time selecting node for guild {:?}", guild_id);
216
217 use client::NodeDistributionStrategy::*;
218
219 match &self.strategy {
220 Sharded => self
221 .get_node_by_index(guild_id.0 as usize % self.nodes.len())
222 .unwrap(),
223 RoundRobin(x) => {
224 let mut idx = x.fetch_add(1, Ordering::SeqCst);
225 if idx == self.nodes.len() {
226 x.store(1, Ordering::SeqCst);
227 idx = 0;
228 }
229
230 self.get_node_by_index(idx).unwrap()
231 }
232 MainFallback => {
233 for node in &*self.nodes {
234 if node.is_running.load(Ordering::SeqCst) {
235 return node.clone();
236 }
237 }
238
239 warn!("No nodes are currently running, waiting 5 seconds and trying again...");
240 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
241
242 for node in &*self.nodes {
243 if node.is_running.load(Ordering::SeqCst) {
244 return node.clone();
245 }
246 }
247
248 warn!("No nodes are currently running, returning first node.");
249
250 self.get_node_by_index(0).unwrap()
251 }
252 LowestLoad => self
253 .nodes
254 .iter()
255 .min_by_key(|x| x.cpu.load().system_load.abs() as u8)
256 .unwrap()
257 .clone(),
258 HighestFreeMemory => self
259 .nodes
260 .iter()
261 .min_by_key(|x| x.memory.load().free)
262 .unwrap()
263 .clone(),
264 Custom(func) => func(self, guild_id).await,
265 #[cfg(feature = "python")]
266 CustomPython(func) => {
267 use pyo3::prelude::*;
268 let client = self.clone();
269 let (tx, rx) = oneshot::channel();
270
271 Python::with_gil(|py| {
272 let current_loop = pyo3_async_runtimes::tokio::get_current_loop(py).unwrap();
273 let func = func.clone_ref(py);
274
275 let client = client.clone();
276 let client2 = client.clone();
277
278 pyo3_async_runtimes::tokio::future_into_py_with_locals(
279 py,
280 pyo3_async_runtimes::TaskLocals::new(current_loop),
281 async move {
282 let future = Python::with_gil(|py| {
283 let coro = func
284 .call(
285 py,
286 (
287 client.into_pyobject(py).unwrap(),
288 guild_id.into_pyobject(py).unwrap(),
289 ),
290 None,
291 )
292 .unwrap();
293
294 pyo3_async_runtimes::tokio::into_future(coro.into_bound(py))
295 })
296 .unwrap();
297
298 match future.await {
299 Err(e) => {
300 Python::with_gil(|py| {
301 e.print_and_set_sys_last_vars(py);
302 });
303 let _ = tx.send(crate::python::node::Node {
304 inner: client2.get_node_by_index(0).unwrap().clone(),
305 });
306 }
307 Ok(x) => {
308 let _ = tx.send(Python::with_gil(|py| {
309 x.extract::<crate::python::node::Node>(py).unwrap()
310 }));
311 }
312 }
313
314 Ok(())
315 },
316 )
317 .unwrap();
318 });
319
320 rx.await.unwrap().inner
321 }
322 }
323 }
324
325 pub fn get_player_context(&self, guild_id: impl Into<GuildId>) -> Option<PlayerContext> {
327 let guild_id = guild_id.into();
328
329 if let Some(x) = self.players.get(&guild_id) {
330 x.0.load().clone().map(|x| (*x).clone())
331 } else {
332 None
333 }
334 }
335
336 pub async fn create_player(
340 &self,
341 guild_id: impl Into<GuildId>,
342 connection_info: impl Into<player::ConnectionInfo>,
343 ) -> LavalinkResult<player::Player> {
344 let guild_id = guild_id.into();
345 let mut connection_info = connection_info.into();
346 connection_info.fix();
347
348 let node = self.get_node_for_guild(guild_id).await;
349
350 let player = node
351 .http
352 .update_player(
353 guild_id,
354 &node.session_id.load(),
355 &http::UpdatePlayer {
356 voice: Some(connection_info.clone()),
357 ..Default::default()
358 },
359 true,
360 )
361 .await?;
362
363 self.players
364 .entry(guild_id)
365 .or_insert((ArcSwapOption::new(None), node));
366
367 Ok(player)
368 }
369
370 pub async fn create_player_context(
374 &self,
375 guild_id: impl Into<GuildId>,
376 connection_info: impl Into<player::ConnectionInfo>,
377 ) -> LavalinkResult<PlayerContext> {
378 self.create_player_context_with_data(guild_id, connection_info, Arc::new(()))
379 .await
380 }
381
382 pub async fn create_player_context_with_data<Data: std::any::Any + Send + Sync>(
386 &self,
387 guild_id: impl Into<GuildId>,
388 connection_info: impl Into<player::ConnectionInfo>,
389 user_data: Arc<Data>,
390 ) -> LavalinkResult<PlayerContext> {
391 let guild_id = guild_id.into();
392 let mut connection_info = connection_info.into();
393 connection_info.fix();
394
395 let node = self.get_node_for_guild(guild_id).await;
396
397 if let Some(x) = self.players.get(&guild_id) {
398 if let Some(x) = &*x.0.load() {
399 return Ok((**x).clone());
400 }
401 }
402
403 let player = node
404 .http
405 .update_player(
406 guild_id,
407 &node.session_id.load(),
408 &http::UpdatePlayer {
409 voice: Some(connection_info.clone()),
410 ..Default::default()
411 },
412 true,
413 )
414 .await?;
415
416 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
417
418 let player_dummy = PlayerContext {
419 guild_id,
420 client: self.clone(),
421 tx,
422 user_data,
423 };
424
425 let player_context = PlayerContextInner {
426 guild_id,
427 queue: VecDeque::new(),
428 player_data: player,
429 dummy: player_dummy.clone(),
430 last_should_continue: true,
431 };
432
433 player_context.start(rx).await;
434
435 self.players.insert(
436 guild_id,
437 (ArcSwapOption::new(Some(player_dummy.clone().into())), node),
438 );
439
440 Ok(player_dummy)
441 }
442
443 pub async fn delete_player(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<()> {
445 let guild_id = guild_id.into();
446 let node = self.get_node_for_guild(guild_id).await;
447
448 if let Some((_, (player, _))) = self.players.remove(&guild_id) {
449 if let Some(x) = &*player.load() {
450 (**x).clone().close()?;
451 }
452 }
453
454 node.http
455 .delete_player(guild_id, &node.session_id.load())
456 .await?;
457
458 Ok(())
459 }
460
461 pub async fn delete_all_player_contexts(&self) -> LavalinkResult<()> {
466 for guild_id in self
467 .players
468 .iter()
469 .filter_map(|i| i.0.load().clone().map(|x| x.guild_id))
470 .collect::<Vec<_>>()
471 {
472 self.delete_player(guild_id).await?;
473 }
474
475 Ok(())
476 }
477
478 pub async fn update_player(
480 &self,
481 guild_id: impl Into<GuildId>,
482 update_player: &http::UpdatePlayer,
483 no_replace: bool,
484 ) -> LavalinkResult<player::Player> {
485 let guild_id = guild_id.into();
486 let node = self.get_node_for_guild(guild_id).await;
487
488 let result = node
489 .http
490 .update_player(guild_id, &node.session_id.load(), update_player, no_replace)
491 .await?;
492
493 if let Some(player) = self.get_player_context(guild_id) {
494 player.update_player_data(result.clone())?;
495 }
496
497 Ok(result)
498 }
499
500 pub async fn load_tracks(
509 &self,
510 guild_id: impl Into<GuildId>,
511 identifier: &str,
512 ) -> LavalinkResult<track::Track> {
513 let guild_id = guild_id.into();
514 let node = self.get_node_for_guild(guild_id).await;
515
516 let result = node.http.load_tracks(identifier).await?;
517
518 Ok(result)
519 }
520
521 pub async fn decode_track(
527 &self,
528 guild_id: impl Into<GuildId>,
529 track: &str,
530 ) -> LavalinkResult<track::TrackData> {
531 let guild_id = guild_id.into();
532 let node = self.get_node_for_guild(guild_id).await;
533
534 let result = node.http.decode_track(track).await?;
535
536 Ok(result)
537 }
538
539 pub async fn decode_tracks(
545 &self,
546 guild_id: impl Into<GuildId>,
547 tracks: &[String],
548 ) -> LavalinkResult<Vec<track::TrackData>> {
549 let guild_id = guild_id.into();
550 let node = self.get_node_for_guild(guild_id).await;
551
552 let result = node.http.decode_tracks(tracks).await?;
553
554 Ok(result)
555 }
556
557 pub async fn request_version(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<String> {
559 let guild_id = guild_id.into();
560 let node = self.get_node_for_guild(guild_id).await;
561
562 let result = node.http.version().await?;
563
564 Ok(result)
565 }
566
567 pub async fn request_stats(
571 &self,
572 guild_id: impl Into<GuildId>,
573 ) -> LavalinkResult<events::Stats> {
574 let guild_id = guild_id.into();
575 let node = self.get_node_for_guild(guild_id).await;
576
577 let result = node.http.stats().await?;
578
579 Ok(result)
580 }
581
582 pub async fn request_info(&self, guild_id: impl Into<GuildId>) -> LavalinkResult<http::Info> {
584 let guild_id = guild_id.into();
585 let node = self.get_node_for_guild(guild_id).await;
586
587 let result = node.http.info().await?;
588
589 Ok(result)
590 }
591
592 pub async fn request_player(
594 &self,
595 guild_id: impl Into<GuildId>,
596 ) -> LavalinkResult<player::Player> {
597 let guild_id = guild_id.into();
598 let node = self.get_node_for_guild(guild_id).await;
599
600 let result = node
601 .http
602 .get_player(guild_id, &node.session_id.load())
603 .await?;
604
605 Ok(result)
606 }
607
608 pub async fn request_all_players(
610 &self,
611 guild_id: impl Into<GuildId>,
612 ) -> LavalinkResult<Vec<player::Player>> {
613 let guild_id = guild_id.into();
614 let node = self.get_node_for_guild(guild_id).await;
615
616 let result = node.http.get_players(&node.session_id.load()).await?;
617
618 Ok(result)
619 }
620
621 pub fn data<Data: Send + Sync + 'static>(&self) -> LavalinkResult<std::sync::Arc<Data>> {
627 self.user_data
628 .clone()
629 .downcast()
630 .map_err(|_| LavalinkError::InvalidDataType)
631 }
632
633 pub fn handle_voice_server_update(
635 &self,
636 guild_id: impl Into<GuildId>,
637 token: String,
638 endpoint: Option<String>,
639 ) {
640 let _ = self.tx.send(client::ClientMessage::ServerUpdate(
641 guild_id.into(),
642 token,
643 endpoint,
644 ));
645 }
646
647 pub fn handle_voice_state_update(
649 &self,
650 guild_id: impl Into<GuildId>,
651 channel_id: Option<impl Into<ChannelId>>,
652 user_id: impl Into<UserId>,
653 session_id: String,
654 ) {
655 let _ = self.tx.send(client::ClientMessage::StateUpdate(
656 guild_id.into(),
657 channel_id.map(|x| x.into()),
658 user_id.into(),
659 session_id,
660 ));
661 }
662
663 pub async fn get_connection_info(
677 &self,
678 guild_id: impl Into<GuildId>,
679 timeout: std::time::Duration,
680 ) -> LavalinkResult<player::ConnectionInfo> {
681 let (tx, rx) = oneshot::channel();
682
683 let _ = self.tx.send(client::ClientMessage::GetConnectionInfo(
684 guild_id.into(),
685 timeout,
686 tx,
687 ));
688
689 rx.await?.map_err(|_| LavalinkError::Timeout)
690 }
691
692 async fn handle_connection_info(self, mut rx: UnboundedReceiver<client::ClientMessage>) {
693 let data: Arc<DashMap<GuildId, (Option<String>, Option<String>, Option<String>)>> =
694 Arc::new(DashMap::new());
695 let channels: Arc<
696 DashMap<GuildId, (UnboundedSender<()>, Arc<Mutex<UnboundedReceiver<()>>>)>,
697 > = Arc::new(DashMap::new());
698
699 while let Some(x) = rx.recv().await {
700 use client::ClientMessage::*;
701
702 match x {
703 GetConnectionInfo(guild_id, timeout, sender) => {
704 let data = data.clone();
705 let channels = channels.clone();
706
707 tokio::spawn(async move {
708 trace!("Requested connection information for guild {:?}", guild_id);
709
710 {
711 channels.entry(guild_id).or_insert({
712 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
713 (tx, Arc::new(Mutex::new(rx)))
714 });
715 }
716
717 let inner_lock = channels.get(&guild_id).unwrap().1.clone();
718 let mut inner_rx = inner_lock.lock().await;
719
720 trace!("Waiting for events in guild {:?}", guild_id);
721
722 loop {
723 match tokio::time::timeout(timeout, inner_rx.recv()).await {
724 Err(x) => {
725 if let Some((Some(token), Some(endpoint), Some(session_id))) =
726 data.get(&guild_id).map(|x| x.value().clone())
727 {
728 trace!(
729 "Connection information requested in {:?} but no changes since the previous request were received.",
730 guild_id
731 );
732
733 let _ = sender.send(Ok(player::ConnectionInfo {
734 token: token.to_string(),
735 endpoint: endpoint.to_string(),
736 session_id: session_id.to_string(),
737 }));
738 return;
739 }
740
741 trace!("Timeout reached in guild {:?}", guild_id);
742
743 let _ = sender.send(Err(x));
744 return;
745 }
746 Ok(x) => {
747 if x.is_none() {
748 trace!("Connection removed in guild {:?}", guild_id);
749 return;
750 };
751
752 trace!("Event received in guild {:?}", guild_id);
753
754 if let Some((Some(token), Some(endpoint), Some(session_id))) =
755 data.get(&guild_id).map(|x| x.value().clone())
756 {
757 trace!(
758 "Both events have been received in guild {:?}",
759 guild_id
760 );
761
762 let _ = sender.send(Ok(player::ConnectionInfo {
763 token: token.to_string(),
764 endpoint: endpoint.to_string(),
765 session_id: session_id.to_string(),
766 }));
767 return;
768 }
769 }
770 }
771 }
772 });
773 }
774 ServerUpdate(guild_id, token, endpoint) => {
775 trace!(
776 "Started handling ServerUpdate event for guild {:?}",
777 guild_id
778 );
779
780 {
781 channels.entry(guild_id).or_insert({
782 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
783 (tx, Arc::new(Mutex::new(rx)))
784 });
785 }
786
787 let mut entry = data.entry(guild_id).or_insert((None, None, None));
788 let session_id = entry.value().2.clone();
789 *entry.value_mut() = (Some(token), endpoint, session_id);
790
791 {
792 let inner_tx = &channels.get(&guild_id).unwrap().0;
793 let _ = inner_tx.send(());
794 }
795
796 trace!(
797 "Finished handling ServerUpdate event for guild {:?}",
798 guild_id
799 );
800 }
801 StateUpdate(guild_id, channel_id, user_id, session_id) => {
802 if user_id != self.user_id {
803 continue;
804 }
805
806 trace!(
807 "Started handling StateUpdate event for guild {:?}",
808 guild_id
809 );
810
811 {
812 channels.entry(guild_id).or_insert({
813 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
814 (tx, Arc::new(Mutex::new(rx)))
815 });
816 }
817
818 if channel_id.is_none() {
819 trace!("Bot disconnected from voice in the guild {:?}", guild_id);
820 data.remove(&guild_id);
821 channels.remove(&guild_id);
822 continue;
823 }
824
825 let mut entry = data.entry(guild_id).or_insert((None, None, None));
826 let token = entry.value().0.clone();
827 let endpoint = entry.value().1.clone();
828 *entry.value_mut() = (token, endpoint, Some(session_id));
829
830 {
831 let inner_tx = &channels.get(&guild_id).unwrap().0;
832 let _ = inner_tx.send(());
833 }
834
835 trace!(
836 "Finished handling StateUpdate event for guild {:?}",
837 guild_id
838 );
839 }
840 }
841 }
842 }
843}