1use crate::replica::{DecodedFrame, ReplicaClient, ReplicaError, ReplicaEvent};
7use crate::wire::{
8 SnapshotMarker, WireError, decode_frame, decode_snapshot_chunk, decode_snapshot_marker,
9};
10use std::io::{self, Read};
11
12impl ReplicaClient {
13 pub fn next_event(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
28 loop {
29 if let Some(result) = self.try_decode_one_event() {
30 return Some(result);
31 }
32 let mut chunk = [0u8; 4096];
34 match self.sock.read(&mut chunk) {
35 Ok(0) => {
36 if self.cursor < self.buf.len() {
37 return Some(Err(ReplicaError::Truncated));
38 }
39 return None;
40 }
41 Ok(n) => self.buf.extend_from_slice(&chunk[..n]),
42 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
43 Err(e) => return Some(Err(ReplicaError::Io(e))),
44 }
45 }
46 }
47
48 fn try_decode_one_event(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
53 if self.cursor >= self.buf.len() {
54 return None;
55 }
56 let first = self.buf[self.cursor];
57 match first {
58 b'+' => self.try_decode_snapshot_marker(),
59 b'$' if self.in_snapshot => self.try_decode_snapshot_chunk(),
60 b'*' if self.in_snapshot => {
61 Some(Err(ReplicaError::UnexpectedInSnapshot))
62 }
63 b'*' => self.try_decode_live_frame(),
64 _ => Some(Err(ReplicaError::Frame(WireError::BadEnvelope))),
65 }
66 }
67
68 fn try_decode_live_frame(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
69 match decode_frame(&self.buf[self.cursor..]) {
70 Ok((offset, argv, used)) => {
71 self.cursor += used;
72 self.maybe_compact_buf();
73 if offset != self.expected_offset {
74 return Some(Err(ReplicaError::OffsetGap {
75 expected: self.expected_offset,
76 got: offset,
77 }));
78 }
79 self.expected_offset = self.expected_offset.saturating_add(1);
80 Some(Ok(ReplicaEvent::Frame(DecodedFrame { offset, argv })))
81 }
82 Err(WireError::Truncated) => None,
83 Err(e) => Some(Err(ReplicaError::Frame(e))),
84 }
85 }
86
87 fn try_decode_snapshot_marker(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
88 match decode_snapshot_marker(&self.buf[self.cursor..]) {
89 Ok(Some((SnapshotMarker::Begin, used))) => {
90 self.cursor += used;
91 self.maybe_compact_buf();
92 self.in_snapshot = true;
93 Some(Ok(ReplicaEvent::SnapshotBegin))
94 }
95 Ok(Some((SnapshotMarker::End(ack_offset), used))) => {
96 self.cursor += used;
97 self.maybe_compact_buf();
98 self.in_snapshot = false;
99 self.expected_offset = ack_offset;
100 Some(Ok(ReplicaEvent::SnapshotEnd { ack_offset }))
101 }
102 Ok(None) => Some(Err(ReplicaError::Frame(WireError::BadEnvelope))),
103 Err(WireError::Truncated) => None,
104 Err(e) => Some(Err(ReplicaError::Frame(e))),
105 }
106 }
107
108 fn try_decode_snapshot_chunk(&mut self) -> Option<Result<ReplicaEvent, ReplicaError>> {
109 match decode_snapshot_chunk(&self.buf[self.cursor..]) {
110 Ok((chunk, used)) => {
111 let owned = chunk.to_vec();
112 self.cursor += used;
113 self.maybe_compact_buf();
114 Some(Ok(ReplicaEvent::SnapshotChunk(owned)))
115 }
116 Err(WireError::Truncated) => None,
117 Err(e) => Some(Err(ReplicaError::Frame(e))),
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use crate::replica::{ReplicaClient, ReplicaError, ReplicaEvent};
125 use crate::wire::{encode_frame, encode_snapshot_begin, encode_snapshot_chunk, encode_snapshot_end};
126 use kevy_resp::Argv;
127 use std::io::Write;
128 use std::net::{TcpListener, TcpStream};
129 use std::thread;
130
131 fn tcp_pair() -> (TcpStream, TcpStream) {
132 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
133 let addr = listener.local_addr().unwrap();
134 let client = TcpStream::connect(addr).unwrap();
135 let (server, _) = listener.accept().unwrap();
136 (server, client)
137 }
138
139 fn argv_for(args: &[&[u8]]) -> Argv {
140 let mut a = Argv::default();
141 for arg in args {
142 a.push(arg);
143 }
144 a
145 }
146
147 #[test]
148 fn next_event_snapshot_path_begin_chunks_end_then_frame() {
149 let (mut srv, cli) = tcp_pair();
150 thread::spawn(move || {
151 srv.write_all(&encode_snapshot_begin()).unwrap();
152 srv.write_all(&encode_snapshot_chunk(b"hello-snapshot")).unwrap();
153 srv.write_all(&encode_snapshot_chunk(b"more-snapshot-bytes")).unwrap();
154 srv.write_all(&encode_snapshot_end(42)).unwrap();
155 srv.write_all(&encode_frame(42, &argv_for(&[b"SET", b"k", b"v"]))).unwrap();
156 std::thread::sleep(std::time::Duration::from_millis(50));
157 drop(srv);
158 });
159 let mut client = ReplicaClient::from_socket_for_test(cli, 0);
160
161 assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotBegin))));
162 match client.next_event() {
163 Some(Ok(ReplicaEvent::SnapshotChunk(bytes))) => {
164 assert_eq!(bytes, b"hello-snapshot");
165 }
166 other => panic!("expected SnapshotChunk, got {other:?}"),
167 }
168 match client.next_event() {
169 Some(Ok(ReplicaEvent::SnapshotChunk(bytes))) => {
170 assert_eq!(bytes, b"more-snapshot-bytes");
171 }
172 other => panic!("expected SnapshotChunk, got {other:?}"),
173 }
174 match client.next_event() {
175 Some(Ok(ReplicaEvent::SnapshotEnd { ack_offset })) => assert_eq!(ack_offset, 42),
176 other => panic!("expected SnapshotEnd, got {other:?}"),
177 }
178 assert_eq!(client.expected_offset(), 42);
179 match client.next_event() {
180 Some(Ok(ReplicaEvent::Frame(f))) => {
181 assert_eq!(f.offset, 42);
182 assert_eq!(f.argv, argv_for(&[b"SET", b"k", b"v"]));
183 }
184 other => panic!("expected Frame, got {other:?}"),
185 }
186 }
187
188 #[test]
189 fn next_event_live_frame_during_snapshot_is_unexpected() {
190 let (mut srv, cli) = tcp_pair();
191 thread::spawn(move || {
192 srv.write_all(&encode_snapshot_begin()).unwrap();
193 srv.write_all(&encode_snapshot_chunk(b"first")).unwrap();
194 srv.write_all(&encode_frame(0, &argv_for(&[b"PING"]))).unwrap();
195 std::thread::sleep(std::time::Duration::from_millis(50));
196 drop(srv);
197 });
198 let mut client = ReplicaClient::from_socket_for_test(cli, 0);
199 assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotBegin))));
200 assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotChunk(_)))));
201 assert!(matches!(
202 client.next_event(),
203 Some(Err(ReplicaError::UnexpectedInSnapshot))
204 ));
205 }
206
207 #[test]
208 fn next_frame_returns_snapshot_in_progress_when_snapshot_starts() {
209 let (mut srv, cli) = tcp_pair();
210 thread::spawn(move || {
211 srv.write_all(&encode_snapshot_begin()).unwrap();
212 std::thread::sleep(std::time::Duration::from_millis(50));
213 drop(srv);
214 });
215 let mut client = ReplicaClient::from_socket_for_test(cli, 0);
216 assert!(matches!(
217 client.next_frame(),
218 Some(Err(ReplicaError::SnapshotInProgress))
219 ));
220 }
221
222 #[test]
223 fn live_frame_path_via_next_event() {
224 let (mut srv, cli) = tcp_pair();
225 thread::spawn(move || {
226 srv.write_all(&encode_frame(0, &argv_for(&[b"SET", b"a", b"1"]))).unwrap();
227 srv.write_all(&encode_frame(1, &argv_for(&[b"SET", b"b", b"2"]))).unwrap();
228 std::thread::sleep(std::time::Duration::from_millis(50));
229 drop(srv);
230 });
231 let mut client = ReplicaClient::from_socket_for_test(cli, 0);
232 for expected_off in 0..2 {
233 match client.next_event() {
234 Some(Ok(ReplicaEvent::Frame(f))) => assert_eq!(f.offset, expected_off),
235 other => panic!("expected Frame {expected_off}, got {other:?}"),
236 }
237 }
238 assert_eq!(client.expected_offset(), 2);
239 }
240
241 #[test]
242 fn snapshot_end_with_zero_offset_handled() {
243 let (mut srv, cli) = tcp_pair();
244 thread::spawn(move || {
245 srv.write_all(&encode_snapshot_begin()).unwrap();
246 srv.write_all(&encode_snapshot_end(0)).unwrap();
247 std::thread::sleep(std::time::Duration::from_millis(50));
248 drop(srv);
249 });
250 let mut client = ReplicaClient::from_socket_for_test(cli, 0);
251 assert!(matches!(client.next_event(), Some(Ok(ReplicaEvent::SnapshotBegin))));
252 match client.next_event() {
253 Some(Ok(ReplicaEvent::SnapshotEnd { ack_offset })) => assert_eq!(ack_offset, 0),
254 other => panic!("expected SnapshotEnd, got {other:?}"),
255 }
256 assert_eq!(client.expected_offset(), 0);
257 }
258}