use byteorder::WriteBytesExt;
use super::*;
#[derive(Debug, Clone, PartialEq)]
enum MessageType {
Auth,
Awareness,
AwarenessQuery,
Doc,
}
fn read_sync_tag(input: &[u8]) -> IResult<&[u8], MessageType> {
let (tail, tag) = read_var_u64(input)?;
let tag = match tag {
0 => MessageType::Doc,
1 => MessageType::Awareness,
2 => MessageType::Auth,
3 => MessageType::AwarenessQuery,
_ => return Err(nom::Err::Error(Error::new(input, ErrorKind::Tag))),
};
Ok((tail, tag))
}
fn write_sync_tag<W: Write>(buffer: &mut W, tag: MessageType) -> Result<(), IoError> {
let tag: u64 = match tag {
MessageType::Doc => 0,
MessageType::Awareness => 1,
MessageType::Auth => 2,
MessageType::AwarenessQuery => 3,
};
write_var_u64(buffer, tag)?;
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub enum SyncMessage {
Auth(Option<String>),
Awareness(AwarenessStates),
AwarenessQuery,
Doc(DocMessage),
}
pub fn read_sync_message(input: &[u8]) -> IResult<&[u8], SyncMessage> {
let (tail, tag) = read_sync_tag(input)?;
let (tail, message) = match tag {
MessageType::Doc => {
let (tail, doc) = read_doc_message(tail)?;
(tail, SyncMessage::Doc(doc))
}
MessageType::Awareness => {
let (tail, update) = read_var_buffer(tail)?;
(
tail,
SyncMessage::Awareness({
let (awareness_tail, awareness) = read_awareness(update)?;
let tail_len = awareness_tail.len();
if tail_len > 0 {
debug!("awareness update has trailing bytes: {tail_len}");
debug_assert!(tail_len > 0, "awareness update has trailing bytes");
}
awareness
}),
)
}
MessageType::Auth => {
let (tail, success) = read_var_u64(tail)?;
if success == 1 {
(tail, SyncMessage::Auth(None))
} else {
let (tail, reason) = read_var_string(tail)?;
(tail, SyncMessage::Auth(Some(reason)))
}
}
MessageType::AwarenessQuery => (tail, SyncMessage::AwarenessQuery),
};
Ok((tail, message))
}
pub fn write_sync_message<W: Write>(buffer: &mut W, msg: &SyncMessage) -> Result<(), IoError> {
match msg {
SyncMessage::Auth(reason) => {
const PERMISSION_DENIED: u8 = 0;
const PERMISSION_GRANTED: u8 = 1;
write_sync_tag(buffer, MessageType::Auth)?;
if let Some(reason) = reason {
buffer.write_u8(PERMISSION_DENIED)?;
write_var_string(buffer, reason)?;
} else {
buffer.write_u8(PERMISSION_GRANTED)?;
}
}
SyncMessage::AwarenessQuery => {
write_sync_tag(buffer, MessageType::AwarenessQuery)?;
}
SyncMessage::Awareness(awareness) => {
write_sync_tag(buffer, MessageType::Awareness)?;
write_var_buffer(buffer, &{
let mut update = Vec::new();
write_awareness(&mut update, awareness)?;
update
})?;
}
SyncMessage::Doc(doc) => {
write_sync_tag(buffer, MessageType::Doc)?;
write_doc_message(buffer, doc)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{awareness::AwarenessState, *};
#[test]
fn test_sync_tag() {
let messages = [
MessageType::Auth,
MessageType::Awareness,
MessageType::AwarenessQuery,
MessageType::Doc,
];
for msg in messages {
let mut buffer = Vec::new();
write_sync_tag(&mut buffer, msg.clone()).unwrap();
let (tail, decoded) = read_sync_tag(&buffer).unwrap();
assert_eq!(tail.len(), 0);
assert_eq!(decoded, msg);
}
}
#[test]
fn test_sync_message() {
let messages = [
SyncMessage::Auth(Some("reason".to_string())),
SyncMessage::Awareness(HashMap::from([(1, AwarenessState::new(1, "test".into()))])),
SyncMessage::AwarenessQuery,
SyncMessage::Doc(DocMessage::Step1(vec![4, 5, 6])),
SyncMessage::Doc(DocMessage::Step2(vec![7, 8, 9])),
SyncMessage::Doc(DocMessage::Update(vec![10, 11, 12])),
];
for msg in messages {
let mut buffer = Vec::new();
write_sync_message(&mut buffer, &msg).unwrap();
let (tail, decoded) = read_sync_message(&buffer).unwrap();
assert_eq!(tail.len(), 0);
assert_eq!(decoded, msg);
}
}
}