librespot_core/dealer/
protocol.rs

1pub mod request;
2
3pub use request::*;
4
5use std::collections::HashMap;
6use std::io::{Error as IoError, Read};
7
8use crate::{Error, deserialize_with::json_proto};
9use base64::{DecodeError, Engine, prelude::BASE64_STANDARD};
10use flate2::read::GzDecoder;
11use log::LevelFilter;
12use serde::Deserialize;
13use serde_json::Error as SerdeError;
14use thiserror::Error;
15
16const IGNORE_UNKNOWN: protobuf_json_mapping::ParseOptions = protobuf_json_mapping::ParseOptions {
17    ignore_unknown_fields: true,
18    _future_options: (),
19};
20
21type JsonValue = serde_json::Value;
22
23#[derive(Debug, Error)]
24enum ProtocolError {
25    #[error("base64 decoding failed: {0}")]
26    Base64(DecodeError),
27    #[error("gzip decoding failed: {0}")]
28    GZip(IoError),
29    #[error("deserialization failed: {0}")]
30    Deserialization(SerdeError),
31    #[error("payload had more then one value. had {0} values")]
32    MoreThenOneValue(usize),
33    #[error("received unexpected data {0:#?}")]
34    UnexpectedData(PayloadValue),
35    #[error("payload was empty")]
36    Empty,
37}
38
39impl From<ProtocolError> for Error {
40    fn from(err: ProtocolError) -> Self {
41        match err {
42            ProtocolError::UnexpectedData(_) => Error::unavailable(err),
43            _ => Error::failed_precondition(err),
44        }
45    }
46}
47
48#[derive(Clone, Debug, Deserialize)]
49pub(super) struct Payload {
50    pub compressed: String,
51}
52
53#[derive(Clone, Debug, Deserialize)]
54pub(super) struct WebsocketRequest {
55    #[serde(default)]
56    pub headers: HashMap<String, String>,
57    pub message_ident: String,
58    pub key: String,
59    pub payload: Payload,
60}
61
62#[derive(Clone, Debug, Deserialize)]
63pub(super) struct WebsocketMessage {
64    #[serde(default)]
65    pub headers: HashMap<String, String>,
66    pub method: Option<String>,
67    #[serde(default)]
68    pub payloads: Vec<MessagePayloadValue>,
69    pub uri: String,
70}
71
72#[derive(Clone, Debug, Deserialize)]
73#[serde(untagged)]
74pub enum MessagePayloadValue {
75    String(String),
76    Bytes(Vec<u8>),
77    Json(JsonValue),
78}
79
80#[derive(Clone, Debug, Deserialize)]
81#[serde(tag = "type", rename_all = "snake_case")]
82pub(super) enum MessageOrRequest {
83    Message(WebsocketMessage),
84    Request(WebsocketRequest),
85}
86
87#[derive(Clone, Debug)]
88pub enum PayloadValue {
89    Empty,
90    Raw(Vec<u8>),
91    Json(String),
92}
93
94#[derive(Clone, Debug)]
95pub struct Message {
96    pub headers: HashMap<String, String>,
97    pub payload: PayloadValue,
98    pub uri: String,
99}
100
101#[derive(Deserialize)]
102#[serde(untagged)]
103pub enum FallbackWrapper<T: protobuf::MessageFull> {
104    Inner(#[serde(deserialize_with = "json_proto")] T),
105    Fallback(JsonValue),
106}
107
108impl Message {
109    pub fn try_from_json<M: protobuf::MessageFull>(
110        value: Self,
111    ) -> Result<FallbackWrapper<M>, Error> {
112        match value.payload {
113            PayloadValue::Json(json) => Ok(serde_json::from_str(&json)?),
114            other => Err(ProtocolError::UnexpectedData(other).into()),
115        }
116    }
117
118    pub fn from_raw<M: protobuf::Message>(value: Self) -> Result<M, Error> {
119        match value.payload {
120            PayloadValue::Raw(bytes) => {
121                M::parse_from_bytes(&bytes).map_err(Error::failed_precondition)
122            }
123            other => Err(ProtocolError::UnexpectedData(other).into()),
124        }
125    }
126}
127
128impl WebsocketMessage {
129    pub fn handle_payload(&mut self) -> Result<PayloadValue, Error> {
130        if self.payloads.is_empty() {
131            return Ok(PayloadValue::Empty);
132        } else if self.payloads.len() > 1 {
133            return Err(ProtocolError::MoreThenOneValue(self.payloads.len()).into());
134        }
135
136        let payload = self.payloads.pop().ok_or(ProtocolError::Empty)?;
137        let bytes = match payload {
138            MessagePayloadValue::String(string) => BASE64_STANDARD
139                .decode(string)
140                .map_err(ProtocolError::Base64)?,
141            MessagePayloadValue::Bytes(bytes) => bytes,
142            MessagePayloadValue::Json(json) => return Ok(PayloadValue::Json(json.to_string())),
143        };
144
145        handle_transfer_encoding(&self.headers, bytes).map(PayloadValue::Raw)
146    }
147}
148
149impl WebsocketRequest {
150    pub fn handle_payload(&self) -> Result<Request, Error> {
151        let payload_bytes = BASE64_STANDARD
152            .decode(&self.payload.compressed)
153            .map_err(ProtocolError::Base64)?;
154
155        let payload = handle_transfer_encoding(&self.headers, payload_bytes)?;
156        let payload = String::from_utf8(payload)?;
157
158        if log::max_level() >= LevelFilter::Trace {
159            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&payload) {
160                trace!("websocket request: {json:#?}");
161            } else {
162                trace!("websocket request: {payload}");
163            }
164        }
165
166        serde_json::from_str(&payload)
167            .map_err(ProtocolError::Deserialization)
168            .map_err(Into::into)
169    }
170}
171
172fn handle_transfer_encoding(
173    headers: &HashMap<String, String>,
174    data: Vec<u8>,
175) -> Result<Vec<u8>, Error> {
176    let encoding = headers.get("Transfer-Encoding").map(String::as_str);
177    if let Some(encoding) = encoding {
178        trace!("message was sent with {encoding} encoding ");
179    } else {
180        trace!("message was sent with no encoding ");
181    }
182
183    if !matches!(encoding, Some("gzip")) {
184        return Ok(data);
185    }
186
187    let mut gz = GzDecoder::new(&data[..]);
188    let mut bytes = vec![];
189    match gz.read_to_end(&mut bytes) {
190        Ok(i) if i == bytes.len() => Ok(bytes),
191        Ok(_) => Err(Error::failed_precondition(
192            "read bytes mismatched with expected bytes",
193        )),
194        Err(why) => Err(ProtocolError::GZip(why).into()),
195    }
196}