use std::io::Cursor;
use crate::message::MessageType;
use crate::message::base::{BaseMessage, ProtoError};
use crate::message::client_info::ClientInfo;
use crate::message::codec_header::CodecHeader;
use crate::message::error::Error;
use crate::message::hello::Hello;
use crate::message::server_settings::ServerSettings;
use crate::message::time::Time;
use crate::message::wire_chunk::WireChunk;
#[derive(Debug)]
pub struct TypedMessage {
pub base: BaseMessage,
pub payload: MessagePayload,
}
#[derive(Debug)]
pub enum MessagePayload {
Time(Time),
Hello(Hello),
ServerSettings(ServerSettings),
CodecHeader(CodecHeader),
WireChunk(WireChunk),
ClientInfo(ClientInfo),
Error(Error),
#[cfg(feature = "custom-protocol")]
Custom(Vec<u8>),
}
pub fn deserialize(base: BaseMessage, payload: &[u8]) -> Result<TypedMessage, ProtoError> {
let mut cursor = Cursor::new(payload);
let msg = match base.msg_type {
MessageType::Time => MessagePayload::Time(Time::read_from(&mut cursor)?),
MessageType::Hello => MessagePayload::Hello(Hello::read_from(&mut cursor)?),
MessageType::ServerSettings => {
MessagePayload::ServerSettings(ServerSettings::read_from(&mut cursor)?)
}
MessageType::CodecHeader => {
MessagePayload::CodecHeader(CodecHeader::read_from(&mut cursor)?)
}
MessageType::WireChunk => MessagePayload::WireChunk(WireChunk::read_from(&mut cursor)?),
MessageType::ClientInfo => MessagePayload::ClientInfo(ClientInfo::read_from(&mut cursor)?),
MessageType::Error => MessagePayload::Error(Error::read_from(&mut cursor)?),
MessageType::Base => {
return Err(ProtoError::UnknownMessageType(0));
}
#[cfg(feature = "custom-protocol")]
MessageType::Custom(_) => MessagePayload::Custom(payload.to_vec()),
};
Ok(TypedMessage { base, payload: msg })
}
pub fn serialize(base: &mut BaseMessage, payload: &MessagePayload) -> Result<Vec<u8>, ProtoError> {
let mut payload_buf = Vec::new();
match payload {
MessagePayload::Time(m) => m.write_to(&mut payload_buf)?,
MessagePayload::Hello(m) => m.write_to(&mut payload_buf)?,
MessagePayload::ServerSettings(m) => m.write_to(&mut payload_buf)?,
MessagePayload::CodecHeader(m) => m.write_to(&mut payload_buf)?,
MessagePayload::WireChunk(m) => m.write_to(&mut payload_buf)?,
MessagePayload::ClientInfo(m) => m.write_to(&mut payload_buf)?,
MessagePayload::Error(m) => m.write_to(&mut payload_buf)?,
#[cfg(feature = "custom-protocol")]
MessagePayload::Custom(data) => payload_buf.extend_from_slice(data),
}
base.size = payload_buf.len() as u32;
let mut frame = Vec::with_capacity(BaseMessage::HEADER_SIZE + payload_buf.len());
base.write_to(&mut frame)?;
frame.extend_from_slice(&payload_buf);
Ok(frame)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Timeval;
fn make_base(msg_type: MessageType, size: u32) -> BaseMessage {
BaseMessage {
msg_type,
id: 0,
refers_to: 0,
sent: Timeval::default(),
received: Timeval::default(),
size,
}
}
#[test]
fn deserialize_time() {
let payload = [0x00, 0x00, 0x00, 0x00, 0xDC, 0x05, 0x00, 0x00];
let base = make_base(MessageType::Time, payload.len() as u32);
let msg = deserialize(base, &payload).unwrap();
match msg.payload {
MessagePayload::Time(t) => assert_eq!(t.latency.usec, 1500),
_ => panic!("expected Time"),
}
}
#[test]
fn deserialize_server_settings() {
let json = r#"{"bufferMs":1000,"latency":0,"muted":false,"volume":100}"#;
let mut payload = Vec::new();
crate::message::wire::write_string(&mut payload, json).unwrap();
let base = make_base(MessageType::ServerSettings, payload.len() as u32);
let msg = deserialize(base, &payload).unwrap();
match msg.payload {
MessagePayload::ServerSettings(ss) => {
assert_eq!(ss.buffer_ms, 1000);
assert_eq!(ss.volume, 100);
}
_ => panic!("expected ServerSettings"),
}
}
#[test]
fn deserialize_codec_header() {
let mut payload = Vec::new();
let ch = CodecHeader {
codec: "flac".into(),
payload: vec![0x66, 0x4C, 0x61, 0x43],
};
ch.write_to(&mut payload).unwrap();
let base = make_base(MessageType::CodecHeader, payload.len() as u32);
let msg = deserialize(base, &payload).unwrap();
match msg.payload {
MessagePayload::CodecHeader(c) => {
assert_eq!(c.codec, "flac");
assert_eq!(c.payload, vec![0x66, 0x4C, 0x61, 0x43]);
}
_ => panic!("expected CodecHeader"),
}
}
#[test]
fn deserialize_wire_chunk() {
let mut payload = Vec::new();
let wc = WireChunk {
timestamp: Timeval { sec: 1, usec: 0 },
payload: vec![0xAA, 0xBB],
};
wc.write_to(&mut payload).unwrap();
let base = make_base(MessageType::WireChunk, payload.len() as u32);
let msg = deserialize(base, &payload).unwrap();
match msg.payload {
MessagePayload::WireChunk(w) => {
assert_eq!(w.timestamp.sec, 1);
assert_eq!(w.payload, vec![0xAA, 0xBB]);
}
_ => panic!("expected WireChunk"),
}
}
#[test]
fn deserialize_error() {
let mut payload = Vec::new();
let err = Error {
code: 401,
error: "Unauthorized".into(),
message: "bad creds".into(),
};
err.write_to(&mut payload).unwrap();
let base = make_base(MessageType::Error, payload.len() as u32);
let msg = deserialize(base, &payload).unwrap();
match msg.payload {
MessagePayload::Error(e) => {
assert_eq!(e.code, 401);
assert_eq!(e.error, "Unauthorized");
}
_ => panic!("expected Error"),
}
}
#[test]
fn deserialize_base_type_returns_error() {
let base = make_base(MessageType::Base, 0);
assert!(deserialize(base, &[]).is_err());
}
#[test]
fn full_frame_round_trip_time() {
let payload = MessagePayload::Time(Time {
latency: Timeval {
sec: 5,
usec: 999_000,
},
});
let mut base = make_base(MessageType::Time, 0);
let frame = serialize(&mut base, &payload).unwrap();
let mut cursor = std::io::Cursor::new(&frame);
let decoded_base = BaseMessage::read_from(&mut cursor).unwrap();
assert_eq!(decoded_base.size, Time::SIZE);
let payload_bytes = &frame[BaseMessage::HEADER_SIZE..];
let msg = deserialize(decoded_base, payload_bytes).unwrap();
match msg.payload {
MessagePayload::Time(t) => {
assert_eq!(t.latency.sec, 5);
assert_eq!(t.latency.usec, 999_000);
}
_ => panic!("expected Time"),
}
}
#[test]
fn full_frame_round_trip_error() {
let payload = MessagePayload::Error(Error {
code: 403,
error: "Forbidden".into(),
message: "not allowed".into(),
});
let mut base = make_base(MessageType::Error, 0);
let frame = serialize(&mut base, &payload).unwrap();
let mut cursor = std::io::Cursor::new(&frame);
let decoded_base = BaseMessage::read_from(&mut cursor).unwrap();
let payload_bytes = &frame[BaseMessage::HEADER_SIZE..];
let msg = deserialize(decoded_base, payload_bytes).unwrap();
match msg.payload {
MessagePayload::Error(e) => {
assert_eq!(e.code, 403);
assert_eq!(e.error, "Forbidden");
assert_eq!(e.message, "not allowed");
}
_ => panic!("expected Error"),
}
}
#[test]
fn serialize_sets_correct_size() {
let payload = MessagePayload::ClientInfo(ClientInfo {
volume: 50,
muted: true,
});
let mut base = make_base(MessageType::ClientInfo, 0);
assert_eq!(base.size, 0);
let frame = serialize(&mut base, &payload).unwrap();
assert!(base.size > 0);
assert_eq!(frame.len(), BaseMessage::HEADER_SIZE + base.size as usize);
}
#[cfg(feature = "custom-protocol")]
#[test]
fn round_trip_custom_message() {
let payload = b"hello custom";
let mut base = make_base(MessageType::Custom(42), 0);
let frame = serialize(&mut base, &MessagePayload::Custom(payload.to_vec())).unwrap();
let mut cursor = std::io::Cursor::new(&frame);
let header = BaseMessage::read_from(&mut cursor).unwrap();
assert_eq!(header.msg_type, MessageType::Custom(42));
assert_eq!(header.size, payload.len() as u32);
let mut body = vec![0u8; header.size as usize];
std::io::Read::read_exact(&mut cursor, &mut body).unwrap();
let msg = deserialize(header, &body).unwrap();
match msg.payload {
MessagePayload::Custom(data) => assert_eq!(data, payload),
_ => panic!("expected Custom"),
}
}
}