1use crate::client::LavalinkClient;
2use crate::error::LavalinkError;
3use crate::model::{events, BoxFuture, Secret, UserId};
4
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use futures::stream::StreamExt;
10#[cfg(feature = "_tungstenite")]
11use http::HeaderMap;
12
13#[cfg(feature = "_tungstenite")]
14use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15#[cfg(feature = "_tungstenite")]
16use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
17#[cfg(feature = "_tungstenite")]
18use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
19
20#[derive(Debug, Clone)]
21#[cfg_attr(not(feature = "python"), derive(Hash, Default))]
22#[cfg_attr(feature = "python", pyo3::pyclass)]
23pub struct NodeBuilder {
37 pub hostname: String,
41 pub is_ssl: bool,
43 pub events: events::Events,
47 pub password: String,
49 pub user_id: UserId,
51 pub session_id: Option<String>,
53}
54
55#[derive(Debug)]
56pub struct Node {
58 pub id: usize,
59 pub session_id: ArcSwap<String>,
60 pub websocket_address: String,
61 pub http: crate::http::Http,
62 pub events: events::Events,
63 pub is_running: AtomicBool,
64 pub(crate) password: Secret,
65 pub user_id: UserId,
66 pub cpu: ArcSwap<crate::model::events::Cpu>,
67 pub memory: ArcSwap<crate::model::events::Memory>,
68}
69
70#[derive(Copy, Clone)]
71struct EventDispatcher<'a>(&'a Node, &'a LavalinkClient);
72
73impl<'a> EventDispatcher<'a> {
75 pub(crate) async fn dispatch<T, F>(self, event: T, handler: F)
76 where
77 F: Fn(&events::Events) -> Option<fn(LavalinkClient, String, &T) -> BoxFuture<()>>,
78 {
79 let EventDispatcher(self_node, lavalink_client) = self;
80 let session_id = self_node.session_id.load_full();
81 let targets = [&self_node.events, &lavalink_client.events].into_iter();
82
83 for handler in targets.filter_map(handler) {
84 handler(lavalink_client.clone(), (*session_id).clone(), &event).await;
85 }
86 }
87
88 #[cfg(not(feature = "python"))]
89 pub(crate) async fn parse_and_dispatch<T: serde::de::DeserializeOwned, F>(
90 self,
91 event: serde_json::Value,
92 handler: F,
93 ) where
94 F: Fn(&events::Events) -> Option<fn(LavalinkClient, String, &T) -> BoxFuture<()>>,
95 T: serde::de::DeserializeOwned,
96 {
97 trace!("{:?}", event);
98 let event = serde_json::from_value(event).unwrap();
99 self.dispatch(event, handler).await
100 }
101}
102
103impl Node {
104 #[cfg(feature = "_tungstenite")]
106 pub async fn connect(&self, lavalink_client: LavalinkClient) -> Result<(), LavalinkError> {
107 let mut url = self.websocket_address.clone().into_client_request()?;
118
119 {
120 let ref_headers = url.headers_mut();
121
122 let mut headers = HeaderMap::new();
123 headers.insert("Authorization", self.password.0.parse()?);
124 headers.insert("User-Id", self.user_id.0.to_string().parse()?);
125 headers.insert("Session-Id", self.session_id.to_string().parse()?);
126 headers.insert(
127 "Client-Name",
128 format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"),)
129 .to_string()
130 .parse()?,
131 );
132
133 ref_headers.extend(headers.clone());
134 }
135
136 let (ws_stream, _) = tokio_tungstenite::connect_async_with_config(
137 url,
138 Some(
139 WebSocketConfig::default()
140 .max_message_size(None)
141 .max_frame_size(None),
142 ),
143 false,
144 )
145 .await?;
146
147 info!("Connected to {}", self.websocket_address);
148
149 let (_write, mut read) = ws_stream.split();
150
151 self.is_running.store(true, Ordering::SeqCst);
152
153 let self_node_id = self.id;
154
155 tokio::spawn(async move {
156 while let Some(Ok(resp)) = read.next().await {
157 let x = match resp {
158 TungsteniteMessage::Text(x) => x,
159 _ => continue,
160 };
161
162 let base_event = match serde_json::from_str::<serde_json::Value>(&x) {
163 Ok(base_event) => base_event,
164 _ => continue,
165 };
166
167 let lavalink_client = lavalink_client.clone();
168
169 tokio::spawn(async move {
170 Node::handle_event(lavalink_client, self_node_id, base_event).await;
171 });
172 }
173
174 let self_node = lavalink_client.nodes.get(self_node_id).unwrap();
175 self_node.is_running.store(false, Ordering::SeqCst);
176 error!("Connection Closed.");
177 });
178
179 Ok(())
180 }
181
182 #[cfg(feature = "_websockets")]
184 pub async fn connect(&self, lavalink_client: LavalinkClient) -> Result<(), LavalinkError> {
185 let uri = <::http::Uri as std::str::FromStr>::from_str(&self.websocket_address)?;
186
187 let (client, _) = tokio_websockets::ClientBuilder::from_uri(uri)
188 .add_header(
189 "authorization".try_into().unwrap(),
190 self.password.0.parse()?,
191 )?
192 .add_header(
193 "user-id".try_into().unwrap(),
194 self.user_id.0.to_string().parse()?,
195 )?
196 .add_header(
197 "session-id".try_into().unwrap(),
198 self.session_id.to_string().parse()?,
199 )?
200 .add_header(
201 "client-name".try_into().unwrap(),
202 format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"),)
203 .to_string()
204 .parse()?,
205 )?
206 .connect()
207 .await?;
208
209 info!("Connected to {}", self.websocket_address);
210
211 let (_write, mut read) = client.split();
212
213 self.is_running.store(true, Ordering::SeqCst);
214
215 let self_node_id = self.id;
216
217 tokio::spawn(async move {
218 while let Some(Ok(resp)) = read.next().await {
219 let x = match resp.as_text() {
220 Some(x) => x,
221 _ => continue,
222 };
223
224 let base_event = match serde_json::from_str::<serde_json::Value>(&x) {
225 Ok(base_event) => base_event,
226 _ => continue,
227 };
228
229 let lavalink_client = lavalink_client.clone();
230
231 tokio::spawn(async move {
232 Node::handle_event(lavalink_client, self_node_id, base_event).await;
233 });
234 }
235
236 let self_node = lavalink_client.nodes.get(self_node_id).unwrap();
237 self_node.is_running.store(false, Ordering::SeqCst);
238 error!("Connection Closed.");
239 });
240
241 Ok(())
242 }
243
244 async fn handle_event(
245 lavalink_client: LavalinkClient,
246 self_node_id: usize,
247 base_event: serde_json::Value,
248 ) {
249 let base_event_clone = base_event.clone();
250 let self_node = lavalink_client.nodes.get(self_node_id).unwrap();
251 let ed = EventDispatcher(self_node, &lavalink_client);
252
253 match base_event.get("op").unwrap().as_str().unwrap() {
254 "ready" => {
255 let ready_event: events::Ready = serde_json::from_value(base_event).unwrap();
256
257 self_node
258 .session_id
259 .swap(Arc::new(ready_event.session_id.to_string()));
260
261 #[cfg(feature = "python")]
262 {
263 let session_id = self_node.session_id.load_full();
264
265 if let Some(handler) = &self_node.events.event_handler {
266 handler
267 .event_ready(
268 lavalink_client.clone(),
269 (*session_id).clone(),
270 ready_event.clone(),
271 )
272 .await;
273 }
274 if let Some(handler) = &lavalink_client.events.event_handler {
275 handler
276 .event_ready(
277 lavalink_client.clone(),
278 (*session_id).clone(),
279 ready_event.clone(),
280 )
281 .await;
282 }
283 }
284
285 ed.dispatch(ready_event, |e| e.ready).await;
286 }
287 "playerUpdate" => {
288 let player_update_event: events::PlayerUpdate =
289 serde_json::from_value(base_event).unwrap();
290
291 if let Some(player) =
292 lavalink_client.get_player_context(player_update_event.guild_id)
293 {
294 if let Err(why) = player.update_state(player_update_event.state.clone()) {
295 error!(
296 "Error updating state for player {}: {}",
297 player_update_event.guild_id.0, why
298 );
299 }
300 }
301
302 #[cfg(feature = "python")]
303 {
304 let session_id = self_node.session_id.load_full();
305
306 if let Some(handler) = &self_node.events.event_handler {
307 handler
308 .event_player_update(
309 lavalink_client.clone(),
310 (*session_id).clone(),
311 player_update_event.clone(),
312 )
313 .await;
314 }
315 if let Some(handler) = &lavalink_client.events.event_handler {
316 handler
317 .event_player_update(
318 lavalink_client.clone(),
319 (*session_id).clone(),
320 player_update_event.clone(),
321 )
322 .await;
323 }
324 }
325
326 ed.dispatch(player_update_event, |e| e.player_update).await;
327 }
328 "stats" => {
329 #[cfg(feature = "python")]
330 {
331 let event: events::Stats = serde_json::from_value(base_event).unwrap();
332 let session_id = self_node.session_id.load_full();
333
334 self_node.cpu.store(Arc::new(event.cpu.clone()));
335 self_node.memory.store(Arc::new(event.memory.clone()));
336
337 if let Some(handler) = &self_node.events.event_handler {
338 handler
339 .event_stats(
340 lavalink_client.clone(),
341 (*session_id).clone(),
342 event.clone(),
343 )
344 .await;
345 }
346 if let Some(handler) = &lavalink_client.events.event_handler {
347 handler
348 .event_stats(
349 lavalink_client.clone(),
350 (*session_id).clone(),
351 event.clone(),
352 )
353 .await;
354 }
355
356 ed.dispatch(event, |e| e.stats).await;
357 }
358 #[cfg(not(feature = "python"))]
359 ed.parse_and_dispatch(base_event, |e| e.stats).await;
360 }
361 "event" => match base_event.get("type").unwrap().as_str().unwrap() {
362 "TrackStartEvent" => {
363 let track_event: events::TrackStart =
364 serde_json::from_value(base_event).unwrap();
365
366 if let Some(player) = lavalink_client.get_player_context(track_event.guild_id) {
367 if let Err(why) = player.update_track(track_event.track.clone().into()) {
368 error!(
369 "Error sending update track message for player {}: {}",
370 track_event.guild_id.0, why
371 );
372 }
373 }
374
375 #[cfg(feature = "python")]
376 {
377 let session_id = self_node.session_id.load_full();
378
379 if let Some(handler) = &self_node.events.event_handler {
380 handler
381 .event_track_start(
382 lavalink_client.clone(),
383 (*session_id).clone(),
384 track_event.clone(),
385 )
386 .await;
387 }
388 if let Some(handler) = &lavalink_client.events.event_handler {
389 handler
390 .event_track_start(
391 lavalink_client.clone(),
392 (*session_id).clone(),
393 track_event.clone(),
394 )
395 .await;
396 }
397 }
398
399 ed.dispatch(track_event, |e| e.track_start).await;
400 }
401 "TrackEndEvent" => {
402 let track_event: events::TrackEnd = serde_json::from_value(base_event).unwrap();
403
404 if let Some(player) = lavalink_client.get_player_context(track_event.guild_id) {
405 if let Err(why) = player.finish(track_event.reason.clone().into()) {
406 error!(
407 "Error sending finish message for player {}: {}",
408 track_event.guild_id.0, why
409 );
410 }
411
412 if let Err(why) = player.update_track(None) {
413 error!(
414 "Error sending update track message for player {}: {}",
415 track_event.guild_id.0, why
416 );
417 }
418 }
419
420 #[cfg(feature = "python")]
421 {
422 let session_id = self_node.session_id.load_full();
423
424 if let Some(handler) = &self_node.events.event_handler {
425 handler
426 .event_track_end(
427 lavalink_client.clone(),
428 (*session_id).clone(),
429 track_event.clone(),
430 )
431 .await;
432 }
433 if let Some(handler) = &lavalink_client.events.event_handler {
434 handler
435 .event_track_end(
436 lavalink_client.clone(),
437 (*session_id).clone(),
438 track_event.clone(),
439 )
440 .await;
441 }
442 }
443
444 ed.dispatch(track_event, |e| e.track_end).await;
445 }
446 "TrackExceptionEvent" => {
447 #[cfg(feature = "python")]
448 {
449 let event: events::TrackException =
450 serde_json::from_value(base_event).unwrap();
451 let session_id = self_node.session_id.load_full();
452
453 if let Some(handler) = &self_node.events.event_handler {
454 handler
455 .event_track_exception(
456 lavalink_client.clone(),
457 (*session_id).clone(),
458 event.clone(),
459 )
460 .await;
461 }
462 if let Some(handler) = &lavalink_client.events.event_handler {
463 handler
464 .event_track_exception(
465 lavalink_client.clone(),
466 (*session_id).clone(),
467 event.clone(),
468 )
469 .await;
470 }
471
472 ed.dispatch(event, |e| e.track_exception).await;
473 }
474 #[cfg(not(feature = "python"))]
475 ed.parse_and_dispatch(base_event, |e| e.track_exception)
476 .await;
477 }
478 "TrackStuckEvent" => {
479 #[cfg(feature = "python")]
480 {
481 let event: events::TrackStuck = serde_json::from_value(base_event).unwrap();
482 let session_id = self_node.session_id.load_full();
483
484 if let Some(handler) = &self_node.events.event_handler {
485 handler
486 .event_track_stuck(
487 lavalink_client.clone(),
488 (*session_id).clone(),
489 event.clone(),
490 )
491 .await;
492 }
493 if let Some(handler) = &lavalink_client.events.event_handler {
494 handler
495 .event_track_stuck(
496 lavalink_client.clone(),
497 (*session_id).clone(),
498 event.clone(),
499 )
500 .await;
501 }
502
503 ed.dispatch(event, |e| e.track_stuck).await;
504 }
505 #[cfg(not(feature = "python"))]
506 ed.parse_and_dispatch(base_event, |e| e.track_stuck).await;
507 }
508 "WebSocketClosedEvent" => {
509 #[cfg(feature = "python")]
510 {
511 let event: events::WebSocketClosed =
512 serde_json::from_value(base_event).unwrap();
513 let session_id = self_node.session_id.load_full();
514
515 if let Some(handler) = &self_node.events.event_handler {
516 handler
517 .event_websocket_closed(
518 lavalink_client.clone(),
519 (*session_id).clone(),
520 event.clone(),
521 )
522 .await;
523 }
524 if let Some(handler) = &lavalink_client.events.event_handler {
525 handler
526 .event_websocket_closed(
527 lavalink_client.clone(),
528 (*session_id).clone(),
529 event.clone(),
530 )
531 .await;
532 }
533
534 ed.dispatch(event, |e| e.websocket_closed).await;
535 }
536 #[cfg(not(feature = "python"))]
537 ed.parse_and_dispatch(base_event, |e| e.websocket_closed)
538 .await;
539 }
540 _ => (),
541 },
542
543 _ => (),
544 }
545
546 ed.dispatch(base_event_clone, |e| e.raw).await;
547 }
548}