1#[cfg(not(feature = "std"))]
52use alloc::{collections::VecDeque, vec, vec::Vec};
53#[cfg(feature = "std")]
54use std::collections::VecDeque;
55
56use super::characteristics::{SyncDataHeader, SyncDataOp};
57
58pub const fn max_payload_size(mtu: u16) -> usize {
60 (mtu as usize).saturating_sub(3 + SyncDataHeader::SIZE)
61}
62
63pub const DEFAULT_MAX_PAYLOAD: usize = 15; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum SyncMessageType {
69 SyncVector,
71 Document,
73 Ack,
75 EndSync,
77 Error,
79}
80
81#[derive(Debug, Clone)]
83pub struct SyncMessage {
84 pub msg_type: SyncMessageType,
86 pub seq: u16,
88 pub total_fragments: u8,
90 pub fragment_index: u8,
92 pub payload: Vec<u8>,
94}
95
96impl SyncMessage {
97 pub fn new(msg_type: SyncMessageType, seq: u16, payload: Vec<u8>) -> Self {
99 Self {
100 msg_type,
101 seq,
102 total_fragments: 1,
103 fragment_index: 0,
104 payload,
105 }
106 }
107
108 pub fn encode(&self) -> Vec<u8> {
110 let op = match self.msg_type {
111 SyncMessageType::SyncVector => SyncDataOp::Vector,
112 SyncMessageType::Document => SyncDataOp::Document,
113 SyncMessageType::Ack => SyncDataOp::Ack,
114 SyncMessageType::EndSync | SyncMessageType::Error => SyncDataOp::End,
115 };
116
117 let header = SyncDataHeader {
118 op,
119 seq: self.seq,
120 total_fragments: self.total_fragments,
121 fragment_index: self.fragment_index,
122 };
123
124 let mut buf = Vec::with_capacity(SyncDataHeader::SIZE + self.payload.len());
125 buf.extend_from_slice(&header.encode());
126 buf.extend_from_slice(&self.payload);
127 buf
128 }
129
130 pub fn decode(data: &[u8]) -> Option<Self> {
132 let header = SyncDataHeader::decode(data)?;
133 let payload = if data.len() > SyncDataHeader::SIZE {
134 data[SyncDataHeader::SIZE..].to_vec()
135 } else {
136 Vec::new()
137 };
138
139 let msg_type = match header.op {
140 SyncDataOp::Vector => SyncMessageType::SyncVector,
141 SyncDataOp::Document => SyncMessageType::Document,
142 SyncDataOp::Ack => SyncMessageType::Ack,
143 SyncDataOp::End => SyncMessageType::EndSync,
144 };
145
146 Some(Self {
147 msg_type,
148 seq: header.seq,
149 total_fragments: header.total_fragments,
150 fragment_index: header.fragment_index,
151 payload,
152 })
153 }
154}
155
156pub fn fragment_payload(
158 msg_type: SyncMessageType,
159 seq: u16,
160 payload: &[u8],
161 max_fragment_size: usize,
162) -> Vec<SyncMessage> {
163 if payload.is_empty() || payload.len() <= max_fragment_size {
164 return vec![SyncMessage::new(msg_type, seq, payload.to_vec())];
165 }
166
167 let total_fragments = payload.len().div_ceil(max_fragment_size);
168 let total_fragments = total_fragments.min(255) as u8;
169
170 payload
171 .chunks(max_fragment_size)
172 .enumerate()
173 .map(|(i, chunk)| SyncMessage {
174 msg_type,
175 seq,
176 total_fragments,
177 fragment_index: i as u8,
178 payload: chunk.to_vec(),
179 })
180 .collect()
181}
182
183#[derive(Debug)]
185pub struct FragmentReassembler {
186 total_fragments: u8,
188 fragments: Vec<Option<Vec<u8>>>,
190 seq: u16,
192 msg_type: SyncMessageType,
194}
195
196impl FragmentReassembler {
197 pub fn new(msg: &SyncMessage) -> Self {
199 let mut fragments = vec![None; msg.total_fragments as usize];
200 fragments[msg.fragment_index as usize] = Some(msg.payload.clone());
201
202 Self {
203 total_fragments: msg.total_fragments,
204 fragments,
205 seq: msg.seq,
206 msg_type: msg.msg_type,
207 }
208 }
209
210 pub fn add_fragment(&mut self, msg: &SyncMessage) -> bool {
214 if msg.seq != self.seq || msg.total_fragments != self.total_fragments {
215 return false;
216 }
217
218 if (msg.fragment_index as usize) < self.fragments.len() {
219 self.fragments[msg.fragment_index as usize] = Some(msg.payload.clone());
220 }
221
222 self.is_complete()
223 }
224
225 pub fn is_complete(&self) -> bool {
227 self.fragments.iter().all(|f| f.is_some())
228 }
229
230 pub fn reassemble(&self) -> Option<Vec<u8>> {
234 if !self.is_complete() {
235 return None;
236 }
237
238 let total_size: usize = self.fragments.iter().flatten().map(|f| f.len()).sum();
239 let mut result = Vec::with_capacity(total_size);
240
241 for data in self.fragments.iter().flatten() {
242 result.extend_from_slice(data);
243 }
244
245 Some(result)
246 }
247
248 pub fn seq(&self) -> u16 {
250 self.seq
251 }
252
253 pub fn msg_type(&self) -> SyncMessageType {
255 self.msg_type
256 }
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Eq)]
261pub enum SyncProtocolState {
262 Idle,
264 SendingVector,
266 ReceivingDocuments,
268 SendingDocuments,
270 WaitingAck,
272 Complete,
274 Error,
276}
277
278pub struct SyncProtocol {
282 state: SyncProtocolState,
284 seq: u16,
286 outgoing: VecDeque<SyncMessage>,
288 pending_acks: Vec<u16>,
290 reassembler: Option<FragmentReassembler>,
292 max_payload: usize,
294}
295
296impl SyncProtocol {
297 pub fn new() -> Self {
299 Self {
300 state: SyncProtocolState::Idle,
301 seq: 0,
302 outgoing: VecDeque::new(),
303 pending_acks: Vec::new(),
304 reassembler: None,
305 max_payload: DEFAULT_MAX_PAYLOAD,
306 }
307 }
308
309 pub fn set_mtu(&mut self, mtu: u16) {
311 self.max_payload = max_payload_size(mtu);
312 }
313
314 pub fn state(&self) -> SyncProtocolState {
316 self.state
317 }
318
319 pub fn start_sync(&mut self, sync_vector: Vec<u8>) {
321 self.state = SyncProtocolState::SendingVector;
322 self.seq = 0;
323
324 let messages = fragment_payload(
326 SyncMessageType::SyncVector,
327 self.next_seq(),
328 &sync_vector,
329 self.max_payload,
330 );
331
332 for msg in messages {
333 self.outgoing.push_back(msg);
334 }
335 }
336
337 pub fn queue_document(&mut self, doc_data: Vec<u8>) {
339 if self.state == SyncProtocolState::Idle {
340 self.state = SyncProtocolState::SendingDocuments;
341 }
342
343 let messages = fragment_payload(
344 SyncMessageType::Document,
345 self.next_seq(),
346 &doc_data,
347 self.max_payload,
348 );
349
350 for msg in messages {
351 self.outgoing.push_back(msg);
352 }
353 }
354
355 pub fn end_sync(&mut self) {
357 let msg = SyncMessage::new(SyncMessageType::EndSync, self.next_seq(), Vec::new());
358 self.outgoing.push_back(msg);
359 self.state = SyncProtocolState::Complete;
360 }
361
362 pub fn next_outgoing(&mut self) -> Option<SyncMessage> {
364 self.outgoing.pop_front()
365 }
366
367 pub fn has_outgoing(&self) -> bool {
369 !self.outgoing.is_empty()
370 }
371
372 pub fn process_incoming(&mut self, data: &[u8]) -> Option<(SyncMessageType, Vec<u8>)> {
376 let msg = SyncMessage::decode(data)?;
377
378 if msg.total_fragments > 1 {
380 if let Some(ref mut reassembler) = self.reassembler {
381 if reassembler.seq() == msg.seq {
382 if reassembler.add_fragment(&msg) {
383 let payload = reassembler.reassemble()?;
384 let msg_type = reassembler.msg_type();
385 self.reassembler = None;
386 return Some((msg_type, payload));
387 }
388 return None;
389 }
390 }
391 self.reassembler = Some(FragmentReassembler::new(&msg));
393 if self.reassembler.as_ref().unwrap().is_complete() {
394 let reassembler = self.reassembler.take().unwrap();
395 let payload = reassembler.reassemble()?;
396 return Some((reassembler.msg_type(), payload));
397 }
398 return None;
399 }
400
401 match msg.msg_type {
403 SyncMessageType::Ack => {
404 self.pending_acks.retain(|&seq| seq != msg.seq);
405 None
406 }
407 SyncMessageType::SyncVector => {
408 self.state = SyncProtocolState::ReceivingDocuments;
409 Some((SyncMessageType::SyncVector, msg.payload))
410 }
411 SyncMessageType::Document => {
412 let ack = SyncMessage::new(SyncMessageType::Ack, msg.seq, Vec::new());
414 self.outgoing.push_back(ack);
415 Some((SyncMessageType::Document, msg.payload))
416 }
417 SyncMessageType::EndSync => {
418 self.state = SyncProtocolState::Complete;
419 Some((SyncMessageType::EndSync, Vec::new()))
420 }
421 SyncMessageType::Error => {
422 self.state = SyncProtocolState::Error;
423 Some((SyncMessageType::Error, msg.payload))
424 }
425 }
426 }
427
428 pub fn reset(&mut self) {
430 self.state = SyncProtocolState::Idle;
431 self.seq = 0;
432 self.outgoing.clear();
433 self.pending_acks.clear();
434 self.reassembler = None;
435 }
436
437 fn next_seq(&mut self) -> u16 {
439 let seq = self.seq;
440 self.seq = self.seq.wrapping_add(1);
441 seq
442 }
443}
444
445impl Default for SyncProtocol {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_max_payload_size() {
457 assert_eq!(max_payload_size(23), 15); assert_eq!(max_payload_size(251), 243); assert_eq!(max_payload_size(8), 0); }
461
462 #[test]
463 fn test_sync_message_encode_decode() {
464 let msg = SyncMessage::new(SyncMessageType::Document, 42, vec![1, 2, 3, 4, 5]);
465
466 let encoded = msg.encode();
467 let decoded = SyncMessage::decode(&encoded).unwrap();
468
469 assert_eq!(decoded.msg_type, SyncMessageType::Document);
470 assert_eq!(decoded.seq, 42);
471 assert_eq!(decoded.payload, vec![1, 2, 3, 4, 5]);
472 }
473
474 #[test]
475 fn test_fragment_payload() {
476 let payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
477 let fragments = fragment_payload(SyncMessageType::Document, 1, &payload, 4);
478
479 assert_eq!(fragments.len(), 3);
480 assert_eq!(fragments[0].total_fragments, 3);
481 assert_eq!(fragments[0].fragment_index, 0);
482 assert_eq!(fragments[0].payload, vec![1, 2, 3, 4]);
483 assert_eq!(fragments[1].fragment_index, 1);
484 assert_eq!(fragments[1].payload, vec![5, 6, 7, 8]);
485 assert_eq!(fragments[2].fragment_index, 2);
486 assert_eq!(fragments[2].payload, vec![9, 10]);
487 }
488
489 #[test]
490 fn test_fragment_reassembler() {
491 let payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
492 let fragments = fragment_payload(SyncMessageType::Document, 1, &payload, 4);
493
494 let mut reassembler = FragmentReassembler::new(&fragments[0]);
495 assert!(!reassembler.is_complete());
496
497 reassembler.add_fragment(&fragments[1]);
498 assert!(!reassembler.is_complete());
499
500 reassembler.add_fragment(&fragments[2]);
501 assert!(reassembler.is_complete());
502
503 let result = reassembler.reassemble().unwrap();
504 assert_eq!(result, payload);
505 }
506
507 #[test]
508 fn test_sync_protocol_basic_flow() {
509 let mut initiator = SyncProtocol::new();
510 let mut responder = SyncProtocol::new();
511
512 initiator.start_sync(vec![1, 2, 3]);
514 assert_eq!(initiator.state(), SyncProtocolState::SendingVector);
515
516 let msg = initiator.next_outgoing().unwrap();
518 let encoded = msg.encode();
519
520 let (msg_type, payload) = responder.process_incoming(&encoded).unwrap();
522 assert_eq!(msg_type, SyncMessageType::SyncVector);
523 assert_eq!(payload, vec![1, 2, 3]);
524
525 responder.queue_document(vec![4, 5, 6]);
527 let msg = responder.next_outgoing().unwrap();
528 let encoded = msg.encode();
529
530 let (msg_type, payload) = initiator.process_incoming(&encoded).unwrap();
532 assert_eq!(msg_type, SyncMessageType::Document);
533 assert_eq!(payload, vec![4, 5, 6]);
534
535 assert!(initiator.has_outgoing());
537
538 initiator.end_sync();
540 assert_eq!(initiator.state(), SyncProtocolState::Complete);
541 }
542
543 #[test]
544 fn test_sync_protocol_with_mtu() {
545 let mut protocol = SyncProtocol::new();
546 protocol.set_mtu(251);
547
548 let large_doc = vec![0u8; 500];
550 protocol.queue_document(large_doc);
551
552 let mut count = 0;
554 while protocol.has_outgoing() {
555 protocol.next_outgoing();
556 count += 1;
557 }
558 assert!(count > 1);
559 }
560
561 #[test]
562 fn test_protocol_reset() {
563 let mut protocol = SyncProtocol::new();
564 protocol.start_sync(vec![1, 2, 3]);
565 protocol.queue_document(vec![4, 5, 6]);
566
567 protocol.reset();
568
569 assert_eq!(protocol.state(), SyncProtocolState::Idle);
570 assert!(!protocol.has_outgoing());
571 }
572}