use {
bytes::Bytes, num_derive::FromPrimitive, num_traits::FromPrimitive,
sender_sink::wrappers::SinkError, serde_json::Value, warp::ws::Message,
};
use crate::types::{ConnID, Frame, Header};
pub trait Address<T> {
fn header(self, header: Header) -> (Header, T);
#[cfg(test)]
fn to(self, to: ConnID) -> (Header, T);
#[cfg(test)]
fn broadcast(self) -> (Header, T);
}
impl Address<Message> for Message {
fn header(self, header: Header) -> (Header, Message) {
(header, self)
}
#[cfg(test)]
fn to(self, to: ConnID) -> (Header, Message) {
(Header::to(to), self)
}
#[cfg(test)]
fn broadcast(self) -> (Header, Message) {
(Header::broadcast(), self)
}
}
pub fn deserialize(msg: &Bytes, frame: Option<Frame>) -> Result<(Header, &[u8]), &'static str> {
match frame {
Some(f) => match f {
Frame::GWSocket => {
let (header, msg_type, length, payload) = parse_binary_header(msg);
let effective_len = payload.len();
let header_len = length as usize;
assert_eq!(
effective_len, header_len,
"Message length {effective_len} does not match {header_len}: Chunked payloads are not supported",
);
if msg_type.is_none() {
return Err("Unknown message type");
}
Ok((header, payload))
}
Frame::JSON => Ok(parse_json_header(msg)),
},
None => Ok((Header::broadcast(), msg)),
}
}
pub fn serialize(msg: Message, conn: ConnID, frame: Option<Frame>) -> Result<Message, SinkError> {
match frame {
Some(f) => match f {
Frame::GWSocket => {
unimplemented!("Client side binary framing has not been implemented")
}
Frame::JSON => match serde_json::from_slice::<Value>(msg.as_bytes()) {
Ok(mut v) if v.is_object() => {
v["_from"] = Value::from(conn);
Ok(Message::text(v.to_string()))
}
Ok(_) => {
tracing::error!("bad data: message is not a JSON object");
Err(SinkError::SendFailed)
}
Err(_) => {
tracing::error!("bad data: message is not valid JSON");
Err(SinkError::SendFailed)
}
},
},
None => Ok(msg),
}
}
#[derive(FromPrimitive, ToPrimitive, Debug, PartialEq)]
pub enum Type {
Text = 1,
Binary = 2,
}
pub(crate) fn parse_json_header(msg: &Bytes) -> (Header, &[u8]) {
(
serde_json::from_slice::<Header>(msg).unwrap_or_default(),
msg,
)
}
pub(crate) fn parse_binary_header(data: &[u8]) -> (Header, Option<Type>, u32, &[u8]) {
let mut id_data = [0; 4];
id_data.copy_from_slice(&data[0..4]);
let id = u32::from_le_bytes(id_data);
let mut msg_type_data = [0; 4];
msg_type_data.copy_from_slice(&data[4..8]);
let msg_type = u32::from_le_bytes(msg_type_data);
let msg_type: Option<Type> = FromPrimitive::from_u32(msg_type);
let mut msg_len_data = [0; 4];
msg_len_data.copy_from_slice(&data[8..12]);
let msg_len = u32::from_le_bytes(msg_len_data);
let header = match id {
0 => Header::broadcast(),
id => Header::to(id),
};
(header, msg_type, msg_len, &data[12..])
}
#[cfg(test)]
mod tests {
use super::{Header, Type, parse_binary_header};
#[test]
fn test_parse_id() {
let payload = [
(123 as u32).to_le_bytes(),
(0 as u32).to_le_bytes(),
(0 as u32).to_le_bytes(),
]
.concat();
let (result, _, _, _) = parse_binary_header(&payload);
assert_eq!(result, Header::to(123));
}
#[test]
fn test_parse_id_0_is_broadcast() {
let payload = [
(0 as u32).to_le_bytes(),
(0 as u32).to_le_bytes(),
(0 as u32).to_le_bytes(),
]
.concat();
let (result, _, _, _) = parse_binary_header(&payload);
assert_eq!(result, Header::broadcast());
}
#[test]
fn test_parse_type() {
let payload = [
(0 as u32).to_le_bytes(),
(2 as u32).to_le_bytes(),
(0 as u32).to_le_bytes(),
]
.concat();
let (_, result, _, _) = parse_binary_header(&payload);
assert_eq!(result, Some(Type::Binary));
}
#[test]
fn test_parse_length() {
let payload = [
(0 as u32).to_le_bytes(),
(0 as u32).to_le_bytes(),
(123 as u32).to_le_bytes(),
]
.concat();
let (_, _, result, _) = parse_binary_header(&payload);
assert_eq!(result, 123);
}
}