Skip to main content

ember_protocol/
command.rs

1//! Command parsing from RESP3 frames.
2//!
3//! Converts a parsed [`Frame`] (expected to be an array) into a typed
4//! [`Command`] enum. This keeps protocol-level concerns separate from
5//! the engine that actually executes commands.
6
7use bytes::Bytes;
8
9use crate::error::ProtocolError;
10use crate::types::Frame;
11
12/// Expiration option for the SET command.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum SetExpire {
15    /// EX seconds — expire after N seconds.
16    Ex(u64),
17    /// PX milliseconds — expire after N milliseconds.
18    Px(u64),
19}
20
21/// A parsed client command, ready for execution.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum Command {
24    /// PING with an optional message. Returns PONG or echoes the message.
25    Ping(Option<Bytes>),
26
27    /// ECHO <message>. Returns the message back to the client.
28    Echo(Bytes),
29
30    /// GET <key>. Returns the value or nil.
31    Get { key: String },
32
33    /// SET <key> <value> [EX seconds | PX milliseconds].
34    Set {
35        key: String,
36        value: Bytes,
37        expire: Option<SetExpire>,
38    },
39
40    /// DEL <key> [key ...]. Returns the number of keys removed.
41    Del { keys: Vec<String> },
42
43    /// EXISTS <key> [key ...]. Returns the number of keys that exist.
44    Exists { keys: Vec<String> },
45
46    /// EXPIRE <key> <seconds>. Sets a TTL on an existing key.
47    Expire { key: String, seconds: u64 },
48
49    /// TTL <key>. Returns remaining time-to-live in seconds.
50    Ttl { key: String },
51
52    /// DBSIZE. Returns the number of keys in the database.
53    DbSize,
54
55    /// INFO [section]. Returns server info. Currently only supports "keyspace".
56    Info { section: Option<String> },
57
58    /// A command we don't recognize (yet).
59    Unknown(String),
60}
61
62impl Command {
63    /// Parses a [`Frame`] into a [`Command`].
64    ///
65    /// Expects an array frame where the first element is the command name
66    /// (as a bulk or simple string) and the rest are arguments.
67    pub fn from_frame(frame: Frame) -> Result<Command, ProtocolError> {
68        let frames = match frame {
69            Frame::Array(frames) => frames,
70            _ => {
71                return Err(ProtocolError::InvalidCommandFrame(
72                    "expected array frame".into(),
73                ));
74            }
75        };
76
77        if frames.is_empty() {
78            return Err(ProtocolError::InvalidCommandFrame(
79                "empty command array".into(),
80            ));
81        }
82
83        let name = extract_string(&frames[0])?;
84        let name_upper = name.to_ascii_uppercase();
85
86        match name_upper.as_str() {
87            "PING" => parse_ping(&frames[1..]),
88            "ECHO" => parse_echo(&frames[1..]),
89            "GET" => parse_get(&frames[1..]),
90            "SET" => parse_set(&frames[1..]),
91            "DEL" => parse_del(&frames[1..]),
92            "EXISTS" => parse_exists(&frames[1..]),
93            "EXPIRE" => parse_expire(&frames[1..]),
94            "TTL" => parse_ttl(&frames[1..]),
95            "DBSIZE" => parse_dbsize(&frames[1..]),
96            "INFO" => parse_info(&frames[1..]),
97            _ => Ok(Command::Unknown(name)),
98        }
99    }
100}
101
102/// Extracts a UTF-8 string from a Bulk or Simple frame.
103fn extract_string(frame: &Frame) -> Result<String, ProtocolError> {
104    match frame {
105        Frame::Bulk(data) => String::from_utf8(data.to_vec()).map_err(|_| {
106            ProtocolError::InvalidCommandFrame("command name is not valid utf-8".into())
107        }),
108        Frame::Simple(s) => Ok(s.clone()),
109        _ => Err(ProtocolError::InvalidCommandFrame(
110            "expected bulk or simple string for command name".into(),
111        )),
112    }
113}
114
115/// Extracts raw bytes from a Bulk or Simple frame.
116fn extract_bytes(frame: &Frame) -> Result<Bytes, ProtocolError> {
117    match frame {
118        Frame::Bulk(data) => Ok(data.clone()),
119        Frame::Simple(s) => Ok(Bytes::from(s.clone().into_bytes())),
120        _ => Err(ProtocolError::InvalidCommandFrame(
121            "expected bulk or simple string argument".into(),
122        )),
123    }
124}
125
126/// Parses a string argument as a positive u64.
127fn parse_u64(frame: &Frame, cmd: &str) -> Result<u64, ProtocolError> {
128    let s = extract_string(frame)?;
129    s.parse::<u64>().map_err(|_| {
130        ProtocolError::InvalidCommandFrame(format!("value is not a valid integer for '{cmd}'"))
131    })
132}
133
134fn parse_ping(args: &[Frame]) -> Result<Command, ProtocolError> {
135    match args.len() {
136        0 => Ok(Command::Ping(None)),
137        1 => {
138            let msg = extract_bytes(&args[0])?;
139            Ok(Command::Ping(Some(msg)))
140        }
141        _ => Err(ProtocolError::WrongArity("PING".into())),
142    }
143}
144
145fn parse_echo(args: &[Frame]) -> Result<Command, ProtocolError> {
146    if args.len() != 1 {
147        return Err(ProtocolError::WrongArity("ECHO".into()));
148    }
149    let msg = extract_bytes(&args[0])?;
150    Ok(Command::Echo(msg))
151}
152
153fn parse_get(args: &[Frame]) -> Result<Command, ProtocolError> {
154    if args.len() != 1 {
155        return Err(ProtocolError::WrongArity("GET".into()));
156    }
157    let key = extract_string(&args[0])?;
158    Ok(Command::Get { key })
159}
160
161fn parse_set(args: &[Frame]) -> Result<Command, ProtocolError> {
162    if args.len() < 2 {
163        return Err(ProtocolError::WrongArity("SET".into()));
164    }
165
166    let key = extract_string(&args[0])?;
167    let value = extract_bytes(&args[1])?;
168
169    let expire = if args.len() > 2 {
170        // parse optional EX/PX
171        if args.len() != 4 {
172            return Err(ProtocolError::WrongArity("SET".into()));
173        }
174        let flag = extract_string(&args[2])?.to_ascii_uppercase();
175        let amount = parse_u64(&args[3], "SET")?;
176
177        if amount == 0 {
178            return Err(ProtocolError::InvalidCommandFrame(
179                "invalid expire time in 'SET' command".into(),
180            ));
181        }
182
183        match flag.as_str() {
184            "EX" => Some(SetExpire::Ex(amount)),
185            "PX" => Some(SetExpire::Px(amount)),
186            _ => {
187                return Err(ProtocolError::InvalidCommandFrame(format!(
188                    "unsupported SET option '{flag}'"
189                )));
190            }
191        }
192    } else {
193        None
194    };
195
196    Ok(Command::Set { key, value, expire })
197}
198
199fn parse_del(args: &[Frame]) -> Result<Command, ProtocolError> {
200    if args.is_empty() {
201        return Err(ProtocolError::WrongArity("DEL".into()));
202    }
203    let keys = args
204        .iter()
205        .map(extract_string)
206        .collect::<Result<Vec<_>, _>>()?;
207    Ok(Command::Del { keys })
208}
209
210fn parse_exists(args: &[Frame]) -> Result<Command, ProtocolError> {
211    if args.is_empty() {
212        return Err(ProtocolError::WrongArity("EXISTS".into()));
213    }
214    let keys = args
215        .iter()
216        .map(extract_string)
217        .collect::<Result<Vec<_>, _>>()?;
218    Ok(Command::Exists { keys })
219}
220
221fn parse_expire(args: &[Frame]) -> Result<Command, ProtocolError> {
222    if args.len() != 2 {
223        return Err(ProtocolError::WrongArity("EXPIRE".into()));
224    }
225    let key = extract_string(&args[0])?;
226    let seconds = parse_u64(&args[1], "EXPIRE")?;
227
228    if seconds == 0 {
229        return Err(ProtocolError::InvalidCommandFrame(
230            "invalid expire time in 'EXPIRE' command".into(),
231        ));
232    }
233
234    Ok(Command::Expire { key, seconds })
235}
236
237fn parse_ttl(args: &[Frame]) -> Result<Command, ProtocolError> {
238    if args.len() != 1 {
239        return Err(ProtocolError::WrongArity("TTL".into()));
240    }
241    let key = extract_string(&args[0])?;
242    Ok(Command::Ttl { key })
243}
244
245fn parse_dbsize(args: &[Frame]) -> Result<Command, ProtocolError> {
246    if !args.is_empty() {
247        return Err(ProtocolError::WrongArity("DBSIZE".into()));
248    }
249    Ok(Command::DbSize)
250}
251
252fn parse_info(args: &[Frame]) -> Result<Command, ProtocolError> {
253    match args.len() {
254        0 => Ok(Command::Info { section: None }),
255        1 => {
256            let section = extract_string(&args[0])?;
257            Ok(Command::Info {
258                section: Some(section),
259            })
260        }
261        _ => Err(ProtocolError::WrongArity("INFO".into())),
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    /// Helper: build an array frame from bulk strings.
270    fn cmd(parts: &[&str]) -> Frame {
271        Frame::Array(
272            parts
273                .iter()
274                .map(|s| Frame::Bulk(Bytes::from(s.to_string())))
275                .collect(),
276        )
277    }
278
279    // --- ping ---
280
281    #[test]
282    fn ping_no_args() {
283        assert_eq!(
284            Command::from_frame(cmd(&["PING"])).unwrap(),
285            Command::Ping(None),
286        );
287    }
288
289    #[test]
290    fn ping_with_message() {
291        assert_eq!(
292            Command::from_frame(cmd(&["PING", "hello"])).unwrap(),
293            Command::Ping(Some(Bytes::from("hello"))),
294        );
295    }
296
297    #[test]
298    fn ping_case_insensitive() {
299        assert_eq!(
300            Command::from_frame(cmd(&["ping"])).unwrap(),
301            Command::Ping(None),
302        );
303        assert_eq!(
304            Command::from_frame(cmd(&["Ping"])).unwrap(),
305            Command::Ping(None),
306        );
307    }
308
309    #[test]
310    fn ping_too_many_args() {
311        let err = Command::from_frame(cmd(&["PING", "a", "b"])).unwrap_err();
312        assert!(matches!(err, ProtocolError::WrongArity(_)));
313    }
314
315    // --- echo ---
316
317    #[test]
318    fn echo() {
319        assert_eq!(
320            Command::from_frame(cmd(&["ECHO", "test"])).unwrap(),
321            Command::Echo(Bytes::from("test")),
322        );
323    }
324
325    #[test]
326    fn echo_missing_arg() {
327        let err = Command::from_frame(cmd(&["ECHO"])).unwrap_err();
328        assert!(matches!(err, ProtocolError::WrongArity(_)));
329    }
330
331    // --- get ---
332
333    #[test]
334    fn get_basic() {
335        assert_eq!(
336            Command::from_frame(cmd(&["GET", "mykey"])).unwrap(),
337            Command::Get {
338                key: "mykey".into()
339            },
340        );
341    }
342
343    #[test]
344    fn get_no_args() {
345        let err = Command::from_frame(cmd(&["GET"])).unwrap_err();
346        assert!(matches!(err, ProtocolError::WrongArity(_)));
347    }
348
349    #[test]
350    fn get_too_many_args() {
351        let err = Command::from_frame(cmd(&["GET", "a", "b"])).unwrap_err();
352        assert!(matches!(err, ProtocolError::WrongArity(_)));
353    }
354
355    #[test]
356    fn get_case_insensitive() {
357        assert_eq!(
358            Command::from_frame(cmd(&["get", "k"])).unwrap(),
359            Command::Get { key: "k".into() },
360        );
361    }
362
363    // --- set ---
364
365    #[test]
366    fn set_basic() {
367        assert_eq!(
368            Command::from_frame(cmd(&["SET", "key", "value"])).unwrap(),
369            Command::Set {
370                key: "key".into(),
371                value: Bytes::from("value"),
372                expire: None,
373            },
374        );
375    }
376
377    #[test]
378    fn set_with_ex() {
379        assert_eq!(
380            Command::from_frame(cmd(&["SET", "key", "val", "EX", "10"])).unwrap(),
381            Command::Set {
382                key: "key".into(),
383                value: Bytes::from("val"),
384                expire: Some(SetExpire::Ex(10)),
385            },
386        );
387    }
388
389    #[test]
390    fn set_with_px() {
391        assert_eq!(
392            Command::from_frame(cmd(&["SET", "key", "val", "PX", "5000"])).unwrap(),
393            Command::Set {
394                key: "key".into(),
395                value: Bytes::from("val"),
396                expire: Some(SetExpire::Px(5000)),
397            },
398        );
399    }
400
401    #[test]
402    fn set_ex_case_insensitive() {
403        assert_eq!(
404            Command::from_frame(cmd(&["set", "k", "v", "ex", "5"])).unwrap(),
405            Command::Set {
406                key: "k".into(),
407                value: Bytes::from("v"),
408                expire: Some(SetExpire::Ex(5)),
409            },
410        );
411    }
412
413    #[test]
414    fn set_missing_value() {
415        let err = Command::from_frame(cmd(&["SET", "key"])).unwrap_err();
416        assert!(matches!(err, ProtocolError::WrongArity(_)));
417    }
418
419    #[test]
420    fn set_invalid_expire_value() {
421        let err = Command::from_frame(cmd(&["SET", "k", "v", "EX", "notanum"])).unwrap_err();
422        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
423    }
424
425    #[test]
426    fn set_zero_expire() {
427        let err = Command::from_frame(cmd(&["SET", "k", "v", "EX", "0"])).unwrap_err();
428        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
429    }
430
431    #[test]
432    fn set_unknown_flag() {
433        let err = Command::from_frame(cmd(&["SET", "k", "v", "ZZ", "10"])).unwrap_err();
434        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
435    }
436
437    #[test]
438    fn set_incomplete_expire() {
439        // EX without a value
440        let err = Command::from_frame(cmd(&["SET", "k", "v", "EX"])).unwrap_err();
441        assert!(matches!(err, ProtocolError::WrongArity(_)));
442    }
443
444    // --- del ---
445
446    #[test]
447    fn del_single() {
448        assert_eq!(
449            Command::from_frame(cmd(&["DEL", "key"])).unwrap(),
450            Command::Del {
451                keys: vec!["key".into()]
452            },
453        );
454    }
455
456    #[test]
457    fn del_multiple() {
458        assert_eq!(
459            Command::from_frame(cmd(&["DEL", "a", "b", "c"])).unwrap(),
460            Command::Del {
461                keys: vec!["a".into(), "b".into(), "c".into()]
462            },
463        );
464    }
465
466    #[test]
467    fn del_no_args() {
468        let err = Command::from_frame(cmd(&["DEL"])).unwrap_err();
469        assert!(matches!(err, ProtocolError::WrongArity(_)));
470    }
471
472    // --- exists ---
473
474    #[test]
475    fn exists_single() {
476        assert_eq!(
477            Command::from_frame(cmd(&["EXISTS", "key"])).unwrap(),
478            Command::Exists {
479                keys: vec!["key".into()]
480            },
481        );
482    }
483
484    #[test]
485    fn exists_multiple() {
486        assert_eq!(
487            Command::from_frame(cmd(&["EXISTS", "a", "b"])).unwrap(),
488            Command::Exists {
489                keys: vec!["a".into(), "b".into()]
490            },
491        );
492    }
493
494    #[test]
495    fn exists_no_args() {
496        let err = Command::from_frame(cmd(&["EXISTS"])).unwrap_err();
497        assert!(matches!(err, ProtocolError::WrongArity(_)));
498    }
499
500    // --- expire ---
501
502    #[test]
503    fn expire_basic() {
504        assert_eq!(
505            Command::from_frame(cmd(&["EXPIRE", "key", "60"])).unwrap(),
506            Command::Expire {
507                key: "key".into(),
508                seconds: 60,
509            },
510        );
511    }
512
513    #[test]
514    fn expire_wrong_arity() {
515        let err = Command::from_frame(cmd(&["EXPIRE", "key"])).unwrap_err();
516        assert!(matches!(err, ProtocolError::WrongArity(_)));
517    }
518
519    #[test]
520    fn expire_invalid_seconds() {
521        let err = Command::from_frame(cmd(&["EXPIRE", "key", "abc"])).unwrap_err();
522        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
523    }
524
525    #[test]
526    fn expire_zero_seconds() {
527        let err = Command::from_frame(cmd(&["EXPIRE", "key", "0"])).unwrap_err();
528        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
529    }
530
531    // --- ttl ---
532
533    #[test]
534    fn ttl_basic() {
535        assert_eq!(
536            Command::from_frame(cmd(&["TTL", "key"])).unwrap(),
537            Command::Ttl { key: "key".into() },
538        );
539    }
540
541    #[test]
542    fn ttl_wrong_arity() {
543        let err = Command::from_frame(cmd(&["TTL"])).unwrap_err();
544        assert!(matches!(err, ProtocolError::WrongArity(_)));
545    }
546
547    // --- dbsize ---
548
549    #[test]
550    fn dbsize_basic() {
551        assert_eq!(
552            Command::from_frame(cmd(&["DBSIZE"])).unwrap(),
553            Command::DbSize,
554        );
555    }
556
557    #[test]
558    fn dbsize_case_insensitive() {
559        assert_eq!(
560            Command::from_frame(cmd(&["dbsize"])).unwrap(),
561            Command::DbSize,
562        );
563    }
564
565    #[test]
566    fn dbsize_extra_args() {
567        let err = Command::from_frame(cmd(&["DBSIZE", "extra"])).unwrap_err();
568        assert!(matches!(err, ProtocolError::WrongArity(_)));
569    }
570
571    // --- info ---
572
573    #[test]
574    fn info_no_section() {
575        assert_eq!(
576            Command::from_frame(cmd(&["INFO"])).unwrap(),
577            Command::Info { section: None },
578        );
579    }
580
581    #[test]
582    fn info_with_section() {
583        assert_eq!(
584            Command::from_frame(cmd(&["INFO", "keyspace"])).unwrap(),
585            Command::Info {
586                section: Some("keyspace".into())
587            },
588        );
589    }
590
591    #[test]
592    fn info_too_many_args() {
593        let err = Command::from_frame(cmd(&["INFO", "a", "b"])).unwrap_err();
594        assert!(matches!(err, ProtocolError::WrongArity(_)));
595    }
596
597    // --- general ---
598
599    #[test]
600    fn unknown_command() {
601        assert_eq!(
602            Command::from_frame(cmd(&["FOOBAR", "arg"])).unwrap(),
603            Command::Unknown("FOOBAR".into()),
604        );
605    }
606
607    #[test]
608    fn non_array_frame() {
609        let err = Command::from_frame(Frame::Simple("PING".into())).unwrap_err();
610        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
611    }
612
613    #[test]
614    fn empty_array() {
615        let err = Command::from_frame(Frame::Array(vec![])).unwrap_err();
616        assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
617    }
618}