1use std::{collections::HashMap, mem::size_of, sync::Arc};
2
3use bytes::{Buf, BufMut, BytesMut};
4use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use tungstenite::Message;
7
8use crate::{common::*, DigitalisError, DigitalisResult};
9
10macro_rules! impl_enum_from {
11 ($parent:ident, $child:ident, $child_ty:ident) => {
12 impl From<$child_ty> for $parent {
13 fn from(msg: $child_ty) -> Self {
14 $parent::$child(msg)
15 }
16 }
17
18 impl $child {
19 pub fn into_message(self) -> DigitalisResult<Message> {
20 $parent::from(self).to_message()
21 }
22 }
23 };
24 ($parent:ident, $child:ident) => {
25 impl_enum_from!($parent, $child, $child);
26 };
27}
28
29macro_rules! impl_into_text_message {
30 ($parent:ident, $child:ident) => {};
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(tag = "op", rename_all = "camelCase")]
35pub enum ServerJsonMessage {
36 ServerInfo(ServerInfo),
37 Status(Status),
38 Advertise(Advertise),
39 Unadvertise(Unadvertise),
40 ParameterValues(ParameterValues),
41 AdvertiseServices(AdvertiseServices),
42 UnadvertiseServices(UnadvertiseServices),
43 ConnectionGraphUpdate(ConnectionGraphUpdate),
44}
45
46impl ServerJsonMessage {
47 pub fn to_message(&self) -> DigitalisResult<Message> {
48 Ok(Message::Text(self.serialize()?))
49 }
50
51 pub fn serialize(&self) -> DigitalisResult<String> {
52 Ok(serde_json::to_string(self)?)
53 }
54
55 pub fn deserialize(text: &str) -> DigitalisResult<Self> {
56 Ok(serde_json::from_str(text)?)
57 }
58}
59
60impl_enum_from!(ServerJsonMessage, ServerInfo);
61impl_enum_from!(ServerJsonMessage, Status);
62impl_enum_from!(ServerJsonMessage, Advertise);
63impl_enum_from!(ServerJsonMessage, Unadvertise);
64impl_enum_from!(ServerJsonMessage, ParameterValues);
65impl_enum_from!(ServerJsonMessage, AdvertiseServices);
66impl_enum_from!(ServerJsonMessage, UnadvertiseServices);
67
68impl_into_text_message!(ServerJsonMessage, ServerInfo);
69impl_into_text_message!(ServerJsonMessage, Status);
70impl_into_text_message!(ServerJsonMessage, Advertise);
71impl_into_text_message!(ServerJsonMessage, Unadvertise);
72impl_into_text_message!(ServerJsonMessage, ParameterValues);
73impl_into_text_message!(ServerJsonMessage, AdvertiseServices);
74impl_into_text_message!(ServerJsonMessage, UnadvertiseServices);
75
76#[serde_as]
77#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
78#[serde(rename_all = "camelCase")]
79pub struct ServerInfo {
80 pub name: String,
81 pub capabilities: Vec<Capability>,
82 pub supported_encodings: Vec<MessageEncoding>,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 pub metadata: Option<HashMap<String, serde_json::Value>>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub session_id: Option<String>,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub struct Status {
91 pub level: Level,
92 pub message: String,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
96pub enum Level {
97 Info = 0,
98 Warning = 1,
99 Error = 2,
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct Advertise {
105 pub channels: Vec<AdvertiseChannel>,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct Unadvertise {
111 pub channel_ids: Vec<ChannelId>,
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub struct ParameterValues {
117 pub parameters: Vec<Parameter>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 pub id: Option<String>,
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
123#[serde(rename_all = "camelCase")]
124pub struct AdvertiseServices {
125 pub services: Vec<Service>,
126}
127
128#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
129#[serde(rename_all = "camelCase")]
130pub struct Service {
131 pub id: ChannelId,
132 pub name: String,
133 pub r#type: String,
134 pub request_schema: String,
135 pub response_schema: String,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub struct UnadvertiseServices {
141 pub ids: Vec<ChannelId>,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145#[serde(rename_all = "camelCase")]
146pub struct ConnectionGraphUpdate {
147 pub publish_topics: Vec<PublishedTopic>,
148 pub suscribed_topics: Vec<SubscribedTopic>,
149 pub advertised_services: Vec<AdvertisedService>,
150 pub removed_topics: Vec<String>,
151 pub removed_services: Vec<String>,
152}
153
154#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
155#[serde(rename_all = "camelCase")]
156pub struct PublishedTopic {
157 pub name: String,
158 pub publisher_ids: Vec<String>,
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct SubscribedTopic {
164 pub name: String,
165 pub subscriber_ids: Vec<String>,
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct AdvertisedService {
171 pub name: String,
172 pub provider_ids: Vec<String>,
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum ServerBinaryMessage {
177 MessageData(MessageData),
178 Time(Time),
179 ServiceCallResponse(ServiceCallResponse),
180}
181
182impl ServerBinaryMessage {
183 pub fn to_message(self) -> DigitalisResult<Message> {
184 let mut buf = BytesMut::new();
185 self.serialize(&mut buf);
186 Ok(Message::Binary(buf.into()))
187 }
188
189 pub fn serialize<T: BufMut>(&self, buf: &mut T) {
190 match self {
191 Self::MessageData(msg) => {
192 buf.put_u8(0x01);
193 msg.serialize(buf);
194 }
195 Self::Time(msg) => {
196 buf.put_u8(0x02);
197 msg.serialize(buf);
198 }
199 Self::ServiceCallResponse(msg) => {
200 buf.put_u8(0x03);
201 msg.serialize(buf);
202 }
203 }
204 }
205
206 pub fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
207 Ok(match buf.get_u8() {
208 0x01 => Self::from(MessageData::deserialize(buf)?),
209 0x02 => Self::from(Time::deserialize(buf)?),
210 0x03 => Self::from(ServiceCallResponse::deserialize(buf)?),
211 x => {
212 return Err(DigitalisError::BinaryDeserializeError(
213 format!("Unknown protocol {}", x).into(),
214 ))
215 }
216 })
217 }
218}
219
220impl_enum_from!(ServerBinaryMessage, MessageData);
221impl_enum_from!(ServerBinaryMessage, Time);
222impl_enum_from!(ServerBinaryMessage, ServiceCallResponse);
223
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct MessageData {
226 pub subscription_id: SubscriptionId,
227 pub receive_timestamp: u64,
228 pub payload: Arc<Vec<u8>>,
229}
230
231impl MessageData {
232 fn serialize<T: BufMut>(&self, buf: &mut T) {
233 buf.put_u32_le(self.subscription_id);
234 buf.put_u64_le(self.receive_timestamp);
235 buf.put_slice(&self.payload);
236 }
237
238 fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
239 if buf.remaining() < size_of::<u32>() + size_of::<u64>() {
240 return Err(DigitalisError::BinaryDeserializeError(
241 "Data is too short".into(),
242 ));
243 }
244
245 let subscription_id = buf.get_u32_le();
246 let receive_timestamp = buf.get_u64_le();
247 let payload = buf.chunk().to_vec();
248 buf.advance(payload.len());
249
250 Ok(Self {
251 subscription_id,
252 receive_timestamp,
253 payload: Arc::new(payload),
254 })
255 }
256}
257
258#[derive(Debug, Clone, PartialEq, Eq)]
259pub struct Time {
260 pub timestamp: u64,
261}
262
263impl Time {
264 fn serialize<T: BufMut>(&self, buf: &mut T) {
265 buf.put_u64_le(self.timestamp);
266 }
267
268 fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
269 if buf.remaining() != size_of::<u64>() {
270 return Err(DigitalisError::BinaryDeserializeError(
271 "Data is too short".into(),
272 ));
273 }
274
275 Ok(Self {
276 timestamp: buf.get_u64_le(),
277 })
278 }
279}
280
281#[derive(Debug, Clone, PartialEq, Eq)]
282pub struct ServiceCallResponse {
283 pub service_id: ChannelId,
284 pub call_id: u32,
285 pub encoding: Vec<u8>,
286 pub payload: Vec<u8>,
287}
288
289impl ServiceCallResponse {
290 fn serialize<T: BufMut>(&self, buf: &mut T) {
291 buf.put_u32_le(self.service_id);
292 buf.put_u32_le(self.call_id);
293 buf.put_u32_le(self.encoding.len() as u32);
294 buf.put_slice(&self.encoding);
295 buf.put_slice(&self.payload);
296 }
297
298 fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
299 if buf.remaining() < size_of::<u32>() * 3 {
300 return Err(DigitalisError::BinaryDeserializeError(
301 "Data is too short".into(),
302 ));
303 }
304
305 let service_id = buf.get_u32_le();
306 let call_id = buf.get_u32_le();
307
308 let encoding_len = buf.get_u32_le() as usize;
309 if buf.remaining() < encoding_len {
310 return Err(DigitalisError::BinaryDeserializeError(
311 "Data is too short".into(),
312 ));
313 }
314 let encoding = buf.chunk()[..encoding_len].to_vec();
315 buf.advance(encoding.len());
316
317 let payload = buf.chunk().to_vec();
318 buf.advance(payload.len());
319
320 Ok(Self {
321 service_id,
322 call_id,
323 encoding,
324 payload,
325 })
326 }
327}
328
329#[cfg(test)]
330mod test {
331 use super::*;
332
333 #[test]
334 fn test_serialize_and_deserialize_message_data() {
335 let msg = MessageData {
336 subscription_id: 25,
337 receive_timestamp: 23893748,
338 payload: Arc::new(vec![1, 23, 125]),
339 };
340
341 let mut buf = Vec::new();
342 msg.serialize(&mut buf);
343 let msg2 = MessageData::deserialize(&mut buf.as_slice()).unwrap();
344
345 assert_eq!(msg, msg2);
346 }
347
348 #[test]
349 fn test_serialize_and_deserialize_time() {
350 let msg = Time {
351 timestamp: 23893748,
352 };
353
354 let mut buf = Vec::new();
355 msg.serialize(&mut buf);
356 let msg2 = Time::deserialize(&mut buf.as_slice()).unwrap();
357
358 assert_eq!(msg, msg2);
359 }
360
361 #[test]
362 fn test_serialize_and_deserialize_service_call_response() {
363 let msg = ServiceCallResponse {
364 service_id: 25,
365 call_id: 23893748,
366 encoding: vec![1, 23, 125],
367 payload: vec![25, 225, 23, 125],
368 };
369
370 let mut buf = Vec::new();
371 msg.serialize(&mut buf);
372 let msg2 = ServiceCallResponse::deserialize(&mut buf.as_slice()).unwrap();
373
374 assert_eq!(msg, msg2);
375 }
376}