digitalis_core/
client.rs

1use std::{mem::size_of, sync::Arc};
2
3use bytes::{Buf, BufMut};
4use serde::{Deserialize, Serialize};
5use tungstenite::Message;
6
7use crate::{common::*, DigitalisError, DigitalisResult};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ClientMessage {
11    Subscribe(Subscribe),
12    Unsubscribe(Unsubscribe),
13    Advertise(ClientAdvertise),
14    Unadvertise(ClientUnadvertise),
15    GetParameters(GetParameters),
16    SetParameters(SetParameters),
17    SubscribeParameterUpdates(SubscribeParameterUpdates),
18    UnsubscribeParameterUpdates(UnsubscribeParameterUpdates),
19    SubscribeConnectionGraph,
20    UnsubscribeConnectionGraph,
21    MessageData(MessageData),
22    ServiceCallRequest(ServiceCallRequest),
23    Close,
24}
25
26impl ClientMessage {
27    pub fn from_ws_message(raw_msg: Message) -> DigitalisResult<Self> {
28        match raw_msg {
29            Message::Binary(msg) => {
30                Ok(ClientBinaryMessage::deserialize(&mut msg.as_slice())?.into())
31            }
32            Message::Text(msg) => Ok(ClientJsonMessage::deserialize(&msg)?.into()),
33            Message::Close(_) => Ok(Self::Close),
34            m => Err(DigitalisError::UnexpectedWebsocketMessage(
35                format!("{}", m).into(),
36            )),
37        }
38    }
39}
40
41impl From<ClientJsonMessage> for ClientMessage {
42    fn from(msg: ClientJsonMessage) -> Self {
43        use ClientJsonMessage::*;
44        match msg {
45            Subscribe(msg) => Self::Subscribe(msg),
46            Unsubscribe(msg) => Self::Unsubscribe(msg),
47            Advertise(msg) => Self::Advertise(msg),
48            Unadvertise(msg) => Self::Unadvertise(msg),
49            GetParameters(msg) => Self::GetParameters(msg),
50            SetParameters(msg) => Self::SetParameters(msg),
51            SubscribeParameterUpdates(msg) => Self::SubscribeParameterUpdates(msg),
52            UnsubscribeParameterUpdates(msg) => Self::UnsubscribeParameterUpdates(msg),
53            SubscribeConnectionGraph => Self::SubscribeConnectionGraph,
54            UnsubscribeConnectionGraph => Self::UnsubscribeConnectionGraph,
55        }
56    }
57}
58
59impl From<ClientBinaryMessage> for ClientMessage {
60    fn from(msg: ClientBinaryMessage) -> Self {
61        use ClientBinaryMessage::*;
62        match msg {
63            MessageData(msg) => Self::MessageData(msg),
64            ServiceCallRequest(msg) => Self::ServiceCallRequest(msg),
65        }
66    }
67}
68
69macro_rules! impl_enum_from {
70    ($parent:ident, $child:ident,$child_ty:ident) => {
71        impl From<$child_ty> for $parent {
72            fn from(msg: $child_ty) -> Self {
73                $parent::$child(msg)
74            }
75        }
76    };
77    ($parent:ident, $child:ident) => {
78        impl_enum_from!($parent, $child, $child);
79    };
80}
81
82macro_rules! impl_into_text_message {
83    ($parent:ident, $child:ident) => {
84        impl $child {
85            pub fn into_message(self) -> DigitalisResult<Message> {
86                $parent::from(self).to_message()
87            }
88        }
89    };
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(tag = "op", rename_all = "camelCase")]
94pub enum ClientJsonMessage {
95    Subscribe(Subscribe),
96    Unsubscribe(Unsubscribe),
97    Advertise(ClientAdvertise),
98    Unadvertise(ClientUnadvertise),
99    GetParameters(GetParameters),
100    SetParameters(SetParameters),
101    SubscribeParameterUpdates(SubscribeParameterUpdates),
102    UnsubscribeParameterUpdates(UnsubscribeParameterUpdates),
103    SubscribeConnectionGraph,
104    UnsubscribeConnectionGraph,
105}
106
107impl ClientJsonMessage {
108    pub fn to_message(&self) -> DigitalisResult<Message> {
109        Ok(Message::Text(self.serialize()?))
110    }
111
112    pub fn serialize(&self) -> DigitalisResult<String> {
113        Ok(serde_json::to_string(self)?)
114    }
115
116    pub fn deserialize(text: &str) -> DigitalisResult<Self> {
117        Ok(serde_json::from_str(text)?)
118    }
119}
120
121impl_enum_from!(ClientJsonMessage, Subscribe);
122impl_enum_from!(ClientJsonMessage, Unsubscribe);
123impl_enum_from!(ClientJsonMessage, Advertise, ClientAdvertise);
124impl_enum_from!(ClientJsonMessage, Unadvertise, ClientUnadvertise);
125impl_enum_from!(ClientJsonMessage, GetParameters);
126impl_enum_from!(ClientJsonMessage, SetParameters);
127impl_enum_from!(ClientJsonMessage, SubscribeParameterUpdates);
128impl_enum_from!(ClientJsonMessage, UnsubscribeParameterUpdates);
129
130impl_into_text_message!(ClientJsonMessage, Subscribe);
131impl_into_text_message!(ClientJsonMessage, Unsubscribe);
132impl_into_text_message!(ClientJsonMessage, ClientAdvertise);
133impl_into_text_message!(ClientJsonMessage, ClientUnadvertise);
134impl_into_text_message!(ClientJsonMessage, GetParameters);
135impl_into_text_message!(ClientJsonMessage, SetParameters);
136impl_into_text_message!(ClientJsonMessage, SubscribeParameterUpdates);
137impl_into_text_message!(ClientJsonMessage, UnsubscribeParameterUpdates);
138
139#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "camelCase")]
141pub struct Subscribe {
142    pub subscriptions: Vec<SubscribeChannel>,
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct Unsubscribe {
148    pub subscription_ids: Vec<SubscriptionId>,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct ClientAdvertise {
154    pub channels: Vec<AdvertiseChannel>,
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct ClientUnadvertise {
160    pub channel_ids: Vec<ChannelId>,
161}
162
163#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct GetParameters {
166    pub parameter_names: Vec<String>,
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub id: Option<String>,
169}
170
171#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
172#[serde(rename_all = "camelCase")]
173pub struct SetParameters {
174    pub parameters: Vec<Parameter>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub id: Option<String>,
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
180#[serde(rename_all = "camelCase")]
181pub struct SubscribeParameterUpdates {
182    pub parameter_names: Vec<String>,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186#[serde(rename_all = "camelCase")]
187pub struct UnsubscribeParameterUpdates {
188    pub parameter_names: Vec<String>,
189}
190
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum ClientBinaryMessage {
193    MessageData(MessageData),
194    ServiceCallRequest(ServiceCallRequest),
195}
196
197impl ClientBinaryMessage {
198    pub fn serialize<T: BufMut>(&self, buf: &mut T) {
199        match self {
200            Self::MessageData(msg) => {
201                buf.put_u8(0x01);
202                msg.serialize(buf);
203            }
204            Self::ServiceCallRequest(msg) => {
205                buf.put_u8(0x02);
206                msg.serialize(buf);
207            }
208        }
209    }
210
211    pub fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
212        Ok(match buf.get_u8() {
213            0x01 => Self::from(MessageData::deserialize(buf)?),
214            0x02 => Self::from(ServiceCallRequest::deserialize(buf)?),
215            x => {
216                return Err(DigitalisError::BinaryDeserializeError(
217                    format!("Unknown protocol {}", x).into(),
218                ))
219            }
220        })
221    }
222}
223
224impl_enum_from!(ClientBinaryMessage, MessageData);
225impl_enum_from!(ClientBinaryMessage, ServiceCallRequest);
226
227#[derive(Debug, Clone, PartialEq, Eq)]
228pub struct MessageData {
229    pub channel_id: ChannelId,
230    pub payload: Arc<Vec<u8>>,
231}
232
233impl MessageData {
234    fn serialize<T: BufMut>(&self, buf: &mut T) {
235        buf.put_u32_le(self.channel_id);
236        buf.put_slice(&self.payload);
237    }
238
239    fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
240        if buf.remaining() < size_of::<u32>() {
241            return Err(DigitalisError::BinaryDeserializeError(
242                "Data is too short".into(),
243            ));
244        }
245
246        let channel_id = buf.get_u32_le();
247        let payload = buf.chunk().to_vec();
248        buf.advance(payload.len());
249
250        Ok(Self {
251            channel_id,
252            payload: Arc::new(payload),
253        })
254    }
255}
256
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct ServiceCallRequest {
259    pub service_id: ChannelId,
260    pub call_id: u32,
261    pub encoding: Vec<u8>,
262    pub payload: Vec<u8>,
263}
264
265impl ServiceCallRequest {
266    fn serialize<T: BufMut>(&self, buf: &mut T) {
267        buf.put_u32_le(self.service_id);
268        buf.put_u32_le(self.call_id);
269        buf.put_u32_le(self.encoding.len() as u32);
270        buf.put_slice(&self.encoding);
271        buf.put_slice(&self.payload);
272    }
273
274    fn deserialize<T: Buf>(buf: &mut T) -> DigitalisResult<Self> {
275        if buf.remaining() < size_of::<u32>() * 3 {
276            return Err(DigitalisError::BinaryDeserializeError(
277                "Data is too short".into(),
278            ));
279        }
280
281        let service_id = buf.get_u32_le();
282        let call_id = buf.get_u32_le();
283
284        let encoding_len = buf.get_u32_le() as usize;
285        if buf.remaining() < encoding_len {
286            return Err(DigitalisError::BinaryDeserializeError(
287                "Data is too short".into(),
288            ));
289        }
290        let encoding = buf.chunk()[..encoding_len].to_vec();
291        buf.advance(encoding.len());
292
293        let payload = buf.chunk().to_vec();
294        buf.advance(payload.len());
295
296        Ok(Self {
297            service_id,
298            call_id,
299            encoding,
300            payload,
301        })
302    }
303}