Skip to main content

protocol_core/
device_sync.rs

1//! Device sync protocol for transferring append-only object history.
2
3use dcbor::prelude::*;
4
5use crate::{repair_peer::FrameType, ProtocolError};
6
7const KEY_DEVICE_ID: u64 = 0;
8const KEY_ROOMS: u64 = 1;
9const KEY_ROOM_ID: u64 = 0;
10const KEY_SENDERS: u64 = 1;
11const KEY_SENDER_ID: u64 = 0;
12const KEY_LATEST_SEQ: u64 = 1;
13const KEY_EARLIEST_SEQ: u64 = 2;
14const KEY_CHUNK_ID: u64 = 0;
15const KEY_TOTAL_CHUNKS: u64 = 1;
16const KEY_ENVELOPES: u64 = 2;
17const KEY_ACCEPTED_COUNT: u64 = 1;
18const KEY_REJECTED_COUNT: u64 = 2;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct SenderSyncState {
22    pub sender_id: Vec<u8>,
23    pub latest_seq: u64,
24    pub earliest_seq: u64,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct RoomSyncState {
29    pub room_id: Vec<u8>,
30    pub senders: Vec<SenderSyncState>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct SyncManifest {
35    pub device_id: Vec<u8>,
36    pub rooms: Vec<RoomSyncState>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct SyncChunk {
41    pub chunk_id: u32,
42    pub total_chunks: u32,
43    pub envelopes: Vec<Vec<u8>>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct SyncAck {
48    pub chunk_id: u32,
49    pub accepted_count: u32,
50    pub rejected_count: u32,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct SyncRange {
55    pub room_id: Vec<u8>,
56    pub sender_id: Vec<u8>,
57    pub from_seq: u64,
58    pub to_seq: u64,
59}
60
61fn prepend_frame(frame: FrameType, body: Vec<u8>) -> Vec<u8> {
62    let mut out = Vec::with_capacity(1 + body.len());
63    out.push(frame as u8);
64    out.extend_from_slice(&body);
65    out
66}
67
68fn strip_frame(expected: FrameType, bytes: &[u8]) -> Result<&[u8], ProtocolError> {
69    match bytes.first() {
70        None => Err(ProtocolError::InvalidEncoding("empty frame".to_string())),
71        Some(&b) if b != expected as u8 => Err(ProtocolError::InvalidEncoding(format!(
72            "expected frame type 0x{:02x}, got 0x{:02x}",
73            expected as u8, b
74        ))),
75        _ => Ok(&bytes[1..]),
76    }
77}
78
79fn parse_map(bytes: &[u8]) -> Result<Map, ProtocolError> {
80    let cbor =
81        CBOR::try_from_data(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))?;
82    cbor.try_into_map()
83        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
84}
85
86fn extract_bytes(map: &Map, key: u64) -> Result<Vec<u8>, ProtocolError> {
87    let cbor: CBOR = map
88        .extract(key)
89        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
90    cbor.try_into_byte_string()
91        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
92}
93
94fn extract_u64(map: &Map, key: u64) -> Result<u64, ProtocolError> {
95    map.extract(key)
96        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
97}
98
99fn extract_u32(map: &Map, key: u64, field: &str) -> Result<u32, ProtocolError> {
100    let value = extract_u64(map, key)?;
101    u32::try_from(value)
102        .map_err(|_| ProtocolError::InvalidEnvelope(format!("{field} exceeds u32 range: {value}")))
103}
104
105fn extract_array(map: &Map, key: u64) -> Result<Vec<CBOR>, ProtocolError> {
106    let cbor: CBOR = map
107        .extract(key)
108        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
109    cbor.try_into_array()
110        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
111}
112
113fn encode_sender_sync_state(state: &SenderSyncState) -> CBOR {
114    let mut map = Map::new();
115    map.insert(KEY_SENDER_ID, CBOR::to_byte_string(&state.sender_id));
116    map.insert(KEY_LATEST_SEQ, state.latest_seq);
117    map.insert(KEY_EARLIEST_SEQ, state.earliest_seq);
118    CBOR::from(map)
119}
120
121fn decode_sender_sync_state(cbor: CBOR) -> Result<SenderSyncState, ProtocolError> {
122    let map = cbor
123        .try_into_map()
124        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
125    Ok(SenderSyncState {
126        sender_id: extract_bytes(&map, KEY_SENDER_ID)?,
127        latest_seq: extract_u64(&map, KEY_LATEST_SEQ)?,
128        earliest_seq: extract_u64(&map, KEY_EARLIEST_SEQ)?,
129    })
130}
131
132fn encode_room_sync_state(room: &RoomSyncState) -> CBOR {
133    let mut map = Map::new();
134    map.insert(KEY_ROOM_ID, CBOR::to_byte_string(&room.room_id));
135    let senders: Vec<CBOR> = room.senders.iter().map(encode_sender_sync_state).collect();
136    map.insert(KEY_SENDERS, CBOR::from(senders));
137    CBOR::from(map)
138}
139
140fn decode_room_sync_state(cbor: CBOR) -> Result<RoomSyncState, ProtocolError> {
141    let map = cbor
142        .try_into_map()
143        .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
144    let room_id = extract_bytes(&map, KEY_ROOM_ID)?;
145    let senders = extract_array(&map, KEY_SENDERS)?
146        .into_iter()
147        .map(decode_sender_sync_state)
148        .collect::<Result<Vec<_>, _>>()?;
149    Ok(RoomSyncState { room_id, senders })
150}
151
152pub fn encode_sync_manifest(manifest: &SyncManifest) -> Vec<u8> {
153    let mut map = Map::new();
154    map.insert(KEY_DEVICE_ID, CBOR::to_byte_string(&manifest.device_id));
155    let rooms: Vec<CBOR> = manifest.rooms.iter().map(encode_room_sync_state).collect();
156    map.insert(KEY_ROOMS, CBOR::from(rooms));
157    prepend_frame(FrameType::SyncManifest, CBOR::from(map).to_cbor_data())
158}
159
160pub fn decode_sync_manifest(bytes: &[u8]) -> Result<SyncManifest, ProtocolError> {
161    let body = strip_frame(FrameType::SyncManifest, bytes)?;
162    let map = parse_map(body)?;
163    let device_id = extract_bytes(&map, KEY_DEVICE_ID)?;
164    let rooms = extract_array(&map, KEY_ROOMS)?
165        .into_iter()
166        .map(decode_room_sync_state)
167        .collect::<Result<Vec<_>, _>>()?;
168    Ok(SyncManifest { device_id, rooms })
169}
170
171pub fn encode_sync_chunk(chunk: &SyncChunk) -> Vec<u8> {
172    let mut map = Map::new();
173    map.insert(KEY_CHUNK_ID, u64::from(chunk.chunk_id));
174    map.insert(KEY_TOTAL_CHUNKS, u64::from(chunk.total_chunks));
175    let envelopes: Vec<CBOR> = chunk.envelopes.iter().map(CBOR::to_byte_string).collect();
176    map.insert(KEY_ENVELOPES, CBOR::from(envelopes));
177    prepend_frame(FrameType::SyncChunk, CBOR::from(map).to_cbor_data())
178}
179
180pub fn decode_sync_chunk(bytes: &[u8]) -> Result<SyncChunk, ProtocolError> {
181    let body = strip_frame(FrameType::SyncChunk, bytes)?;
182    let map = parse_map(body)?;
183    let chunk_id = extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?;
184    let total_chunks = extract_u32(&map, KEY_TOTAL_CHUNKS, "total_chunks")?;
185    let envelopes = extract_array(&map, KEY_ENVELOPES)?
186        .into_iter()
187        .map(|cbor| {
188            cbor.try_into_byte_string()
189                .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
190        })
191        .collect::<Result<Vec<_>, _>>()?;
192    Ok(SyncChunk {
193        chunk_id,
194        total_chunks,
195        envelopes,
196    })
197}
198
199pub fn encode_sync_ack(ack: &SyncAck) -> Vec<u8> {
200    let mut map = Map::new();
201    map.insert(KEY_CHUNK_ID, u64::from(ack.chunk_id));
202    map.insert(KEY_ACCEPTED_COUNT, u64::from(ack.accepted_count));
203    map.insert(KEY_REJECTED_COUNT, u64::from(ack.rejected_count));
204    prepend_frame(FrameType::SyncAck, CBOR::from(map).to_cbor_data())
205}
206
207pub fn decode_sync_ack(bytes: &[u8]) -> Result<SyncAck, ProtocolError> {
208    let body = strip_frame(FrameType::SyncAck, bytes)?;
209    let map = parse_map(body)?;
210    Ok(SyncAck {
211        chunk_id: extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?,
212        accepted_count: extract_u32(&map, KEY_ACCEPTED_COUNT, "accepted_count")?,
213        rejected_count: extract_u32(&map, KEY_REJECTED_COUNT, "rejected_count")?,
214    })
215}
216
217/// Compute the ranges the local side has that the remote side lacks.
218pub fn diff_manifests(local: &SyncManifest, remote: &SyncManifest) -> Vec<SyncRange> {
219    let mut ranges = Vec::new();
220
221    for local_room in &local.rooms {
222        let remote_room = remote
223            .rooms
224            .iter()
225            .find(|room| room.room_id == local_room.room_id);
226
227        for local_sender in &local_room.senders {
228            let remote_sender = remote_room.and_then(|room| {
229                room.senders
230                    .iter()
231                    .find(|sender| sender.sender_id == local_sender.sender_id)
232            });
233
234            match remote_sender {
235                None => ranges.push(SyncRange {
236                    room_id: local_room.room_id.clone(),
237                    sender_id: local_sender.sender_id.clone(),
238                    from_seq: local_sender.earliest_seq,
239                    to_seq: local_sender.latest_seq,
240                }),
241                Some(remote_sender) if local_sender.latest_seq > remote_sender.latest_seq => {
242                    ranges.push(SyncRange {
243                        room_id: local_room.room_id.clone(),
244                        sender_id: local_sender.sender_id.clone(),
245                        from_seq: remote_sender.latest_seq + 1,
246                        to_seq: local_sender.latest_seq,
247                    });
248                }
249                _ => {}
250            }
251        }
252    }
253
254    ranges
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    fn sample_manifest(device_byte: u8) -> SyncManifest {
262        SyncManifest {
263            device_id: vec![device_byte; 16],
264            rooms: vec![RoomSyncState {
265                room_id: vec![0xAA; 8],
266                senders: vec![
267                    SenderSyncState {
268                        sender_id: vec![0x01; 8],
269                        latest_seq: 10,
270                        earliest_seq: 1,
271                    },
272                    SenderSyncState {
273                        sender_id: vec![0x02; 8],
274                        latest_seq: 5,
275                        earliest_seq: 5,
276                    },
277                ],
278            }],
279        }
280    }
281
282    #[test]
283    fn test_sync_manifest_round_trip() {
284        let manifest = sample_manifest(0xDE);
285        let encoded = encode_sync_manifest(&manifest);
286        assert_eq!(encoded[0], FrameType::SyncManifest as u8);
287        let decoded = decode_sync_manifest(&encoded).expect("decode must succeed");
288        assert_eq!(decoded, manifest);
289    }
290
291    #[test]
292    fn test_sync_chunk_round_trip() {
293        let chunk = SyncChunk {
294            chunk_id: 2,
295            total_chunks: 5,
296            envelopes: vec![vec![0xAA, 0xBB], vec![0xCC, 0xDD, 0xEE]],
297        };
298        let encoded = encode_sync_chunk(&chunk);
299        assert_eq!(encoded[0], FrameType::SyncChunk as u8);
300        let decoded = decode_sync_chunk(&encoded).expect("decode must succeed");
301        assert_eq!(decoded, chunk);
302    }
303
304    #[test]
305    fn test_sync_ack_round_trip() {
306        let ack = SyncAck {
307            chunk_id: 3,
308            accepted_count: 10,
309            rejected_count: 2,
310        };
311        let encoded = encode_sync_ack(&ack);
312        assert_eq!(encoded[0], FrameType::SyncAck as u8);
313        let decoded = decode_sync_ack(&encoded).expect("decode must succeed");
314        assert_eq!(decoded, ack);
315    }
316
317    fn make_sender(id_byte: u8, earliest: u64, latest: u64) -> SenderSyncState {
318        SenderSyncState {
319            sender_id: vec![id_byte],
320            latest_seq: latest,
321            earliest_seq: earliest,
322        }
323    }
324
325    fn make_room(room_byte: u8, senders: Vec<SenderSyncState>) -> RoomSyncState {
326        RoomSyncState {
327            room_id: vec![room_byte],
328            senders,
329        }
330    }
331
332    fn make_manifest(rooms: Vec<RoomSyncState>) -> SyncManifest {
333        SyncManifest {
334            device_id: vec![0x00],
335            rooms,
336        }
337    }
338
339    #[test]
340    fn test_diff_local_has_room_remote_doesnt_returns_full_range() {
341        let local = make_manifest(vec![make_room(0xAA, vec![make_sender(0x01, 3, 9)])]);
342        let remote = make_manifest(vec![]);
343        let ranges = diff_manifests(&local, &remote);
344        assert_eq!(ranges.len(), 1);
345        assert_eq!(ranges[0].room_id, vec![0xAA]);
346        assert_eq!(ranges[0].sender_id, vec![0x01]);
347        assert_eq!(ranges[0].from_seq, 3);
348        assert_eq!(ranges[0].to_seq, 9);
349    }
350
351    #[test]
352    fn test_diff_local_higher_seq_than_remote_returns_partial_range() {
353        let local = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 20)])]);
354        let remote = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 12)])]);
355        let ranges = diff_manifests(&local, &remote);
356        assert_eq!(ranges.len(), 1);
357        assert_eq!(ranges[0].from_seq, 13);
358        assert_eq!(ranges[0].to_seq, 20);
359    }
360}