1pub mod manager;
2mod maps;
3pub mod protocol;
4
5use std::{
6 iter,
7 pin::Pin,
8 sync::{
9 Arc, Mutex,
10 atomic::{self, AtomicBool},
11 },
12 task::Poll,
13 time::Duration,
14};
15
16use futures_core::{Future, Stream};
17use futures_util::{SinkExt, StreamExt, future::join_all};
18use thiserror::Error;
19use tokio::{
20 select,
21 sync::{
22 Semaphore,
23 mpsc::{self, UnboundedReceiver},
24 },
25 task::JoinHandle,
26};
27use tokio_tungstenite::tungstenite;
28use tungstenite::error::UrlError;
29use url::Url;
30
31use self::{
32 maps::*,
33 protocol::{Message, MessageOrRequest, Request, WebsocketMessage, WebsocketRequest},
34};
35
36use crate::{
37 Error, socket,
38 util::{CancelOnDrop, TimeoutOnDrop, keep_flushing},
39};
40
41type WsMessage = tungstenite::Message;
42type WsError = tungstenite::Error;
43type WsResult<T> = Result<T, Error>;
44type GetUrlResult = Result<Url, Error>;
45
46impl From<WsError> for Error {
47 fn from(err: WsError) -> Self {
48 Error::failed_precondition(err)
49 }
50}
51
52const WEBSOCKET_CLOSE_TIMEOUT: Duration = Duration::from_secs(3);
53
54const PING_INTERVAL: Duration = Duration::from_secs(30);
55const PING_TIMEOUT: Duration = Duration::from_secs(3);
56
57const RECONNECT_INTERVAL: Duration = Duration::from_secs(10);
58
59const DEALER_REQUEST_HANDLERS_POISON_MSG: &str =
60 "dealer request handlers mutex should not be poisoned";
61const DEALER_MESSAGE_HANDLERS_POISON_MSG: &str =
62 "dealer message handlers mutex should not be poisoned";
63
64struct Response {
65 pub success: bool,
66}
67
68struct Responder {
69 key: String,
70 tx: mpsc::UnboundedSender<WsMessage>,
71 sent: bool,
72}
73
74impl Responder {
75 fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
76 Self {
77 key,
78 tx,
79 sent: false,
80 }
81 }
82
83 fn send_internal(&mut self, response: Response) {
85 let response = serde_json::json!({
86 "type": "reply",
87 "key": &self.key,
88 "payload": {
89 "success": response.success,
90 }
91 })
92 .to_string();
93
94 if let Err(e) = self.tx.send(WsMessage::Text(response.into())) {
95 warn!("Wasn't able to reply to dealer request: {e}");
96 }
97 }
98
99 pub fn send(mut self, response: Response) {
100 self.send_internal(response);
101 self.sent = true;
102 }
103
104 pub fn force_unanswered(mut self) {
105 self.sent = true;
106 }
107}
108
109impl Drop for Responder {
110 fn drop(&mut self) {
111 if !self.sent {
112 self.send_internal(Response { success: false });
113 }
114 }
115}
116
117trait IntoResponse {
118 fn respond(self, responder: Responder);
119}
120
121impl IntoResponse for Response {
122 fn respond(self, responder: Responder) {
123 responder.send(self)
124 }
125}
126
127impl<F> IntoResponse for F
128where
129 F: Future<Output = Response> + Send + 'static,
130{
131 fn respond(self, responder: Responder) {
132 tokio::spawn(async move {
133 responder.send(self.await);
134 });
135 }
136}
137
138impl<F, R> RequestHandler for F
139where
140 F: (Fn(Request) -> R) + Send + 'static,
141 R: IntoResponse,
142{
143 fn handle_request(&self, request: Request, responder: Responder) {
144 self(request).respond(responder);
145 }
146}
147
148trait RequestHandler: Send + 'static {
149 fn handle_request(&self, request: Request, responder: Responder);
150}
151
152type MessageHandler = mpsc::UnboundedSender<Message>;
153
154pub struct Subscription(UnboundedReceiver<Message>);
157
158impl Stream for Subscription {
159 type Item = Message;
160
161 fn poll_next(
162 mut self: Pin<&mut Self>,
163 cx: &mut std::task::Context<'_>,
164 ) -> Poll<Option<Self::Item>> {
165 self.0.poll_recv(cx)
166 }
167}
168
169fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
170 let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
171 ("hm", '/', rest)
172 } else if let Some(rest) = s.strip_prefix("spotify:") {
173 ("spotify", ':', rest)
174 } else if s.contains('/') {
175 ("", '/', s)
176 } else {
177 return None;
178 };
179
180 let rest = rest.trim_end_matches(sep);
181 let split = rest.split(sep);
182
183 Some(iter::once(scheme).chain(split))
184}
185
186#[derive(Debug, Clone, Error)]
187enum AddHandlerError {
188 #[error("There is already a handler for the given uri")]
189 AlreadyHandled,
190 #[error("The specified uri {0} is invalid")]
191 InvalidUri(String),
192}
193
194impl From<AddHandlerError> for Error {
195 fn from(err: AddHandlerError) -> Self {
196 match err {
197 AddHandlerError::AlreadyHandled => Error::aborted(err),
198 AddHandlerError::InvalidUri(_) => Error::invalid_argument(err),
199 }
200 }
201}
202
203#[derive(Debug, Clone, Error)]
204enum SubscriptionError {
205 #[error("The specified uri is invalid")]
206 InvalidUri(String),
207}
208
209impl From<SubscriptionError> for Error {
210 fn from(err: SubscriptionError) -> Self {
211 Error::invalid_argument(err)
212 }
213}
214
215fn add_handler(
216 map: &mut HandlerMap<Box<dyn RequestHandler>>,
217 uri: &str,
218 handler: impl RequestHandler,
219) -> Result<(), Error> {
220 let split = split_uri(uri).ok_or_else(|| AddHandlerError::InvalidUri(uri.to_string()))?;
221 map.insert(split, Box::new(handler))
222}
223
224fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
225 map.remove(split_uri(uri)?)
226}
227
228fn subscribe(
229 map: &mut SubscriberMap<MessageHandler>,
230 uris: &[&str],
231) -> Result<Subscription, Error> {
232 let (tx, rx) = mpsc::unbounded_channel();
233
234 for &uri in uris {
235 let split = split_uri(uri).ok_or_else(|| SubscriptionError::InvalidUri(uri.to_string()))?;
236 map.insert(split, tx.clone());
237 }
238
239 Ok(Subscription(rx))
240}
241
242fn handles(
243 req_map: &HandlerMap<Box<dyn RequestHandler>>,
244 msg_map: &SubscriberMap<MessageHandler>,
245 uri: &str,
246) -> bool {
247 if req_map.contains(uri) {
248 return true;
249 }
250
251 match split_uri(uri) {
252 None => false,
253 Some(mut split) => msg_map.contains(&mut split),
254 }
255}
256
257#[derive(Default)]
258struct Builder {
259 message_handlers: SubscriberMap<MessageHandler>,
260 request_handlers: HandlerMap<Box<dyn RequestHandler>>,
261}
262
263macro_rules! create_dealer {
264 ($builder:expr, $shared:ident -> $body:expr) => {
265 match $builder {
266 builder => {
267 let shared = Arc::new(DealerShared {
268 message_handlers: Mutex::new(builder.message_handlers),
269 request_handlers: Mutex::new(builder.request_handlers),
270 notify_drop: Semaphore::new(0),
271 });
272
273 let handle = {
274 let $shared = Arc::clone(&shared);
275 tokio::spawn($body)
276 };
277
278 Dealer {
279 shared,
280 handle: TimeoutOnDrop::new(handle, WEBSOCKET_CLOSE_TIMEOUT),
281 }
282 }
283 }
284 };
285}
286
287impl Builder {
288 pub fn new() -> Self {
289 Self::default()
290 }
291
292 pub fn add_handler(&mut self, uri: &str, handler: impl RequestHandler) -> Result<(), Error> {
293 add_handler(&mut self.request_handlers, uri, handler)
294 }
295
296 pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, Error> {
297 subscribe(&mut self.message_handlers, uris)
298 }
299
300 pub fn handles(&self, uri: &str) -> bool {
301 handles(&self.request_handlers, &self.message_handlers, uri)
302 }
303
304 pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
305 where
306 Fut: Future<Output = GetUrlResult> + Send + 'static,
307 F: (Fn() -> Fut) + Send + 'static,
308 {
309 create_dealer!(self, shared -> run(shared, None, get_url, proxy))
310 }
311
312 pub async fn launch<Fut, F>(self, get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
313 where
314 Fut: Future<Output = GetUrlResult> + Send + 'static,
315 F: (Fn() -> Fut) + Send + 'static,
316 {
317 let dealer = create_dealer!(self, shared -> {
318 let url = get_url().await?;
320 let tasks = connect(&url, proxy.as_ref(), &shared).await?;
321
322 run(shared, Some(tasks), get_url, proxy)
324 });
325
326 Ok(dealer)
327 }
328}
329
330struct DealerShared {
331 message_handlers: Mutex<SubscriberMap<MessageHandler>>,
332 request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
333
334 notify_drop: Semaphore,
337}
338
339impl DealerShared {
340 fn dispatch_message(&self, mut msg: WebsocketMessage) {
341 let msg = match msg.handle_payload() {
342 Ok(value) => Message {
343 headers: msg.headers,
344 payload: value,
345 uri: msg.uri,
346 },
347 Err(why) => {
348 warn!("failure during data parsing for {}: {why}", msg.uri);
349 return;
350 }
351 };
352
353 if let Some(split) = split_uri(&msg.uri) {
354 if self
355 .message_handlers
356 .lock()
357 .expect(DEALER_MESSAGE_HANDLERS_POISON_MSG)
358 .retain(split, &mut |tx| tx.send(msg.clone()).is_ok())
359 {
360 return;
361 }
362 }
363
364 debug!("No subscriber for msg.uri: {}", msg.uri);
365 }
366
367 fn dispatch_request(
368 &self,
369 request: WebsocketRequest,
370 send_tx: &mpsc::UnboundedSender<WsMessage>,
371 ) {
372 trace!("dealer request {}", &request.message_ident);
373
374 let payload_request = match request.handle_payload() {
375 Ok(payload) => payload,
376 Err(why) => {
377 warn!("request payload handling failed because of {why}");
378 return;
379 }
380 };
381
382 let responder = Responder::new(request.key.clone(), send_tx.clone());
384
385 let split = if let Some(split) = split_uri(&request.message_ident) {
386 split
387 } else {
388 warn!(
389 "Dealer request with invalid message_ident: {}",
390 &request.message_ident
391 );
392 return;
393 };
394
395 let handler_map = self
396 .request_handlers
397 .lock()
398 .expect(DEALER_REQUEST_HANDLERS_POISON_MSG);
399
400 if let Some(handler) = handler_map.get(split) {
401 handler.handle_request(payload_request, responder);
402 return;
403 }
404
405 warn!("No handler for message_ident: {}", &request.message_ident);
406 }
407
408 fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
409 match m {
410 MessageOrRequest::Message(m) => self.dispatch_message(m),
411 MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
412 }
413 }
414
415 async fn closed(&self) {
416 if self.notify_drop.acquire().await.is_ok() {
417 error!("should never have gotten a permit");
418 }
419 }
420
421 fn is_closed(&self) -> bool {
422 self.notify_drop.is_closed()
423 }
424}
425
426struct Dealer {
427 shared: Arc<DealerShared>,
428 handle: TimeoutOnDrop<Result<(), Error>>,
429}
430
431impl Dealer {
432 pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), Error>
433 where
434 H: RequestHandler,
435 {
436 add_handler(
437 &mut self
438 .shared
439 .request_handlers
440 .lock()
441 .expect(DEALER_REQUEST_HANDLERS_POISON_MSG),
442 uri,
443 handler,
444 )
445 }
446
447 pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
448 remove_handler(
449 &mut self
450 .shared
451 .request_handlers
452 .lock()
453 .expect(DEALER_REQUEST_HANDLERS_POISON_MSG),
454 uri,
455 )
456 }
457
458 pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, Error> {
459 subscribe(
460 &mut self
461 .shared
462 .message_handlers
463 .lock()
464 .expect(DEALER_MESSAGE_HANDLERS_POISON_MSG),
465 uris,
466 )
467 }
468
469 pub fn handles(&self, uri: &str) -> bool {
470 handles(
471 &self
472 .shared
473 .request_handlers
474 .lock()
475 .expect(DEALER_REQUEST_HANDLERS_POISON_MSG),
476 &self
477 .shared
478 .message_handlers
479 .lock()
480 .expect(DEALER_MESSAGE_HANDLERS_POISON_MSG),
481 uri,
482 )
483 }
484
485 pub async fn close(mut self) {
486 debug!("closing dealer");
487
488 self.shared.notify_drop.close();
489
490 if let Some(handle) = self.handle.take() {
491 if let Err(e) = CancelOnDrop(handle).await {
492 error!("error aborting dealer operations: {e}");
493 }
494 }
495 }
496}
497
498async fn connect(
500 address: &Url,
501 proxy: Option<&Url>,
502 shared: &Arc<DealerShared>,
503) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
504 let host = address
505 .host_str()
506 .ok_or(WsError::Url(UrlError::NoHostName))?;
507
508 let default_port = match address.scheme() {
509 "ws" => 80,
510 "wss" => 443,
511 _ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme).into()),
512 };
513
514 let port = address.port().unwrap_or(default_port);
515
516 let stream = socket::connect(host, port, proxy).await?;
517
518 let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address.as_str(), stream)
519 .await?
520 .0
521 .split();
522
523 let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
524
525 let send_task = {
527 let shared = Arc::clone(shared);
528
529 tokio::spawn(async move {
530 let result = loop {
531 select! {
532 biased;
533 () = shared.closed() => {
534 break Ok(None);
535 }
536 msg = send_rx.recv() => {
537 if let Some(msg) = msg {
538 if let WsMessage::Close(close_frame) = msg {
540 break Ok(close_frame);
541 }
542
543 if let Err(e) = ws_tx.feed(msg).await {
544 break Err(e);
545 }
546 } else {
547 break Ok(None);
548 }
549 },
550 e = keep_flushing(&mut ws_tx) => {
551 break Err(e)
552 }
553 else => (),
554 }
555 };
556
557 send_rx.close();
558
559 let result = match result {
561 Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
562 Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
563 Err(e) => {
564 warn!("Dealer finished with an error: {e}");
565 ws_tx.send(WsMessage::Close(None)).await
566 }
567 };
568
569 if let Err(e) = result {
570 warn!("Error while closing websocket: {e}");
571 }
572
573 debug!("Dropping send task");
574 })
575 };
576
577 let shared = Arc::clone(shared);
578
579 let receive_task = tokio::spawn(async {
581 let pong_received = AtomicBool::new(true);
582 let send_tx = send_tx;
583 let shared = shared;
584
585 let receive_task = async {
586 let mut ws_rx = ws_rx;
587
588 loop {
589 match ws_rx.next().await {
590 Some(Ok(msg)) => match msg {
591 WsMessage::Text(t) => match serde_json::from_str(&t) {
592 Ok(m) => shared.dispatch(m, &send_tx),
593 Err(e) => warn!("Message couldn't be parsed: {e}. Message was {t}"),
594 },
595 WsMessage::Binary(_) => {
596 info!("Received invalid binary message");
597 }
598 WsMessage::Pong(_) => {
599 trace!("Received pong");
600 pong_received.store(true, atomic::Ordering::Relaxed);
601 }
602 _ => (), },
604 Some(Err(e)) => {
605 warn!("Websocket connection failed: {e}");
606 break;
607 }
608 None => {
609 debug!("Websocket connection closed.");
610 break;
611 }
612 }
613 }
614 };
615
616 let ping_task = async {
618 use tokio::time::{interval, sleep};
619
620 let mut timer = interval(PING_INTERVAL);
621
622 loop {
623 timer.tick().await;
624
625 pong_received.store(false, atomic::Ordering::Relaxed);
626 if send_tx
627 .send(WsMessage::Ping(bytes::Bytes::default()))
628 .is_err()
629 {
630 break;
632 }
633
634 trace!("Sent ping");
635
636 sleep(PING_TIMEOUT).await;
637
638 if !pong_received.load(atomic::Ordering::SeqCst) {
639 warn!("Websocket peer does not respond.");
641 break;
642 }
643 }
644 };
645
646 select! {
649 () = ping_task => (),
650 () = receive_task => ()
651 }
652
653 let _ = send_tx.send(WsMessage::Close(None));
655
656 debug!("Dropping receive task");
657 });
658
659 Ok((send_task, receive_task))
660}
661
662async fn run<F, Fut>(
664 shared: Arc<DealerShared>,
665 initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
666 mut get_url: F,
667 proxy: Option<Url>,
668) -> Result<(), Error>
669where
670 Fut: Future<Output = GetUrlResult> + Send + 'static,
671 F: (FnMut() -> Fut) + Send + 'static,
672{
673 let init_task = |t| Some(TimeoutOnDrop::new(t, WEBSOCKET_CLOSE_TIMEOUT));
674
675 let mut tasks = if let Some((s, r)) = initial_tasks {
676 (init_task(s), init_task(r))
677 } else {
678 (None, None)
679 };
680
681 while !shared.is_closed() {
682 match &mut tasks {
683 (Some(t0), Some(t1)) => {
684 select! {
685 () = shared.closed() => break,
686 r = t0 => {
687 if let Err(e) = r {
688 error!("timeout on task 0: {e}");
689 }
690 tasks.0.take();
691 },
692 r = t1 => {
693 if let Err(e) = r {
694 error!("timeout on task 1: {e}");
695 }
696 tasks.1.take();
697 }
698 }
699 }
700 _ => {
701 let url = select! {
702 () = shared.closed() => {
703 break
704 },
705 e = get_url() => e
706 }?;
707
708 match connect(&url, proxy.as_ref(), &shared).await {
709 Ok((s, r)) => tasks = (init_task(s), init_task(r)),
710 Err(e) => {
711 error!("Error while connecting: {e}");
712 tokio::time::sleep(RECONNECT_INTERVAL).await;
713 }
714 }
715 }
716 }
717 }
718
719 let tasks = tasks.0.into_iter().chain(tasks.1);
720
721 let _ = join_all(tasks).await;
722
723 Ok(())
724}