1use super::message::{MessageError, SyncMessage};
7use super::tracker::SyncTracker;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum SyncError {
13 #[error("message error: {0}")]
15 Message(#[from] MessageError),
16
17 #[error("diff decode error: {0}")]
19 DiffDecode(String),
20
21 #[error("diff apply error: {0}")]
23 DiffApply(String),
24
25 #[error("version mismatch: expected base {expected}, got {actual}")]
27 VersionMismatch {
28 expected: u64,
30 actual: u64,
32 },
33
34 #[error("state not initialized")]
36 NotInitialized,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum ProcessResult {
42 Updated,
44 AckOnly,
46 Duplicate,
48}
49
50pub struct SyncEngine<S, D> {
61 tracker: SyncTracker,
63
64 state: Option<S>,
66
67 acked_snapshot: Option<S>,
70
71 encode_diff: fn(&D) -> Vec<u8>,
73
74 decode_diff: fn(&[u8]) -> Result<D, String>,
76
77 compute_diff: fn(&S, &S) -> D,
79
80 apply_diff: fn(&mut S, &D) -> Result<(), String>,
82
83 is_diff_empty: fn(&D) -> bool,
85}
86
87impl<S: Clone, D> SyncEngine<S, D> {
88 pub fn new(
90 encode_diff: fn(&D) -> Vec<u8>,
91 decode_diff: fn(&[u8]) -> Result<D, String>,
92 compute_diff: fn(&S, &S) -> D,
93 apply_diff: fn(&mut S, &D) -> Result<(), String>,
94 is_diff_empty: fn(&D) -> bool,
95 ) -> Self {
96 Self {
97 tracker: SyncTracker::new(),
98 state: None,
99 acked_snapshot: None,
100 encode_diff,
101 decode_diff,
102 compute_diff,
103 apply_diff,
104 is_diff_empty,
105 }
106 }
107
108 pub fn init(&mut self, initial_state: S) {
110 self.state = Some(initial_state.clone());
111 self.acked_snapshot = Some(initial_state);
112 self.tracker.reset();
113 }
114
115 pub fn is_initialized(&self) -> bool {
117 self.state.is_some()
118 }
119
120 pub fn state(&self) -> Option<&S> {
122 self.state.as_ref()
123 }
124
125 pub fn state_mut(&mut self) -> Option<&mut S> {
129 self.state.as_mut()
130 }
131
132 pub fn mark_changed(&mut self) -> u64 {
136 self.tracker.bump_version()
137 }
138
139 pub fn update_state(&mut self, new_state: S) -> u64 {
141 self.state = Some(new_state);
142 self.tracker.bump_version()
143 }
144
145 pub fn tracker(&self) -> &SyncTracker {
147 &self.tracker
148 }
149
150 pub fn has_pending_updates(&self) -> bool {
152 self.tracker.has_pending_updates()
153 }
154
155 pub fn needs_ack(&self) -> bool {
157 self.tracker.needs_ack()
158 }
159
160 pub fn generate_message(&mut self) -> Result<Option<SyncMessage>, SyncError> {
164 let state = self.state.as_ref().ok_or(SyncError::NotInitialized)?;
165
166 if !self.tracker.has_pending_updates() && !self.tracker.needs_ack() {
168 return Ok(None);
169 }
170
171 if !self.tracker.has_pending_updates() {
173 let msg = self.tracker.create_ack();
174 return Ok(Some(msg));
175 }
176
177 let base_state = self.acked_snapshot.as_ref().ok_or(SyncError::NotInitialized)?;
179 let diff = (self.compute_diff)(base_state, state);
180
181 let diff_bytes = if (self.is_diff_empty)(&diff) {
184 Vec::new()
185 } else {
186 (self.encode_diff)(&diff)
187 };
188
189 let base_version = self.tracker.diff_base_version();
190 let msg = self.tracker.create_message(diff_bytes, base_version);
191 self.tracker.record_sent(self.tracker.current_version());
192
193 Ok(Some(msg))
194 }
195
196 pub fn generate_ack(&self) -> Result<SyncMessage, SyncError> {
198 if !self.is_initialized() {
199 return Err(SyncError::NotInitialized);
200 }
201 Ok(self.tracker.create_ack())
202 }
203
204 pub fn process_message(&mut self, msg: &SyncMessage) -> Result<ProcessResult, SyncError> {
208 let state = self.state.as_mut().ok_or(SyncError::NotInitialized)?;
209
210 let is_new = self.tracker.process_incoming(msg);
212
213 if msg.is_ack_only() {
214 if msg.acked_state_num > 0 {
216 self.update_acked_snapshot();
217 }
218 return Ok(ProcessResult::AckOnly);
219 }
220
221 if !is_new {
222 return Ok(ProcessResult::Duplicate);
223 }
224
225 if !msg.diff.is_empty() {
227 let diff = (self.decode_diff)(&msg.diff)
228 .map_err(SyncError::DiffDecode)?;
229 (self.apply_diff)(state, &diff)
230 .map_err(SyncError::DiffApply)?;
231 }
232
233 if msg.acked_state_num > 0 {
235 self.update_acked_snapshot();
236 }
237
238 Ok(ProcessResult::Updated)
239 }
240
241 fn update_acked_snapshot(&mut self) {
243 if let Some(state) = &self.state {
244 if self.tracker.last_acked_version() > 0 {
246 self.acked_snapshot = Some(state.clone());
249 }
250 }
251 }
252
253 pub fn current_version(&self) -> u64 {
255 self.tracker.current_version()
256 }
257
258 pub fn peer_version(&self) -> u64 {
260 self.tracker.peer_version()
261 }
262
263 pub fn is_synchronized(&self) -> bool {
265 self.tracker.is_synchronized()
266 }
267
268 pub fn reset(&mut self) {
270 self.tracker.reset();
271 self.state = None;
272 self.acked_snapshot = None;
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[derive(Debug, Clone, PartialEq)]
282 struct TestState {
283 value: i32,
284 }
285
286 #[derive(Debug, Clone, PartialEq)]
288 struct TestDiff {
289 delta: i32,
290 }
291
292 fn encode_diff(diff: &TestDiff) -> Vec<u8> {
293 diff.delta.to_le_bytes().to_vec()
294 }
295
296 fn decode_diff(data: &[u8]) -> Result<TestDiff, String> {
297 if data.len() != 4 {
298 return Err("invalid diff length".to_string());
299 }
300 let delta = i32::from_le_bytes(data.try_into().unwrap());
301 Ok(TestDiff { delta })
302 }
303
304 fn compute_diff(old: &TestState, new: &TestState) -> TestDiff {
305 TestDiff {
306 delta: new.value - old.value,
307 }
308 }
309
310 fn apply_diff(state: &mut TestState, diff: &TestDiff) -> Result<(), String> {
311 state.value += diff.delta;
312 Ok(())
313 }
314
315 fn is_diff_empty(diff: &TestDiff) -> bool {
316 diff.delta == 0
317 }
318
319 fn create_engine() -> SyncEngine<TestState, TestDiff> {
320 SyncEngine::new(encode_diff, decode_diff, compute_diff, apply_diff, is_diff_empty)
321 }
322
323 #[test]
324 fn test_init() {
325 let mut engine = create_engine();
326 assert!(!engine.is_initialized());
327
328 engine.init(TestState { value: 42 });
329 assert!(engine.is_initialized());
330 assert_eq!(engine.state().unwrap().value, 42);
331 }
332
333 #[test]
334 fn test_update_state() {
335 let mut engine = create_engine();
336 engine.init(TestState { value: 0 });
337
338 let version = engine.update_state(TestState { value: 100 });
339 assert_eq!(version, 1);
340 assert_eq!(engine.state().unwrap().value, 100);
341 assert!(engine.has_pending_updates());
342 }
343
344 #[test]
345 fn test_generate_message() {
346 let mut engine = create_engine();
347 engine.init(TestState { value: 0 });
348
349 let msg = engine.generate_message().unwrap();
351 assert!(msg.is_none());
352
353 engine.update_state(TestState { value: 10 });
355
356 let msg = engine.generate_message().unwrap().unwrap();
358 assert_eq!(msg.sender_state_num, 1);
359 assert!(!msg.is_ack_only());
360
361 let diff = decode_diff(&msg.diff).unwrap();
363 assert_eq!(diff.delta, 10);
364 }
365
366 #[test]
367 fn test_process_message() {
368 let mut engine = create_engine();
369 engine.init(TestState { value: 0 });
370
371 let diff = TestDiff { delta: 50 };
373 let msg = SyncMessage::new(1, 0, 0, encode_diff(&diff));
374
375 let result = engine.process_message(&msg).unwrap();
376 assert_eq!(result, ProcessResult::Updated);
377 assert_eq!(engine.state().unwrap().value, 50);
378 assert_eq!(engine.peer_version(), 1);
379 }
380
381 #[test]
382 fn test_process_ack_only() {
383 let mut engine = create_engine();
384 engine.init(TestState { value: 0 });
385 engine.update_state(TestState { value: 10 });
386
387 let msg = SyncMessage::ack_only(1, 1);
388 let result = engine.process_message(&msg).unwrap();
389 assert_eq!(result, ProcessResult::AckOnly);
390 }
391
392 #[test]
393 fn test_duplicate_message() {
394 let mut engine = create_engine();
395 engine.init(TestState { value: 0 });
396
397 let diff = TestDiff { delta: 10 };
398 let msg = SyncMessage::new(1, 0, 0, encode_diff(&diff));
399
400 engine.process_message(&msg).unwrap();
402
403 let result = engine.process_message(&msg).unwrap();
405 assert_eq!(result, ProcessResult::Duplicate);
406 }
407
408 #[test]
409 fn test_bidirectional_sync() {
410 let mut engine_a = create_engine();
411 let mut engine_b = create_engine();
412
413 engine_a.init(TestState { value: 0 });
414 engine_b.init(TestState { value: 0 });
415
416 engine_a.update_state(TestState { value: 100 });
418 let msg_from_a = engine_a.generate_message().unwrap().unwrap();
419
420 engine_b.process_message(&msg_from_a).unwrap();
422 assert_eq!(engine_b.state().unwrap().value, 100);
423 assert_eq!(engine_b.peer_version(), 1);
424
425 let ack_from_b = engine_b.generate_ack().unwrap();
427 engine_a.process_message(&ack_from_b).unwrap();
428
429 assert_eq!(engine_a.tracker().last_acked_version(), 1);
431 }
432
433 #[test]
434 fn test_not_initialized_error() {
435 let mut engine = create_engine();
436
437 let result = engine.generate_message();
438 assert!(matches!(result, Err(SyncError::NotInitialized)));
439
440 let msg = SyncMessage::ack_only(1, 0);
441 let result = engine.process_message(&msg);
442 assert!(matches!(result, Err(SyncError::NotInitialized)));
443 }
444
445 #[test]
446 fn test_empty_diff() {
447 let mut engine = create_engine();
448 engine.init(TestState { value: 42 });
449
450 engine.mark_changed();
452
453 let msg = engine.generate_message().unwrap().unwrap();
454 assert!(msg.diff.is_empty() || msg.is_ack_only());
457 }
458
459 #[test]
460 fn test_reset() {
461 let mut engine = create_engine();
462 engine.init(TestState { value: 100 });
463 engine.update_state(TestState { value: 200 });
464
465 engine.reset();
466
467 assert!(!engine.is_initialized());
468 assert!(engine.state().is_none());
469 assert_eq!(engine.current_version(), 0);
470 }
471}