librespot_core/dealer/
mod.rs

1pub mod manager;
2mod maps;
3pub mod protocol;
4
5use std::{
6    iter,
7    pin::Pin,
8    sync::{
9        Arc,
10        atomic::{self, AtomicBool},
11    },
12    task::Poll,
13    time::Duration,
14};
15
16use futures_core::{Future, Stream};
17use futures_util::{SinkExt, StreamExt, future::join_all};
18use parking_lot::Mutex;
19use thiserror::Error;
20use tokio::{
21    select,
22    sync::{
23        Semaphore,
24        mpsc::{self, UnboundedReceiver},
25    },
26    task::JoinHandle,
27};
28use tokio_tungstenite::tungstenite;
29use tungstenite::error::UrlError;
30use url::Url;
31
32use self::{
33    maps::*,
34    protocol::{Message, MessageOrRequest, Request, WebsocketMessage, WebsocketRequest},
35};
36
37use crate::{
38    Error, socket,
39    util::{CancelOnDrop, TimeoutOnDrop, keep_flushing},
40};
41
42type WsMessage = tungstenite::Message;
43type WsError = tungstenite::Error;
44type WsResult<T> = Result<T, Error>;
45type GetUrlResult = Result<Url, Error>;
46
47impl From<WsError> for Error {
48    fn from(err: WsError) -> Self {
49        Error::failed_precondition(err)
50    }
51}
52
53const WEBSOCKET_CLOSE_TIMEOUT: Duration = Duration::from_secs(3);
54
55const PING_INTERVAL: Duration = Duration::from_secs(30);
56const PING_TIMEOUT: Duration = Duration::from_secs(3);
57
58const RECONNECT_INTERVAL: Duration = Duration::from_secs(10);
59
60struct Response {
61    pub success: bool,
62}
63
64struct Responder {
65    key: String,
66    tx: mpsc::UnboundedSender<WsMessage>,
67    sent: bool,
68}
69
70impl Responder {
71    fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
72        Self {
73            key,
74            tx,
75            sent: false,
76        }
77    }
78
79    // Should only be called once
80    fn send_internal(&mut self, response: Response) {
81        let response = serde_json::json!({
82            "type": "reply",
83            "key": &self.key,
84            "payload": {
85                "success": response.success,
86            }
87        })
88        .to_string();
89
90        if let Err(e) = self.tx.send(WsMessage::Text(response.into())) {
91            warn!("Wasn't able to reply to dealer request: {e}");
92        }
93    }
94
95    pub fn send(mut self, response: Response) {
96        self.send_internal(response);
97        self.sent = true;
98    }
99
100    pub fn force_unanswered(mut self) {
101        self.sent = true;
102    }
103}
104
105impl Drop for Responder {
106    fn drop(&mut self) {
107        if !self.sent {
108            self.send_internal(Response { success: false });
109        }
110    }
111}
112
113trait IntoResponse {
114    fn respond(self, responder: Responder);
115}
116
117impl IntoResponse for Response {
118    fn respond(self, responder: Responder) {
119        responder.send(self)
120    }
121}
122
123impl<F> IntoResponse for F
124where
125    F: Future<Output = Response> + Send + 'static,
126{
127    fn respond(self, responder: Responder) {
128        tokio::spawn(async move {
129            responder.send(self.await);
130        });
131    }
132}
133
134impl<F, R> RequestHandler for F
135where
136    F: (Fn(Request) -> R) + Send + 'static,
137    R: IntoResponse,
138{
139    fn handle_request(&self, request: Request, responder: Responder) {
140        self(request).respond(responder);
141    }
142}
143
144trait RequestHandler: Send + 'static {
145    fn handle_request(&self, request: Request, responder: Responder);
146}
147
148type MessageHandler = mpsc::UnboundedSender<Message>;
149
150// TODO: Maybe it's possible to unregister subscription directly when they
151//       are dropped instead of on next failed attempt.
152pub struct Subscription(UnboundedReceiver<Message>);
153
154impl Stream for Subscription {
155    type Item = Message;
156
157    fn poll_next(
158        mut self: Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160    ) -> Poll<Option<Self::Item>> {
161        self.0.poll_recv(cx)
162    }
163}
164
165fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
166    let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
167        ("hm", '/', rest)
168    } else if let Some(rest) = s.strip_prefix("spotify:") {
169        ("spotify", ':', rest)
170    } else if s.contains('/') {
171        ("", '/', s)
172    } else {
173        return None;
174    };
175
176    let rest = rest.trim_end_matches(sep);
177    let split = rest.split(sep);
178
179    Some(iter::once(scheme).chain(split))
180}
181
182#[derive(Debug, Clone, Error)]
183enum AddHandlerError {
184    #[error("There is already a handler for the given uri")]
185    AlreadyHandled,
186    #[error("The specified uri {0} is invalid")]
187    InvalidUri(String),
188}
189
190impl From<AddHandlerError> for Error {
191    fn from(err: AddHandlerError) -> Self {
192        match err {
193            AddHandlerError::AlreadyHandled => Error::aborted(err),
194            AddHandlerError::InvalidUri(_) => Error::invalid_argument(err),
195        }
196    }
197}
198
199#[derive(Debug, Clone, Error)]
200enum SubscriptionError {
201    #[error("The specified uri is invalid")]
202    InvalidUri(String),
203}
204
205impl From<SubscriptionError> for Error {
206    fn from(err: SubscriptionError) -> Self {
207        Error::invalid_argument(err)
208    }
209}
210
211fn add_handler(
212    map: &mut HandlerMap<Box<dyn RequestHandler>>,
213    uri: &str,
214    handler: impl RequestHandler,
215) -> Result<(), Error> {
216    let split = split_uri(uri).ok_or_else(|| AddHandlerError::InvalidUri(uri.to_string()))?;
217    map.insert(split, Box::new(handler))
218}
219
220fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
221    map.remove(split_uri(uri)?)
222}
223
224fn subscribe(
225    map: &mut SubscriberMap<MessageHandler>,
226    uris: &[&str],
227) -> Result<Subscription, Error> {
228    let (tx, rx) = mpsc::unbounded_channel();
229
230    for &uri in uris {
231        let split = split_uri(uri).ok_or_else(|| SubscriptionError::InvalidUri(uri.to_string()))?;
232        map.insert(split, tx.clone());
233    }
234
235    Ok(Subscription(rx))
236}
237
238fn handles(
239    req_map: &HandlerMap<Box<dyn RequestHandler>>,
240    msg_map: &SubscriberMap<MessageHandler>,
241    uri: &str,
242) -> bool {
243    if req_map.contains(uri) {
244        return true;
245    }
246
247    match split_uri(uri) {
248        None => false,
249        Some(mut split) => msg_map.contains(&mut split),
250    }
251}
252
253#[derive(Default)]
254struct Builder {
255    message_handlers: SubscriberMap<MessageHandler>,
256    request_handlers: HandlerMap<Box<dyn RequestHandler>>,
257}
258
259macro_rules! create_dealer {
260    ($builder:expr, $shared:ident -> $body:expr) => {
261        match $builder {
262            builder => {
263                let shared = Arc::new(DealerShared {
264                    message_handlers: Mutex::new(builder.message_handlers),
265                    request_handlers: Mutex::new(builder.request_handlers),
266                    notify_drop: Semaphore::new(0),
267                });
268
269                let handle = {
270                    let $shared = Arc::clone(&shared);
271                    tokio::spawn($body)
272                };
273
274                Dealer {
275                    shared,
276                    handle: TimeoutOnDrop::new(handle, WEBSOCKET_CLOSE_TIMEOUT),
277                }
278            }
279        }
280    };
281}
282
283impl Builder {
284    pub fn new() -> Self {
285        Self::default()
286    }
287
288    pub fn add_handler(&mut self, uri: &str, handler: impl RequestHandler) -> Result<(), Error> {
289        add_handler(&mut self.request_handlers, uri, handler)
290    }
291
292    pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, Error> {
293        subscribe(&mut self.message_handlers, uris)
294    }
295
296    pub fn handles(&self, uri: &str) -> bool {
297        handles(&self.request_handlers, &self.message_handlers, uri)
298    }
299
300    pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
301    where
302        Fut: Future<Output = GetUrlResult> + Send + 'static,
303        F: (Fn() -> Fut) + Send + 'static,
304    {
305        create_dealer!(self, shared -> run(shared, None, get_url, proxy))
306    }
307
308    pub async fn launch<Fut, F>(self, get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
309    where
310        Fut: Future<Output = GetUrlResult> + Send + 'static,
311        F: (Fn() -> Fut) + Send + 'static,
312    {
313        let dealer = create_dealer!(self, shared -> {
314            // Try to connect.
315            let url = get_url().await?;
316            let tasks = connect(&url, proxy.as_ref(), &shared).await?;
317
318            // If a connection is established, continue in a background task.
319            run(shared, Some(tasks), get_url, proxy)
320        });
321
322        Ok(dealer)
323    }
324}
325
326struct DealerShared {
327    message_handlers: Mutex<SubscriberMap<MessageHandler>>,
328    request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
329
330    // Semaphore with 0 permits. By closing this semaphore, we indicate
331    // that the actual Dealer struct has been dropped.
332    notify_drop: Semaphore,
333}
334
335impl DealerShared {
336    fn dispatch_message(&self, mut msg: WebsocketMessage) {
337        let msg = match msg.handle_payload() {
338            Ok(value) => Message {
339                headers: msg.headers,
340                payload: value,
341                uri: msg.uri,
342            },
343            Err(why) => {
344                warn!("failure during data parsing for {}: {why}", msg.uri);
345                return;
346            }
347        };
348
349        if let Some(split) = split_uri(&msg.uri) {
350            if self
351                .message_handlers
352                .lock()
353                .retain(split, &mut |tx| tx.send(msg.clone()).is_ok())
354            {
355                return;
356            }
357        }
358
359        debug!("No subscriber for msg.uri: {}", msg.uri);
360    }
361
362    fn dispatch_request(
363        &self,
364        request: WebsocketRequest,
365        send_tx: &mpsc::UnboundedSender<WsMessage>,
366    ) {
367        trace!("dealer request {}", &request.message_ident);
368
369        let payload_request = match request.handle_payload() {
370            Ok(payload) => payload,
371            Err(why) => {
372                warn!("request payload handling failed because of {why}");
373                return;
374            }
375        };
376
377        // ResponseSender will automatically send "success: false" if it is dropped without an answer.
378        let responder = Responder::new(request.key.clone(), send_tx.clone());
379
380        let split = if let Some(split) = split_uri(&request.message_ident) {
381            split
382        } else {
383            warn!(
384                "Dealer request with invalid message_ident: {}",
385                &request.message_ident
386            );
387            return;
388        };
389
390        let handler_map = self.request_handlers.lock();
391
392        if let Some(handler) = handler_map.get(split) {
393            handler.handle_request(payload_request, responder);
394            return;
395        }
396
397        warn!("No handler for message_ident: {}", &request.message_ident);
398    }
399
400    fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
401        match m {
402            MessageOrRequest::Message(m) => self.dispatch_message(m),
403            MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
404        }
405    }
406
407    async fn closed(&self) {
408        if self.notify_drop.acquire().await.is_ok() {
409            error!("should never have gotten a permit");
410        }
411    }
412
413    fn is_closed(&self) -> bool {
414        self.notify_drop.is_closed()
415    }
416}
417
418struct Dealer {
419    shared: Arc<DealerShared>,
420    handle: TimeoutOnDrop<Result<(), Error>>,
421}
422
423impl Dealer {
424    pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), Error>
425    where
426        H: RequestHandler,
427    {
428        add_handler(&mut self.shared.request_handlers.lock(), uri, handler)
429    }
430
431    pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
432        remove_handler(&mut self.shared.request_handlers.lock(), uri)
433    }
434
435    pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, Error> {
436        subscribe(&mut self.shared.message_handlers.lock(), uris)
437    }
438
439    pub fn handles(&self, uri: &str) -> bool {
440        handles(
441            &self.shared.request_handlers.lock(),
442            &self.shared.message_handlers.lock(),
443            uri,
444        )
445    }
446
447    pub async fn close(mut self) {
448        debug!("closing dealer");
449
450        self.shared.notify_drop.close();
451
452        if let Some(handle) = self.handle.take() {
453            if let Err(e) = CancelOnDrop(handle).await {
454                error!("error aborting dealer operations: {e}");
455            }
456        }
457    }
458}
459
460/// Initializes a connection and returns futures that will finish when the connection is closed/lost.
461async fn connect(
462    address: &Url,
463    proxy: Option<&Url>,
464    shared: &Arc<DealerShared>,
465) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
466    let host = address
467        .host_str()
468        .ok_or(WsError::Url(UrlError::NoHostName))?;
469
470    let default_port = match address.scheme() {
471        "ws" => 80,
472        "wss" => 443,
473        _ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme).into()),
474    };
475
476    let port = address.port().unwrap_or(default_port);
477
478    let stream = socket::connect(host, port, proxy).await?;
479
480    let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address.as_str(), stream)
481        .await?
482        .0
483        .split();
484
485    let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
486
487    // Spawn a task that will forward messages from the channel to the websocket.
488    let send_task = {
489        let shared = Arc::clone(shared);
490
491        tokio::spawn(async move {
492            let result = loop {
493                select! {
494                    biased;
495                    () = shared.closed() => {
496                        break Ok(None);
497                    }
498                    msg = send_rx.recv() => {
499                        if let Some(msg) = msg {
500                            // New message arrived through channel
501                            if let WsMessage::Close(close_frame) = msg {
502                                break Ok(close_frame);
503                            }
504
505                            if let Err(e) = ws_tx.feed(msg).await  {
506                                break Err(e);
507                            }
508                        } else {
509                            break Ok(None);
510                        }
511                    },
512                    e = keep_flushing(&mut ws_tx) => {
513                        break Err(e)
514                    }
515                    else => (),
516                }
517            };
518
519            send_rx.close();
520
521            // I don't trust in tokio_tungstenite's implementation of Sink::close.
522            let result = match result {
523                Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
524                Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
525                Err(e) => {
526                    warn!("Dealer finished with an error: {e}");
527                    ws_tx.send(WsMessage::Close(None)).await
528                }
529            };
530
531            if let Err(e) = result {
532                warn!("Error while closing websocket: {e}");
533            }
534
535            debug!("Dropping send task");
536        })
537    };
538
539    let shared = Arc::clone(shared);
540
541    // A task that receives messages from the web socket.
542    let receive_task = tokio::spawn(async {
543        let pong_received = AtomicBool::new(true);
544        let send_tx = send_tx;
545        let shared = shared;
546
547        let receive_task = async {
548            let mut ws_rx = ws_rx;
549
550            loop {
551                match ws_rx.next().await {
552                    Some(Ok(msg)) => match msg {
553                        WsMessage::Text(t) => match serde_json::from_str(&t) {
554                            Ok(m) => shared.dispatch(m, &send_tx),
555                            Err(e) => warn!("Message couldn't be parsed: {e}. Message was {t}"),
556                        },
557                        WsMessage::Binary(_) => {
558                            info!("Received invalid binary message");
559                        }
560                        WsMessage::Pong(_) => {
561                            trace!("Received pong");
562                            pong_received.store(true, atomic::Ordering::Relaxed);
563                        }
564                        _ => (), // tungstenite handles Close and Ping automatically
565                    },
566                    Some(Err(e)) => {
567                        warn!("Websocket connection failed: {e}");
568                        break;
569                    }
570                    None => {
571                        debug!("Websocket connection closed.");
572                        break;
573                    }
574                }
575            }
576        };
577
578        // Sends pings and checks whether a pong comes back.
579        let ping_task = async {
580            use tokio::time::{interval, sleep};
581
582            let mut timer = interval(PING_INTERVAL);
583
584            loop {
585                timer.tick().await;
586
587                pong_received.store(false, atomic::Ordering::Relaxed);
588                if send_tx
589                    .send(WsMessage::Ping(bytes::Bytes::default()))
590                    .is_err()
591                {
592                    // The sender is closed.
593                    break;
594                }
595
596                trace!("Sent ping");
597
598                sleep(PING_TIMEOUT).await;
599
600                if !pong_received.load(atomic::Ordering::SeqCst) {
601                    // No response
602                    warn!("Websocket peer does not respond.");
603                    break;
604                }
605            }
606        };
607
608        // Exit this task as soon as one our subtasks fails.
609        // In both cases the connection is probably lost.
610        select! {
611            () = ping_task => (),
612            () = receive_task => ()
613        }
614
615        // Try to take send_task down with us, in case it's still alive.
616        let _ = send_tx.send(WsMessage::Close(None));
617
618        debug!("Dropping receive task");
619    });
620
621    Ok((send_task, receive_task))
622}
623
624/// The main background task for `Dealer`, which coordinates reconnecting.
625async fn run<F, Fut>(
626    shared: Arc<DealerShared>,
627    initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
628    mut get_url: F,
629    proxy: Option<Url>,
630) -> Result<(), Error>
631where
632    Fut: Future<Output = GetUrlResult> + Send + 'static,
633    F: (FnMut() -> Fut) + Send + 'static,
634{
635    let init_task = |t| Some(TimeoutOnDrop::new(t, WEBSOCKET_CLOSE_TIMEOUT));
636
637    let mut tasks = if let Some((s, r)) = initial_tasks {
638        (init_task(s), init_task(r))
639    } else {
640        (None, None)
641    };
642
643    while !shared.is_closed() {
644        match &mut tasks {
645            (Some(t0), Some(t1)) => {
646                select! {
647                    () = shared.closed() => break,
648                    r = t0 => {
649                        if let Err(e) = r {
650                            error!("timeout on task 0: {e}");
651                        }
652                        tasks.0.take();
653                    },
654                    r = t1 => {
655                        if let Err(e) = r {
656                            error!("timeout on task 1: {e}");
657                        }
658                        tasks.1.take();
659                    }
660                }
661            }
662            _ => {
663                let url = select! {
664                    () = shared.closed() => {
665                        break
666                    },
667                    e = get_url() => e
668                }?;
669
670                match connect(&url, proxy.as_ref(), &shared).await {
671                    Ok((s, r)) => tasks = (init_task(s), init_task(r)),
672                    Err(e) => {
673                        error!("Error while connecting: {e}");
674                        tokio::time::sleep(RECONNECT_INTERVAL).await;
675                    }
676                }
677            }
678        }
679    }
680
681    let tasks = tasks.0.into_iter().chain(tasks.1);
682
683    let _ = join_all(tasks).await;
684
685    Ok(())
686}