1#[cfg(not(feature = "std"))]
37use alloc::{collections::VecDeque, vec, vec::Vec};
38#[cfg(feature = "std")]
39use std::collections::VecDeque;
40
41use super::characteristics::{SyncDataHeader, SyncDataOp};
42
43pub const fn max_payload_size(mtu: u16) -> usize {
45 (mtu as usize).saturating_sub(3 + SyncDataHeader::SIZE)
46}
47
48pub const DEFAULT_MAX_PAYLOAD: usize = 15; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum SyncMessageType {
54 SyncVector,
56 Document,
58 Ack,
60 EndSync,
62 Error,
64}
65
66#[derive(Debug, Clone)]
68pub struct SyncMessage {
69 pub msg_type: SyncMessageType,
71 pub seq: u16,
73 pub total_fragments: u8,
75 pub fragment_index: u8,
77 pub payload: Vec<u8>,
79}
80
81impl SyncMessage {
82 pub fn new(msg_type: SyncMessageType, seq: u16, payload: Vec<u8>) -> Self {
84 Self {
85 msg_type,
86 seq,
87 total_fragments: 1,
88 fragment_index: 0,
89 payload,
90 }
91 }
92
93 pub fn encode(&self) -> Vec<u8> {
95 let op = match self.msg_type {
96 SyncMessageType::SyncVector => SyncDataOp::Vector,
97 SyncMessageType::Document => SyncDataOp::Document,
98 SyncMessageType::Ack => SyncDataOp::Ack,
99 SyncMessageType::EndSync | SyncMessageType::Error => SyncDataOp::End,
100 };
101
102 let header = SyncDataHeader {
103 op,
104 seq: self.seq,
105 total_fragments: self.total_fragments,
106 fragment_index: self.fragment_index,
107 };
108
109 let mut buf = Vec::with_capacity(SyncDataHeader::SIZE + self.payload.len());
110 buf.extend_from_slice(&header.encode());
111 buf.extend_from_slice(&self.payload);
112 buf
113 }
114
115 pub fn decode(data: &[u8]) -> Option<Self> {
117 let header = SyncDataHeader::decode(data)?;
118 let payload = if data.len() > SyncDataHeader::SIZE {
119 data[SyncDataHeader::SIZE..].to_vec()
120 } else {
121 Vec::new()
122 };
123
124 let msg_type = match header.op {
125 SyncDataOp::Vector => SyncMessageType::SyncVector,
126 SyncDataOp::Document => SyncMessageType::Document,
127 SyncDataOp::Ack => SyncMessageType::Ack,
128 SyncDataOp::End => SyncMessageType::EndSync,
129 };
130
131 Some(Self {
132 msg_type,
133 seq: header.seq,
134 total_fragments: header.total_fragments,
135 fragment_index: header.fragment_index,
136 payload,
137 })
138 }
139}
140
141pub fn fragment_payload(
143 msg_type: SyncMessageType,
144 seq: u16,
145 payload: &[u8],
146 max_fragment_size: usize,
147) -> Vec<SyncMessage> {
148 if payload.is_empty() || payload.len() <= max_fragment_size {
149 return vec![SyncMessage::new(msg_type, seq, payload.to_vec())];
150 }
151
152 let total_fragments = payload.len().div_ceil(max_fragment_size);
153 let total_fragments = total_fragments.min(255) as u8;
154
155 payload
156 .chunks(max_fragment_size)
157 .enumerate()
158 .map(|(i, chunk)| SyncMessage {
159 msg_type,
160 seq,
161 total_fragments,
162 fragment_index: i as u8,
163 payload: chunk.to_vec(),
164 })
165 .collect()
166}
167
168#[derive(Debug)]
170pub struct FragmentReassembler {
171 total_fragments: u8,
173 fragments: Vec<Option<Vec<u8>>>,
175 seq: u16,
177 msg_type: SyncMessageType,
179}
180
181impl FragmentReassembler {
182 pub fn new(msg: &SyncMessage) -> Self {
184 let mut fragments = vec![None; msg.total_fragments as usize];
185 fragments[msg.fragment_index as usize] = Some(msg.payload.clone());
186
187 Self {
188 total_fragments: msg.total_fragments,
189 fragments,
190 seq: msg.seq,
191 msg_type: msg.msg_type,
192 }
193 }
194
195 pub fn add_fragment(&mut self, msg: &SyncMessage) -> bool {
199 if msg.seq != self.seq || msg.total_fragments != self.total_fragments {
200 return false;
201 }
202
203 if (msg.fragment_index as usize) < self.fragments.len() {
204 self.fragments[msg.fragment_index as usize] = Some(msg.payload.clone());
205 }
206
207 self.is_complete()
208 }
209
210 pub fn is_complete(&self) -> bool {
212 self.fragments.iter().all(|f| f.is_some())
213 }
214
215 pub fn reassemble(&self) -> Option<Vec<u8>> {
219 if !self.is_complete() {
220 return None;
221 }
222
223 let total_size: usize = self.fragments.iter().flatten().map(|f| f.len()).sum();
224 let mut result = Vec::with_capacity(total_size);
225
226 for data in self.fragments.iter().flatten() {
227 result.extend_from_slice(data);
228 }
229
230 Some(result)
231 }
232
233 pub fn seq(&self) -> u16 {
235 self.seq
236 }
237
238 pub fn msg_type(&self) -> SyncMessageType {
240 self.msg_type
241 }
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum SyncProtocolState {
247 Idle,
249 SendingVector,
251 ReceivingDocuments,
253 SendingDocuments,
255 WaitingAck,
257 Complete,
259 Error,
261}
262
263pub struct SyncProtocol {
267 state: SyncProtocolState,
269 seq: u16,
271 outgoing: VecDeque<SyncMessage>,
273 pending_acks: Vec<u16>,
275 reassembler: Option<FragmentReassembler>,
277 max_payload: usize,
279}
280
281impl SyncProtocol {
282 pub fn new() -> Self {
284 Self {
285 state: SyncProtocolState::Idle,
286 seq: 0,
287 outgoing: VecDeque::new(),
288 pending_acks: Vec::new(),
289 reassembler: None,
290 max_payload: DEFAULT_MAX_PAYLOAD,
291 }
292 }
293
294 pub fn set_mtu(&mut self, mtu: u16) {
296 self.max_payload = max_payload_size(mtu);
297 }
298
299 pub fn state(&self) -> SyncProtocolState {
301 self.state
302 }
303
304 pub fn start_sync(&mut self, sync_vector: Vec<u8>) {
306 self.state = SyncProtocolState::SendingVector;
307 self.seq = 0;
308
309 let messages = fragment_payload(
311 SyncMessageType::SyncVector,
312 self.next_seq(),
313 &sync_vector,
314 self.max_payload,
315 );
316
317 for msg in messages {
318 self.outgoing.push_back(msg);
319 }
320 }
321
322 pub fn queue_document(&mut self, doc_data: Vec<u8>) {
324 if self.state == SyncProtocolState::Idle {
325 self.state = SyncProtocolState::SendingDocuments;
326 }
327
328 let messages = fragment_payload(
329 SyncMessageType::Document,
330 self.next_seq(),
331 &doc_data,
332 self.max_payload,
333 );
334
335 for msg in messages {
336 self.outgoing.push_back(msg);
337 }
338 }
339
340 pub fn end_sync(&mut self) {
342 let msg = SyncMessage::new(SyncMessageType::EndSync, self.next_seq(), Vec::new());
343 self.outgoing.push_back(msg);
344 self.state = SyncProtocolState::Complete;
345 }
346
347 pub fn next_outgoing(&mut self) -> Option<SyncMessage> {
349 self.outgoing.pop_front()
350 }
351
352 pub fn has_outgoing(&self) -> bool {
354 !self.outgoing.is_empty()
355 }
356
357 pub fn process_incoming(&mut self, data: &[u8]) -> Option<(SyncMessageType, Vec<u8>)> {
361 let msg = SyncMessage::decode(data)?;
362
363 if msg.total_fragments > 1 {
365 if let Some(ref mut reassembler) = self.reassembler {
366 if reassembler.seq() == msg.seq {
367 if reassembler.add_fragment(&msg) {
368 let payload = reassembler.reassemble()?;
369 let msg_type = reassembler.msg_type();
370 self.reassembler = None;
371 return Some((msg_type, payload));
372 }
373 return None;
374 }
375 }
376 self.reassembler = Some(FragmentReassembler::new(&msg));
378 if self.reassembler.as_ref().unwrap().is_complete() {
379 let reassembler = self.reassembler.take().unwrap();
380 let payload = reassembler.reassemble()?;
381 return Some((reassembler.msg_type(), payload));
382 }
383 return None;
384 }
385
386 match msg.msg_type {
388 SyncMessageType::Ack => {
389 self.pending_acks.retain(|&seq| seq != msg.seq);
390 None
391 }
392 SyncMessageType::SyncVector => {
393 self.state = SyncProtocolState::ReceivingDocuments;
394 Some((SyncMessageType::SyncVector, msg.payload))
395 }
396 SyncMessageType::Document => {
397 let ack = SyncMessage::new(SyncMessageType::Ack, msg.seq, Vec::new());
399 self.outgoing.push_back(ack);
400 Some((SyncMessageType::Document, msg.payload))
401 }
402 SyncMessageType::EndSync => {
403 self.state = SyncProtocolState::Complete;
404 Some((SyncMessageType::EndSync, Vec::new()))
405 }
406 SyncMessageType::Error => {
407 self.state = SyncProtocolState::Error;
408 Some((SyncMessageType::Error, msg.payload))
409 }
410 }
411 }
412
413 pub fn reset(&mut self) {
415 self.state = SyncProtocolState::Idle;
416 self.seq = 0;
417 self.outgoing.clear();
418 self.pending_acks.clear();
419 self.reassembler = None;
420 }
421
422 fn next_seq(&mut self) -> u16 {
424 let seq = self.seq;
425 self.seq = self.seq.wrapping_add(1);
426 seq
427 }
428}
429
430impl Default for SyncProtocol {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_max_payload_size() {
442 assert_eq!(max_payload_size(23), 15); assert_eq!(max_payload_size(251), 243); assert_eq!(max_payload_size(8), 0); }
446
447 #[test]
448 fn test_sync_message_encode_decode() {
449 let msg = SyncMessage::new(SyncMessageType::Document, 42, vec![1, 2, 3, 4, 5]);
450
451 let encoded = msg.encode();
452 let decoded = SyncMessage::decode(&encoded).unwrap();
453
454 assert_eq!(decoded.msg_type, SyncMessageType::Document);
455 assert_eq!(decoded.seq, 42);
456 assert_eq!(decoded.payload, vec![1, 2, 3, 4, 5]);
457 }
458
459 #[test]
460 fn test_fragment_payload() {
461 let payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
462 let fragments = fragment_payload(SyncMessageType::Document, 1, &payload, 4);
463
464 assert_eq!(fragments.len(), 3);
465 assert_eq!(fragments[0].total_fragments, 3);
466 assert_eq!(fragments[0].fragment_index, 0);
467 assert_eq!(fragments[0].payload, vec![1, 2, 3, 4]);
468 assert_eq!(fragments[1].fragment_index, 1);
469 assert_eq!(fragments[1].payload, vec![5, 6, 7, 8]);
470 assert_eq!(fragments[2].fragment_index, 2);
471 assert_eq!(fragments[2].payload, vec![9, 10]);
472 }
473
474 #[test]
475 fn test_fragment_reassembler() {
476 let payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
477 let fragments = fragment_payload(SyncMessageType::Document, 1, &payload, 4);
478
479 let mut reassembler = FragmentReassembler::new(&fragments[0]);
480 assert!(!reassembler.is_complete());
481
482 reassembler.add_fragment(&fragments[1]);
483 assert!(!reassembler.is_complete());
484
485 reassembler.add_fragment(&fragments[2]);
486 assert!(reassembler.is_complete());
487
488 let result = reassembler.reassemble().unwrap();
489 assert_eq!(result, payload);
490 }
491
492 #[test]
493 fn test_sync_protocol_basic_flow() {
494 let mut initiator = SyncProtocol::new();
495 let mut responder = SyncProtocol::new();
496
497 initiator.start_sync(vec![1, 2, 3]);
499 assert_eq!(initiator.state(), SyncProtocolState::SendingVector);
500
501 let msg = initiator.next_outgoing().unwrap();
503 let encoded = msg.encode();
504
505 let (msg_type, payload) = responder.process_incoming(&encoded).unwrap();
507 assert_eq!(msg_type, SyncMessageType::SyncVector);
508 assert_eq!(payload, vec![1, 2, 3]);
509
510 responder.queue_document(vec![4, 5, 6]);
512 let msg = responder.next_outgoing().unwrap();
513 let encoded = msg.encode();
514
515 let (msg_type, payload) = initiator.process_incoming(&encoded).unwrap();
517 assert_eq!(msg_type, SyncMessageType::Document);
518 assert_eq!(payload, vec![4, 5, 6]);
519
520 assert!(initiator.has_outgoing());
522
523 initiator.end_sync();
525 assert_eq!(initiator.state(), SyncProtocolState::Complete);
526 }
527
528 #[test]
529 fn test_sync_protocol_with_mtu() {
530 let mut protocol = SyncProtocol::new();
531 protocol.set_mtu(251);
532
533 let large_doc = vec![0u8; 500];
535 protocol.queue_document(large_doc);
536
537 let mut count = 0;
539 while protocol.has_outgoing() {
540 protocol.next_outgoing();
541 count += 1;
542 }
543 assert!(count > 1);
544 }
545
546 #[test]
547 fn test_protocol_reset() {
548 let mut protocol = SyncProtocol::new();
549 protocol.start_sync(vec![1, 2, 3]);
550 protocol.queue_document(vec![4, 5, 6]);
551
552 protocol.reset();
553
554 assert_eq!(protocol.state(), SyncProtocolState::Idle);
555 assert!(!protocol.has_outgoing());
556 }
557}