protocol-core 0.3.12

Reusable mesh protocol framing, sync, and repair primitives.
Documentation
//! Device sync protocol for transferring append-only object history.

use dcbor::prelude::*;

use crate::{repair_peer::FrameType, ProtocolError};

const KEY_DEVICE_ID: u64 = 0;
const KEY_ROOMS: u64 = 1;
const KEY_ROOM_ID: u64 = 0;
const KEY_SENDERS: u64 = 1;
const KEY_SENDER_ID: u64 = 0;
const KEY_LATEST_SEQ: u64 = 1;
const KEY_EARLIEST_SEQ: u64 = 2;
const KEY_CHUNK_ID: u64 = 0;
const KEY_TOTAL_CHUNKS: u64 = 1;
const KEY_ENVELOPES: u64 = 2;
const KEY_ACCEPTED_COUNT: u64 = 1;
const KEY_REJECTED_COUNT: u64 = 2;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SenderSyncState {
    pub sender_id: Vec<u8>,
    pub latest_seq: u64,
    pub earliest_seq: u64,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RoomSyncState {
    pub room_id: Vec<u8>,
    pub senders: Vec<SenderSyncState>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncManifest {
    pub device_id: Vec<u8>,
    pub rooms: Vec<RoomSyncState>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncChunk {
    pub chunk_id: u32,
    pub total_chunks: u32,
    pub envelopes: Vec<Vec<u8>>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncAck {
    pub chunk_id: u32,
    pub accepted_count: u32,
    pub rejected_count: u32,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncRange {
    pub room_id: Vec<u8>,
    pub sender_id: Vec<u8>,
    pub from_seq: u64,
    pub to_seq: u64,
}

fn prepend_frame(frame: FrameType, body: Vec<u8>) -> Vec<u8> {
    let mut out = Vec::with_capacity(1 + body.len());
    out.push(frame as u8);
    out.extend_from_slice(&body);
    out
}

fn strip_frame(expected: FrameType, bytes: &[u8]) -> Result<&[u8], ProtocolError> {
    match bytes.first() {
        None => Err(ProtocolError::InvalidEncoding("empty frame".to_string())),
        Some(&b) if b != expected as u8 => Err(ProtocolError::InvalidEncoding(format!(
            "expected frame type 0x{:02x}, got 0x{:02x}",
            expected as u8, b
        ))),
        _ => Ok(&bytes[1..]),
    }
}

fn parse_map(bytes: &[u8]) -> Result<Map, ProtocolError> {
    let cbor =
        CBOR::try_from_data(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))?;
    cbor.try_into_map()
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}

fn extract_bytes(map: &Map, key: u64) -> Result<Vec<u8>, ProtocolError> {
    let cbor: CBOR = map
        .extract(key)
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
    cbor.try_into_byte_string()
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}

fn extract_u64(map: &Map, key: u64) -> Result<u64, ProtocolError> {
    map.extract(key)
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}

fn extract_u32(map: &Map, key: u64, field: &str) -> Result<u32, ProtocolError> {
    let value = extract_u64(map, key)?;
    u32::try_from(value)
        .map_err(|_| ProtocolError::InvalidEnvelope(format!("{field} exceeds u32 range: {value}")))
}

fn extract_array(map: &Map, key: u64) -> Result<Vec<CBOR>, ProtocolError> {
    let cbor: CBOR = map
        .extract(key)
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
    cbor.try_into_array()
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}

fn encode_sender_sync_state(state: &SenderSyncState) -> CBOR {
    let mut map = Map::new();
    map.insert(KEY_SENDER_ID, CBOR::to_byte_string(&state.sender_id));
    map.insert(KEY_LATEST_SEQ, state.latest_seq);
    map.insert(KEY_EARLIEST_SEQ, state.earliest_seq);
    CBOR::from(map)
}

fn decode_sender_sync_state(cbor: CBOR) -> Result<SenderSyncState, ProtocolError> {
    let map = cbor
        .try_into_map()
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
    Ok(SenderSyncState {
        sender_id: extract_bytes(&map, KEY_SENDER_ID)?,
        latest_seq: extract_u64(&map, KEY_LATEST_SEQ)?,
        earliest_seq: extract_u64(&map, KEY_EARLIEST_SEQ)?,
    })
}

fn encode_room_sync_state(room: &RoomSyncState) -> CBOR {
    let mut map = Map::new();
    map.insert(KEY_ROOM_ID, CBOR::to_byte_string(&room.room_id));
    let senders: Vec<CBOR> = room.senders.iter().map(encode_sender_sync_state).collect();
    map.insert(KEY_SENDERS, CBOR::from(senders));
    CBOR::from(map)
}

fn decode_room_sync_state(cbor: CBOR) -> Result<RoomSyncState, ProtocolError> {
    let map = cbor
        .try_into_map()
        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
    let room_id = extract_bytes(&map, KEY_ROOM_ID)?;
    let senders = extract_array(&map, KEY_SENDERS)?
        .into_iter()
        .map(decode_sender_sync_state)
        .collect::<Result<Vec<_>, _>>()?;
    Ok(RoomSyncState { room_id, senders })
}

pub fn encode_sync_manifest(manifest: &SyncManifest) -> Vec<u8> {
    let mut map = Map::new();
    map.insert(KEY_DEVICE_ID, CBOR::to_byte_string(&manifest.device_id));
    let rooms: Vec<CBOR> = manifest.rooms.iter().map(encode_room_sync_state).collect();
    map.insert(KEY_ROOMS, CBOR::from(rooms));
    prepend_frame(FrameType::SyncManifest, CBOR::from(map).to_cbor_data())
}

pub fn decode_sync_manifest(bytes: &[u8]) -> Result<SyncManifest, ProtocolError> {
    let body = strip_frame(FrameType::SyncManifest, bytes)?;
    let map = parse_map(body)?;
    let device_id = extract_bytes(&map, KEY_DEVICE_ID)?;
    let rooms = extract_array(&map, KEY_ROOMS)?
        .into_iter()
        .map(decode_room_sync_state)
        .collect::<Result<Vec<_>, _>>()?;
    Ok(SyncManifest { device_id, rooms })
}

pub fn encode_sync_chunk(chunk: &SyncChunk) -> Vec<u8> {
    let mut map = Map::new();
    map.insert(KEY_CHUNK_ID, u64::from(chunk.chunk_id));
    map.insert(KEY_TOTAL_CHUNKS, u64::from(chunk.total_chunks));
    let envelopes: Vec<CBOR> = chunk.envelopes.iter().map(CBOR::to_byte_string).collect();
    map.insert(KEY_ENVELOPES, CBOR::from(envelopes));
    prepend_frame(FrameType::SyncChunk, CBOR::from(map).to_cbor_data())
}

pub fn decode_sync_chunk(bytes: &[u8]) -> Result<SyncChunk, ProtocolError> {
    let body = strip_frame(FrameType::SyncChunk, bytes)?;
    let map = parse_map(body)?;
    let chunk_id = extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?;
    let total_chunks = extract_u32(&map, KEY_TOTAL_CHUNKS, "total_chunks")?;
    let envelopes = extract_array(&map, KEY_ENVELOPES)?
        .into_iter()
        .map(|cbor| {
            cbor.try_into_byte_string()
                .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
        })
        .collect::<Result<Vec<_>, _>>()?;
    Ok(SyncChunk {
        chunk_id,
        total_chunks,
        envelopes,
    })
}

pub fn encode_sync_ack(ack: &SyncAck) -> Vec<u8> {
    let mut map = Map::new();
    map.insert(KEY_CHUNK_ID, u64::from(ack.chunk_id));
    map.insert(KEY_ACCEPTED_COUNT, u64::from(ack.accepted_count));
    map.insert(KEY_REJECTED_COUNT, u64::from(ack.rejected_count));
    prepend_frame(FrameType::SyncAck, CBOR::from(map).to_cbor_data())
}

pub fn decode_sync_ack(bytes: &[u8]) -> Result<SyncAck, ProtocolError> {
    let body = strip_frame(FrameType::SyncAck, bytes)?;
    let map = parse_map(body)?;
    Ok(SyncAck {
        chunk_id: extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?,
        accepted_count: extract_u32(&map, KEY_ACCEPTED_COUNT, "accepted_count")?,
        rejected_count: extract_u32(&map, KEY_REJECTED_COUNT, "rejected_count")?,
    })
}

/// Compute the ranges the local side has that the remote side lacks.
pub fn diff_manifests(local: &SyncManifest, remote: &SyncManifest) -> Vec<SyncRange> {
    let mut ranges = Vec::new();

    for local_room in &local.rooms {
        let remote_room = remote
            .rooms
            .iter()
            .find(|room| room.room_id == local_room.room_id);

        for local_sender in &local_room.senders {
            let remote_sender = remote_room.and_then(|room| {
                room.senders
                    .iter()
                    .find(|sender| sender.sender_id == local_sender.sender_id)
            });

            match remote_sender {
                None => ranges.push(SyncRange {
                    room_id: local_room.room_id.clone(),
                    sender_id: local_sender.sender_id.clone(),
                    from_seq: local_sender.earliest_seq,
                    to_seq: local_sender.latest_seq,
                }),
                Some(remote_sender) if local_sender.latest_seq > remote_sender.latest_seq => {
                    ranges.push(SyncRange {
                        room_id: local_room.room_id.clone(),
                        sender_id: local_sender.sender_id.clone(),
                        from_seq: remote_sender.latest_seq + 1,
                        to_seq: local_sender.latest_seq,
                    });
                }
                _ => {}
            }
        }
    }

    ranges
}

#[cfg(test)]
mod tests {
    use super::*;

    fn sample_manifest(device_byte: u8) -> SyncManifest {
        SyncManifest {
            device_id: vec![device_byte; 16],
            rooms: vec![RoomSyncState {
                room_id: vec![0xAA; 8],
                senders: vec![
                    SenderSyncState {
                        sender_id: vec![0x01; 8],
                        latest_seq: 10,
                        earliest_seq: 1,
                    },
                    SenderSyncState {
                        sender_id: vec![0x02; 8],
                        latest_seq: 5,
                        earliest_seq: 5,
                    },
                ],
            }],
        }
    }

    #[test]
    fn test_sync_manifest_round_trip() {
        let manifest = sample_manifest(0xDE);
        let encoded = encode_sync_manifest(&manifest);
        assert_eq!(encoded[0], FrameType::SyncManifest as u8);
        let decoded = decode_sync_manifest(&encoded).expect("decode must succeed");
        assert_eq!(decoded, manifest);
    }

    #[test]
    fn test_sync_chunk_round_trip() {
        let chunk = SyncChunk {
            chunk_id: 2,
            total_chunks: 5,
            envelopes: vec![vec![0xAA, 0xBB], vec![0xCC, 0xDD, 0xEE]],
        };
        let encoded = encode_sync_chunk(&chunk);
        assert_eq!(encoded[0], FrameType::SyncChunk as u8);
        let decoded = decode_sync_chunk(&encoded).expect("decode must succeed");
        assert_eq!(decoded, chunk);
    }

    #[test]
    fn test_sync_ack_round_trip() {
        let ack = SyncAck {
            chunk_id: 3,
            accepted_count: 10,
            rejected_count: 2,
        };
        let encoded = encode_sync_ack(&ack);
        assert_eq!(encoded[0], FrameType::SyncAck as u8);
        let decoded = decode_sync_ack(&encoded).expect("decode must succeed");
        assert_eq!(decoded, ack);
    }

    fn make_sender(id_byte: u8, earliest: u64, latest: u64) -> SenderSyncState {
        SenderSyncState {
            sender_id: vec![id_byte],
            latest_seq: latest,
            earliest_seq: earliest,
        }
    }

    fn make_room(room_byte: u8, senders: Vec<SenderSyncState>) -> RoomSyncState {
        RoomSyncState {
            room_id: vec![room_byte],
            senders,
        }
    }

    fn make_manifest(rooms: Vec<RoomSyncState>) -> SyncManifest {
        SyncManifest {
            device_id: vec![0x00],
            rooms,
        }
    }

    #[test]
    fn test_diff_local_has_room_remote_doesnt_returns_full_range() {
        let local = make_manifest(vec![make_room(0xAA, vec![make_sender(0x01, 3, 9)])]);
        let remote = make_manifest(vec![]);
        let ranges = diff_manifests(&local, &remote);
        assert_eq!(ranges.len(), 1);
        assert_eq!(ranges[0].room_id, vec![0xAA]);
        assert_eq!(ranges[0].sender_id, vec![0x01]);
        assert_eq!(ranges[0].from_seq, 3);
        assert_eq!(ranges[0].to_seq, 9);
    }

    #[test]
    fn test_diff_local_higher_seq_than_remote_returns_partial_range() {
        let local = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 20)])]);
        let remote = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 12)])]);
        let ranges = diff_manifests(&local, &remote);
        assert_eq!(ranges.len(), 1);
        assert_eq!(ranges[0].from_seq, 13);
        assert_eq!(ranges[0].to_seq, 20);
    }
}