nomad_protocol/sync/
engine.rs

1//! Sync engine
2//!
3//! Coordinates state synchronization between two endpoints.
4//! Generic over the state type S which must implement SyncState.
5
6use super::message::{MessageError, SyncMessage};
7use super::tracker::SyncTracker;
8use thiserror::Error;
9
10/// Errors from the sync engine.
11#[derive(Debug, Error)]
12pub enum SyncError {
13    /// Error encoding or decoding sync messages.
14    #[error("message error: {0}")]
15    Message(#[from] MessageError),
16
17    /// Failed to decode the diff payload.
18    #[error("diff decode error: {0}")]
19    DiffDecode(String),
20
21    /// Failed to apply the diff to local state.
22    #[error("diff apply error: {0}")]
23    DiffApply(String),
24
25    /// Diff was based on a different version than expected.
26    #[error("version mismatch: expected base {expected}, got {actual}")]
27    VersionMismatch {
28        /// The base version we expected.
29        expected: u64,
30        /// The base version received.
31        actual: u64,
32    },
33
34    /// Operation requires initialized state but none exists.
35    #[error("state not initialized")]
36    NotInitialized,
37}
38
39/// Result of processing an incoming sync message
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum ProcessResult {
42    /// State was updated with the diff
43    Updated,
44    /// Message was ack-only, no state change
45    AckOnly,
46    /// Duplicate message (already have this version)
47    Duplicate,
48}
49
50/// Sync engine for bidirectional state synchronization
51///
52/// The engine is generic over:
53/// - `S`: The state type being synchronized
54/// - `D`: The diff type for that state
55///
56/// The engine manages:
57/// - Version tracking via SyncTracker
58/// - State snapshots for diff computation
59/// - Diff generation and application
60pub struct SyncEngine<S, D> {
61    /// Version tracking
62    tracker: SyncTracker,
63
64    /// Current local state
65    state: Option<S>,
66
67    /// State snapshot at last_acked version (for diff computation)
68    /// This is the state we know the peer has
69    acked_snapshot: Option<S>,
70
71    /// Callback for encoding diffs
72    encode_diff: fn(&D) -> Vec<u8>,
73
74    /// Callback for decoding diffs
75    decode_diff: fn(&[u8]) -> Result<D, String>,
76
77    /// Callback for computing diff between states
78    compute_diff: fn(&S, &S) -> D,
79
80    /// Callback for applying diff to state
81    apply_diff: fn(&mut S, &D) -> Result<(), String>,
82
83    /// Callback for checking if diff is empty
84    is_diff_empty: fn(&D) -> bool,
85}
86
87impl<S: Clone, D> SyncEngine<S, D> {
88    /// Create a new sync engine with the required callbacks
89    pub fn new(
90        encode_diff: fn(&D) -> Vec<u8>,
91        decode_diff: fn(&[u8]) -> Result<D, String>,
92        compute_diff: fn(&S, &S) -> D,
93        apply_diff: fn(&mut S, &D) -> Result<(), String>,
94        is_diff_empty: fn(&D) -> bool,
95    ) -> Self {
96        Self {
97            tracker: SyncTracker::new(),
98            state: None,
99            acked_snapshot: None,
100            encode_diff,
101            decode_diff,
102            compute_diff,
103            apply_diff,
104            is_diff_empty,
105        }
106    }
107
108    /// Initialize the engine with initial state
109    pub fn init(&mut self, initial_state: S) {
110        self.state = Some(initial_state.clone());
111        self.acked_snapshot = Some(initial_state);
112        self.tracker.reset();
113    }
114
115    /// Check if the engine is initialized
116    pub fn is_initialized(&self) -> bool {
117        self.state.is_some()
118    }
119
120    /// Get a reference to the current state
121    pub fn state(&self) -> Option<&S> {
122        self.state.as_ref()
123    }
124
125    /// Get a mutable reference to the current state
126    ///
127    /// Note: After modifying, call `mark_changed()` to bump version
128    pub fn state_mut(&mut self) -> Option<&mut S> {
129        self.state.as_mut()
130    }
131
132    /// Mark that the local state has changed
133    ///
134    /// Call this after modifying the state to bump the version.
135    pub fn mark_changed(&mut self) -> u64 {
136        self.tracker.bump_version()
137    }
138
139    /// Update local state and bump version atomically
140    pub fn update_state(&mut self, new_state: S) -> u64 {
141        self.state = Some(new_state);
142        self.tracker.bump_version()
143    }
144
145    /// Get the tracker for inspection
146    pub fn tracker(&self) -> &SyncTracker {
147        &self.tracker
148    }
149
150    /// Check if we have updates to send
151    pub fn has_pending_updates(&self) -> bool {
152        self.tracker.has_pending_updates()
153    }
154
155    /// Check if we need to send an ack
156    pub fn needs_ack(&self) -> bool {
157        self.tracker.needs_ack()
158    }
159
160    /// Generate a sync message to send to peer
161    ///
162    /// Returns None if there's nothing to send
163    pub fn generate_message(&mut self) -> Result<Option<SyncMessage>, SyncError> {
164        let state = self.state.as_ref().ok_or(SyncError::NotInitialized)?;
165
166        // If no pending updates and no ack needed, nothing to send
167        if !self.tracker.has_pending_updates() && !self.tracker.needs_ack() {
168            return Ok(None);
169        }
170
171        // If only need ack, send ack-only
172        if !self.tracker.has_pending_updates() {
173            let msg = self.tracker.create_ack();
174            return Ok(Some(msg));
175        }
176
177        // Compute diff from acked snapshot
178        let base_state = self.acked_snapshot.as_ref().ok_or(SyncError::NotInitialized)?;
179        let diff = (self.compute_diff)(base_state, state);
180
181        // If diff is empty but we have pending updates, still send it
182        // (version bump matters even without content change)
183        let diff_bytes = if (self.is_diff_empty)(&diff) {
184            Vec::new()
185        } else {
186            (self.encode_diff)(&diff)
187        };
188
189        let base_version = self.tracker.diff_base_version();
190        let msg = self.tracker.create_message(diff_bytes, base_version);
191        self.tracker.record_sent(self.tracker.current_version());
192
193        Ok(Some(msg))
194    }
195
196    /// Generate an ack-only message
197    pub fn generate_ack(&self) -> Result<SyncMessage, SyncError> {
198        if !self.is_initialized() {
199            return Err(SyncError::NotInitialized);
200        }
201        Ok(self.tracker.create_ack())
202    }
203
204    /// Process an incoming sync message
205    ///
206    /// Returns the result of processing
207    pub fn process_message(&mut self, msg: &SyncMessage) -> Result<ProcessResult, SyncError> {
208        let state = self.state.as_mut().ok_or(SyncError::NotInitialized)?;
209
210        // Update tracker first (this handles ack fields)
211        let is_new = self.tracker.process_incoming(msg);
212
213        if msg.is_ack_only() {
214            // Update acked snapshot if peer acked new version
215            if msg.acked_state_num > 0 {
216                self.update_acked_snapshot();
217            }
218            return Ok(ProcessResult::AckOnly);
219        }
220
221        if !is_new {
222            return Ok(ProcessResult::Duplicate);
223        }
224
225        // Decode and apply diff
226        if !msg.diff.is_empty() {
227            let diff = (self.decode_diff)(&msg.diff)
228                .map_err(SyncError::DiffDecode)?;
229            (self.apply_diff)(state, &diff)
230                .map_err(SyncError::DiffApply)?;
231        }
232
233        // Update acked snapshot if peer acked new version
234        if msg.acked_state_num > 0 {
235            self.update_acked_snapshot();
236        }
237
238        Ok(ProcessResult::Updated)
239    }
240
241    /// Update the acked snapshot to current state
242    fn update_acked_snapshot(&mut self) {
243        if let Some(state) = &self.state {
244            // Only update if we have a valid ack
245            if self.tracker.last_acked_version() > 0 {
246                // For simplicity, snapshot current state
247                // In practice, might want versioned history
248                self.acked_snapshot = Some(state.clone());
249            }
250        }
251    }
252
253    /// Get current local version
254    pub fn current_version(&self) -> u64 {
255        self.tracker.current_version()
256    }
257
258    /// Get peer's version
259    pub fn peer_version(&self) -> u64 {
260        self.tracker.peer_version()
261    }
262
263    /// Check if synchronized with peer
264    pub fn is_synchronized(&self) -> bool {
265        self.tracker.is_synchronized()
266    }
267
268    /// Reset the engine
269    pub fn reset(&mut self) {
270        self.tracker.reset();
271        self.state = None;
272        self.acked_snapshot = None;
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    // Simple test state type
281    #[derive(Debug, Clone, PartialEq)]
282    struct TestState {
283        value: i32,
284    }
285
286    // Simple diff type
287    #[derive(Debug, Clone, PartialEq)]
288    struct TestDiff {
289        delta: i32,
290    }
291
292    fn encode_diff(diff: &TestDiff) -> Vec<u8> {
293        diff.delta.to_le_bytes().to_vec()
294    }
295
296    fn decode_diff(data: &[u8]) -> Result<TestDiff, String> {
297        if data.len() != 4 {
298            return Err("invalid diff length".to_string());
299        }
300        let delta = i32::from_le_bytes(data.try_into().unwrap());
301        Ok(TestDiff { delta })
302    }
303
304    fn compute_diff(old: &TestState, new: &TestState) -> TestDiff {
305        TestDiff {
306            delta: new.value - old.value,
307        }
308    }
309
310    fn apply_diff(state: &mut TestState, diff: &TestDiff) -> Result<(), String> {
311        state.value += diff.delta;
312        Ok(())
313    }
314
315    fn is_diff_empty(diff: &TestDiff) -> bool {
316        diff.delta == 0
317    }
318
319    fn create_engine() -> SyncEngine<TestState, TestDiff> {
320        SyncEngine::new(encode_diff, decode_diff, compute_diff, apply_diff, is_diff_empty)
321    }
322
323    #[test]
324    fn test_init() {
325        let mut engine = create_engine();
326        assert!(!engine.is_initialized());
327
328        engine.init(TestState { value: 42 });
329        assert!(engine.is_initialized());
330        assert_eq!(engine.state().unwrap().value, 42);
331    }
332
333    #[test]
334    fn test_update_state() {
335        let mut engine = create_engine();
336        engine.init(TestState { value: 0 });
337
338        let version = engine.update_state(TestState { value: 100 });
339        assert_eq!(version, 1);
340        assert_eq!(engine.state().unwrap().value, 100);
341        assert!(engine.has_pending_updates());
342    }
343
344    #[test]
345    fn test_generate_message() {
346        let mut engine = create_engine();
347        engine.init(TestState { value: 0 });
348
349        // No pending updates initially
350        let msg = engine.generate_message().unwrap();
351        assert!(msg.is_none());
352
353        // Update state
354        engine.update_state(TestState { value: 10 });
355
356        // Now should generate message
357        let msg = engine.generate_message().unwrap().unwrap();
358        assert_eq!(msg.sender_state_num, 1);
359        assert!(!msg.is_ack_only());
360
361        // Diff should encode the delta
362        let diff = decode_diff(&msg.diff).unwrap();
363        assert_eq!(diff.delta, 10);
364    }
365
366    #[test]
367    fn test_process_message() {
368        let mut engine = create_engine();
369        engine.init(TestState { value: 0 });
370
371        // Create incoming message with diff
372        let diff = TestDiff { delta: 50 };
373        let msg = SyncMessage::new(1, 0, 0, encode_diff(&diff));
374
375        let result = engine.process_message(&msg).unwrap();
376        assert_eq!(result, ProcessResult::Updated);
377        assert_eq!(engine.state().unwrap().value, 50);
378        assert_eq!(engine.peer_version(), 1);
379    }
380
381    #[test]
382    fn test_process_ack_only() {
383        let mut engine = create_engine();
384        engine.init(TestState { value: 0 });
385        engine.update_state(TestState { value: 10 });
386
387        let msg = SyncMessage::ack_only(1, 1);
388        let result = engine.process_message(&msg).unwrap();
389        assert_eq!(result, ProcessResult::AckOnly);
390    }
391
392    #[test]
393    fn test_duplicate_message() {
394        let mut engine = create_engine();
395        engine.init(TestState { value: 0 });
396
397        let diff = TestDiff { delta: 10 };
398        let msg = SyncMessage::new(1, 0, 0, encode_diff(&diff));
399
400        // First message
401        engine.process_message(&msg).unwrap();
402
403        // Same message again (same sender_state_num)
404        let result = engine.process_message(&msg).unwrap();
405        assert_eq!(result, ProcessResult::Duplicate);
406    }
407
408    #[test]
409    fn test_bidirectional_sync() {
410        let mut engine_a = create_engine();
411        let mut engine_b = create_engine();
412
413        engine_a.init(TestState { value: 0 });
414        engine_b.init(TestState { value: 0 });
415
416        // A updates state
417        engine_a.update_state(TestState { value: 100 });
418        let msg_from_a = engine_a.generate_message().unwrap().unwrap();
419
420        // B receives and processes
421        engine_b.process_message(&msg_from_a).unwrap();
422        assert_eq!(engine_b.state().unwrap().value, 100);
423        assert_eq!(engine_b.peer_version(), 1);
424
425        // B sends ack back
426        let ack_from_b = engine_b.generate_ack().unwrap();
427        engine_a.process_message(&ack_from_b).unwrap();
428
429        // A should see B acked version 1
430        assert_eq!(engine_a.tracker().last_acked_version(), 1);
431    }
432
433    #[test]
434    fn test_not_initialized_error() {
435        let mut engine = create_engine();
436
437        let result = engine.generate_message();
438        assert!(matches!(result, Err(SyncError::NotInitialized)));
439
440        let msg = SyncMessage::ack_only(1, 0);
441        let result = engine.process_message(&msg);
442        assert!(matches!(result, Err(SyncError::NotInitialized)));
443    }
444
445    #[test]
446    fn test_empty_diff() {
447        let mut engine = create_engine();
448        engine.init(TestState { value: 42 });
449
450        // Mark changed but with same value (empty diff)
451        engine.mark_changed();
452
453        let msg = engine.generate_message().unwrap().unwrap();
454        // Should still send a message (version bump matters)
455        // But diff should be empty
456        assert!(msg.diff.is_empty() || msg.is_ack_only());
457    }
458
459    #[test]
460    fn test_reset() {
461        let mut engine = create_engine();
462        engine.init(TestState { value: 100 });
463        engine.update_state(TestState { value: 200 });
464
465        engine.reset();
466
467        assert!(!engine.is_initialized());
468        assert!(engine.state().is_none());
469        assert_eq!(engine.current_version(), 0);
470    }
471}