1use bytes::Bytes;
8
9use crate::error::ProtocolError;
10use crate::types::Frame;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum SetExpire {
15 Ex(u64),
17 Px(u64),
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum Command {
24 Ping(Option<Bytes>),
26
27 Echo(Bytes),
29
30 Get { key: String },
32
33 Set {
35 key: String,
36 value: Bytes,
37 expire: Option<SetExpire>,
38 },
39
40 Del { keys: Vec<String> },
42
43 Exists { keys: Vec<String> },
45
46 Expire { key: String, seconds: u64 },
48
49 Ttl { key: String },
51
52 DbSize,
54
55 Info { section: Option<String> },
57
58 Unknown(String),
60}
61
62impl Command {
63 pub fn from_frame(frame: Frame) -> Result<Command, ProtocolError> {
68 let frames = match frame {
69 Frame::Array(frames) => frames,
70 _ => {
71 return Err(ProtocolError::InvalidCommandFrame(
72 "expected array frame".into(),
73 ));
74 }
75 };
76
77 if frames.is_empty() {
78 return Err(ProtocolError::InvalidCommandFrame(
79 "empty command array".into(),
80 ));
81 }
82
83 let name = extract_string(&frames[0])?;
84 let name_upper = name.to_ascii_uppercase();
85
86 match name_upper.as_str() {
87 "PING" => parse_ping(&frames[1..]),
88 "ECHO" => parse_echo(&frames[1..]),
89 "GET" => parse_get(&frames[1..]),
90 "SET" => parse_set(&frames[1..]),
91 "DEL" => parse_del(&frames[1..]),
92 "EXISTS" => parse_exists(&frames[1..]),
93 "EXPIRE" => parse_expire(&frames[1..]),
94 "TTL" => parse_ttl(&frames[1..]),
95 "DBSIZE" => parse_dbsize(&frames[1..]),
96 "INFO" => parse_info(&frames[1..]),
97 _ => Ok(Command::Unknown(name)),
98 }
99 }
100}
101
102fn extract_string(frame: &Frame) -> Result<String, ProtocolError> {
104 match frame {
105 Frame::Bulk(data) => String::from_utf8(data.to_vec()).map_err(|_| {
106 ProtocolError::InvalidCommandFrame("command name is not valid utf-8".into())
107 }),
108 Frame::Simple(s) => Ok(s.clone()),
109 _ => Err(ProtocolError::InvalidCommandFrame(
110 "expected bulk or simple string for command name".into(),
111 )),
112 }
113}
114
115fn extract_bytes(frame: &Frame) -> Result<Bytes, ProtocolError> {
117 match frame {
118 Frame::Bulk(data) => Ok(data.clone()),
119 Frame::Simple(s) => Ok(Bytes::from(s.clone().into_bytes())),
120 _ => Err(ProtocolError::InvalidCommandFrame(
121 "expected bulk or simple string argument".into(),
122 )),
123 }
124}
125
126fn parse_u64(frame: &Frame, cmd: &str) -> Result<u64, ProtocolError> {
128 let s = extract_string(frame)?;
129 s.parse::<u64>().map_err(|_| {
130 ProtocolError::InvalidCommandFrame(format!("value is not a valid integer for '{cmd}'"))
131 })
132}
133
134fn parse_ping(args: &[Frame]) -> Result<Command, ProtocolError> {
135 match args.len() {
136 0 => Ok(Command::Ping(None)),
137 1 => {
138 let msg = extract_bytes(&args[0])?;
139 Ok(Command::Ping(Some(msg)))
140 }
141 _ => Err(ProtocolError::WrongArity("PING".into())),
142 }
143}
144
145fn parse_echo(args: &[Frame]) -> Result<Command, ProtocolError> {
146 if args.len() != 1 {
147 return Err(ProtocolError::WrongArity("ECHO".into()));
148 }
149 let msg = extract_bytes(&args[0])?;
150 Ok(Command::Echo(msg))
151}
152
153fn parse_get(args: &[Frame]) -> Result<Command, ProtocolError> {
154 if args.len() != 1 {
155 return Err(ProtocolError::WrongArity("GET".into()));
156 }
157 let key = extract_string(&args[0])?;
158 Ok(Command::Get { key })
159}
160
161fn parse_set(args: &[Frame]) -> Result<Command, ProtocolError> {
162 if args.len() < 2 {
163 return Err(ProtocolError::WrongArity("SET".into()));
164 }
165
166 let key = extract_string(&args[0])?;
167 let value = extract_bytes(&args[1])?;
168
169 let expire = if args.len() > 2 {
170 if args.len() != 4 {
172 return Err(ProtocolError::WrongArity("SET".into()));
173 }
174 let flag = extract_string(&args[2])?.to_ascii_uppercase();
175 let amount = parse_u64(&args[3], "SET")?;
176
177 if amount == 0 {
178 return Err(ProtocolError::InvalidCommandFrame(
179 "invalid expire time in 'SET' command".into(),
180 ));
181 }
182
183 match flag.as_str() {
184 "EX" => Some(SetExpire::Ex(amount)),
185 "PX" => Some(SetExpire::Px(amount)),
186 _ => {
187 return Err(ProtocolError::InvalidCommandFrame(format!(
188 "unsupported SET option '{flag}'"
189 )));
190 }
191 }
192 } else {
193 None
194 };
195
196 Ok(Command::Set { key, value, expire })
197}
198
199fn parse_del(args: &[Frame]) -> Result<Command, ProtocolError> {
200 if args.is_empty() {
201 return Err(ProtocolError::WrongArity("DEL".into()));
202 }
203 let keys = args
204 .iter()
205 .map(extract_string)
206 .collect::<Result<Vec<_>, _>>()?;
207 Ok(Command::Del { keys })
208}
209
210fn parse_exists(args: &[Frame]) -> Result<Command, ProtocolError> {
211 if args.is_empty() {
212 return Err(ProtocolError::WrongArity("EXISTS".into()));
213 }
214 let keys = args
215 .iter()
216 .map(extract_string)
217 .collect::<Result<Vec<_>, _>>()?;
218 Ok(Command::Exists { keys })
219}
220
221fn parse_expire(args: &[Frame]) -> Result<Command, ProtocolError> {
222 if args.len() != 2 {
223 return Err(ProtocolError::WrongArity("EXPIRE".into()));
224 }
225 let key = extract_string(&args[0])?;
226 let seconds = parse_u64(&args[1], "EXPIRE")?;
227
228 if seconds == 0 {
229 return Err(ProtocolError::InvalidCommandFrame(
230 "invalid expire time in 'EXPIRE' command".into(),
231 ));
232 }
233
234 Ok(Command::Expire { key, seconds })
235}
236
237fn parse_ttl(args: &[Frame]) -> Result<Command, ProtocolError> {
238 if args.len() != 1 {
239 return Err(ProtocolError::WrongArity("TTL".into()));
240 }
241 let key = extract_string(&args[0])?;
242 Ok(Command::Ttl { key })
243}
244
245fn parse_dbsize(args: &[Frame]) -> Result<Command, ProtocolError> {
246 if !args.is_empty() {
247 return Err(ProtocolError::WrongArity("DBSIZE".into()));
248 }
249 Ok(Command::DbSize)
250}
251
252fn parse_info(args: &[Frame]) -> Result<Command, ProtocolError> {
253 match args.len() {
254 0 => Ok(Command::Info { section: None }),
255 1 => {
256 let section = extract_string(&args[0])?;
257 Ok(Command::Info {
258 section: Some(section),
259 })
260 }
261 _ => Err(ProtocolError::WrongArity("INFO".into())),
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 fn cmd(parts: &[&str]) -> Frame {
271 Frame::Array(
272 parts
273 .iter()
274 .map(|s| Frame::Bulk(Bytes::from(s.to_string())))
275 .collect(),
276 )
277 }
278
279 #[test]
282 fn ping_no_args() {
283 assert_eq!(
284 Command::from_frame(cmd(&["PING"])).unwrap(),
285 Command::Ping(None),
286 );
287 }
288
289 #[test]
290 fn ping_with_message() {
291 assert_eq!(
292 Command::from_frame(cmd(&["PING", "hello"])).unwrap(),
293 Command::Ping(Some(Bytes::from("hello"))),
294 );
295 }
296
297 #[test]
298 fn ping_case_insensitive() {
299 assert_eq!(
300 Command::from_frame(cmd(&["ping"])).unwrap(),
301 Command::Ping(None),
302 );
303 assert_eq!(
304 Command::from_frame(cmd(&["Ping"])).unwrap(),
305 Command::Ping(None),
306 );
307 }
308
309 #[test]
310 fn ping_too_many_args() {
311 let err = Command::from_frame(cmd(&["PING", "a", "b"])).unwrap_err();
312 assert!(matches!(err, ProtocolError::WrongArity(_)));
313 }
314
315 #[test]
318 fn echo() {
319 assert_eq!(
320 Command::from_frame(cmd(&["ECHO", "test"])).unwrap(),
321 Command::Echo(Bytes::from("test")),
322 );
323 }
324
325 #[test]
326 fn echo_missing_arg() {
327 let err = Command::from_frame(cmd(&["ECHO"])).unwrap_err();
328 assert!(matches!(err, ProtocolError::WrongArity(_)));
329 }
330
331 #[test]
334 fn get_basic() {
335 assert_eq!(
336 Command::from_frame(cmd(&["GET", "mykey"])).unwrap(),
337 Command::Get {
338 key: "mykey".into()
339 },
340 );
341 }
342
343 #[test]
344 fn get_no_args() {
345 let err = Command::from_frame(cmd(&["GET"])).unwrap_err();
346 assert!(matches!(err, ProtocolError::WrongArity(_)));
347 }
348
349 #[test]
350 fn get_too_many_args() {
351 let err = Command::from_frame(cmd(&["GET", "a", "b"])).unwrap_err();
352 assert!(matches!(err, ProtocolError::WrongArity(_)));
353 }
354
355 #[test]
356 fn get_case_insensitive() {
357 assert_eq!(
358 Command::from_frame(cmd(&["get", "k"])).unwrap(),
359 Command::Get { key: "k".into() },
360 );
361 }
362
363 #[test]
366 fn set_basic() {
367 assert_eq!(
368 Command::from_frame(cmd(&["SET", "key", "value"])).unwrap(),
369 Command::Set {
370 key: "key".into(),
371 value: Bytes::from("value"),
372 expire: None,
373 },
374 );
375 }
376
377 #[test]
378 fn set_with_ex() {
379 assert_eq!(
380 Command::from_frame(cmd(&["SET", "key", "val", "EX", "10"])).unwrap(),
381 Command::Set {
382 key: "key".into(),
383 value: Bytes::from("val"),
384 expire: Some(SetExpire::Ex(10)),
385 },
386 );
387 }
388
389 #[test]
390 fn set_with_px() {
391 assert_eq!(
392 Command::from_frame(cmd(&["SET", "key", "val", "PX", "5000"])).unwrap(),
393 Command::Set {
394 key: "key".into(),
395 value: Bytes::from("val"),
396 expire: Some(SetExpire::Px(5000)),
397 },
398 );
399 }
400
401 #[test]
402 fn set_ex_case_insensitive() {
403 assert_eq!(
404 Command::from_frame(cmd(&["set", "k", "v", "ex", "5"])).unwrap(),
405 Command::Set {
406 key: "k".into(),
407 value: Bytes::from("v"),
408 expire: Some(SetExpire::Ex(5)),
409 },
410 );
411 }
412
413 #[test]
414 fn set_missing_value() {
415 let err = Command::from_frame(cmd(&["SET", "key"])).unwrap_err();
416 assert!(matches!(err, ProtocolError::WrongArity(_)));
417 }
418
419 #[test]
420 fn set_invalid_expire_value() {
421 let err = Command::from_frame(cmd(&["SET", "k", "v", "EX", "notanum"])).unwrap_err();
422 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
423 }
424
425 #[test]
426 fn set_zero_expire() {
427 let err = Command::from_frame(cmd(&["SET", "k", "v", "EX", "0"])).unwrap_err();
428 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
429 }
430
431 #[test]
432 fn set_unknown_flag() {
433 let err = Command::from_frame(cmd(&["SET", "k", "v", "ZZ", "10"])).unwrap_err();
434 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
435 }
436
437 #[test]
438 fn set_incomplete_expire() {
439 let err = Command::from_frame(cmd(&["SET", "k", "v", "EX"])).unwrap_err();
441 assert!(matches!(err, ProtocolError::WrongArity(_)));
442 }
443
444 #[test]
447 fn del_single() {
448 assert_eq!(
449 Command::from_frame(cmd(&["DEL", "key"])).unwrap(),
450 Command::Del {
451 keys: vec!["key".into()]
452 },
453 );
454 }
455
456 #[test]
457 fn del_multiple() {
458 assert_eq!(
459 Command::from_frame(cmd(&["DEL", "a", "b", "c"])).unwrap(),
460 Command::Del {
461 keys: vec!["a".into(), "b".into(), "c".into()]
462 },
463 );
464 }
465
466 #[test]
467 fn del_no_args() {
468 let err = Command::from_frame(cmd(&["DEL"])).unwrap_err();
469 assert!(matches!(err, ProtocolError::WrongArity(_)));
470 }
471
472 #[test]
475 fn exists_single() {
476 assert_eq!(
477 Command::from_frame(cmd(&["EXISTS", "key"])).unwrap(),
478 Command::Exists {
479 keys: vec!["key".into()]
480 },
481 );
482 }
483
484 #[test]
485 fn exists_multiple() {
486 assert_eq!(
487 Command::from_frame(cmd(&["EXISTS", "a", "b"])).unwrap(),
488 Command::Exists {
489 keys: vec!["a".into(), "b".into()]
490 },
491 );
492 }
493
494 #[test]
495 fn exists_no_args() {
496 let err = Command::from_frame(cmd(&["EXISTS"])).unwrap_err();
497 assert!(matches!(err, ProtocolError::WrongArity(_)));
498 }
499
500 #[test]
503 fn expire_basic() {
504 assert_eq!(
505 Command::from_frame(cmd(&["EXPIRE", "key", "60"])).unwrap(),
506 Command::Expire {
507 key: "key".into(),
508 seconds: 60,
509 },
510 );
511 }
512
513 #[test]
514 fn expire_wrong_arity() {
515 let err = Command::from_frame(cmd(&["EXPIRE", "key"])).unwrap_err();
516 assert!(matches!(err, ProtocolError::WrongArity(_)));
517 }
518
519 #[test]
520 fn expire_invalid_seconds() {
521 let err = Command::from_frame(cmd(&["EXPIRE", "key", "abc"])).unwrap_err();
522 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
523 }
524
525 #[test]
526 fn expire_zero_seconds() {
527 let err = Command::from_frame(cmd(&["EXPIRE", "key", "0"])).unwrap_err();
528 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
529 }
530
531 #[test]
534 fn ttl_basic() {
535 assert_eq!(
536 Command::from_frame(cmd(&["TTL", "key"])).unwrap(),
537 Command::Ttl { key: "key".into() },
538 );
539 }
540
541 #[test]
542 fn ttl_wrong_arity() {
543 let err = Command::from_frame(cmd(&["TTL"])).unwrap_err();
544 assert!(matches!(err, ProtocolError::WrongArity(_)));
545 }
546
547 #[test]
550 fn dbsize_basic() {
551 assert_eq!(
552 Command::from_frame(cmd(&["DBSIZE"])).unwrap(),
553 Command::DbSize,
554 );
555 }
556
557 #[test]
558 fn dbsize_case_insensitive() {
559 assert_eq!(
560 Command::from_frame(cmd(&["dbsize"])).unwrap(),
561 Command::DbSize,
562 );
563 }
564
565 #[test]
566 fn dbsize_extra_args() {
567 let err = Command::from_frame(cmd(&["DBSIZE", "extra"])).unwrap_err();
568 assert!(matches!(err, ProtocolError::WrongArity(_)));
569 }
570
571 #[test]
574 fn info_no_section() {
575 assert_eq!(
576 Command::from_frame(cmd(&["INFO"])).unwrap(),
577 Command::Info { section: None },
578 );
579 }
580
581 #[test]
582 fn info_with_section() {
583 assert_eq!(
584 Command::from_frame(cmd(&["INFO", "keyspace"])).unwrap(),
585 Command::Info {
586 section: Some("keyspace".into())
587 },
588 );
589 }
590
591 #[test]
592 fn info_too_many_args() {
593 let err = Command::from_frame(cmd(&["INFO", "a", "b"])).unwrap_err();
594 assert!(matches!(err, ProtocolError::WrongArity(_)));
595 }
596
597 #[test]
600 fn unknown_command() {
601 assert_eq!(
602 Command::from_frame(cmd(&["FOOBAR", "arg"])).unwrap(),
603 Command::Unknown("FOOBAR".into()),
604 );
605 }
606
607 #[test]
608 fn non_array_frame() {
609 let err = Command::from_frame(Frame::Simple("PING".into())).unwrap_err();
610 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
611 }
612
613 #[test]
614 fn empty_array() {
615 let err = Command::from_frame(Frame::Array(vec![])).unwrap_err();
616 assert!(matches!(err, ProtocolError::InvalidCommandFrame(_)));
617 }
618}