Skip to main content

allsource_core/infrastructure/resp/
protocol.rs

1//! RESP3 protocol parser and serializer.
2//!
3//! Implements a subset of the Redis RESP3 wire protocol sufficient for:
4//! - Parsing client commands (arrays of bulk strings)
5//! - Serializing responses (simple strings, errors, integers, bulk strings,
6//!   arrays, maps, and null)
7//!
8//! Reference: <https://github.com/redis/redis-specifications/blob/master/protocol/RESP3.md>
9
10use std::collections::BTreeMap;
11use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
12
13/// A value in the RESP3 protocol.
14#[derive(Debug, Clone, PartialEq)]
15pub enum RespValue {
16    /// `+OK\r\n`
17    SimpleString(String),
18    /// `-ERR message\r\n`
19    Error(String),
20    /// `:<number>\r\n`
21    Integer(i64),
22    /// `$<len>\r\n<data>\r\n` or `$-1\r\n` for null
23    BulkString(Vec<u8>),
24    /// `*<count>\r\n...`
25    Array(Vec<RespValue>),
26    /// `%<count>\r\n...` (RESP3 map type)
27    Map(BTreeMap<String, RespValue>),
28    /// `_\r\n` (RESP3 null)
29    Null,
30}
31
32impl RespValue {
33    /// Interpret this value as a UTF-8 string (for command parsing).
34    pub fn as_str(&self) -> Option<&str> {
35        match self {
36            RespValue::BulkString(b) => std::str::from_utf8(b).ok(),
37            RespValue::SimpleString(s) => Some(s.as_str()),
38            _ => None,
39        }
40    }
41
42    /// Encode this value into RESP3 wire format.
43    pub fn encode(&self) -> Vec<u8> {
44        let mut buf = Vec::new();
45        self.encode_into(&mut buf);
46        buf
47    }
48
49    fn encode_into(&self, buf: &mut Vec<u8>) {
50        match self {
51            RespValue::SimpleString(s) => {
52                buf.push(b'+');
53                buf.extend_from_slice(s.as_bytes());
54                buf.extend_from_slice(b"\r\n");
55            }
56            RespValue::Error(s) => {
57                buf.push(b'-');
58                buf.extend_from_slice(s.as_bytes());
59                buf.extend_from_slice(b"\r\n");
60            }
61            RespValue::Integer(n) => {
62                buf.push(b':');
63                buf.extend_from_slice(n.to_string().as_bytes());
64                buf.extend_from_slice(b"\r\n");
65            }
66            RespValue::BulkString(data) => {
67                buf.push(b'$');
68                buf.extend_from_slice(data.len().to_string().as_bytes());
69                buf.extend_from_slice(b"\r\n");
70                buf.extend_from_slice(data);
71                buf.extend_from_slice(b"\r\n");
72            }
73            RespValue::Array(items) => {
74                buf.push(b'*');
75                buf.extend_from_slice(items.len().to_string().as_bytes());
76                buf.extend_from_slice(b"\r\n");
77                for item in items {
78                    item.encode_into(buf);
79                }
80            }
81            RespValue::Map(map) => {
82                buf.push(b'%');
83                buf.extend_from_slice(map.len().to_string().as_bytes());
84                buf.extend_from_slice(b"\r\n");
85                for (k, v) in map {
86                    RespValue::BulkString(k.as_bytes().to_vec()).encode_into(buf);
87                    v.encode_into(buf);
88                }
89            }
90            RespValue::Null => {
91                buf.extend_from_slice(b"_\r\n");
92            }
93        }
94    }
95}
96
97/// Parse a single RESP value from an async buffered reader.
98///
99/// Returns `None` if the connection was closed (EOF).
100pub async fn parse_value<R: AsyncBufRead + Unpin>(
101    reader: &mut R,
102) -> std::io::Result<Option<RespValue>> {
103    let mut line = String::new();
104    let n = reader.read_line(&mut line).await?;
105    if n == 0 {
106        return Ok(None); // EOF
107    }
108    let line = line.trim_end_matches("\r\n").trim_end_matches('\n');
109
110    if line.is_empty() {
111        return Ok(None);
112    }
113
114    let type_byte = line.as_bytes()[0];
115    let rest = &line[1..];
116
117    match type_byte {
118        b'+' => Ok(Some(RespValue::SimpleString(rest.to_string()))),
119        b'-' => Ok(Some(RespValue::Error(rest.to_string()))),
120        b':' => {
121            let n: i64 = rest.parse().map_err(|_| {
122                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid integer")
123            })?;
124            Ok(Some(RespValue::Integer(n)))
125        }
126        b'$' => {
127            let len: i64 = rest.parse().map_err(|_| {
128                std::io::Error::new(
129                    std::io::ErrorKind::InvalidData,
130                    "invalid bulk string length",
131                )
132            })?;
133            if len < 0 {
134                return Ok(Some(RespValue::Null));
135            }
136            let len = len as usize;
137            let mut data = vec![0u8; len + 2]; // +2 for trailing \r\n
138            reader.read_exact(&mut data).await?;
139            data.truncate(len); // remove \r\n
140            Ok(Some(RespValue::BulkString(data)))
141        }
142        b'*' => {
143            let count: i64 = rest.parse().map_err(|_| {
144                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid array length")
145            })?;
146            if count < 0 {
147                return Ok(Some(RespValue::Null));
148            }
149            let mut items = Vec::with_capacity(count as usize);
150            for _ in 0..count {
151                match Box::pin(parse_value(reader)).await? {
152                    Some(v) => items.push(v),
153                    None => {
154                        return Err(std::io::Error::new(
155                            std::io::ErrorKind::UnexpectedEof,
156                            "unexpected EOF in array",
157                        ));
158                    }
159                }
160            }
161            Ok(Some(RespValue::Array(items)))
162        }
163        b'_' => Ok(Some(RespValue::Null)),
164        _ => Err(std::io::Error::new(
165            std::io::ErrorKind::InvalidData,
166            format!("unknown RESP type byte: {}", type_byte as char),
167        )),
168    }
169}
170
171/// Write a RESP value to an async writer.
172pub async fn write_value<W: tokio::io::AsyncWrite + Unpin>(
173    writer: &mut W,
174    value: &RespValue,
175) -> std::io::Result<()> {
176    writer.write_all(&value.encode()).await?;
177    writer.flush().await
178}
179
180// ── Helper constructors ─────────────────────────────────────────────────────
181
182impl RespValue {
183    pub fn ok() -> Self {
184        RespValue::SimpleString("OK".to_string())
185    }
186
187    pub fn err(msg: impl Into<String>) -> Self {
188        RespValue::Error(format!("ERR {}", msg.into()))
189    }
190
191    pub fn bulk(s: impl Into<Vec<u8>>) -> Self {
192        RespValue::BulkString(s.into())
193    }
194
195    pub fn bulk_string(s: &str) -> Self {
196        RespValue::BulkString(s.as_bytes().to_vec())
197    }
198
199    pub fn integer(n: i64) -> Self {
200        RespValue::Integer(n)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use tokio::io::BufReader;
208
209    #[test]
210    fn test_encode_simple_string() {
211        let v = RespValue::SimpleString("OK".to_string());
212        assert_eq!(v.encode(), b"+OK\r\n");
213    }
214
215    #[test]
216    fn test_encode_error() {
217        let v = RespValue::Error("ERR bad command".to_string());
218        assert_eq!(v.encode(), b"-ERR bad command\r\n");
219    }
220
221    #[test]
222    fn test_encode_integer() {
223        let v = RespValue::Integer(42);
224        assert_eq!(v.encode(), b":42\r\n");
225    }
226
227    #[test]
228    fn test_encode_bulk_string() {
229        let v = RespValue::BulkString(b"hello".to_vec());
230        assert_eq!(v.encode(), b"$5\r\nhello\r\n");
231    }
232
233    #[test]
234    fn test_encode_null() {
235        let v = RespValue::Null;
236        assert_eq!(v.encode(), b"_\r\n");
237    }
238
239    #[test]
240    fn test_encode_array() {
241        let v = RespValue::Array(vec![
242            RespValue::BulkString(b"SET".to_vec()),
243            RespValue::BulkString(b"key".to_vec()),
244        ]);
245        assert_eq!(v.encode(), b"*2\r\n$3\r\nSET\r\n$3\r\nkey\r\n");
246    }
247
248    #[test]
249    fn test_encode_map() {
250        let mut map = BTreeMap::new();
251        map.insert("key".to_string(), RespValue::BulkString(b"value".to_vec()));
252        let v = RespValue::Map(map);
253        let encoded = v.encode();
254        assert_eq!(encoded, b"%1\r\n$3\r\nkey\r\n$5\r\nvalue\r\n");
255    }
256
257    #[tokio::test]
258    async fn test_parse_simple_string() {
259        let data = b"+OK\r\n";
260        let mut reader = BufReader::new(&data[..]);
261        let v = parse_value(&mut reader).await.unwrap().unwrap();
262        assert_eq!(v, RespValue::SimpleString("OK".to_string()));
263    }
264
265    #[tokio::test]
266    async fn test_parse_bulk_string() {
267        let data = b"$5\r\nhello\r\n";
268        let mut reader = BufReader::new(&data[..]);
269        let v = parse_value(&mut reader).await.unwrap().unwrap();
270        assert_eq!(v, RespValue::BulkString(b"hello".to_vec()));
271    }
272
273    #[tokio::test]
274    async fn test_parse_array() {
275        let data = b"*2\r\n$4\r\nPING\r\n$5\r\nhello\r\n";
276        let mut reader = BufReader::new(&data[..]);
277        let v = parse_value(&mut reader).await.unwrap().unwrap();
278        assert_eq!(
279            v,
280            RespValue::Array(vec![
281                RespValue::BulkString(b"PING".to_vec()),
282                RespValue::BulkString(b"hello".to_vec()),
283            ])
284        );
285    }
286
287    #[tokio::test]
288    async fn test_parse_integer() {
289        let data = b":1000\r\n";
290        let mut reader = BufReader::new(&data[..]);
291        let v = parse_value(&mut reader).await.unwrap().unwrap();
292        assert_eq!(v, RespValue::Integer(1000));
293    }
294
295    #[tokio::test]
296    async fn test_parse_null_bulk_string() {
297        let data = b"$-1\r\n";
298        let mut reader = BufReader::new(&data[..]);
299        let v = parse_value(&mut reader).await.unwrap().unwrap();
300        assert_eq!(v, RespValue::Null);
301    }
302
303    #[tokio::test]
304    async fn test_parse_eof() {
305        let data = b"";
306        let mut reader = BufReader::new(&data[..]);
307        let v = parse_value(&mut reader).await.unwrap();
308        assert!(v.is_none());
309    }
310
311    #[tokio::test]
312    async fn test_roundtrip_array() {
313        let original = RespValue::Array(vec![
314            RespValue::BulkString(b"XADD".to_vec()),
315            RespValue::BulkString(b"stream".to_vec()),
316            RespValue::BulkString(b"*".to_vec()),
317            RespValue::BulkString(b"field".to_vec()),
318            RespValue::BulkString(b"value".to_vec()),
319        ]);
320        let encoded = original.encode();
321        let mut reader = BufReader::new(&encoded[..]);
322        let parsed = parse_value(&mut reader).await.unwrap().unwrap();
323        assert_eq!(parsed, original);
324    }
325
326    #[test]
327    fn test_as_str() {
328        assert_eq!(RespValue::bulk_string("hello").as_str(), Some("hello"));
329        assert_eq!(
330            RespValue::SimpleString("world".to_string()).as_str(),
331            Some("world")
332        );
333        assert_eq!(RespValue::Integer(42).as_str(), None);
334    }
335}