1pub mod manager;
2mod maps;
3pub mod protocol;
4
5use std::{
6 iter,
7 pin::Pin,
8 sync::{
9 Arc,
10 atomic::{self, AtomicBool},
11 },
12 task::Poll,
13 time::Duration,
14};
15
16use futures_core::{Future, Stream};
17use futures_util::{SinkExt, StreamExt, future::join_all};
18use parking_lot::Mutex;
19use thiserror::Error;
20use tokio::{
21 select,
22 sync::{
23 Semaphore,
24 mpsc::{self, UnboundedReceiver},
25 },
26 task::JoinHandle,
27};
28use tokio_tungstenite::tungstenite;
29use tungstenite::error::UrlError;
30use url::Url;
31
32use self::{
33 maps::*,
34 protocol::{Message, MessageOrRequest, Request, WebsocketMessage, WebsocketRequest},
35};
36
37use crate::{
38 Error, socket,
39 util::{CancelOnDrop, TimeoutOnDrop, keep_flushing},
40};
41
42type WsMessage = tungstenite::Message;
43type WsError = tungstenite::Error;
44type WsResult<T> = Result<T, Error>;
45type GetUrlResult = Result<Url, Error>;
46
47impl From<WsError> for Error {
48 fn from(err: WsError) -> Self {
49 Error::failed_precondition(err)
50 }
51}
52
53const WEBSOCKET_CLOSE_TIMEOUT: Duration = Duration::from_secs(3);
54
55const PING_INTERVAL: Duration = Duration::from_secs(30);
56const PING_TIMEOUT: Duration = Duration::from_secs(3);
57
58const RECONNECT_INTERVAL: Duration = Duration::from_secs(10);
59
60struct Response {
61 pub success: bool,
62}
63
64struct Responder {
65 key: String,
66 tx: mpsc::UnboundedSender<WsMessage>,
67 sent: bool,
68}
69
70impl Responder {
71 fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
72 Self {
73 key,
74 tx,
75 sent: false,
76 }
77 }
78
79 fn send_internal(&mut self, response: Response) {
81 let response = serde_json::json!({
82 "type": "reply",
83 "key": &self.key,
84 "payload": {
85 "success": response.success,
86 }
87 })
88 .to_string();
89
90 if let Err(e) = self.tx.send(WsMessage::Text(response.into())) {
91 warn!("Wasn't able to reply to dealer request: {e}");
92 }
93 }
94
95 pub fn send(mut self, response: Response) {
96 self.send_internal(response);
97 self.sent = true;
98 }
99
100 pub fn force_unanswered(mut self) {
101 self.sent = true;
102 }
103}
104
105impl Drop for Responder {
106 fn drop(&mut self) {
107 if !self.sent {
108 self.send_internal(Response { success: false });
109 }
110 }
111}
112
113trait IntoResponse {
114 fn respond(self, responder: Responder);
115}
116
117impl IntoResponse for Response {
118 fn respond(self, responder: Responder) {
119 responder.send(self)
120 }
121}
122
123impl<F> IntoResponse for F
124where
125 F: Future<Output = Response> + Send + 'static,
126{
127 fn respond(self, responder: Responder) {
128 tokio::spawn(async move {
129 responder.send(self.await);
130 });
131 }
132}
133
134impl<F, R> RequestHandler for F
135where
136 F: (Fn(Request) -> R) + Send + 'static,
137 R: IntoResponse,
138{
139 fn handle_request(&self, request: Request, responder: Responder) {
140 self(request).respond(responder);
141 }
142}
143
144trait RequestHandler: Send + 'static {
145 fn handle_request(&self, request: Request, responder: Responder);
146}
147
148type MessageHandler = mpsc::UnboundedSender<Message>;
149
150pub struct Subscription(UnboundedReceiver<Message>);
153
154impl Stream for Subscription {
155 type Item = Message;
156
157 fn poll_next(
158 mut self: Pin<&mut Self>,
159 cx: &mut std::task::Context<'_>,
160 ) -> Poll<Option<Self::Item>> {
161 self.0.poll_recv(cx)
162 }
163}
164
165fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
166 let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
167 ("hm", '/', rest)
168 } else if let Some(rest) = s.strip_prefix("spotify:") {
169 ("spotify", ':', rest)
170 } else if s.contains('/') {
171 ("", '/', s)
172 } else {
173 return None;
174 };
175
176 let rest = rest.trim_end_matches(sep);
177 let split = rest.split(sep);
178
179 Some(iter::once(scheme).chain(split))
180}
181
182#[derive(Debug, Clone, Error)]
183enum AddHandlerError {
184 #[error("There is already a handler for the given uri")]
185 AlreadyHandled,
186 #[error("The specified uri {0} is invalid")]
187 InvalidUri(String),
188}
189
190impl From<AddHandlerError> for Error {
191 fn from(err: AddHandlerError) -> Self {
192 match err {
193 AddHandlerError::AlreadyHandled => Error::aborted(err),
194 AddHandlerError::InvalidUri(_) => Error::invalid_argument(err),
195 }
196 }
197}
198
199#[derive(Debug, Clone, Error)]
200enum SubscriptionError {
201 #[error("The specified uri is invalid")]
202 InvalidUri(String),
203}
204
205impl From<SubscriptionError> for Error {
206 fn from(err: SubscriptionError) -> Self {
207 Error::invalid_argument(err)
208 }
209}
210
211fn add_handler(
212 map: &mut HandlerMap<Box<dyn RequestHandler>>,
213 uri: &str,
214 handler: impl RequestHandler,
215) -> Result<(), Error> {
216 let split = split_uri(uri).ok_or_else(|| AddHandlerError::InvalidUri(uri.to_string()))?;
217 map.insert(split, Box::new(handler))
218}
219
220fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
221 map.remove(split_uri(uri)?)
222}
223
224fn subscribe(
225 map: &mut SubscriberMap<MessageHandler>,
226 uris: &[&str],
227) -> Result<Subscription, Error> {
228 let (tx, rx) = mpsc::unbounded_channel();
229
230 for &uri in uris {
231 let split = split_uri(uri).ok_or_else(|| SubscriptionError::InvalidUri(uri.to_string()))?;
232 map.insert(split, tx.clone());
233 }
234
235 Ok(Subscription(rx))
236}
237
238fn handles(
239 req_map: &HandlerMap<Box<dyn RequestHandler>>,
240 msg_map: &SubscriberMap<MessageHandler>,
241 uri: &str,
242) -> bool {
243 if req_map.contains(uri) {
244 return true;
245 }
246
247 match split_uri(uri) {
248 None => false,
249 Some(mut split) => msg_map.contains(&mut split),
250 }
251}
252
253#[derive(Default)]
254struct Builder {
255 message_handlers: SubscriberMap<MessageHandler>,
256 request_handlers: HandlerMap<Box<dyn RequestHandler>>,
257}
258
259macro_rules! create_dealer {
260 ($builder:expr, $shared:ident -> $body:expr) => {
261 match $builder {
262 builder => {
263 let shared = Arc::new(DealerShared {
264 message_handlers: Mutex::new(builder.message_handlers),
265 request_handlers: Mutex::new(builder.request_handlers),
266 notify_drop: Semaphore::new(0),
267 });
268
269 let handle = {
270 let $shared = Arc::clone(&shared);
271 tokio::spawn($body)
272 };
273
274 Dealer {
275 shared,
276 handle: TimeoutOnDrop::new(handle, WEBSOCKET_CLOSE_TIMEOUT),
277 }
278 }
279 }
280 };
281}
282
283impl Builder {
284 pub fn new() -> Self {
285 Self::default()
286 }
287
288 pub fn add_handler(&mut self, uri: &str, handler: impl RequestHandler) -> Result<(), Error> {
289 add_handler(&mut self.request_handlers, uri, handler)
290 }
291
292 pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, Error> {
293 subscribe(&mut self.message_handlers, uris)
294 }
295
296 pub fn handles(&self, uri: &str) -> bool {
297 handles(&self.request_handlers, &self.message_handlers, uri)
298 }
299
300 pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
301 where
302 Fut: Future<Output = GetUrlResult> + Send + 'static,
303 F: (Fn() -> Fut) + Send + 'static,
304 {
305 create_dealer!(self, shared -> run(shared, None, get_url, proxy))
306 }
307
308 pub async fn launch<Fut, F>(self, get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
309 where
310 Fut: Future<Output = GetUrlResult> + Send + 'static,
311 F: (Fn() -> Fut) + Send + 'static,
312 {
313 let dealer = create_dealer!(self, shared -> {
314 let url = get_url().await?;
316 let tasks = connect(&url, proxy.as_ref(), &shared).await?;
317
318 run(shared, Some(tasks), get_url, proxy)
320 });
321
322 Ok(dealer)
323 }
324}
325
326struct DealerShared {
327 message_handlers: Mutex<SubscriberMap<MessageHandler>>,
328 request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
329
330 notify_drop: Semaphore,
333}
334
335impl DealerShared {
336 fn dispatch_message(&self, mut msg: WebsocketMessage) {
337 let msg = match msg.handle_payload() {
338 Ok(value) => Message {
339 headers: msg.headers,
340 payload: value,
341 uri: msg.uri,
342 },
343 Err(why) => {
344 warn!("failure during data parsing for {}: {why}", msg.uri);
345 return;
346 }
347 };
348
349 if let Some(split) = split_uri(&msg.uri) {
350 if self
351 .message_handlers
352 .lock()
353 .retain(split, &mut |tx| tx.send(msg.clone()).is_ok())
354 {
355 return;
356 }
357 }
358
359 debug!("No subscriber for msg.uri: {}", msg.uri);
360 }
361
362 fn dispatch_request(
363 &self,
364 request: WebsocketRequest,
365 send_tx: &mpsc::UnboundedSender<WsMessage>,
366 ) {
367 trace!("dealer request {}", &request.message_ident);
368
369 let payload_request = match request.handle_payload() {
370 Ok(payload) => payload,
371 Err(why) => {
372 warn!("request payload handling failed because of {why}");
373 return;
374 }
375 };
376
377 let responder = Responder::new(request.key.clone(), send_tx.clone());
379
380 let split = if let Some(split) = split_uri(&request.message_ident) {
381 split
382 } else {
383 warn!(
384 "Dealer request with invalid message_ident: {}",
385 &request.message_ident
386 );
387 return;
388 };
389
390 let handler_map = self.request_handlers.lock();
391
392 if let Some(handler) = handler_map.get(split) {
393 handler.handle_request(payload_request, responder);
394 return;
395 }
396
397 warn!("No handler for message_ident: {}", &request.message_ident);
398 }
399
400 fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
401 match m {
402 MessageOrRequest::Message(m) => self.dispatch_message(m),
403 MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
404 }
405 }
406
407 async fn closed(&self) {
408 if self.notify_drop.acquire().await.is_ok() {
409 error!("should never have gotten a permit");
410 }
411 }
412
413 fn is_closed(&self) -> bool {
414 self.notify_drop.is_closed()
415 }
416}
417
418struct Dealer {
419 shared: Arc<DealerShared>,
420 handle: TimeoutOnDrop<Result<(), Error>>,
421}
422
423impl Dealer {
424 pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), Error>
425 where
426 H: RequestHandler,
427 {
428 add_handler(&mut self.shared.request_handlers.lock(), uri, handler)
429 }
430
431 pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
432 remove_handler(&mut self.shared.request_handlers.lock(), uri)
433 }
434
435 pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, Error> {
436 subscribe(&mut self.shared.message_handlers.lock(), uris)
437 }
438
439 pub fn handles(&self, uri: &str) -> bool {
440 handles(
441 &self.shared.request_handlers.lock(),
442 &self.shared.message_handlers.lock(),
443 uri,
444 )
445 }
446
447 pub async fn close(mut self) {
448 debug!("closing dealer");
449
450 self.shared.notify_drop.close();
451
452 if let Some(handle) = self.handle.take() {
453 if let Err(e) = CancelOnDrop(handle).await {
454 error!("error aborting dealer operations: {e}");
455 }
456 }
457 }
458}
459
460async fn connect(
462 address: &Url,
463 proxy: Option<&Url>,
464 shared: &Arc<DealerShared>,
465) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
466 let host = address
467 .host_str()
468 .ok_or(WsError::Url(UrlError::NoHostName))?;
469
470 let default_port = match address.scheme() {
471 "ws" => 80,
472 "wss" => 443,
473 _ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme).into()),
474 };
475
476 let port = address.port().unwrap_or(default_port);
477
478 let stream = socket::connect(host, port, proxy).await?;
479
480 let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address.as_str(), stream)
481 .await?
482 .0
483 .split();
484
485 let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
486
487 let send_task = {
489 let shared = Arc::clone(shared);
490
491 tokio::spawn(async move {
492 let result = loop {
493 select! {
494 biased;
495 () = shared.closed() => {
496 break Ok(None);
497 }
498 msg = send_rx.recv() => {
499 if let Some(msg) = msg {
500 if let WsMessage::Close(close_frame) = msg {
502 break Ok(close_frame);
503 }
504
505 if let Err(e) = ws_tx.feed(msg).await {
506 break Err(e);
507 }
508 } else {
509 break Ok(None);
510 }
511 },
512 e = keep_flushing(&mut ws_tx) => {
513 break Err(e)
514 }
515 else => (),
516 }
517 };
518
519 send_rx.close();
520
521 let result = match result {
523 Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
524 Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
525 Err(e) => {
526 warn!("Dealer finished with an error: {e}");
527 ws_tx.send(WsMessage::Close(None)).await
528 }
529 };
530
531 if let Err(e) = result {
532 warn!("Error while closing websocket: {e}");
533 }
534
535 debug!("Dropping send task");
536 })
537 };
538
539 let shared = Arc::clone(shared);
540
541 let receive_task = tokio::spawn(async {
543 let pong_received = AtomicBool::new(true);
544 let send_tx = send_tx;
545 let shared = shared;
546
547 let receive_task = async {
548 let mut ws_rx = ws_rx;
549
550 loop {
551 match ws_rx.next().await {
552 Some(Ok(msg)) => match msg {
553 WsMessage::Text(t) => match serde_json::from_str(&t) {
554 Ok(m) => shared.dispatch(m, &send_tx),
555 Err(e) => warn!("Message couldn't be parsed: {e}. Message was {t}"),
556 },
557 WsMessage::Binary(_) => {
558 info!("Received invalid binary message");
559 }
560 WsMessage::Pong(_) => {
561 trace!("Received pong");
562 pong_received.store(true, atomic::Ordering::Relaxed);
563 }
564 _ => (), },
566 Some(Err(e)) => {
567 warn!("Websocket connection failed: {e}");
568 break;
569 }
570 None => {
571 debug!("Websocket connection closed.");
572 break;
573 }
574 }
575 }
576 };
577
578 let ping_task = async {
580 use tokio::time::{interval, sleep};
581
582 let mut timer = interval(PING_INTERVAL);
583
584 loop {
585 timer.tick().await;
586
587 pong_received.store(false, atomic::Ordering::Relaxed);
588 if send_tx
589 .send(WsMessage::Ping(bytes::Bytes::default()))
590 .is_err()
591 {
592 break;
594 }
595
596 trace!("Sent ping");
597
598 sleep(PING_TIMEOUT).await;
599
600 if !pong_received.load(atomic::Ordering::SeqCst) {
601 warn!("Websocket peer does not respond.");
603 break;
604 }
605 }
606 };
607
608 select! {
611 () = ping_task => (),
612 () = receive_task => ()
613 }
614
615 let _ = send_tx.send(WsMessage::Close(None));
617
618 debug!("Dropping receive task");
619 });
620
621 Ok((send_task, receive_task))
622}
623
624async fn run<F, Fut>(
626 shared: Arc<DealerShared>,
627 initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
628 mut get_url: F,
629 proxy: Option<Url>,
630) -> Result<(), Error>
631where
632 Fut: Future<Output = GetUrlResult> + Send + 'static,
633 F: (FnMut() -> Fut) + Send + 'static,
634{
635 let init_task = |t| Some(TimeoutOnDrop::new(t, WEBSOCKET_CLOSE_TIMEOUT));
636
637 let mut tasks = if let Some((s, r)) = initial_tasks {
638 (init_task(s), init_task(r))
639 } else {
640 (None, None)
641 };
642
643 while !shared.is_closed() {
644 match &mut tasks {
645 (Some(t0), Some(t1)) => {
646 select! {
647 () = shared.closed() => break,
648 r = t0 => {
649 if let Err(e) = r {
650 error!("timeout on task 0: {e}");
651 }
652 tasks.0.take();
653 },
654 r = t1 => {
655 if let Err(e) = r {
656 error!("timeout on task 1: {e}");
657 }
658 tasks.1.take();
659 }
660 }
661 }
662 _ => {
663 let url = select! {
664 () = shared.closed() => {
665 break
666 },
667 e = get_url() => e
668 }?;
669
670 match connect(&url, proxy.as_ref(), &shared).await {
671 Ok((s, r)) => tasks = (init_task(s), init_task(r)),
672 Err(e) => {
673 error!("Error while connecting: {e}");
674 tokio::time::sleep(RECONNECT_INTERVAL).await;
675 }
676 }
677 }
678 }
679 }
680
681 let tasks = tasks.0.into_iter().chain(tasks.1);
682
683 let _ = join_all(tasks).await;
684
685 Ok(())
686}