y-octo 0.0.3

High-performance and thread-safe CRDT implementation compatible with Yjs
Documentation
use byteorder::WriteBytesExt;

use super::*;

#[derive(Debug, Clone, PartialEq)]
enum MessageType {
    Auth,
    Awareness,
    AwarenessQuery,
    Doc,
}

fn read_sync_tag(input: &[u8]) -> IResult<&[u8], MessageType> {
    let (tail, tag) = read_var_u64(input)?;
    let tag = match tag {
        0 => MessageType::Doc,
        1 => MessageType::Awareness,
        2 => MessageType::Auth,
        3 => MessageType::AwarenessQuery,
        _ => return Err(nom::Err::Error(Error::new(input, ErrorKind::Tag))),
    };

    Ok((tail, tag))
}

fn write_sync_tag<W: Write>(buffer: &mut W, tag: MessageType) -> Result<(), IoError> {
    let tag: u64 = match tag {
        MessageType::Doc => 0,
        MessageType::Awareness => 1,
        MessageType::Auth => 2,
        MessageType::AwarenessQuery => 3,
    };

    write_var_u64(buffer, tag)?;

    Ok(())
}

// sync message
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub enum SyncMessage {
    Auth(Option<String>),
    Awareness(AwarenessStates),
    AwarenessQuery,
    Doc(DocMessage),
}

pub fn read_sync_message(input: &[u8]) -> IResult<&[u8], SyncMessage> {
    let (tail, tag) = read_sync_tag(input)?;

    let (tail, message) = match tag {
        MessageType::Doc => {
            let (tail, doc) = read_doc_message(tail)?;
            (tail, SyncMessage::Doc(doc))
        }
        MessageType::Awareness => {
            let (tail, update) = read_var_buffer(tail)?;
            (
                tail,
                SyncMessage::Awareness({
                    let (awareness_tail, awareness) = read_awareness(update)?;
                    let tail_len = awareness_tail.len();
                    if tail_len > 0 {
                        debug!("awareness update has trailing bytes: {tail_len}");
                        debug_assert!(tail_len > 0, "awareness update has trailing bytes");
                    }
                    awareness
                }),
            )
        }
        MessageType::Auth => {
            let (tail, success) = read_var_u64(tail)?;

            if success == 1 {
                (tail, SyncMessage::Auth(None))
            } else {
                let (tail, reason) = read_var_string(tail)?;
                (tail, SyncMessage::Auth(Some(reason)))
            }
        }
        MessageType::AwarenessQuery => (tail, SyncMessage::AwarenessQuery),
    };

    Ok((tail, message))
}

pub fn write_sync_message<W: Write>(buffer: &mut W, msg: &SyncMessage) -> Result<(), IoError> {
    match msg {
        SyncMessage::Auth(reason) => {
            const PERMISSION_DENIED: u8 = 0;
            const PERMISSION_GRANTED: u8 = 1;

            write_sync_tag(buffer, MessageType::Auth)?;
            if let Some(reason) = reason {
                buffer.write_u8(PERMISSION_DENIED)?;
                write_var_string(buffer, reason)?;
            } else {
                buffer.write_u8(PERMISSION_GRANTED)?;
            }
        }
        SyncMessage::AwarenessQuery => {
            write_sync_tag(buffer, MessageType::AwarenessQuery)?;
        }
        SyncMessage::Awareness(awareness) => {
            write_sync_tag(buffer, MessageType::Awareness)?;
            write_var_buffer(buffer, &{
                let mut update = Vec::new();
                write_awareness(&mut update, awareness)?;
                update
            })?;
        }
        SyncMessage::Doc(doc) => {
            write_sync_tag(buffer, MessageType::Doc)?;
            write_doc_message(buffer, doc)?;
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::{awareness::AwarenessState, *};

    #[test]
    fn test_sync_tag() {
        let messages = [
            MessageType::Auth,
            MessageType::Awareness,
            MessageType::AwarenessQuery,
            MessageType::Doc,
        ];

        for msg in messages {
            let mut buffer = Vec::new();

            write_sync_tag(&mut buffer, msg.clone()).unwrap();
            let (tail, decoded) = read_sync_tag(&buffer).unwrap();

            assert_eq!(tail.len(), 0);
            assert_eq!(decoded, msg);
        }
    }

    #[test]
    fn test_sync_message() {
        let messages = [
            SyncMessage::Auth(Some("reason".to_string())),
            SyncMessage::Awareness(HashMap::from([(1, AwarenessState::new(1, "test".into()))])),
            SyncMessage::AwarenessQuery,
            SyncMessage::Doc(DocMessage::Step1(vec![4, 5, 6])),
            SyncMessage::Doc(DocMessage::Step2(vec![7, 8, 9])),
            SyncMessage::Doc(DocMessage::Update(vec![10, 11, 12])),
        ];

        for msg in messages {
            let mut buffer = Vec::new();
            write_sync_message(&mut buffer, &msg).unwrap();
            let (tail, decoded) = read_sync_message(&buffer).unwrap();
            assert_eq!(tail.len(), 0);
            assert_eq!(decoded, msg);
        }
    }
}