Skip to main content

kevy_replicate/
replica_decode.rs

1//! Snapshot-aware event decoding for [`crate::replica::ReplicaClient`]
2//! — split from `replica.rs` to keep that file under the 500-LOC
3//! project ceiling. The state-machine helpers live here; the type
4//! definitions, `connect`, and `next_frame` stay in `replica.rs`.
5
6use crate::replica::{DecodedFrame, ReplicaClient, ReplicaError, ReplicaEvent};
7use crate::wire::{
8    SnapshotMarker, WireError, decode_frame, decode_snapshot_chunk, decode_snapshot_marker,
9};
10use std::io::{self, Read};
11
12impl ReplicaClient {
13    /// Snapshot-aware iterator step. Returns one [`ReplicaEvent`] per
14    /// call — a live `Frame`, a `SnapshotBegin`/`SnapshotChunk`/
15    /// `SnapshotEnd`, or one of the [`ReplicaError`] variants.
16    /// Returns `None` on clean peer EOF.
17    ///
18    /// Snapshot bookkeeping:
19    /// - Entering `SnapshotBegin` sets `in_snapshot = true`; chunk
20    ///   bytes are valid until `SnapshotEnd`.
21    /// - `SnapshotEnd { ack_offset }` sets `expected_offset =
22    ///   ack_offset` (so the next live `Frame` has no gap) and
23    ///   clears `in_snapshot`.
24    /// - Live `*2\r\n` bytes during a snapshot return
25    ///   [`ReplicaError::UnexpectedInSnapshot`] (v1.18 forbids
26    ///   interleaving — see `docs/snapshot.md`).
27    pub fn next_event(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
28        loop {
29            if let Some(result) = self.try_decode_one_event() {
30                return Some(result);
31            }
32            // Need more bytes off the socket.
33            let mut chunk = [0u8; 4096];
34            match self.sock.read(&mut chunk) {
35                Ok(0) => {
36                    if self.cursor < self.buf.len() {
37                        return Some(Err(ReplicaError::Truncated));
38                    }
39                    return None;
40                }
41                Ok(n) => self.buf.extend_from_slice(&chunk[..n]),
42                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
43                Err(e) => return Some(Err(ReplicaError::Io(e))),
44            }
45        }
46    }
47
48    /// Try to decode one event from the buffered bytes. Returns
49    /// `None` when more bytes are needed (the loop in [`Self::next_event`]
50    /// will read more). Split out so the outer loop stays tiny + the
51    /// per-event dispatch fits the project's 50-LOC-fn rule.
52    fn try_decode_one_event(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
53        if self.cursor >= self.buf.len() {
54            return None;
55        }
56        let first = self.buf[self.cursor];
57        match first {
58            b'+' => self.try_decode_snapshot_marker(),
59            b'$' if self.in_snapshot => self.try_decode_snapshot_chunk(),
60            b'*' if self.in_snapshot => {
61                Some(Err(ReplicaError::UnexpectedInSnapshot))
62            }
63            b'*' => self.try_decode_live_frame(),
64            _ => Some(Err(ReplicaError::Frame(WireError::BadEnvelope))),
65        }
66    }
67
68    fn try_decode_live_frame(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
69        match decode_frame(&self.buf[self.cursor..]) {
70            Ok((offset, argv, used)) => {
71                self.cursor += used;
72                self.maybe_compact_buf();
73                if offset != self.expected_offset {
74                    return Some(Err(ReplicaError::OffsetGap {
75                        expected: self.expected_offset,
76                        got: offset,
77                    }));
78                }
79                self.expected_offset = self.expected_offset.saturating_add(1);
80                Some(Ok(ReplicaEvent::Frame(DecodedFrame { offset, argv })))
81            }
82            Err(WireError::Truncated) => None,
83            Err(e) => Some(Err(ReplicaError::Frame(e))),
84        }
85    }
86
87    fn try_decode_snapshot_marker(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
88        match decode_snapshot_marker(&self.buf[self.cursor..]) {
89            Ok(Some((SnapshotMarker::Begin, used))) => {
90                self.cursor += used;
91                self.maybe_compact_buf();
92                self.in_snapshot = true;
93                Some(Ok(ReplicaEvent::SnapshotBegin))
94            }
95            Ok(Some((SnapshotMarker::End(ack_offset), used))) => {
96                self.cursor += used;
97                self.maybe_compact_buf();
98                self.in_snapshot = false;
99                self.expected_offset = ack_offset;
100                Some(Ok(ReplicaEvent::SnapshotEnd { ack_offset }))
101            }
102            Ok(None) => Some(Err(ReplicaError::Frame(WireError::BadEnvelope))),
103            Err(WireError::Truncated) => None,
104            Err(e) => Some(Err(ReplicaError::Frame(e))),
105        }
106    }
107
108    fn try_decode_snapshot_chunk(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
109        match decode_snapshot_chunk(&self.buf[self.cursor..]) {
110            Ok((chunk, used)) => {
111                let owned = chunk.to_vec();
112                self.cursor += used;
113                self.maybe_compact_buf();
114                Some(Ok(ReplicaEvent::SnapshotChunk(owned)))
115            }
116            Err(WireError::Truncated) => None,
117            Err(e) => Some(Err(ReplicaError::Frame(e))),
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use crate::replica::{ReplicaClient, ReplicaError, ReplicaEvent};
125    use crate::wire::{encode_frame, encode_snapshot_begin, encode_snapshot_chunk, encode_snapshot_end};
126    use kevy_resp::Argv;
127    use std::io::Write;
128    use std::net::{TcpListener, TcpStream};
129    use std::thread;
130
131    fn tcp_pair() -> (TcpStream, TcpStream) {
132        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
133        let addr = listener.local_addr().unwrap();
134        let client = TcpStream::connect(addr).unwrap();
135        let (server, _) = listener.accept().unwrap();
136        (server, client)
137    }
138
139    fn argv_for(args: &[&[u8]]) -> Argv {
140        let mut a = Argv::default();
141        for arg in args {
142            a.push(arg);
143        }
144        a
145    }
146
147    #[test]
148    fn next_event_snapshot_path_begin_chunks_end_then_frame() {
149        let (mut srv, cli) = tcp_pair();
150        thread::spawn(move || {
151            srv.write_all(&encode_snapshot_begin()).unwrap();
152            srv.write_all(&encode_snapshot_chunk(b"hello-snapshot")).unwrap();
153            srv.write_all(&encode_snapshot_chunk(b"more-snapshot-bytes")).unwrap();
154            srv.write_all(&encode_snapshot_end(42)).unwrap();
155            srv.write_all(&encode_frame(42, &argv_for(&[b"SET", b"k", b"v"]))).unwrap();
156            std::thread::sleep(std::time::Duration::from_millis(50));
157            drop(srv);
158        });
159        let mut client = ReplicaClient::from_socket_for_test(cli, 0);
160
161        assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotBegin))));
162        match client.next_event() {
163            Some(Ok(ReplicaEvent::SnapshotChunk(bytes))) => {
164                assert_eq!(bytes, b"hello-snapshot");
165            }
166            other => panic!("expected SnapshotChunk, got {other:?}"),
167        }
168        match client.next_event() {
169            Some(Ok(ReplicaEvent::SnapshotChunk(bytes))) => {
170                assert_eq!(bytes, b"more-snapshot-bytes");
171            }
172            other => panic!("expected SnapshotChunk, got {other:?}"),
173        }
174        match client.next_event() {
175            Some(Ok(ReplicaEvent::SnapshotEnd { ack_offset })) => assert_eq!(ack_offset, 42),
176            other => panic!("expected SnapshotEnd, got {other:?}"),
177        }
178        assert_eq!(client.expected_offset(), 42);
179        match client.next_event() {
180            Some(Ok(ReplicaEvent::Frame(f))) => {
181                assert_eq!(f.offset, 42);
182                assert_eq!(f.argv, argv_for(&[b"SET", b"k", b"v"]));
183            }
184            other => panic!("expected Frame, got {other:?}"),
185        }
186    }
187
188    #[test]
189    fn next_event_live_frame_during_snapshot_is_unexpected() {
190        let (mut srv, cli) = tcp_pair();
191        thread::spawn(move || {
192            srv.write_all(&encode_snapshot_begin()).unwrap();
193            srv.write_all(&encode_snapshot_chunk(b"first")).unwrap();
194            srv.write_all(&encode_frame(0, &argv_for(&[b"PING"]))).unwrap();
195            std::thread::sleep(std::time::Duration::from_millis(50));
196            drop(srv);
197        });
198        let mut client = ReplicaClient::from_socket_for_test(cli, 0);
199        assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotBegin))));
200        assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotChunk(_)))));
201        assert!(matches!(
202            client.next_event(),
203            Some(Err(ReplicaError::UnexpectedInSnapshot))
204        ));
205    }
206
207    #[test]
208    fn next_frame_returns_snapshot_in_progress_when_snapshot_starts() {
209        let (mut srv, cli) = tcp_pair();
210        thread::spawn(move || {
211            srv.write_all(&encode_snapshot_begin()).unwrap();
212            std::thread::sleep(std::time::Duration::from_millis(50));
213            drop(srv);
214        });
215        let mut client = ReplicaClient::from_socket_for_test(cli, 0);
216        assert!(matches!(
217            client.next_frame(),
218            Some(Err(ReplicaError::SnapshotInProgress))
219        ));
220    }
221
222    #[test]
223    fn live_frame_path_via_next_event() {
224        let (mut srv, cli) = tcp_pair();
225        thread::spawn(move || {
226            srv.write_all(&encode_frame(0, &argv_for(&[b"SET", b"a", b"1"]))).unwrap();
227            srv.write_all(&encode_frame(1, &argv_for(&[b"SET", b"b", b"2"]))).unwrap();
228            std::thread::sleep(std::time::Duration::from_millis(50));
229            drop(srv);
230        });
231        let mut client = ReplicaClient::from_socket_for_test(cli, 0);
232        for expected_off in 0..2 {
233            match client.next_event() {
234                Some(Ok(ReplicaEvent::Frame(f))) => assert_eq!(f.offset, expected_off),
235                other => panic!("expected Frame {expected_off}, got {other:?}"),
236            }
237        }
238        assert_eq!(client.expected_offset(), 2);
239    }
240
241    #[test]
242    fn snapshot_end_with_zero_offset_handled() {
243        let (mut srv, cli) = tcp_pair();
244        thread::spawn(move || {
245            srv.write_all(&encode_snapshot_begin()).unwrap();
246            srv.write_all(&encode_snapshot_end(0)).unwrap();
247            std::thread::sleep(std::time::Duration::from_millis(50));
248            drop(srv);
249        });
250        let mut client = ReplicaClient::from_socket_for_test(cli, 0);
251        assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotBegin))));
252        match client.next_event() {
253            Some(Ok(ReplicaEvent::SnapshotEnd { ack_offset })) => assert_eq!(ack_offset, 0),
254            other => panic!("expected SnapshotEnd, got {other:?}"),
255        }
256        assert_eq!(client.expected_offset(), 0);
257    }
258}