Skip to main content

pylon_runtime/
resp.rs

1//! RESP (Redis Serialization Protocol) parser and serializer.
2//!
3//! Implements RESP2, the wire protocol used by Redis. This allows pylon's
4//! cache to be accessed by any standard Redis client library.
5//!
6//! # Wire format
7//!
8//! - Simple strings: `+OK\r\n`
9//! - Errors:         `-ERR message\r\n`
10//! - Integers:       `:1000\r\n`
11//! - Bulk strings:   `$5\r\nhello\r\n`  (length-prefixed)
12//! - Arrays:         `*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n`
13//! - Null:           `$-1\r\n`
14
15use std::io::BufRead;
16
17/// A RESP value.
18#[derive(Debug, Clone, PartialEq)]
19pub enum RespValue {
20    SimpleString(String),
21    Error(String),
22    Integer(i64),
23    BulkString(Option<String>),    // None = null bulk string
24    Array(Option<Vec<RespValue>>), // None = null array
25}
26
27impl RespValue {
28    /// Serialize to RESP wire format.
29    pub fn serialize(&self) -> Vec<u8> {
30        match self {
31            RespValue::SimpleString(s) => format!("+{s}\r\n").into_bytes(),
32            RespValue::Error(s) => format!("-{s}\r\n").into_bytes(),
33            RespValue::Integer(n) => format!(":{n}\r\n").into_bytes(),
34            RespValue::BulkString(None) => b"$-1\r\n".to_vec(),
35            RespValue::BulkString(Some(s)) => {
36                let mut buf = format!("${}\r\n", s.len()).into_bytes();
37                buf.extend_from_slice(s.as_bytes());
38                buf.extend_from_slice(b"\r\n");
39                buf
40            }
41            RespValue::Array(None) => b"*-1\r\n".to_vec(),
42            RespValue::Array(Some(items)) => {
43                let mut buf = format!("*{}\r\n", items.len()).into_bytes();
44                for item in items {
45                    buf.extend_from_slice(&item.serialize());
46                }
47                buf
48            }
49        }
50    }
51
52    /// Create a bulk string value.
53    pub fn bulk(s: &str) -> Self {
54        RespValue::BulkString(Some(s.to_string()))
55    }
56
57    /// Create a null bulk string.
58    pub fn null() -> Self {
59        RespValue::BulkString(None)
60    }
61
62    /// Create a simple string "OK".
63    pub fn ok() -> Self {
64        RespValue::SimpleString("OK".to_string())
65    }
66
67    /// Create an integer value.
68    pub fn int(n: i64) -> Self {
69        RespValue::Integer(n)
70    }
71
72    /// Create an error value with the standard "ERR" prefix.
73    pub fn err(msg: &str) -> Self {
74        RespValue::Error(format!("ERR {msg}"))
75    }
76
77    /// Create an array value.
78    pub fn array(items: Vec<RespValue>) -> Self {
79        RespValue::Array(Some(items))
80    }
81}
82
83/// Parse a single RESP value from a buffered reader.
84///
85/// Returns `Err` on I/O errors, malformed input, or EOF.
86pub fn parse_resp<R: BufRead>(reader: &mut R) -> Result<RespValue, String> {
87    let mut line = String::new();
88    let bytes_read = reader
89        .read_line(&mut line)
90        .map_err(|e| format!("Read error: {e}"))?;
91
92    if bytes_read == 0 {
93        return Err("Connection closed".into());
94    }
95
96    // Must have at least type byte + \r\n.
97    if line.len() < 3 || !line.ends_with("\r\n") {
98        return Err(format!("Malformed RESP line: {:?}", line));
99    }
100
101    let content = &line[1..line.len() - 2]; // strip type byte and \r\n
102
103    match line.as_bytes()[0] {
104        b'+' => Ok(RespValue::SimpleString(content.to_string())),
105        b'-' => Ok(RespValue::Error(content.to_string())),
106        b':' => {
107            let n: i64 = content
108                .parse()
109                .map_err(|_| format!("Invalid integer: {content:?}"))?;
110            Ok(RespValue::Integer(n))
111        }
112        b'$' => {
113            let len: i64 = content
114                .parse()
115                .map_err(|_| format!("Invalid bulk length: {content:?}"))?;
116            if len < 0 {
117                return Ok(RespValue::BulkString(None));
118            }
119            let len = len as usize;
120            let mut buf = vec![0u8; len + 2]; // data + trailing \r\n
121            reader
122                .read_exact(&mut buf)
123                .map_err(|e| format!("Read error: {e}"))?;
124            if buf[len] != b'\r' || buf[len + 1] != b'\n' {
125                return Err("Missing \\r\\n after bulk string data".into());
126            }
127            let s = String::from_utf8(buf[..len].to_vec())
128                .map_err(|_| "Invalid UTF-8 in bulk string")?;
129            Ok(RespValue::BulkString(Some(s)))
130        }
131        b'*' => {
132            let count: i64 = content
133                .parse()
134                .map_err(|_| format!("Invalid array length: {content:?}"))?;
135            if count < 0 {
136                return Ok(RespValue::Array(None));
137            }
138            let mut items = Vec::with_capacity(count as usize);
139            for _ in 0..count {
140                items.push(parse_resp(reader)?);
141            }
142            Ok(RespValue::Array(Some(items)))
143        }
144        other => Err(format!("Unknown RESP type byte: {:?}", other as char)),
145    }
146}
147
148// ---------------------------------------------------------------------------
149// Tests
150// ---------------------------------------------------------------------------
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::io::BufReader;
156
157    /// Helper: parse a RESP value from raw bytes.
158    fn parse(input: &[u8]) -> Result<RespValue, String> {
159        let mut reader = BufReader::new(input);
160        parse_resp(&mut reader)
161    }
162
163    /// Helper: serialize then re-parse, asserting roundtrip equality.
164    fn roundtrip(value: &RespValue) {
165        let bytes = value.serialize();
166        let parsed = parse(&bytes).expect("roundtrip parse failed");
167        assert_eq!(&parsed, value, "roundtrip mismatch");
168    }
169
170    // -- Simple strings --
171
172    #[test]
173    fn parse_simple_string() {
174        let val = parse(b"+OK\r\n").unwrap();
175        assert_eq!(val, RespValue::SimpleString("OK".into()));
176    }
177
178    #[test]
179    fn serialize_simple_string() {
180        let val = RespValue::SimpleString("hello world".into());
181        assert_eq!(val.serialize(), b"+hello world\r\n");
182    }
183
184    #[test]
185    fn roundtrip_simple_string() {
186        roundtrip(&RespValue::SimpleString("PONG".into()));
187        roundtrip(&RespValue::ok());
188    }
189
190    // -- Errors --
191
192    #[test]
193    fn parse_error() {
194        let val = parse(b"-ERR unknown command\r\n").unwrap();
195        assert_eq!(val, RespValue::Error("ERR unknown command".into()));
196    }
197
198    #[test]
199    fn serialize_error() {
200        let val = RespValue::err("bad key");
201        assert_eq!(val.serialize(), b"-ERR bad key\r\n");
202    }
203
204    #[test]
205    fn roundtrip_error() {
206        roundtrip(&RespValue::err("something went wrong"));
207    }
208
209    // -- Integers --
210
211    #[test]
212    fn parse_integer() {
213        assert_eq!(parse(b":1000\r\n").unwrap(), RespValue::Integer(1000));
214        assert_eq!(parse(b":-42\r\n").unwrap(), RespValue::Integer(-42));
215        assert_eq!(parse(b":0\r\n").unwrap(), RespValue::Integer(0));
216    }
217
218    #[test]
219    fn serialize_integer() {
220        assert_eq!(RespValue::int(99).serialize(), b":99\r\n");
221        assert_eq!(RespValue::int(-1).serialize(), b":-1\r\n");
222    }
223
224    #[test]
225    fn roundtrip_integer() {
226        roundtrip(&RespValue::int(0));
227        roundtrip(&RespValue::int(i64::MAX));
228        roundtrip(&RespValue::int(i64::MIN));
229    }
230
231    // -- Bulk strings --
232
233    #[test]
234    fn parse_bulk_string() {
235        let val = parse(b"$5\r\nhello\r\n").unwrap();
236        assert_eq!(val, RespValue::BulkString(Some("hello".into())));
237    }
238
239    #[test]
240    fn parse_null_bulk_string() {
241        let val = parse(b"$-1\r\n").unwrap();
242        assert_eq!(val, RespValue::BulkString(None));
243    }
244
245    #[test]
246    fn parse_empty_bulk_string() {
247        let val = parse(b"$0\r\n\r\n").unwrap();
248        assert_eq!(val, RespValue::BulkString(Some(String::new())));
249    }
250
251    #[test]
252    fn serialize_bulk_string() {
253        assert_eq!(RespValue::bulk("foo").serialize(), b"$3\r\nfoo\r\n");
254    }
255
256    #[test]
257    fn serialize_null_bulk_string() {
258        assert_eq!(RespValue::null().serialize(), b"$-1\r\n");
259    }
260
261    #[test]
262    fn serialize_empty_bulk_string() {
263        assert_eq!(
264            RespValue::BulkString(Some(String::new())).serialize(),
265            b"$0\r\n\r\n"
266        );
267    }
268
269    #[test]
270    fn roundtrip_bulk_string() {
271        roundtrip(&RespValue::bulk("hello"));
272        roundtrip(&RespValue::null());
273        roundtrip(&RespValue::BulkString(Some(String::new())));
274    }
275
276    #[test]
277    fn large_bulk_string() {
278        let large = "x".repeat(100_000);
279        let val = RespValue::bulk(&large);
280        roundtrip(&val);
281    }
282
283    // -- Arrays --
284
285    #[test]
286    fn parse_array() {
287        let input = b"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n";
288        let val = parse(input).unwrap();
289        assert_eq!(
290            val,
291            RespValue::Array(Some(vec![RespValue::bulk("foo"), RespValue::bulk("bar"),]))
292        );
293    }
294
295    #[test]
296    fn parse_null_array() {
297        let val = parse(b"*-1\r\n").unwrap();
298        assert_eq!(val, RespValue::Array(None));
299    }
300
301    #[test]
302    fn parse_empty_array() {
303        let val = parse(b"*0\r\n").unwrap();
304        assert_eq!(val, RespValue::Array(Some(vec![])));
305    }
306
307    #[test]
308    fn serialize_array() {
309        let val = RespValue::array(vec![RespValue::bulk("a"), RespValue::int(1)]);
310        let bytes = val.serialize();
311        assert_eq!(bytes, b"*2\r\n$1\r\na\r\n:1\r\n");
312    }
313
314    #[test]
315    fn serialize_null_array() {
316        assert_eq!(RespValue::Array(None).serialize(), b"*-1\r\n");
317    }
318
319    #[test]
320    fn roundtrip_array() {
321        roundtrip(&RespValue::array(vec![
322            RespValue::bulk("SET"),
323            RespValue::bulk("key"),
324            RespValue::bulk("value"),
325        ]));
326        roundtrip(&RespValue::Array(None));
327        roundtrip(&RespValue::array(vec![]));
328    }
329
330    #[test]
331    fn nested_arrays() {
332        let inner = RespValue::array(vec![RespValue::int(1), RespValue::int(2)]);
333        let outer = RespValue::array(vec![inner.clone(), RespValue::bulk("end")]);
334        roundtrip(&outer);
335    }
336
337    #[test]
338    fn deeply_nested_arrays() {
339        let mut val = RespValue::int(42);
340        for _ in 0..10 {
341            val = RespValue::array(vec![val]);
342        }
343        roundtrip(&val);
344    }
345
346    // -- Mixed types in arrays --
347
348    #[test]
349    fn mixed_type_array() {
350        let val = RespValue::array(vec![
351            RespValue::SimpleString("OK".into()),
352            RespValue::err("bad"),
353            RespValue::int(42),
354            RespValue::bulk("hello"),
355            RespValue::null(),
356        ]);
357        roundtrip(&val);
358    }
359
360    // -- Error cases --
361
362    #[test]
363    fn empty_input() {
364        assert!(parse(b"").is_err());
365    }
366
367    #[test]
368    fn malformed_line() {
369        assert!(parse(b"x\r\n").is_err());
370    }
371
372    #[test]
373    fn invalid_integer() {
374        assert!(parse(b":notanumber\r\n").is_err());
375    }
376
377    #[test]
378    fn truncated_bulk_string() {
379        // Says length 10 but only provides 3 bytes.
380        assert!(parse(b"$10\r\nfoo\r\n").is_err());
381    }
382
383    // -- Helpers --
384
385    #[test]
386    fn helper_constructors() {
387        assert_eq!(RespValue::ok(), RespValue::SimpleString("OK".into()));
388        assert_eq!(RespValue::null(), RespValue::BulkString(None));
389        assert_eq!(RespValue::int(5), RespValue::Integer(5));
390        assert_eq!(RespValue::err("fail"), RespValue::Error("ERR fail".into()));
391        assert_eq!(
392            RespValue::array(vec![RespValue::int(1)]),
393            RespValue::Array(Some(vec![RespValue::Integer(1)]))
394        );
395    }
396
397    // -- Multiple values in sequence --
398
399    #[test]
400    fn parse_multiple_values_from_stream() {
401        let input = b"+OK\r\n:42\r\n$5\r\nhello\r\n";
402        let mut reader = BufReader::new(&input[..]);
403
404        let v1 = parse_resp(&mut reader).unwrap();
405        assert_eq!(v1, RespValue::SimpleString("OK".into()));
406
407        let v2 = parse_resp(&mut reader).unwrap();
408        assert_eq!(v2, RespValue::Integer(42));
409
410        let v3 = parse_resp(&mut reader).unwrap();
411        assert_eq!(v3, RespValue::bulk("hello"));
412    }
413}