use dcbor::prelude::*;
use crate::{repair_peer::FrameType, ProtocolError};
const KEY_DEVICE_ID: u64 = 0;
const KEY_ROOMS: u64 = 1;
const KEY_ROOM_ID: u64 = 0;
const KEY_SENDERS: u64 = 1;
const KEY_SENDER_ID: u64 = 0;
const KEY_LATEST_SEQ: u64 = 1;
const KEY_EARLIEST_SEQ: u64 = 2;
const KEY_CHUNK_ID: u64 = 0;
const KEY_TOTAL_CHUNKS: u64 = 1;
const KEY_ENVELOPES: u64 = 2;
const KEY_ACCEPTED_COUNT: u64 = 1;
const KEY_REJECTED_COUNT: u64 = 2;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SenderSyncState {
pub sender_id: Vec<u8>,
pub latest_seq: u64,
pub earliest_seq: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RoomSyncState {
pub room_id: Vec<u8>,
pub senders: Vec<SenderSyncState>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncManifest {
pub device_id: Vec<u8>,
pub rooms: Vec<RoomSyncState>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncChunk {
pub chunk_id: u32,
pub total_chunks: u32,
pub envelopes: Vec<Vec<u8>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncAck {
pub chunk_id: u32,
pub accepted_count: u32,
pub rejected_count: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncRange {
pub room_id: Vec<u8>,
pub sender_id: Vec<u8>,
pub from_seq: u64,
pub to_seq: u64,
}
fn prepend_frame(frame: FrameType, body: Vec<u8>) -> Vec<u8> {
let mut out = Vec::with_capacity(1 + body.len());
out.push(frame as u8);
out.extend_from_slice(&body);
out
}
fn strip_frame(expected: FrameType, bytes: &[u8]) -> Result<&[u8], ProtocolError> {
match bytes.first() {
None => Err(ProtocolError::InvalidEncoding("empty frame".to_string())),
Some(&b) if b != expected as u8 => Err(ProtocolError::InvalidEncoding(format!(
"expected frame type 0x{:02x}, got 0x{:02x}",
expected as u8, b
))),
_ => Ok(&bytes[1..]),
}
}
fn parse_map(bytes: &[u8]) -> Result<Map, ProtocolError> {
let cbor =
CBOR::try_from_data(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))?;
cbor.try_into_map()
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}
fn extract_bytes(map: &Map, key: u64) -> Result<Vec<u8>, ProtocolError> {
let cbor: CBOR = map
.extract(key)
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
cbor.try_into_byte_string()
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}
fn extract_u64(map: &Map, key: u64) -> Result<u64, ProtocolError> {
map.extract(key)
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}
fn extract_u32(map: &Map, key: u64, field: &str) -> Result<u32, ProtocolError> {
let value = extract_u64(map, key)?;
u32::try_from(value)
.map_err(|_| ProtocolError::InvalidEnvelope(format!("{field} exceeds u32 range: {value}")))
}
fn extract_array(map: &Map, key: u64) -> Result<Vec<CBOR>, ProtocolError> {
let cbor: CBOR = map
.extract(key)
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
cbor.try_into_array()
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
}
fn encode_sender_sync_state(state: &SenderSyncState) -> CBOR {
let mut map = Map::new();
map.insert(KEY_SENDER_ID, CBOR::to_byte_string(&state.sender_id));
map.insert(KEY_LATEST_SEQ, state.latest_seq);
map.insert(KEY_EARLIEST_SEQ, state.earliest_seq);
CBOR::from(map)
}
fn decode_sender_sync_state(cbor: CBOR) -> Result<SenderSyncState, ProtocolError> {
let map = cbor
.try_into_map()
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
Ok(SenderSyncState {
sender_id: extract_bytes(&map, KEY_SENDER_ID)?,
latest_seq: extract_u64(&map, KEY_LATEST_SEQ)?,
earliest_seq: extract_u64(&map, KEY_EARLIEST_SEQ)?,
})
}
fn encode_room_sync_state(room: &RoomSyncState) -> CBOR {
let mut map = Map::new();
map.insert(KEY_ROOM_ID, CBOR::to_byte_string(&room.room_id));
let senders: Vec<CBOR> = room.senders.iter().map(encode_sender_sync_state).collect();
map.insert(KEY_SENDERS, CBOR::from(senders));
CBOR::from(map)
}
fn decode_room_sync_state(cbor: CBOR) -> Result<RoomSyncState, ProtocolError> {
let map = cbor
.try_into_map()
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
let room_id = extract_bytes(&map, KEY_ROOM_ID)?;
let senders = extract_array(&map, KEY_SENDERS)?
.into_iter()
.map(decode_sender_sync_state)
.collect::<Result<Vec<_>, _>>()?;
Ok(RoomSyncState { room_id, senders })
}
pub fn encode_sync_manifest(manifest: &SyncManifest) -> Vec<u8> {
let mut map = Map::new();
map.insert(KEY_DEVICE_ID, CBOR::to_byte_string(&manifest.device_id));
let rooms: Vec<CBOR> = manifest.rooms.iter().map(encode_room_sync_state).collect();
map.insert(KEY_ROOMS, CBOR::from(rooms));
prepend_frame(FrameType::SyncManifest, CBOR::from(map).to_cbor_data())
}
pub fn decode_sync_manifest(bytes: &[u8]) -> Result<SyncManifest, ProtocolError> {
let body = strip_frame(FrameType::SyncManifest, bytes)?;
let map = parse_map(body)?;
let device_id = extract_bytes(&map, KEY_DEVICE_ID)?;
let rooms = extract_array(&map, KEY_ROOMS)?
.into_iter()
.map(decode_room_sync_state)
.collect::<Result<Vec<_>, _>>()?;
Ok(SyncManifest { device_id, rooms })
}
pub fn encode_sync_chunk(chunk: &SyncChunk) -> Vec<u8> {
let mut map = Map::new();
map.insert(KEY_CHUNK_ID, u64::from(chunk.chunk_id));
map.insert(KEY_TOTAL_CHUNKS, u64::from(chunk.total_chunks));
let envelopes: Vec<CBOR> = chunk.envelopes.iter().map(CBOR::to_byte_string).collect();
map.insert(KEY_ENVELOPES, CBOR::from(envelopes));
prepend_frame(FrameType::SyncChunk, CBOR::from(map).to_cbor_data())
}
pub fn decode_sync_chunk(bytes: &[u8]) -> Result<SyncChunk, ProtocolError> {
let body = strip_frame(FrameType::SyncChunk, bytes)?;
let map = parse_map(body)?;
let chunk_id = extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?;
let total_chunks = extract_u32(&map, KEY_TOTAL_CHUNKS, "total_chunks")?;
let envelopes = extract_array(&map, KEY_ENVELOPES)?
.into_iter()
.map(|cbor| {
cbor.try_into_byte_string()
.map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(SyncChunk {
chunk_id,
total_chunks,
envelopes,
})
}
pub fn encode_sync_ack(ack: &SyncAck) -> Vec<u8> {
let mut map = Map::new();
map.insert(KEY_CHUNK_ID, u64::from(ack.chunk_id));
map.insert(KEY_ACCEPTED_COUNT, u64::from(ack.accepted_count));
map.insert(KEY_REJECTED_COUNT, u64::from(ack.rejected_count));
prepend_frame(FrameType::SyncAck, CBOR::from(map).to_cbor_data())
}
pub fn decode_sync_ack(bytes: &[u8]) -> Result<SyncAck, ProtocolError> {
let body = strip_frame(FrameType::SyncAck, bytes)?;
let map = parse_map(body)?;
Ok(SyncAck {
chunk_id: extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?,
accepted_count: extract_u32(&map, KEY_ACCEPTED_COUNT, "accepted_count")?,
rejected_count: extract_u32(&map, KEY_REJECTED_COUNT, "rejected_count")?,
})
}
pub fn diff_manifests(local: &SyncManifest, remote: &SyncManifest) -> Vec<SyncRange> {
let mut ranges = Vec::new();
for local_room in &local.rooms {
let remote_room = remote
.rooms
.iter()
.find(|room| room.room_id == local_room.room_id);
for local_sender in &local_room.senders {
let remote_sender = remote_room.and_then(|room| {
room.senders
.iter()
.find(|sender| sender.sender_id == local_sender.sender_id)
});
match remote_sender {
None => ranges.push(SyncRange {
room_id: local_room.room_id.clone(),
sender_id: local_sender.sender_id.clone(),
from_seq: local_sender.earliest_seq,
to_seq: local_sender.latest_seq,
}),
Some(remote_sender) if local_sender.latest_seq > remote_sender.latest_seq => {
ranges.push(SyncRange {
room_id: local_room.room_id.clone(),
sender_id: local_sender.sender_id.clone(),
from_seq: remote_sender.latest_seq + 1,
to_seq: local_sender.latest_seq,
});
}
_ => {}
}
}
}
ranges
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_manifest(device_byte: u8) -> SyncManifest {
SyncManifest {
device_id: vec![device_byte; 16],
rooms: vec![RoomSyncState {
room_id: vec![0xAA; 8],
senders: vec![
SenderSyncState {
sender_id: vec![0x01; 8],
latest_seq: 10,
earliest_seq: 1,
},
SenderSyncState {
sender_id: vec![0x02; 8],
latest_seq: 5,
earliest_seq: 5,
},
],
}],
}
}
#[test]
fn test_sync_manifest_round_trip() {
let manifest = sample_manifest(0xDE);
let encoded = encode_sync_manifest(&manifest);
assert_eq!(encoded[0], FrameType::SyncManifest as u8);
let decoded = decode_sync_manifest(&encoded).expect("decode must succeed");
assert_eq!(decoded, manifest);
}
#[test]
fn test_sync_chunk_round_trip() {
let chunk = SyncChunk {
chunk_id: 2,
total_chunks: 5,
envelopes: vec![vec![0xAA, 0xBB], vec![0xCC, 0xDD, 0xEE]],
};
let encoded = encode_sync_chunk(&chunk);
assert_eq!(encoded[0], FrameType::SyncChunk as u8);
let decoded = decode_sync_chunk(&encoded).expect("decode must succeed");
assert_eq!(decoded, chunk);
}
#[test]
fn test_sync_ack_round_trip() {
let ack = SyncAck {
chunk_id: 3,
accepted_count: 10,
rejected_count: 2,
};
let encoded = encode_sync_ack(&ack);
assert_eq!(encoded[0], FrameType::SyncAck as u8);
let decoded = decode_sync_ack(&encoded).expect("decode must succeed");
assert_eq!(decoded, ack);
}
fn make_sender(id_byte: u8, earliest: u64, latest: u64) -> SenderSyncState {
SenderSyncState {
sender_id: vec![id_byte],
latest_seq: latest,
earliest_seq: earliest,
}
}
fn make_room(room_byte: u8, senders: Vec<SenderSyncState>) -> RoomSyncState {
RoomSyncState {
room_id: vec![room_byte],
senders,
}
}
fn make_manifest(rooms: Vec<RoomSyncState>) -> SyncManifest {
SyncManifest {
device_id: vec![0x00],
rooms,
}
}
#[test]
fn test_diff_local_has_room_remote_doesnt_returns_full_range() {
let local = make_manifest(vec![make_room(0xAA, vec![make_sender(0x01, 3, 9)])]);
let remote = make_manifest(vec![]);
let ranges = diff_manifests(&local, &remote);
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0].room_id, vec![0xAA]);
assert_eq!(ranges[0].sender_id, vec![0x01]);
assert_eq!(ranges[0].from_seq, 3);
assert_eq!(ranges[0].to_seq, 9);
}
#[test]
fn test_diff_local_higher_seq_than_remote_returns_partial_range() {
let local = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 20)])]);
let remote = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 12)])]);
let ranges = diff_manifests(&local, &remote);
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0].from_seq, 13);
assert_eq!(ranges[0].to_seq, 20);
}
}