librespot_core/dealer/
protocol.rs1pub mod request;
2
3pub use request::*;
4
5use std::collections::HashMap;
6use std::io::{Error as IoError, Read};
7
8use crate::{Error, deserialize_with::json_proto};
9use base64::{DecodeError, Engine, prelude::BASE64_STANDARD};
10use flate2::read::GzDecoder;
11use log::LevelFilter;
12use serde::Deserialize;
13use serde_json::Error as SerdeError;
14use thiserror::Error;
15
16const IGNORE_UNKNOWN: protobuf_json_mapping::ParseOptions = protobuf_json_mapping::ParseOptions {
17 ignore_unknown_fields: true,
18 _future_options: (),
19};
20
21type JsonValue = serde_json::Value;
22
23#[derive(Debug, Error)]
24enum ProtocolError {
25 #[error("base64 decoding failed: {0}")]
26 Base64(DecodeError),
27 #[error("gzip decoding failed: {0}")]
28 GZip(IoError),
29 #[error("deserialization failed: {0}")]
30 Deserialization(SerdeError),
31 #[error("payload had more then one value. had {0} values")]
32 MoreThenOneValue(usize),
33 #[error("received unexpected data {0:#?}")]
34 UnexpectedData(PayloadValue),
35 #[error("payload was empty")]
36 Empty,
37}
38
39impl From<ProtocolError> for Error {
40 fn from(err: ProtocolError) -> Self {
41 match err {
42 ProtocolError::UnexpectedData(_) => Error::unavailable(err),
43 _ => Error::failed_precondition(err),
44 }
45 }
46}
47
48#[derive(Clone, Debug, Deserialize)]
49pub(super) struct Payload {
50 pub compressed: String,
51}
52
53#[derive(Clone, Debug, Deserialize)]
54pub(super) struct WebsocketRequest {
55 #[serde(default)]
56 pub headers: HashMap<String, String>,
57 pub message_ident: String,
58 pub key: String,
59 pub payload: Payload,
60}
61
62#[derive(Clone, Debug, Deserialize)]
63pub(super) struct WebsocketMessage {
64 #[serde(default)]
65 pub headers: HashMap<String, String>,
66 pub method: Option<String>,
67 #[serde(default)]
68 pub payloads: Vec<MessagePayloadValue>,
69 pub uri: String,
70}
71
72#[derive(Clone, Debug, Deserialize)]
73#[serde(untagged)]
74pub enum MessagePayloadValue {
75 String(String),
76 Bytes(Vec<u8>),
77 Json(JsonValue),
78}
79
80#[derive(Clone, Debug, Deserialize)]
81#[serde(tag = "type", rename_all = "snake_case")]
82pub(super) enum MessageOrRequest {
83 Message(WebsocketMessage),
84 Request(WebsocketRequest),
85}
86
87#[derive(Clone, Debug)]
88pub enum PayloadValue {
89 Empty,
90 Raw(Vec<u8>),
91 Json(String),
92}
93
94#[derive(Clone, Debug)]
95pub struct Message {
96 pub headers: HashMap<String, String>,
97 pub payload: PayloadValue,
98 pub uri: String,
99}
100
101#[derive(Deserialize)]
102#[serde(untagged)]
103pub enum FallbackWrapper<T: protobuf::MessageFull> {
104 Inner(#[serde(deserialize_with = "json_proto")] T),
105 Fallback(JsonValue),
106}
107
108impl Message {
109 pub fn try_from_json<M: protobuf::MessageFull>(
110 value: Self,
111 ) -> Result<FallbackWrapper<M>, Error> {
112 match value.payload {
113 PayloadValue::Json(json) => Ok(serde_json::from_str(&json)?),
114 other => Err(ProtocolError::UnexpectedData(other).into()),
115 }
116 }
117
118 pub fn from_raw<M: protobuf::Message>(value: Self) -> Result<M, Error> {
119 match value.payload {
120 PayloadValue::Raw(bytes) => {
121 M::parse_from_bytes(&bytes).map_err(Error::failed_precondition)
122 }
123 other => Err(ProtocolError::UnexpectedData(other).into()),
124 }
125 }
126}
127
128impl WebsocketMessage {
129 pub fn handle_payload(&mut self) -> Result<PayloadValue, Error> {
130 if self.payloads.is_empty() {
131 return Ok(PayloadValue::Empty);
132 } else if self.payloads.len() > 1 {
133 return Err(ProtocolError::MoreThenOneValue(self.payloads.len()).into());
134 }
135
136 let payload = self.payloads.pop().ok_or(ProtocolError::Empty)?;
137 let bytes = match payload {
138 MessagePayloadValue::String(string) => BASE64_STANDARD
139 .decode(string)
140 .map_err(ProtocolError::Base64)?,
141 MessagePayloadValue::Bytes(bytes) => bytes,
142 MessagePayloadValue::Json(json) => return Ok(PayloadValue::Json(json.to_string())),
143 };
144
145 handle_transfer_encoding(&self.headers, bytes).map(PayloadValue::Raw)
146 }
147}
148
149impl WebsocketRequest {
150 pub fn handle_payload(&self) -> Result<Request, Error> {
151 let payload_bytes = BASE64_STANDARD
152 .decode(&self.payload.compressed)
153 .map_err(ProtocolError::Base64)?;
154
155 let payload = handle_transfer_encoding(&self.headers, payload_bytes)?;
156 let payload = String::from_utf8(payload)?;
157
158 if log::max_level() >= LevelFilter::Trace {
159 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&payload) {
160 trace!("websocket request: {json:#?}");
161 } else {
162 trace!("websocket request: {payload}");
163 }
164 }
165
166 serde_json::from_str(&payload)
167 .map_err(ProtocolError::Deserialization)
168 .map_err(Into::into)
169 }
170}
171
172fn handle_transfer_encoding(
173 headers: &HashMap<String, String>,
174 data: Vec<u8>,
175) -> Result<Vec<u8>, Error> {
176 let encoding = headers.get("Transfer-Encoding").map(String::as_str);
177 if let Some(encoding) = encoding {
178 trace!("message was sent with {encoding} encoding ");
179 } else {
180 trace!("message was sent with no encoding ");
181 }
182
183 if !matches!(encoding, Some("gzip")) {
184 return Ok(data);
185 }
186
187 let mut gz = GzDecoder::new(&data[..]);
188 let mut bytes = vec![];
189 match gz.read_to_end(&mut bytes) {
190 Ok(i) if i == bytes.len() => Ok(bytes),
191 Ok(_) => Err(Error::failed_precondition(
192 "read bytes mismatched with expected bytes",
193 )),
194 Err(why) => Err(ProtocolError::GZip(why).into()),
195 }
196}