librespot_core/
channel.rs1use std::{
2 collections::HashMap,
3 fmt,
4 pin::Pin,
5 task::{Context, Poll},
6 time::{Duration, Instant},
7};
8
9use byteorder::{BigEndian, ByteOrder};
10use bytes::Bytes;
11use futures_core::Stream;
12use futures_util::{StreamExt, lock::BiLock, ready};
13use num_traits::FromPrimitive;
14use thiserror::Error;
15use tokio::sync::mpsc;
16
17use crate::{Error, packet::PacketType, util::SeqGenerator};
18
19component! {
20 ChannelManager : ChannelManagerInner {
21 sequence: SeqGenerator<u16> = SeqGenerator::new(0),
22 channels: HashMap<u16, mpsc::UnboundedSender<(u8, Bytes)>> = HashMap::new(),
23 download_rate_estimate: usize = 0,
24 download_measurement_start: Option<Instant> = None,
25 download_measurement_bytes: usize = 0,
26 invalid: bool = false,
27 }
28}
29
30const ONE_SECOND: Duration = Duration::from_secs(1);
31
32#[derive(Debug, Error, Hash, PartialEq, Eq, Copy, Clone)]
33pub struct ChannelError;
34
35impl From<ChannelError> for Error {
36 fn from(err: ChannelError) -> Self {
37 Error::aborted(err)
38 }
39}
40
41impl fmt::Display for ChannelError {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 write!(f, "channel error")
44 }
45}
46
47pub struct Channel {
48 receiver: mpsc::UnboundedReceiver<(u8, Bytes)>,
49 state: ChannelState,
50}
51
52pub struct ChannelHeaders(BiLock<Channel>);
53pub struct ChannelData(BiLock<Channel>);
54
55pub enum ChannelEvent {
56 Header(u8, Vec<u8>),
57 Data(Bytes),
58}
59
60#[derive(Clone)]
61enum ChannelState {
62 Header(Bytes),
63 Data,
64 Closed,
65}
66
67impl ChannelManager {
68 pub fn allocate(&self) -> (u16, Channel) {
69 let (tx, rx) = mpsc::unbounded_channel();
70
71 let seq = self.lock(|inner| {
72 let seq = inner.sequence.get();
73 if !inner.invalid {
74 inner.channels.insert(seq, tx);
75 }
76 seq
77 });
78
79 let channel = Channel {
80 receiver: rx,
81 state: ChannelState::Header(Bytes::new()),
82 };
83
84 (seq, channel)
85 }
86
87 pub(crate) fn dispatch(&self, cmd: PacketType, mut data: Bytes) -> Result<(), Error> {
88 use std::collections::hash_map::Entry;
89
90 let id: u16 = BigEndian::read_u16(data.split_to(2).as_ref());
91
92 self.lock(|inner| {
93 let current_time = Instant::now();
94 if let Some(download_measurement_start) = inner.download_measurement_start {
95 if (current_time - download_measurement_start) > ONE_SECOND {
96 inner.download_rate_estimate = ONE_SECOND.as_millis() as usize
97 * inner.download_measurement_bytes
98 / (current_time - download_measurement_start).as_millis() as usize;
99 inner.download_measurement_start = Some(current_time);
100 inner.download_measurement_bytes = 0;
101 }
102 } else {
103 inner.download_measurement_start = Some(current_time);
104 }
105
106 inner.download_measurement_bytes += data.len();
107
108 if let Entry::Occupied(entry) = inner.channels.entry(id) {
109 entry
110 .get()
111 .send((cmd as u8, data))
112 .map_err(|_| ChannelError)?;
113 }
114
115 Ok(())
116 })
117 }
118
119 pub fn get_download_rate_estimate(&self) -> usize {
120 self.lock(|inner| inner.download_rate_estimate)
121 }
122
123 pub(crate) fn shutdown(&self) {
124 self.lock(|inner| {
125 inner.invalid = true;
126 inner.channels.clear();
128 });
129 }
130}
131
132impl Channel {
133 fn recv_packet(&mut self, cx: &mut Context<'_>) -> Poll<Result<Bytes, ChannelError>> {
134 let (cmd, packet) = ready!(self.receiver.poll_recv(cx)).ok_or(ChannelError)?;
135
136 let packet_type = FromPrimitive::from_u8(cmd);
137 if let Some(PacketType::ChannelError) = packet_type {
138 let code = BigEndian::read_u16(&packet.as_ref()[..2]);
139 error!("channel error: {} {}", packet.len(), code);
140
141 self.state = ChannelState::Closed;
142
143 Poll::Ready(Err(ChannelError))
144 } else {
145 Poll::Ready(Ok(packet))
146 }
147 }
148
149 pub fn split(self) -> (ChannelHeaders, ChannelData) {
150 let (headers, data) = BiLock::new(self);
151
152 (ChannelHeaders(headers), ChannelData(data))
153 }
154}
155
156impl Stream for Channel {
157 type Item = Result<ChannelEvent, ChannelError>;
158
159 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
160 loop {
161 match self.state.clone() {
162 ChannelState::Closed => {
163 error!("Polling already terminated channel");
164 return Poll::Ready(None);
165 }
166
167 ChannelState::Header(mut data) => {
168 if data.is_empty() {
169 data = ready!(self.recv_packet(cx))?;
170 }
171
172 let length = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
173 if length == 0 {
174 self.state = ChannelState::Data;
175 } else {
176 let header_id = data.split_to(1).as_ref()[0];
177 let header_data = data.split_to(length - 1).as_ref().to_owned();
178
179 self.state = ChannelState::Header(data);
180
181 let event = ChannelEvent::Header(header_id, header_data);
182 return Poll::Ready(Some(Ok(event)));
183 }
184 }
185
186 ChannelState::Data => {
187 let data = ready!(self.recv_packet(cx))?;
188 if data.is_empty() {
189 self.receiver.close();
190 self.state = ChannelState::Closed;
191 return Poll::Ready(None);
192 } else {
193 let event = ChannelEvent::Data(data);
194 return Poll::Ready(Some(Ok(event)));
195 }
196 }
197 }
198 }
199 }
200}
201
202impl Stream for ChannelData {
203 type Item = Result<Bytes, ChannelError>;
204
205 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
206 let mut channel = ready!(self.0.poll_lock(cx));
207
208 loop {
209 match ready!(channel.poll_next_unpin(cx)?) {
210 Some(ChannelEvent::Header(..)) => (),
211 Some(ChannelEvent::Data(data)) => return Poll::Ready(Some(Ok(data))),
212 None => return Poll::Ready(None),
213 }
214 }
215 }
216}
217
218impl Stream for ChannelHeaders {
219 type Item = Result<(u8, Vec<u8>), ChannelError>;
220
221 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222 let mut channel = ready!(self.0.poll_lock(cx));
223
224 match ready!(channel.poll_next_unpin(cx)?) {
225 Some(ChannelEvent::Header(id, data)) => Poll::Ready(Some(Ok((id, data)))),
226 _ => Poll::Ready(None),
227 }
228 }
229}