1use kevy_resp::{Argv, ArgvView, ProtocolError, parse_command_into};
13
14pub use crate::wire_snapshot::{
18 SNAPSHOT_CHUNK_MAX, SNAPSHOT_LINE_MAX, SnapshotMarker, decode_snapshot_chunk,
19 decode_snapshot_marker, encode_snapshot_begin, encode_snapshot_chunk, encode_snapshot_end,
20};
21
22#[derive(Debug)]
27pub enum WireError {
28 Truncated,
31 BadEnvelope,
34 BadOffset,
36 NegativeOffset(i64),
38 BadPayload(ProtocolError),
40}
41
42impl std::fmt::Display for WireError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Self::Truncated => write!(f, "wire frame truncated"),
46 Self::BadEnvelope => write!(f, "wire envelope not *2"),
47 Self::BadOffset => write!(f, "wire offset element not RESP integer"),
48 Self::NegativeOffset(n) => write!(f, "wire offset is negative: {n}"),
49 Self::BadPayload(e) => write!(f, "wire inner payload malformed: {e:?}"),
50 }
51 }
52}
53
54impl std::error::Error for WireError {}
55
56impl PartialEq for WireError {
57 fn eq(&self, other: &Self) -> bool {
58 core::mem::discriminant(self) == core::mem::discriminant(other)
63 }
64}
65
66pub fn encode_frame<A: ArgvView + ?Sized>(offset: u64, argv: &A) -> Vec<u8> {
79 debug_assert!(
80 offset <= i64::MAX as u64,
81 "replication offset {offset} exceeds i64::MAX — wire envelope cannot encode",
82 );
83 let est = 32 + argv_byte_estimate_view(argv);
87 let mut out = Vec::with_capacity(est);
88 out.extend_from_slice(b"*2\r\n");
90 out.push(b':');
92 push_u64(&mut out, offset);
93 out.extend_from_slice(b"\r\n");
94 let n = argv.len();
97 out.push(b'*');
98 push_u64(&mut out, n as u64);
99 out.extend_from_slice(b"\r\n");
100 for i in 0..n {
101 let arg = &argv[i];
102 out.push(b'$');
103 push_u64(&mut out, arg.len() as u64);
104 out.extend_from_slice(b"\r\n");
105 out.extend_from_slice(arg);
106 out.extend_from_slice(b"\r\n");
107 }
108 out
109}
110
111pub fn decode_frame(buf: &[u8]) -> Result<(u64, Argv, usize), WireError> {
119 let after_env = parse_envelope_header(buf)?;
121 let (offset, after_offset) = parse_offset_line(buf, after_env)?;
123 let inner = &buf[after_offset..];
125 let mut argv = Argv::default();
126 let consumed_inner = match parse_command_into(inner, &mut argv) {
127 Ok(Some(n)) => n,
128 Ok(None) => return Err(WireError::Truncated),
129 Err(e) => return Err(WireError::BadPayload(e)),
130 };
131 Ok((offset, argv, after_offset + consumed_inner))
132}
133
134fn parse_envelope_header(buf: &[u8]) -> Result<usize, WireError> {
137 if buf.len() < 4 {
139 return Err(WireError::Truncated);
140 }
141 if buf[0] != b'*' {
142 return Err(WireError::BadEnvelope);
143 }
144 let eol = find_crlf(buf, 1).ok_or(WireError::Truncated)?;
145 let count = parse_decimal(&buf[1..eol]).ok_or(WireError::BadEnvelope)?;
146 if count != 2 {
147 return Err(WireError::BadEnvelope);
148 }
149 Ok(eol + 2)
150}
151
152fn parse_offset_line(buf: &[u8], start: usize) -> Result<(u64, usize), WireError> {
154 if start >= buf.len() {
155 return Err(WireError::Truncated);
156 }
157 if buf[start] != b':' {
158 return Err(WireError::BadOffset);
159 }
160 let eol = find_crlf(buf, start + 1).ok_or(WireError::Truncated)?;
161 let raw = &buf[start + 1..eol];
162 let signed = parse_signed_decimal(raw).ok_or(WireError::BadOffset)?;
165 if signed < 0 {
166 return Err(WireError::NegativeOffset(signed));
167 }
168 Ok((signed as u64, eol + 2))
169}
170
171pub(crate) fn find_crlf(buf: &[u8], from: usize) -> Option<usize> {
174 let mut i = from;
175 while i + 1 < buf.len() {
176 if buf[i] == b'\r' && buf[i + 1] == b'\n' {
177 return Some(i);
178 }
179 i += 1;
180 }
181 None
182}
183
184pub(crate) fn parse_decimal(bytes: &[u8]) -> Option<u64> {
186 if bytes.is_empty() {
187 return None;
188 }
189 let mut n: u64 = 0;
190 for &b in bytes {
191 if !b.is_ascii_digit() {
192 return None;
193 }
194 n = n.checked_mul(10)?.checked_add(u64::from(b - b'0'))?;
195 }
196 Some(n)
197}
198
199fn parse_signed_decimal(bytes: &[u8]) -> Option<i64> {
202 if bytes.is_empty() {
203 return None;
204 }
205 let (neg, digits) = match bytes[0] {
206 b'-' => (true, &bytes[1..]),
207 b'+' => (false, &bytes[1..]),
208 _ => (false, bytes),
209 };
210 let n = parse_decimal(digits)?;
211 if neg {
212 if n > (i64::MAX as u64) + 1 {
215 return None;
216 }
217 if n == (i64::MAX as u64) + 1 {
218 return Some(i64::MIN);
219 }
220 Some(-(n as i64))
221 } else {
222 if n > i64::MAX as u64 {
223 return None;
224 }
225 Some(n as i64)
226 }
227}
228
229pub(crate) fn push_u64(out: &mut Vec<u8>, n: u64) {
232 if n == 0 {
233 out.push(b'0');
234 return;
235 }
236 let mut tmp = [0u8; 20]; let mut i = tmp.len();
238 let mut v = n;
239 while v != 0 {
240 i -= 1;
241 tmp[i] = b'0' + (v % 10) as u8;
242 v /= 10;
243 }
244 out.extend_from_slice(&tmp[i..]);
245}
246
247fn argv_byte_estimate_view<A: ArgvView + ?Sized>(argv: &A) -> usize {
251 let mut bytes = 0usize;
254 for i in 0..argv.len() {
255 bytes += argv[i].len();
256 }
257 bytes + argv.len() * 10
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 fn argv_from(args: &[&[u8]]) -> Argv {
265 let mut a = Argv::default();
266 for arg in args {
267 a.push(arg);
268 }
269 a
270 }
271
272 #[test]
273 fn roundtrip_simple_set() {
274 let argv = argv_from(&[b"SET", b"foo", b"bar"]);
275 let bytes = encode_frame(42, &argv);
276 let (offset, decoded, used) = decode_frame(&bytes).expect("decode");
277 assert_eq!(offset, 42);
278 assert_eq!(decoded, argv);
279 assert_eq!(used, bytes.len());
280 }
281
282 #[test]
283 fn roundtrip_offset_zero_and_max() {
284 for offset in [0u64, 1, i64::MAX as u64] {
287 let argv = argv_from(&[b"PING"]);
288 let bytes = encode_frame(offset, &argv);
289 let (back, _, _) = decode_frame(&bytes).expect("decode");
290 assert_eq!(back, offset);
291 }
292 }
293
294 #[test]
295 #[cfg(debug_assertions)] #[should_panic(expected = "exceeds i64::MAX")]
297 fn encoding_offset_above_i64_max_panics_in_debug() {
298 let argv = argv_from(&[b"PING"]);
302 let _ = encode_frame(u64::MAX, &argv);
303 }
304
305 #[test]
306 fn roundtrip_argv_with_binary_and_empty_args() {
307 let bin: Vec<u8> = (0u8..=255).collect();
308 let argv = argv_from(&[b"HSET", b"key", b"field", &bin, b""]);
309 let bytes = encode_frame(7, &argv);
310 let (_, decoded, _) = decode_frame(&bytes).expect("decode");
311 assert_eq!(decoded.len(), 5);
312 assert_eq!(decoded.get(3), Some(bin.as_slice()));
313 assert_eq!(decoded.get(4), Some(&b""[..]));
314 }
315
316 #[test]
317 fn two_concatenated_frames_decode_in_order() {
318 let a = encode_frame(1, &argv_from(&[b"SET", b"k", b"a"]));
319 let b = encode_frame(2, &argv_from(&[b"DEL", b"k"]));
320 let mut buf = a.clone();
321 buf.extend_from_slice(&b);
322
323 let (off1, argv1, used1) = decode_frame(&buf).expect("frame 1");
324 assert_eq!(off1, 1);
325 assert_eq!(argv1, argv_from(&[b"SET", b"k", b"a"]));
326 assert_eq!(used1, a.len());
327
328 let (off2, argv2, used2) = decode_frame(&buf[used1..]).expect("frame 2");
329 assert_eq!(off2, 2);
330 assert_eq!(argv2, argv_from(&[b"DEL", b"k"]));
331 assert_eq!(used1 + used2, buf.len());
332 }
333
334 #[test]
335 fn offsets_are_strictly_increasing_when_emitted_in_order() {
336 let mut bytes = Vec::new();
337 for o in 0u64..16 {
338 bytes.extend(encode_frame(o, &argv_from(&[b"PING"])));
339 }
340 let mut pos = 0;
341 let mut last: Option<u64> = None;
342 while pos < bytes.len() {
343 let (offset, _, used) = decode_frame(&bytes[pos..]).expect("decode");
344 if let Some(prev) = last {
345 assert!(offset > prev, "offset {offset} not > prev {prev}");
346 }
347 last = Some(offset);
348 pos += used;
349 }
350 assert_eq!(last, Some(15));
351 assert_eq!(pos, bytes.len());
352 }
353
354 #[test]
355 fn truncated_envelope_is_truncated_not_bad() {
356 assert_eq!(decode_frame(&[]), Err(WireError::Truncated));
358 assert_eq!(decode_frame(b"*"), Err(WireError::Truncated));
360 assert_eq!(decode_frame(b"*2\r\n"), Err(WireError::Truncated));
362 assert_eq!(decode_frame(b"*2\r\n:42"), Err(WireError::Truncated));
364 assert_eq!(decode_frame(b"*2\r\n:42\r\n"), Err(WireError::Truncated));
366 assert_eq!(decode_frame(b"*2\r\n:42\r\n*1\r\n$3\r\nfo"), Err(WireError::Truncated));
368 }
369
370 #[test]
371 fn wrong_envelope_count_rejected() {
372 let bad = b"*1\r\n:42\r\n";
374 assert!(matches!(decode_frame(bad), Err(WireError::BadEnvelope)));
375 let bad3 = b"*3\r\n:42\r\n*0\r\n:0\r\n";
377 assert!(matches!(decode_frame(bad3), Err(WireError::BadEnvelope)));
378 }
379
380 #[test]
381 fn non_array_envelope_rejected() {
382 let bad = b":42\r\n*1\r\n$4\r\nPING\r\n";
384 assert!(matches!(decode_frame(bad), Err(WireError::BadEnvelope)));
385 }
386
387 #[test]
388 fn offset_not_integer_rejected() {
389 let bad = b"*2\r\n$2\r\n42\r\n*1\r\n$4\r\nPING\r\n";
391 assert!(matches!(decode_frame(bad), Err(WireError::BadOffset)));
392 }
393
394 #[test]
395 fn negative_offset_rejected_with_value() {
396 let bad = b"*2\r\n:-7\r\n*1\r\n$4\r\nPING\r\n";
397 match decode_frame(bad) {
398 Err(WireError::NegativeOffset(n)) => assert_eq!(n, -7),
399 other => panic!("expected NegativeOffset, got {other:?}"),
400 }
401 }
402
403 #[test]
404 fn malformed_inner_payload_surfaces_bad_payload() {
405 let bad = b"*2\r\n:1\r\n*1\r\n!nope\r\n";
408 assert!(matches!(decode_frame(bad), Err(WireError::BadPayload(_))));
409 }
410
411 #[test]
412 fn offset_with_extra_digits_overflow_rejected() {
413 let mut bad = b"*2\r\n:".to_vec();
417 bad.extend(std::iter::repeat_n(b'9', 21));
418 bad.extend_from_slice(b"\r\n*1\r\n$4\r\nPING\r\n");
419 assert!(matches!(decode_frame(&bad), Err(WireError::BadOffset)));
420 }
421
422 #[test]
427 fn encoded_bytes_are_exactly_what_spec_says() {
428 let argv = argv_from(&[b"SET", b"foo", b"bar"]);
431 let bytes = encode_frame(99, &argv);
432 let expected =
433 b"*2\r\n:99\r\n*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n";
434 assert_eq!(bytes, expected);
435 }
436}