librespot_core/dealer/
mod.rs

1pub mod manager;
2mod maps;
3pub mod protocol;
4
5use std::{
6    iter,
7    pin::Pin,
8    sync::{
9        Arc, Mutex,
10        atomic::{self, AtomicBool},
11    },
12    task::Poll,
13    time::Duration,
14};
15
16use futures_core::{Future, Stream};
17use futures_util::{SinkExt, StreamExt, future::join_all};
18use thiserror::Error;
19use tokio::{
20    select,
21    sync::{
22        Semaphore,
23        mpsc::{self, UnboundedReceiver},
24    },
25    task::JoinHandle,
26};
27use tokio_tungstenite::tungstenite;
28use tungstenite::error::UrlError;
29use url::Url;
30
31use self::{
32    maps::*,
33    protocol::{Message, MessageOrRequest, Request, WebsocketMessage, WebsocketRequest},
34};
35
36use crate::{
37    Error, socket,
38    util::{CancelOnDrop, TimeoutOnDrop, keep_flushing},
39};
40
41type WsMessage = tungstenite::Message;
42type WsError = tungstenite::Error;
43type WsResult<T> = Result<T, Error>;
44type GetUrlResult = Result<Url, Error>;
45
46impl From<WsError> for Error {
47    fn from(err: WsError) -> Self {
48        Error::failed_precondition(err)
49    }
50}
51
52const WEBSOCKET_CLOSE_TIMEOUT: Duration = Duration::from_secs(3);
53
54const PING_INTERVAL: Duration = Duration::from_secs(30);
55const PING_TIMEOUT: Duration = Duration::from_secs(3);
56
57const RECONNECT_INTERVAL: Duration = Duration::from_secs(10);
58
59const DEALER_REQUEST_HANDLERS_POISON_MSG: &str =
60    "dealer request handlers mutex should not be poisoned";
61const DEALER_MESSAGE_HANDLERS_POISON_MSG: &str =
62    "dealer message handlers mutex should not be poisoned";
63
64struct Response {
65    pub success: bool,
66}
67
68struct Responder {
69    key: String,
70    tx: mpsc::UnboundedSender<WsMessage>,
71    sent: bool,
72}
73
74impl Responder {
75    fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
76        Self {
77            key,
78            tx,
79            sent: false,
80        }
81    }
82
83    // Should only be called once
84    fn send_internal(&mut self, response: Response) {
85        let response = serde_json::json!({
86            "type": "reply",
87            "key": &self.key,
88            "payload": {
89                "success": response.success,
90            }
91        })
92        .to_string();
93
94        if let Err(e) = self.tx.send(WsMessage::Text(response.into())) {
95            warn!("Wasn't able to reply to dealer request: {e}");
96        }
97    }
98
99    pub fn send(mut self, response: Response) {
100        self.send_internal(response);
101        self.sent = true;
102    }
103
104    pub fn force_unanswered(mut self) {
105        self.sent = true;
106    }
107}
108
109impl Drop for Responder {
110    fn drop(&mut self) {
111        if !self.sent {
112            self.send_internal(Response { success: false });
113        }
114    }
115}
116
117trait IntoResponse {
118    fn respond(self, responder: Responder);
119}
120
121impl IntoResponse for Response {
122    fn respond(self, responder: Responder) {
123        responder.send(self)
124    }
125}
126
127impl<F> IntoResponse for F
128where
129    F: Future<Output = Response> + Send + 'static,
130{
131    fn respond(self, responder: Responder) {
132        tokio::spawn(async move {
133            responder.send(self.await);
134        });
135    }
136}
137
138impl<F, R> RequestHandler for F
139where
140    F: (Fn(Request) -> R) + Send + 'static,
141    R: IntoResponse,
142{
143    fn handle_request(&self, request: Request, responder: Responder) {
144        self(request).respond(responder);
145    }
146}
147
148trait RequestHandler: Send + 'static {
149    fn handle_request(&self, request: Request, responder: Responder);
150}
151
152type MessageHandler = mpsc::UnboundedSender<Message>;
153
154// TODO: Maybe it's possible to unregister subscription directly when they
155//       are dropped instead of on next failed attempt.
156pub struct Subscription(UnboundedReceiver<Message>);
157
158impl Stream for Subscription {
159    type Item = Message;
160
161    fn poll_next(
162        mut self: Pin<&mut Self>,
163        cx: &mut std::task::Context<'_>,
164    ) -> Poll<Option<Self::Item>> {
165        self.0.poll_recv(cx)
166    }
167}
168
169fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
170    let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
171        ("hm", '/', rest)
172    } else if let Some(rest) = s.strip_prefix("spotify:") {
173        ("spotify", ':', rest)
174    } else if s.contains('/') {
175        ("", '/', s)
176    } else {
177        return None;
178    };
179
180    let rest = rest.trim_end_matches(sep);
181    let split = rest.split(sep);
182
183    Some(iter::once(scheme).chain(split))
184}
185
186#[derive(Debug, Clone, Error)]
187enum AddHandlerError {
188    #[error("There is already a handler for the given uri")]
189    AlreadyHandled,
190    #[error("The specified uri {0} is invalid")]
191    InvalidUri(String),
192}
193
194impl From<AddHandlerError> for Error {
195    fn from(err: AddHandlerError) -> Self {
196        match err {
197            AddHandlerError::AlreadyHandled => Error::aborted(err),
198            AddHandlerError::InvalidUri(_) => Error::invalid_argument(err),
199        }
200    }
201}
202
203#[derive(Debug, Clone, Error)]
204enum SubscriptionError {
205    #[error("The specified uri is invalid")]
206    InvalidUri(String),
207}
208
209impl From<SubscriptionError> for Error {
210    fn from(err: SubscriptionError) -> Self {
211        Error::invalid_argument(err)
212    }
213}
214
215fn add_handler(
216    map: &mut HandlerMap<Box<dyn RequestHandler>>,
217    uri: &str,
218    handler: impl RequestHandler,
219) -> Result<(), Error> {
220    let split = split_uri(uri).ok_or_else(|| AddHandlerError::InvalidUri(uri.to_string()))?;
221    map.insert(split, Box::new(handler))
222}
223
224fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
225    map.remove(split_uri(uri)?)
226}
227
228fn subscribe(
229    map: &mut SubscriberMap<MessageHandler>,
230    uris: &[&str],
231) -> Result<Subscription, Error> {
232    let (tx, rx) = mpsc::unbounded_channel();
233
234    for &uri in uris {
235        let split = split_uri(uri).ok_or_else(|| SubscriptionError::InvalidUri(uri.to_string()))?;
236        map.insert(split, tx.clone());
237    }
238
239    Ok(Subscription(rx))
240}
241
242fn handles(
243    req_map: &HandlerMap<Box<dyn RequestHandler>>,
244    msg_map: &SubscriberMap<MessageHandler>,
245    uri: &str,
246) -> bool {
247    if req_map.contains(uri) {
248        return true;
249    }
250
251    match split_uri(uri) {
252        None => false,
253        Some(mut split) => msg_map.contains(&mut split),
254    }
255}
256
257#[derive(Default)]
258struct Builder {
259    message_handlers: SubscriberMap<MessageHandler>,
260    request_handlers: HandlerMap<Box<dyn RequestHandler>>,
261}
262
263macro_rules! create_dealer {
264    ($builder:expr, $shared:ident -> $body:expr) => {
265        match $builder {
266            builder => {
267                let shared = Arc::new(DealerShared {
268                    message_handlers: Mutex::new(builder.message_handlers),
269                    request_handlers: Mutex::new(builder.request_handlers),
270                    notify_drop: Semaphore::new(0),
271                });
272
273                let handle = {
274                    let $shared = Arc::clone(&shared);
275                    tokio::spawn($body)
276                };
277
278                Dealer {
279                    shared,
280                    handle: TimeoutOnDrop::new(handle, WEBSOCKET_CLOSE_TIMEOUT),
281                }
282            }
283        }
284    };
285}
286
287impl Builder {
288    pub fn new() -> Self {
289        Self::default()
290    }
291
292    pub fn add_handler(&mut self, uri: &str, handler: impl RequestHandler) -> Result<(), Error> {
293        add_handler(&mut self.request_handlers, uri, handler)
294    }
295
296    pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, Error> {
297        subscribe(&mut self.message_handlers, uris)
298    }
299
300    pub fn handles(&self, uri: &str) -> bool {
301        handles(&self.request_handlers, &self.message_handlers, uri)
302    }
303
304    pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
305    where
306        Fut: Future<Output = GetUrlResult> + Send + 'static,
307        F: (Fn() -> Fut) + Send + 'static,
308    {
309        create_dealer!(self, shared -> run(shared, None, get_url, proxy))
310    }
311
312    pub async fn launch<Fut, F>(self, get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
313    where
314        Fut: Future<Output = GetUrlResult> + Send + 'static,
315        F: (Fn() -> Fut) + Send + 'static,
316    {
317        let dealer = create_dealer!(self, shared -> {
318            // Try to connect.
319            let url = get_url().await?;
320            let tasks = connect(&url, proxy.as_ref(), &shared).await?;
321
322            // If a connection is established, continue in a background task.
323            run(shared, Some(tasks), get_url, proxy)
324        });
325
326        Ok(dealer)
327    }
328}
329
330struct DealerShared {
331    message_handlers: Mutex<SubscriberMap<MessageHandler>>,
332    request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
333
334    // Semaphore with 0 permits. By closing this semaphore, we indicate
335    // that the actual Dealer struct has been dropped.
336    notify_drop: Semaphore,
337}
338
339impl DealerShared {
340    fn dispatch_message(&self, mut msg: WebsocketMessage) {
341        let msg = match msg.handle_payload() {
342            Ok(value) => Message {
343                headers: msg.headers,
344                payload: value,
345                uri: msg.uri,
346            },
347            Err(why) => {
348                warn!("failure during data parsing for {}: {why}", msg.uri);
349                return;
350            }
351        };
352
353        if let Some(split) = split_uri(&msg.uri) {
354            if self
355                .message_handlers
356                .lock()
357                .expect(DEALER_MESSAGE_HANDLERS_POISON_MSG)
358                .retain(split, &mut |tx| tx.send(msg.clone()).is_ok())
359            {
360                return;
361            }
362        }
363
364        debug!("No subscriber for msg.uri: {}", msg.uri);
365    }
366
367    fn dispatch_request(
368        &self,
369        request: WebsocketRequest,
370        send_tx: &mpsc::UnboundedSender<WsMessage>,
371    ) {
372        trace!("dealer request {}", &request.message_ident);
373
374        let payload_request = match request.handle_payload() {
375            Ok(payload) => payload,
376            Err(why) => {
377                warn!("request payload handling failed because of {why}");
378                return;
379            }
380        };
381
382        // ResponseSender will automatically send "success: false" if it is dropped without an answer.
383        let responder = Responder::new(request.key.clone(), send_tx.clone());
384
385        let split = if let Some(split) = split_uri(&request.message_ident) {
386            split
387        } else {
388            warn!(
389                "Dealer request with invalid message_ident: {}",
390                &request.message_ident
391            );
392            return;
393        };
394
395        let handler_map = self
396            .request_handlers
397            .lock()
398            .expect(DEALER_REQUEST_HANDLERS_POISON_MSG);
399
400        if let Some(handler) = handler_map.get(split) {
401            handler.handle_request(payload_request, responder);
402            return;
403        }
404
405        warn!("No handler for message_ident: {}", &request.message_ident);
406    }
407
408    fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
409        match m {
410            MessageOrRequest::Message(m) => self.dispatch_message(m),
411            MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
412        }
413    }
414
415    async fn closed(&self) {
416        if self.notify_drop.acquire().await.is_ok() {
417            error!("should never have gotten a permit");
418        }
419    }
420
421    fn is_closed(&self) -> bool {
422        self.notify_drop.is_closed()
423    }
424}
425
426struct Dealer {
427    shared: Arc<DealerShared>,
428    handle: TimeoutOnDrop<Result<(), Error>>,
429}
430
431impl Dealer {
432    pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), Error>
433    where
434        H: RequestHandler,
435    {
436        add_handler(
437            &mut self
438                .shared
439                .request_handlers
440                .lock()
441                .expect(DEALER_REQUEST_HANDLERS_POISON_MSG),
442            uri,
443            handler,
444        )
445    }
446
447    pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
448        remove_handler(
449            &mut self
450                .shared
451                .request_handlers
452                .lock()
453                .expect(DEALER_REQUEST_HANDLERS_POISON_MSG),
454            uri,
455        )
456    }
457
458    pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, Error> {
459        subscribe(
460            &mut self
461                .shared
462                .message_handlers
463                .lock()
464                .expect(DEALER_MESSAGE_HANDLERS_POISON_MSG),
465            uris,
466        )
467    }
468
469    pub fn handles(&self, uri: &str) -> bool {
470        handles(
471            &self
472                .shared
473                .request_handlers
474                .lock()
475                .expect(DEALER_REQUEST_HANDLERS_POISON_MSG),
476            &self
477                .shared
478                .message_handlers
479                .lock()
480                .expect(DEALER_MESSAGE_HANDLERS_POISON_MSG),
481            uri,
482        )
483    }
484
485    pub async fn close(mut self) {
486        debug!("closing dealer");
487
488        self.shared.notify_drop.close();
489
490        if let Some(handle) = self.handle.take() {
491            if let Err(e) = CancelOnDrop(handle).await {
492                error!("error aborting dealer operations: {e}");
493            }
494        }
495    }
496}
497
498/// Initializes a connection and returns futures that will finish when the connection is closed/lost.
499async fn connect(
500    address: &Url,
501    proxy: Option<&Url>,
502    shared: &Arc<DealerShared>,
503) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
504    let host = address
505        .host_str()
506        .ok_or(WsError::Url(UrlError::NoHostName))?;
507
508    let default_port = match address.scheme() {
509        "ws" => 80,
510        "wss" => 443,
511        _ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme).into()),
512    };
513
514    let port = address.port().unwrap_or(default_port);
515
516    let stream = socket::connect(host, port, proxy).await?;
517
518    let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address.as_str(), stream)
519        .await?
520        .0
521        .split();
522
523    let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
524
525    // Spawn a task that will forward messages from the channel to the websocket.
526    let send_task = {
527        let shared = Arc::clone(shared);
528
529        tokio::spawn(async move {
530            let result = loop {
531                select! {
532                    biased;
533                    () = shared.closed() => {
534                        break Ok(None);
535                    }
536                    msg = send_rx.recv() => {
537                        if let Some(msg) = msg {
538                            // New message arrived through channel
539                            if let WsMessage::Close(close_frame) = msg {
540                                break Ok(close_frame);
541                            }
542
543                            if let Err(e) = ws_tx.feed(msg).await  {
544                                break Err(e);
545                            }
546                        } else {
547                            break Ok(None);
548                        }
549                    },
550                    e = keep_flushing(&mut ws_tx) => {
551                        break Err(e)
552                    }
553                    else => (),
554                }
555            };
556
557            send_rx.close();
558
559            // I don't trust in tokio_tungstenite's implementation of Sink::close.
560            let result = match result {
561                Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
562                Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
563                Err(e) => {
564                    warn!("Dealer finished with an error: {e}");
565                    ws_tx.send(WsMessage::Close(None)).await
566                }
567            };
568
569            if let Err(e) = result {
570                warn!("Error while closing websocket: {e}");
571            }
572
573            debug!("Dropping send task");
574        })
575    };
576
577    let shared = Arc::clone(shared);
578
579    // A task that receives messages from the web socket.
580    let receive_task = tokio::spawn(async {
581        let pong_received = AtomicBool::new(true);
582        let send_tx = send_tx;
583        let shared = shared;
584
585        let receive_task = async {
586            let mut ws_rx = ws_rx;
587
588            loop {
589                match ws_rx.next().await {
590                    Some(Ok(msg)) => match msg {
591                        WsMessage::Text(t) => match serde_json::from_str(&t) {
592                            Ok(m) => shared.dispatch(m, &send_tx),
593                            Err(e) => warn!("Message couldn't be parsed: {e}. Message was {t}"),
594                        },
595                        WsMessage::Binary(_) => {
596                            info!("Received invalid binary message");
597                        }
598                        WsMessage::Pong(_) => {
599                            trace!("Received pong");
600                            pong_received.store(true, atomic::Ordering::Relaxed);
601                        }
602                        _ => (), // tungstenite handles Close and Ping automatically
603                    },
604                    Some(Err(e)) => {
605                        warn!("Websocket connection failed: {e}");
606                        break;
607                    }
608                    None => {
609                        debug!("Websocket connection closed.");
610                        break;
611                    }
612                }
613            }
614        };
615
616        // Sends pings and checks whether a pong comes back.
617        let ping_task = async {
618            use tokio::time::{interval, sleep};
619
620            let mut timer = interval(PING_INTERVAL);
621
622            loop {
623                timer.tick().await;
624
625                pong_received.store(false, atomic::Ordering::Relaxed);
626                if send_tx
627                    .send(WsMessage::Ping(bytes::Bytes::default()))
628                    .is_err()
629                {
630                    // The sender is closed.
631                    break;
632                }
633
634                trace!("Sent ping");
635
636                sleep(PING_TIMEOUT).await;
637
638                if !pong_received.load(atomic::Ordering::SeqCst) {
639                    // No response
640                    warn!("Websocket peer does not respond.");
641                    break;
642                }
643            }
644        };
645
646        // Exit this task as soon as one our subtasks fails.
647        // In both cases the connection is probably lost.
648        select! {
649            () = ping_task => (),
650            () = receive_task => ()
651        }
652
653        // Try to take send_task down with us, in case it's still alive.
654        let _ = send_tx.send(WsMessage::Close(None));
655
656        debug!("Dropping receive task");
657    });
658
659    Ok((send_task, receive_task))
660}
661
662/// The main background task for `Dealer`, which coordinates reconnecting.
663async fn run<F, Fut>(
664    shared: Arc<DealerShared>,
665    initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
666    mut get_url: F,
667    proxy: Option<Url>,
668) -> Result<(), Error>
669where
670    Fut: Future<Output = GetUrlResult> + Send + 'static,
671    F: (FnMut() -> Fut) + Send + 'static,
672{
673    let init_task = |t| Some(TimeoutOnDrop::new(t, WEBSOCKET_CLOSE_TIMEOUT));
674
675    let mut tasks = if let Some((s, r)) = initial_tasks {
676        (init_task(s), init_task(r))
677    } else {
678        (None, None)
679    };
680
681    while !shared.is_closed() {
682        match &mut tasks {
683            (Some(t0), Some(t1)) => {
684                select! {
685                    () = shared.closed() => break,
686                    r = t0 => {
687                        if let Err(e) = r {
688                            error!("timeout on task 0: {e}");
689                        }
690                        tasks.0.take();
691                    },
692                    r = t1 => {
693                        if let Err(e) = r {
694                            error!("timeout on task 1: {e}");
695                        }
696                        tasks.1.take();
697                    }
698                }
699            }
700            _ => {
701                let url = select! {
702                    () = shared.closed() => {
703                        break
704                    },
705                    e = get_url() => e
706                }?;
707
708                match connect(&url, proxy.as_ref(), &shared).await {
709                    Ok((s, r)) => tasks = (init_task(s), init_task(r)),
710                    Err(e) => {
711                        error!("Error while connecting: {e}");
712                        tokio::time::sleep(RECONNECT_INTERVAL).await;
713                    }
714                }
715            }
716        }
717    }
718
719    let tasks = tasks.0.into_iter().chain(tasks.1);
720
721    let _ = join_all(tasks).await;
722
723    Ok(())
724}