lavalink_rs/
node.rs

1use crate::client::LavalinkClient;
2use crate::error::LavalinkError;
3use crate::model::{events, BoxFuture, Secret, UserId};
4
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use futures::stream::StreamExt;
10#[cfg(feature = "_tungstenite")]
11use http::HeaderMap;
12
13#[cfg(feature = "_tungstenite")]
14use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15#[cfg(feature = "_tungstenite")]
16use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
17#[cfg(feature = "_tungstenite")]
18use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
19
20#[derive(Debug, Clone)]
21#[cfg_attr(not(feature = "python"), derive(Hash, Default))]
22#[cfg_attr(feature = "python", pyo3::pyclass)]
23/// A builder for the node.
24///
25/// # Example
26///
27/// ```
28/// # use crate::model::UserId;
29/// let node_builder = NodeBuilder {
30///     hostname: "localhost:2333".to_string(),
31///     password: "youshallnotpass".to_string(),
32///     user_id: UserId(551759974905151548),
33///     ..Default::default()
34/// };
35/// ```
36pub struct NodeBuilder {
37    /// The hostname of the Lavalink server.
38    ///
39    /// Example: "localhost:2333"
40    pub hostname: String,
41    /// If the Lavalink server is behind SSL encryption.
42    pub is_ssl: bool,
43    /// The event handler specific for this node.
44    ///
45    /// In most cases, the default is good.
46    pub events: events::Events,
47    /// The Lavalink server password.
48    pub password: String,
49    /// The bot User ID that will use Lavalink.
50    pub user_id: UserId,
51    /// The previous Session ID if resuming.
52    pub session_id: Option<String>,
53}
54
55#[derive(Debug)]
56/// A Lavalink server node.
57pub struct Node {
58    pub id: usize,
59    pub session_id: ArcSwap<String>,
60    pub websocket_address: String,
61    pub http: crate::http::Http,
62    pub events: events::Events,
63    pub is_running: AtomicBool,
64    pub(crate) password: Secret,
65    pub user_id: UserId,
66    pub cpu: ArcSwap<crate::model::events::Cpu>,
67    pub memory: ArcSwap<crate::model::events::Memory>,
68}
69
70#[derive(Copy, Clone)]
71struct EventDispatcher<'a>(&'a Node, &'a LavalinkClient);
72
73// Thanks Alba :D
74impl<'a> EventDispatcher<'a> {
75    pub(crate) async fn dispatch<T, F>(self, event: T, handler: F)
76    where
77        F: Fn(&events::Events) -> Option<fn(LavalinkClient, String, &T) -> BoxFuture<()>>,
78    {
79        let EventDispatcher(self_node, lavalink_client) = self;
80        let session_id = self_node.session_id.load_full();
81        let targets = [&self_node.events, &lavalink_client.events].into_iter();
82
83        for handler in targets.filter_map(handler) {
84            handler(lavalink_client.clone(), (*session_id).clone(), &event).await;
85        }
86    }
87
88    #[cfg(not(feature = "python"))]
89    pub(crate) async fn parse_and_dispatch<T: serde::de::DeserializeOwned, F>(
90        self,
91        event: serde_json::Value,
92        handler: F,
93    ) where
94        F: Fn(&events::Events) -> Option<fn(LavalinkClient, String, &T) -> BoxFuture<()>>,
95        T: serde::de::DeserializeOwned,
96    {
97        trace!("{:?}", event);
98        let event = serde_json::from_value(event).unwrap();
99        self.dispatch(event, handler).await
100    }
101}
102
103impl Node {
104    /// Create a connection to the Lavalink server.
105    #[cfg(feature = "_tungstenite")]
106    pub async fn connect(&self, lavalink_client: LavalinkClient) -> Result<(), LavalinkError> {
107        //let mut url = Request::builder()
108        //    .method("GET")
109        //    .header("Host", &self.websocket_address)
110        //    .header("Connection", "Upgrade")
111        //    .header("Upgrade", "websocket")
112        //    .header("Sec-WebSocket-Version", "13")
113        //    .header("Sec-WebSocket-Key", generate_key())
114        //    .uri(&self.websocket_address)
115        //    .body(())?;
116
117        let mut url = self.websocket_address.clone().into_client_request()?;
118
119        {
120            let ref_headers = url.headers_mut();
121
122            let mut headers = HeaderMap::new();
123            headers.insert("Authorization", self.password.0.parse()?);
124            headers.insert("User-Id", self.user_id.0.to_string().parse()?);
125            headers.insert("Session-Id", self.session_id.to_string().parse()?);
126            headers.insert(
127                "Client-Name",
128                format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"),)
129                    .to_string()
130                    .parse()?,
131            );
132
133            ref_headers.extend(headers.clone());
134        }
135
136        let (ws_stream, _) = tokio_tungstenite::connect_async_with_config(
137            url,
138            Some(
139                WebSocketConfig::default()
140                    .max_message_size(None)
141                    .max_frame_size(None),
142            ),
143            false,
144        )
145        .await?;
146
147        info!("Connected to {}", self.websocket_address);
148
149        let (_write, mut read) = ws_stream.split();
150
151        self.is_running.store(true, Ordering::SeqCst);
152
153        let self_node_id = self.id;
154
155        tokio::spawn(async move {
156            while let Some(Ok(resp)) = read.next().await {
157                let x = match resp {
158                    TungsteniteMessage::Text(x) => x,
159                    _ => continue,
160                };
161
162                let base_event = match serde_json::from_str::<serde_json::Value>(&x) {
163                    Ok(base_event) => base_event,
164                    _ => continue,
165                };
166
167                let lavalink_client = lavalink_client.clone();
168
169                tokio::spawn(async move {
170                    Node::handle_event(lavalink_client, self_node_id, base_event).await;
171                });
172            }
173
174            let self_node = lavalink_client.nodes.get(self_node_id).unwrap();
175            self_node.is_running.store(false, Ordering::SeqCst);
176            error!("Connection Closed.");
177        });
178
179        Ok(())
180    }
181
182    /// Create a connection to the Lavalink server.
183    #[cfg(feature = "_websockets")]
184    pub async fn connect(&self, lavalink_client: LavalinkClient) -> Result<(), LavalinkError> {
185        let uri = <::http::Uri as std::str::FromStr>::from_str(&self.websocket_address)?;
186
187        let (client, _) = tokio_websockets::ClientBuilder::from_uri(uri)
188            .add_header(
189                "authorization".try_into().unwrap(),
190                self.password.0.parse()?,
191            )?
192            .add_header(
193                "user-id".try_into().unwrap(),
194                self.user_id.0.to_string().parse()?,
195            )?
196            .add_header(
197                "session-id".try_into().unwrap(),
198                self.session_id.to_string().parse()?,
199            )?
200            .add_header(
201                "client-name".try_into().unwrap(),
202                format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"),)
203                    .to_string()
204                    .parse()?,
205            )?
206            .connect()
207            .await?;
208
209        info!("Connected to {}", self.websocket_address);
210
211        let (_write, mut read) = client.split();
212
213        self.is_running.store(true, Ordering::SeqCst);
214
215        let self_node_id = self.id;
216
217        tokio::spawn(async move {
218            while let Some(Ok(resp)) = read.next().await {
219                let x = match resp.as_text() {
220                    Some(x) => x,
221                    _ => continue,
222                };
223
224                let base_event = match serde_json::from_str::<serde_json::Value>(&x) {
225                    Ok(base_event) => base_event,
226                    _ => continue,
227                };
228
229                let lavalink_client = lavalink_client.clone();
230
231                tokio::spawn(async move {
232                    Node::handle_event(lavalink_client, self_node_id, base_event).await;
233                });
234            }
235
236            let self_node = lavalink_client.nodes.get(self_node_id).unwrap();
237            self_node.is_running.store(false, Ordering::SeqCst);
238            error!("Connection Closed.");
239        });
240
241        Ok(())
242    }
243
244    async fn handle_event(
245        lavalink_client: LavalinkClient,
246        self_node_id: usize,
247        base_event: serde_json::Value,
248    ) {
249        let base_event_clone = base_event.clone();
250        let self_node = lavalink_client.nodes.get(self_node_id).unwrap();
251        let ed = EventDispatcher(self_node, &lavalink_client);
252
253        match base_event.get("op").unwrap().as_str().unwrap() {
254            "ready" => {
255                let ready_event: events::Ready = serde_json::from_value(base_event).unwrap();
256
257                self_node
258                    .session_id
259                    .swap(Arc::new(ready_event.session_id.to_string()));
260
261                #[cfg(feature = "python")]
262                {
263                    let session_id = self_node.session_id.load_full();
264
265                    if let Some(handler) = &self_node.events.event_handler {
266                        handler
267                            .event_ready(
268                                lavalink_client.clone(),
269                                (*session_id).clone(),
270                                ready_event.clone(),
271                            )
272                            .await;
273                    }
274                    if let Some(handler) = &lavalink_client.events.event_handler {
275                        handler
276                            .event_ready(
277                                lavalink_client.clone(),
278                                (*session_id).clone(),
279                                ready_event.clone(),
280                            )
281                            .await;
282                    }
283                }
284
285                ed.dispatch(ready_event, |e| e.ready).await;
286            }
287            "playerUpdate" => {
288                let player_update_event: events::PlayerUpdate =
289                    serde_json::from_value(base_event).unwrap();
290
291                if let Some(player) =
292                    lavalink_client.get_player_context(player_update_event.guild_id)
293                {
294                    if let Err(why) = player.update_state(player_update_event.state.clone()) {
295                        error!(
296                            "Error updating state for player {}: {}",
297                            player_update_event.guild_id.0, why
298                        );
299                    }
300                }
301
302                #[cfg(feature = "python")]
303                {
304                    let session_id = self_node.session_id.load_full();
305
306                    if let Some(handler) = &self_node.events.event_handler {
307                        handler
308                            .event_player_update(
309                                lavalink_client.clone(),
310                                (*session_id).clone(),
311                                player_update_event.clone(),
312                            )
313                            .await;
314                    }
315                    if let Some(handler) = &lavalink_client.events.event_handler {
316                        handler
317                            .event_player_update(
318                                lavalink_client.clone(),
319                                (*session_id).clone(),
320                                player_update_event.clone(),
321                            )
322                            .await;
323                    }
324                }
325
326                ed.dispatch(player_update_event, |e| e.player_update).await;
327            }
328            "stats" => {
329                #[cfg(feature = "python")]
330                {
331                    let event: events::Stats = serde_json::from_value(base_event).unwrap();
332                    let session_id = self_node.session_id.load_full();
333
334                    self_node.cpu.store(Arc::new(event.cpu.clone()));
335                    self_node.memory.store(Arc::new(event.memory.clone()));
336
337                    if let Some(handler) = &self_node.events.event_handler {
338                        handler
339                            .event_stats(
340                                lavalink_client.clone(),
341                                (*session_id).clone(),
342                                event.clone(),
343                            )
344                            .await;
345                    }
346                    if let Some(handler) = &lavalink_client.events.event_handler {
347                        handler
348                            .event_stats(
349                                lavalink_client.clone(),
350                                (*session_id).clone(),
351                                event.clone(),
352                            )
353                            .await;
354                    }
355
356                    ed.dispatch(event, |e| e.stats).await;
357                }
358                #[cfg(not(feature = "python"))]
359                ed.parse_and_dispatch(base_event, |e| e.stats).await;
360            }
361            "event" => match base_event.get("type").unwrap().as_str().unwrap() {
362                "TrackStartEvent" => {
363                    let track_event: events::TrackStart =
364                        serde_json::from_value(base_event).unwrap();
365
366                    if let Some(player) = lavalink_client.get_player_context(track_event.guild_id) {
367                        if let Err(why) = player.update_track(track_event.track.clone().into()) {
368                            error!(
369                                "Error sending update track message for player {}: {}",
370                                track_event.guild_id.0, why
371                            );
372                        }
373                    }
374
375                    #[cfg(feature = "python")]
376                    {
377                        let session_id = self_node.session_id.load_full();
378
379                        if let Some(handler) = &self_node.events.event_handler {
380                            handler
381                                .event_track_start(
382                                    lavalink_client.clone(),
383                                    (*session_id).clone(),
384                                    track_event.clone(),
385                                )
386                                .await;
387                        }
388                        if let Some(handler) = &lavalink_client.events.event_handler {
389                            handler
390                                .event_track_start(
391                                    lavalink_client.clone(),
392                                    (*session_id).clone(),
393                                    track_event.clone(),
394                                )
395                                .await;
396                        }
397                    }
398
399                    ed.dispatch(track_event, |e| e.track_start).await;
400                }
401                "TrackEndEvent" => {
402                    let track_event: events::TrackEnd = serde_json::from_value(base_event).unwrap();
403
404                    if let Some(player) = lavalink_client.get_player_context(track_event.guild_id) {
405                        if let Err(why) = player.finish(track_event.reason.clone().into()) {
406                            error!(
407                                "Error sending finish message for player {}: {}",
408                                track_event.guild_id.0, why
409                            );
410                        }
411
412                        if let Err(why) = player.update_track(None) {
413                            error!(
414                                "Error sending update track message for player {}: {}",
415                                track_event.guild_id.0, why
416                            );
417                        }
418                    }
419
420                    #[cfg(feature = "python")]
421                    {
422                        let session_id = self_node.session_id.load_full();
423
424                        if let Some(handler) = &self_node.events.event_handler {
425                            handler
426                                .event_track_end(
427                                    lavalink_client.clone(),
428                                    (*session_id).clone(),
429                                    track_event.clone(),
430                                )
431                                .await;
432                        }
433                        if let Some(handler) = &lavalink_client.events.event_handler {
434                            handler
435                                .event_track_end(
436                                    lavalink_client.clone(),
437                                    (*session_id).clone(),
438                                    track_event.clone(),
439                                )
440                                .await;
441                        }
442                    }
443
444                    ed.dispatch(track_event, |e| e.track_end).await;
445                }
446                "TrackExceptionEvent" => {
447                    #[cfg(feature = "python")]
448                    {
449                        let event: events::TrackException =
450                            serde_json::from_value(base_event).unwrap();
451                        let session_id = self_node.session_id.load_full();
452
453                        if let Some(handler) = &self_node.events.event_handler {
454                            handler
455                                .event_track_exception(
456                                    lavalink_client.clone(),
457                                    (*session_id).clone(),
458                                    event.clone(),
459                                )
460                                .await;
461                        }
462                        if let Some(handler) = &lavalink_client.events.event_handler {
463                            handler
464                                .event_track_exception(
465                                    lavalink_client.clone(),
466                                    (*session_id).clone(),
467                                    event.clone(),
468                                )
469                                .await;
470                        }
471
472                        ed.dispatch(event, |e| e.track_exception).await;
473                    }
474                    #[cfg(not(feature = "python"))]
475                    ed.parse_and_dispatch(base_event, |e| e.track_exception)
476                        .await;
477                }
478                "TrackStuckEvent" => {
479                    #[cfg(feature = "python")]
480                    {
481                        let event: events::TrackStuck = serde_json::from_value(base_event).unwrap();
482                        let session_id = self_node.session_id.load_full();
483
484                        if let Some(handler) = &self_node.events.event_handler {
485                            handler
486                                .event_track_stuck(
487                                    lavalink_client.clone(),
488                                    (*session_id).clone(),
489                                    event.clone(),
490                                )
491                                .await;
492                        }
493                        if let Some(handler) = &lavalink_client.events.event_handler {
494                            handler
495                                .event_track_stuck(
496                                    lavalink_client.clone(),
497                                    (*session_id).clone(),
498                                    event.clone(),
499                                )
500                                .await;
501                        }
502
503                        ed.dispatch(event, |e| e.track_stuck).await;
504                    }
505                    #[cfg(not(feature = "python"))]
506                    ed.parse_and_dispatch(base_event, |e| e.track_stuck).await;
507                }
508                "WebSocketClosedEvent" => {
509                    #[cfg(feature = "python")]
510                    {
511                        let event: events::WebSocketClosed =
512                            serde_json::from_value(base_event).unwrap();
513                        let session_id = self_node.session_id.load_full();
514
515                        if let Some(handler) = &self_node.events.event_handler {
516                            handler
517                                .event_websocket_closed(
518                                    lavalink_client.clone(),
519                                    (*session_id).clone(),
520                                    event.clone(),
521                                )
522                                .await;
523                        }
524                        if let Some(handler) = &lavalink_client.events.event_handler {
525                            handler
526                                .event_websocket_closed(
527                                    lavalink_client.clone(),
528                                    (*session_id).clone(),
529                                    event.clone(),
530                                )
531                                .await;
532                        }
533
534                        ed.dispatch(event, |e| e.websocket_closed).await;
535                    }
536                    #[cfg(not(feature = "python"))]
537                    ed.parse_and_dispatch(base_event, |e| e.websocket_closed)
538                        .await;
539                }
540                _ => (),
541            },
542
543            _ => (),
544        }
545
546        ed.dispatch(base_event_clone, |e| e.raw).await;
547    }
548}