librespot_core/mercury/
mod.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use byteorder::{BigEndian, ByteOrder};
9use bytes::Bytes;
10use futures_util::FutureExt;
11use protobuf::Message;
12use tokio::sync::{mpsc, oneshot};
13
14use crate::{Error, packet::PacketType, protocol, util::SeqGenerator};
15
16mod types;
17pub use self::types::*;
18
19mod sender;
20pub use self::sender::MercurySender;
21
22component! {
23    MercuryManager : MercuryManagerInner {
24        sequence: SeqGenerator<u64> = SeqGenerator::new(0),
25        pending: HashMap<Vec<u8>, MercuryPending> = HashMap::new(),
26        subscriptions: Vec<(String, mpsc::UnboundedSender<MercuryResponse>)> = Vec::new(),
27        invalid: bool = false,
28    }
29}
30
31pub struct MercuryPending {
32    parts: Vec<Vec<u8>>,
33    partial: Option<Vec<u8>>,
34    callback: Option<oneshot::Sender<Result<MercuryResponse, Error>>>,
35}
36
37pub struct MercuryFuture<T> {
38    receiver: oneshot::Receiver<Result<T, Error>>,
39}
40
41impl<T> Future for MercuryFuture<T> {
42    type Output = Result<T, Error>;
43
44    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45        self.receiver.poll_unpin(cx)?
46    }
47}
48
49impl MercuryManager {
50    fn next_seq(&self) -> Vec<u8> {
51        let mut seq = vec![0u8; 8];
52        BigEndian::write_u64(&mut seq, self.lock(|inner| inner.sequence.get()));
53        seq
54    }
55
56    fn request(&self, req: MercuryRequest) -> Result<MercuryFuture<MercuryResponse>, Error> {
57        let (tx, rx) = oneshot::channel();
58
59        let pending = MercuryPending {
60            parts: Vec::new(),
61            partial: None,
62            callback: Some(tx),
63        };
64
65        let seq = self.next_seq();
66        self.lock(|inner| {
67            if !inner.invalid {
68                inner.pending.insert(seq.clone(), pending);
69            }
70        });
71
72        let cmd = req.method.command();
73        let data = req.encode(&seq)?;
74
75        self.session().send_packet(cmd, data)?;
76        Ok(MercuryFuture { receiver: rx })
77    }
78
79    pub fn get<T: Into<String>>(&self, uri: T) -> Result<MercuryFuture<MercuryResponse>, Error> {
80        self.request(MercuryRequest {
81            method: MercuryMethod::Get,
82            uri: uri.into(),
83            content_type: None,
84            payload: Vec::new(),
85        })
86    }
87
88    pub fn send<T: Into<String>>(
89        &self,
90        uri: T,
91        data: Vec<u8>,
92    ) -> Result<MercuryFuture<MercuryResponse>, Error> {
93        self.request(MercuryRequest {
94            method: MercuryMethod::Send,
95            uri: uri.into(),
96            content_type: None,
97            payload: vec![data],
98        })
99    }
100
101    pub fn sender<T: Into<String>>(&self, uri: T) -> MercurySender {
102        MercurySender::new(self.clone(), uri.into())
103    }
104
105    pub fn subscribe<T: Into<String>>(
106        &self,
107        uri: T,
108    ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<MercuryResponse>, Error>> + 'static
109    {
110        let uri = uri.into();
111        let request = self.request(MercuryRequest {
112            method: MercuryMethod::Sub,
113            uri: uri.clone(),
114            content_type: None,
115            payload: Vec::new(),
116        });
117
118        let manager = self.clone();
119        async move {
120            let response = request?.await?;
121
122            let (tx, rx) = mpsc::unbounded_channel();
123
124            manager.lock(move |inner| {
125                if !inner.invalid {
126                    debug!("subscribed uri={} count={}", uri, response.payload.len());
127                    if !response.payload.is_empty() {
128                        // Old subscription protocol, watch the provided list of URIs
129                        for sub in response.payload {
130                            match protocol::pubsub::Subscription::parse_from_bytes(&sub) {
131                                Ok(mut sub) => {
132                                    let sub_uri = sub.take_uri();
133
134                                    debug!("subscribed sub_uri={sub_uri}");
135
136                                    inner.subscriptions.push((sub_uri, tx.clone()));
137                                }
138                                Err(e) => {
139                                    error!("could not subscribe to {uri}: {e}");
140                                }
141                            }
142                        }
143                    } else {
144                        // New subscription protocol, watch the requested URI
145                        inner.subscriptions.push((uri, tx));
146                    }
147                }
148            });
149
150            Ok(rx)
151        }
152    }
153
154    pub fn listen_for<T: Into<String>>(
155        &self,
156        uri: T,
157    ) -> impl Future<Output = mpsc::UnboundedReceiver<MercuryResponse>> + 'static {
158        let uri = uri.into();
159
160        let manager = self.clone();
161        async move {
162            let (tx, rx) = mpsc::unbounded_channel();
163
164            manager.lock(move |inner| {
165                if !inner.invalid {
166                    debug!("listening to uri={uri}");
167                    inner.subscriptions.push((uri, tx));
168                }
169            });
170
171            rx
172        }
173    }
174
175    pub(crate) fn dispatch(&self, cmd: PacketType, mut data: Bytes) -> Result<(), Error> {
176        let seq_len = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
177        let seq = data.split_to(seq_len).as_ref().to_owned();
178
179        let flags = data.split_to(1).as_ref()[0];
180        let count = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
181
182        let pending = self.lock(|inner| inner.pending.remove(&seq));
183
184        let mut pending = match pending {
185            Some(pending) => pending,
186            None => {
187                if let PacketType::MercuryEvent = cmd {
188                    MercuryPending {
189                        parts: Vec::new(),
190                        partial: None,
191                        callback: None,
192                    }
193                } else {
194                    warn!("Ignore seq {:?} cmd {:x}", seq, cmd as u8);
195                    return Err(MercuryError::Command(cmd).into());
196                }
197            }
198        };
199
200        for i in 0..count {
201            let mut part = Self::parse_part(&mut data);
202            if let Some(mut partial) = pending.partial.take() {
203                partial.extend_from_slice(&part);
204                part = partial;
205            }
206
207            if i == count - 1 && (flags == 2) {
208                pending.partial = Some(part)
209            } else {
210                pending.parts.push(part);
211            }
212        }
213
214        if flags == 0x1 {
215            self.complete_request(cmd, pending)?;
216        } else {
217            self.lock(move |inner| inner.pending.insert(seq, pending));
218        }
219
220        Ok(())
221    }
222
223    fn parse_part(data: &mut Bytes) -> Vec<u8> {
224        let size = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
225        data.split_to(size).as_ref().to_owned()
226    }
227
228    fn complete_request(&self, cmd: PacketType, mut pending: MercuryPending) -> Result<(), Error> {
229        let header_data = pending.parts.remove(0);
230        let header = protocol::mercury::Header::parse_from_bytes(&header_data)?;
231
232        let response = MercuryResponse {
233            uri: header.uri().to_string(),
234            status_code: header.status_code(),
235            payload: pending.parts,
236        };
237
238        let status_code = response.status_code;
239        if status_code >= 500 {
240            error!("error {} for uri {}", status_code, &response.uri);
241            Err(MercuryError::Response(response).into())
242        } else if status_code >= 400 {
243            error!("error {} for uri {}", status_code, &response.uri);
244            if let Some(cb) = pending.callback {
245                cb.send(Err(MercuryError::Response(response.clone()).into()))
246                    .map_err(|_| MercuryError::Channel)?;
247            }
248            Err(MercuryError::Response(response).into())
249        } else if let PacketType::MercuryEvent = cmd {
250            // TODO: This is just a workaround to make utf-8 encoded usernames work.
251            // A better solution would be to use an uri struct and urlencode it directly
252            // before sending while saving the subscription under its unencoded form.
253            let mut uri_split = response.uri.split('/');
254
255            let encoded_uri = std::iter::once(uri_split.next().unwrap_or_default().to_string())
256                .chain(uri_split.map(|component| {
257                    form_urlencoded::byte_serialize(component.as_bytes()).collect::<String>()
258                }))
259                .collect::<Vec<String>>()
260                .join("/");
261
262            let mut found = false;
263
264            self.lock(|inner| {
265                inner.subscriptions.retain(|(prefix, sub)| {
266                    if encoded_uri.starts_with(prefix) {
267                        found = true;
268
269                        // if send fails, remove from list of subs
270                        // TODO: send unsub message
271                        sub.send(response.clone()).is_ok()
272                    } else {
273                        // URI doesn't match
274                        true
275                    }
276                });
277            });
278
279            if found {
280                Ok(())
281            } else if self.session().dealer().handles(&response.uri) {
282                trace!("mercury response <{}> is handled by dealer", response.uri);
283                Ok(())
284            } else {
285                debug!("unknown subscription uri={}", &response.uri);
286                trace!("response pushed over Mercury: {response:?}");
287                Err(MercuryError::Response(response).into())
288            }
289        } else if let Some(cb) = pending.callback {
290            cb.send(Ok(response)).map_err(|_| MercuryError::Channel)?;
291            Ok(())
292        } else {
293            error!("can't handle Mercury response: {response:?}");
294            Err(MercuryError::Response(response).into())
295        }
296    }
297
298    pub(crate) fn shutdown(&self) {
299        self.lock(|inner| {
300            inner.invalid = true;
301            // destroy the sending halves of the channels to signal everyone who is waiting for something.
302            inner.pending.clear();
303            inner.subscriptions.clear();
304        });
305    }
306}