librespot_core/
channel.rs

1use std::{
2    collections::HashMap,
3    fmt,
4    pin::Pin,
5    task::{Context, Poll},
6    time::{Duration, Instant},
7};
8
9use byteorder::{BigEndian, ByteOrder};
10use bytes::Bytes;
11use futures_core::Stream;
12use futures_util::{StreamExt, lock::BiLock, ready};
13use num_traits::FromPrimitive;
14use thiserror::Error;
15use tokio::sync::mpsc;
16
17use crate::{Error, packet::PacketType, util::SeqGenerator};
18
19component! {
20    ChannelManager : ChannelManagerInner {
21        sequence: SeqGenerator<u16> = SeqGenerator::new(0),
22        channels: HashMap<u16, mpsc::UnboundedSender<(u8, Bytes)>> = HashMap::new(),
23        download_rate_estimate: usize = 0,
24        download_measurement_start: Option<Instant> = None,
25        download_measurement_bytes: usize = 0,
26        invalid: bool = false,
27    }
28}
29
30const ONE_SECOND: Duration = Duration::from_secs(1);
31
32#[derive(Debug, Error, Hash, PartialEq, Eq, Copy, Clone)]
33pub struct ChannelError;
34
35impl From<ChannelError> for Error {
36    fn from(err: ChannelError) -> Self {
37        Error::aborted(err)
38    }
39}
40
41impl fmt::Display for ChannelError {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        write!(f, "channel error")
44    }
45}
46
47pub struct Channel {
48    receiver: mpsc::UnboundedReceiver<(u8, Bytes)>,
49    state: ChannelState,
50}
51
52pub struct ChannelHeaders(BiLock<Channel>);
53pub struct ChannelData(BiLock<Channel>);
54
55pub enum ChannelEvent {
56    Header(u8, Vec<u8>),
57    Data(Bytes),
58}
59
60#[derive(Clone)]
61enum ChannelState {
62    Header(Bytes),
63    Data,
64    Closed,
65}
66
67impl ChannelManager {
68    pub fn allocate(&self) -> (u16, Channel) {
69        let (tx, rx) = mpsc::unbounded_channel();
70
71        let seq = self.lock(|inner| {
72            let seq = inner.sequence.get();
73            if !inner.invalid {
74                inner.channels.insert(seq, tx);
75            }
76            seq
77        });
78
79        let channel = Channel {
80            receiver: rx,
81            state: ChannelState::Header(Bytes::new()),
82        };
83
84        (seq, channel)
85    }
86
87    pub(crate) fn dispatch(&self, cmd: PacketType, mut data: Bytes) -> Result<(), Error> {
88        use std::collections::hash_map::Entry;
89
90        let id: u16 = BigEndian::read_u16(data.split_to(2).as_ref());
91
92        self.lock(|inner| {
93            let current_time = Instant::now();
94            if let Some(download_measurement_start) = inner.download_measurement_start {
95                if (current_time - download_measurement_start) > ONE_SECOND {
96                    inner.download_rate_estimate = ONE_SECOND.as_millis() as usize
97                        * inner.download_measurement_bytes
98                        / (current_time - download_measurement_start).as_millis() as usize;
99                    inner.download_measurement_start = Some(current_time);
100                    inner.download_measurement_bytes = 0;
101                }
102            } else {
103                inner.download_measurement_start = Some(current_time);
104            }
105
106            inner.download_measurement_bytes += data.len();
107
108            if let Entry::Occupied(entry) = inner.channels.entry(id) {
109                entry
110                    .get()
111                    .send((cmd as u8, data))
112                    .map_err(|_| ChannelError)?;
113            }
114
115            Ok(())
116        })
117    }
118
119    pub fn get_download_rate_estimate(&self) -> usize {
120        self.lock(|inner| inner.download_rate_estimate)
121    }
122
123    pub(crate) fn shutdown(&self) {
124        self.lock(|inner| {
125            inner.invalid = true;
126            // destroy the sending halves of the channels to signal everyone who is waiting for something.
127            inner.channels.clear();
128        });
129    }
130}
131
132impl Channel {
133    fn recv_packet(&mut self, cx: &mut Context<'_>) -> Poll<Result<Bytes, ChannelError>> {
134        let (cmd, packet) = ready!(self.receiver.poll_recv(cx)).ok_or(ChannelError)?;
135
136        let packet_type = FromPrimitive::from_u8(cmd);
137        if let Some(PacketType::ChannelError) = packet_type {
138            let code = BigEndian::read_u16(&packet.as_ref()[..2]);
139            error!("channel error: {} {}", packet.len(), code);
140
141            self.state = ChannelState::Closed;
142
143            Poll::Ready(Err(ChannelError))
144        } else {
145            Poll::Ready(Ok(packet))
146        }
147    }
148
149    pub fn split(self) -> (ChannelHeaders, ChannelData) {
150        let (headers, data) = BiLock::new(self);
151
152        (ChannelHeaders(headers), ChannelData(data))
153    }
154}
155
156impl Stream for Channel {
157    type Item = Result<ChannelEvent, ChannelError>;
158
159    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
160        loop {
161            match self.state.clone() {
162                ChannelState::Closed => {
163                    error!("Polling already terminated channel");
164                    return Poll::Ready(None);
165                }
166
167                ChannelState::Header(mut data) => {
168                    if data.is_empty() {
169                        data = ready!(self.recv_packet(cx))?;
170                    }
171
172                    let length = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
173                    if length == 0 {
174                        self.state = ChannelState::Data;
175                    } else {
176                        let header_id = data.split_to(1).as_ref()[0];
177                        let header_data = data.split_to(length - 1).as_ref().to_owned();
178
179                        self.state = ChannelState::Header(data);
180
181                        let event = ChannelEvent::Header(header_id, header_data);
182                        return Poll::Ready(Some(Ok(event)));
183                    }
184                }
185
186                ChannelState::Data => {
187                    let data = ready!(self.recv_packet(cx))?;
188                    if data.is_empty() {
189                        self.receiver.close();
190                        self.state = ChannelState::Closed;
191                        return Poll::Ready(None);
192                    } else {
193                        let event = ChannelEvent::Data(data);
194                        return Poll::Ready(Some(Ok(event)));
195                    }
196                }
197            }
198        }
199    }
200}
201
202impl Stream for ChannelData {
203    type Item = Result<Bytes, ChannelError>;
204
205    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
206        let mut channel = ready!(self.0.poll_lock(cx));
207
208        loop {
209            match ready!(channel.poll_next_unpin(cx)?) {
210                Some(ChannelEvent::Header(..)) => (),
211                Some(ChannelEvent::Data(data)) => return Poll::Ready(Some(Ok(data))),
212                None => return Poll::Ready(None),
213            }
214        }
215    }
216}
217
218impl Stream for ChannelHeaders {
219    type Item = Result<(u8, Vec<u8>), ChannelError>;
220
221    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222        let mut channel = ready!(self.0.poll_lock(cx));
223
224        match ready!(channel.poll_next_unpin(cx)?) {
225            Some(ChannelEvent::Header(id, data)) => Poll::Ready(Some(Ok((id, data)))),
226            _ => Poll::Ready(None),
227        }
228    }
229}