1use dcbor::prelude::*;
4
5use crate::{repair_peer::FrameType, ProtocolError};
6
7const KEY_DEVICE_ID: u64 = 0;
8const KEY_ROOMS: u64 = 1;
9const KEY_ROOM_ID: u64 = 0;
10const KEY_SENDERS: u64 = 1;
11const KEY_SENDER_ID: u64 = 0;
12const KEY_LATEST_SEQ: u64 = 1;
13const KEY_EARLIEST_SEQ: u64 = 2;
14const KEY_CHUNK_ID: u64 = 0;
15const KEY_TOTAL_CHUNKS: u64 = 1;
16const KEY_ENVELOPES: u64 = 2;
17const KEY_ACCEPTED_COUNT: u64 = 1;
18const KEY_REJECTED_COUNT: u64 = 2;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct SenderSyncState {
22 pub sender_id: Vec<u8>,
23 pub latest_seq: u64,
24 pub earliest_seq: u64,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct RoomSyncState {
29 pub room_id: Vec<u8>,
30 pub senders: Vec<SenderSyncState>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct SyncManifest {
35 pub device_id: Vec<u8>,
36 pub rooms: Vec<RoomSyncState>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct SyncChunk {
41 pub chunk_id: u32,
42 pub total_chunks: u32,
43 pub envelopes: Vec<Vec<u8>>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct SyncAck {
48 pub chunk_id: u32,
49 pub accepted_count: u32,
50 pub rejected_count: u32,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct SyncRange {
55 pub room_id: Vec<u8>,
56 pub sender_id: Vec<u8>,
57 pub from_seq: u64,
58 pub to_seq: u64,
59}
60
61fn prepend_frame(frame: FrameType, body: Vec<u8>) -> Vec<u8> {
62 let mut out = Vec::with_capacity(1 + body.len());
63 out.push(frame as u8);
64 out.extend_from_slice(&body);
65 out
66}
67
68fn strip_frame(expected: FrameType, bytes: &[u8]) -> Result<&[u8], ProtocolError> {
69 match bytes.first() {
70 None => Err(ProtocolError::InvalidEncoding("empty frame".to_string())),
71 Some(&b) if b != expected as u8 => Err(ProtocolError::InvalidEncoding(format!(
72 "expected frame type 0x{:02x}, got 0x{:02x}",
73 expected as u8, b
74 ))),
75 _ => Ok(&bytes[1..]),
76 }
77}
78
79fn parse_map(bytes: &[u8]) -> Result<Map, ProtocolError> {
80 let cbor =
81 CBOR::try_from_data(bytes).map_err(|e| ProtocolError::InvalidEncoding(e.to_string()))?;
82 cbor.try_into_map()
83 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
84}
85
86fn extract_bytes(map: &Map, key: u64) -> Result<Vec<u8>, ProtocolError> {
87 let cbor: CBOR = map
88 .extract(key)
89 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
90 cbor.try_into_byte_string()
91 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
92}
93
94fn extract_u64(map: &Map, key: u64) -> Result<u64, ProtocolError> {
95 map.extract(key)
96 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
97}
98
99fn extract_u32(map: &Map, key: u64, field: &str) -> Result<u32, ProtocolError> {
100 let value = extract_u64(map, key)?;
101 u32::try_from(value)
102 .map_err(|_| ProtocolError::InvalidEnvelope(format!("{field} exceeds u32 range: {value}")))
103}
104
105fn extract_array(map: &Map, key: u64) -> Result<Vec<CBOR>, ProtocolError> {
106 let cbor: CBOR = map
107 .extract(key)
108 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
109 cbor.try_into_array()
110 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
111}
112
113fn encode_sender_sync_state(state: &SenderSyncState) -> CBOR {
114 let mut map = Map::new();
115 map.insert(KEY_SENDER_ID, CBOR::to_byte_string(&state.sender_id));
116 map.insert(KEY_LATEST_SEQ, state.latest_seq);
117 map.insert(KEY_EARLIEST_SEQ, state.earliest_seq);
118 CBOR::from(map)
119}
120
121fn decode_sender_sync_state(cbor: CBOR) -> Result<SenderSyncState, ProtocolError> {
122 let map = cbor
123 .try_into_map()
124 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
125 Ok(SenderSyncState {
126 sender_id: extract_bytes(&map, KEY_SENDER_ID)?,
127 latest_seq: extract_u64(&map, KEY_LATEST_SEQ)?,
128 earliest_seq: extract_u64(&map, KEY_EARLIEST_SEQ)?,
129 })
130}
131
132fn encode_room_sync_state(room: &RoomSyncState) -> CBOR {
133 let mut map = Map::new();
134 map.insert(KEY_ROOM_ID, CBOR::to_byte_string(&room.room_id));
135 let senders: Vec<CBOR> = room.senders.iter().map(encode_sender_sync_state).collect();
136 map.insert(KEY_SENDERS, CBOR::from(senders));
137 CBOR::from(map)
138}
139
140fn decode_room_sync_state(cbor: CBOR) -> Result<RoomSyncState, ProtocolError> {
141 let map = cbor
142 .try_into_map()
143 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))?;
144 let room_id = extract_bytes(&map, KEY_ROOM_ID)?;
145 let senders = extract_array(&map, KEY_SENDERS)?
146 .into_iter()
147 .map(decode_sender_sync_state)
148 .collect::<Result<Vec<_>, _>>()?;
149 Ok(RoomSyncState { room_id, senders })
150}
151
152pub fn encode_sync_manifest(manifest: &SyncManifest) -> Vec<u8> {
153 let mut map = Map::new();
154 map.insert(KEY_DEVICE_ID, CBOR::to_byte_string(&manifest.device_id));
155 let rooms: Vec<CBOR> = manifest.rooms.iter().map(encode_room_sync_state).collect();
156 map.insert(KEY_ROOMS, CBOR::from(rooms));
157 prepend_frame(FrameType::SyncManifest, CBOR::from(map).to_cbor_data())
158}
159
160pub fn decode_sync_manifest(bytes: &[u8]) -> Result<SyncManifest, ProtocolError> {
161 let body = strip_frame(FrameType::SyncManifest, bytes)?;
162 let map = parse_map(body)?;
163 let device_id = extract_bytes(&map, KEY_DEVICE_ID)?;
164 let rooms = extract_array(&map, KEY_ROOMS)?
165 .into_iter()
166 .map(decode_room_sync_state)
167 .collect::<Result<Vec<_>, _>>()?;
168 Ok(SyncManifest { device_id, rooms })
169}
170
171pub fn encode_sync_chunk(chunk: &SyncChunk) -> Vec<u8> {
172 let mut map = Map::new();
173 map.insert(KEY_CHUNK_ID, u64::from(chunk.chunk_id));
174 map.insert(KEY_TOTAL_CHUNKS, u64::from(chunk.total_chunks));
175 let envelopes: Vec<CBOR> = chunk.envelopes.iter().map(CBOR::to_byte_string).collect();
176 map.insert(KEY_ENVELOPES, CBOR::from(envelopes));
177 prepend_frame(FrameType::SyncChunk, CBOR::from(map).to_cbor_data())
178}
179
180pub fn decode_sync_chunk(bytes: &[u8]) -> Result<SyncChunk, ProtocolError> {
181 let body = strip_frame(FrameType::SyncChunk, bytes)?;
182 let map = parse_map(body)?;
183 let chunk_id = extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?;
184 let total_chunks = extract_u32(&map, KEY_TOTAL_CHUNKS, "total_chunks")?;
185 let envelopes = extract_array(&map, KEY_ENVELOPES)?
186 .into_iter()
187 .map(|cbor| {
188 cbor.try_into_byte_string()
189 .map_err(|e| ProtocolError::InvalidEnvelope(e.to_string()))
190 })
191 .collect::<Result<Vec<_>, _>>()?;
192 Ok(SyncChunk {
193 chunk_id,
194 total_chunks,
195 envelopes,
196 })
197}
198
199pub fn encode_sync_ack(ack: &SyncAck) -> Vec<u8> {
200 let mut map = Map::new();
201 map.insert(KEY_CHUNK_ID, u64::from(ack.chunk_id));
202 map.insert(KEY_ACCEPTED_COUNT, u64::from(ack.accepted_count));
203 map.insert(KEY_REJECTED_COUNT, u64::from(ack.rejected_count));
204 prepend_frame(FrameType::SyncAck, CBOR::from(map).to_cbor_data())
205}
206
207pub fn decode_sync_ack(bytes: &[u8]) -> Result<SyncAck, ProtocolError> {
208 let body = strip_frame(FrameType::SyncAck, bytes)?;
209 let map = parse_map(body)?;
210 Ok(SyncAck {
211 chunk_id: extract_u32(&map, KEY_CHUNK_ID, "chunk_id")?,
212 accepted_count: extract_u32(&map, KEY_ACCEPTED_COUNT, "accepted_count")?,
213 rejected_count: extract_u32(&map, KEY_REJECTED_COUNT, "rejected_count")?,
214 })
215}
216
217pub fn diff_manifests(local: &SyncManifest, remote: &SyncManifest) -> Vec<SyncRange> {
219 let mut ranges = Vec::new();
220
221 for local_room in &local.rooms {
222 let remote_room = remote
223 .rooms
224 .iter()
225 .find(|room| room.room_id == local_room.room_id);
226
227 for local_sender in &local_room.senders {
228 let remote_sender = remote_room.and_then(|room| {
229 room.senders
230 .iter()
231 .find(|sender| sender.sender_id == local_sender.sender_id)
232 });
233
234 match remote_sender {
235 None => ranges.push(SyncRange {
236 room_id: local_room.room_id.clone(),
237 sender_id: local_sender.sender_id.clone(),
238 from_seq: local_sender.earliest_seq,
239 to_seq: local_sender.latest_seq,
240 }),
241 Some(remote_sender) if local_sender.latest_seq > remote_sender.latest_seq => {
242 ranges.push(SyncRange {
243 room_id: local_room.room_id.clone(),
244 sender_id: local_sender.sender_id.clone(),
245 from_seq: remote_sender.latest_seq + 1,
246 to_seq: local_sender.latest_seq,
247 });
248 }
249 _ => {}
250 }
251 }
252 }
253
254 ranges
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 fn sample_manifest(device_byte: u8) -> SyncManifest {
262 SyncManifest {
263 device_id: vec![device_byte; 16],
264 rooms: vec![RoomSyncState {
265 room_id: vec![0xAA; 8],
266 senders: vec![
267 SenderSyncState {
268 sender_id: vec![0x01; 8],
269 latest_seq: 10,
270 earliest_seq: 1,
271 },
272 SenderSyncState {
273 sender_id: vec![0x02; 8],
274 latest_seq: 5,
275 earliest_seq: 5,
276 },
277 ],
278 }],
279 }
280 }
281
282 #[test]
283 fn test_sync_manifest_round_trip() {
284 let manifest = sample_manifest(0xDE);
285 let encoded = encode_sync_manifest(&manifest);
286 assert_eq!(encoded[0], FrameType::SyncManifest as u8);
287 let decoded = decode_sync_manifest(&encoded).expect("decode must succeed");
288 assert_eq!(decoded, manifest);
289 }
290
291 #[test]
292 fn test_sync_chunk_round_trip() {
293 let chunk = SyncChunk {
294 chunk_id: 2,
295 total_chunks: 5,
296 envelopes: vec![vec![0xAA, 0xBB], vec![0xCC, 0xDD, 0xEE]],
297 };
298 let encoded = encode_sync_chunk(&chunk);
299 assert_eq!(encoded[0], FrameType::SyncChunk as u8);
300 let decoded = decode_sync_chunk(&encoded).expect("decode must succeed");
301 assert_eq!(decoded, chunk);
302 }
303
304 #[test]
305 fn test_sync_ack_round_trip() {
306 let ack = SyncAck {
307 chunk_id: 3,
308 accepted_count: 10,
309 rejected_count: 2,
310 };
311 let encoded = encode_sync_ack(&ack);
312 assert_eq!(encoded[0], FrameType::SyncAck as u8);
313 let decoded = decode_sync_ack(&encoded).expect("decode must succeed");
314 assert_eq!(decoded, ack);
315 }
316
317 fn make_sender(id_byte: u8, earliest: u64, latest: u64) -> SenderSyncState {
318 SenderSyncState {
319 sender_id: vec![id_byte],
320 latest_seq: latest,
321 earliest_seq: earliest,
322 }
323 }
324
325 fn make_room(room_byte: u8, senders: Vec<SenderSyncState>) -> RoomSyncState {
326 RoomSyncState {
327 room_id: vec![room_byte],
328 senders,
329 }
330 }
331
332 fn make_manifest(rooms: Vec<RoomSyncState>) -> SyncManifest {
333 SyncManifest {
334 device_id: vec![0x00],
335 rooms,
336 }
337 }
338
339 #[test]
340 fn test_diff_local_has_room_remote_doesnt_returns_full_range() {
341 let local = make_manifest(vec![make_room(0xAA, vec![make_sender(0x01, 3, 9)])]);
342 let remote = make_manifest(vec![]);
343 let ranges = diff_manifests(&local, &remote);
344 assert_eq!(ranges.len(), 1);
345 assert_eq!(ranges[0].room_id, vec![0xAA]);
346 assert_eq!(ranges[0].sender_id, vec![0x01]);
347 assert_eq!(ranges[0].from_seq, 3);
348 assert_eq!(ranges[0].to_seq, 9);
349 }
350
351 #[test]
352 fn test_diff_local_higher_seq_than_remote_returns_partial_range() {
353 let local = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 20)])]);
354 let remote = make_manifest(vec![make_room(0xBB, vec![make_sender(0x01, 1, 12)])]);
355 let ranges = diff_manifests(&local, &remote);
356 assert_eq!(ranges.len(), 1);
357 assert_eq!(ranges[0].from_seq, 13);
358 assert_eq!(ranges[0].to_seq, 20);
359 }
360}