1use prost::Message;
6use std::{collections::HashMap, fmt::Display};
7
8use crate::errors::{CommandError, StreamError};
9
10mod wire {
11 tonic::include_proto!("wire");
12}
13
14#[derive(Debug, Clone, PartialEq)]
16pub enum DelInput<'a> {
17 Single(&'a str),
19 Multiple(Vec<&'a str>),
21}
22
23#[derive(Debug, Clone, PartialEq)]
27pub enum HSetInput<'a> {
28 Single(&'a str, &'a str),
30 Multiple(Vec<(&'a str, &'a str)>),
32}
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum SetInput {
37 Str(String),
39 Int(i64),
41 Float(f64),
43}
44
45impl Into<ScalarValue> for SetInput {
46 fn into(self) -> ScalarValue {
47 match self {
48 SetInput::Str(s) => ScalarValue::VStr(s),
49 SetInput::Int(i) => ScalarValue::VInt(i),
50 SetInput::Float(f) => ScalarValue::VFloat(f),
51 }
52 }
53}
54
55impl TryInto<SetInput> for ScalarValue {
56 type Error = String;
57
58 fn try_into(self) -> Result<SetInput, Self::Error> {
59 match self {
60 ScalarValue::VStr(s) => Ok(SetInput::Str(s)),
61 ScalarValue::VInt(i) => Ok(SetInput::Int(i)),
62 ScalarValue::VFloat(f) => Ok(SetInput::Float(f)),
63 ScalarValue::VBool(_) => Err("Cannot convert Value::VBool to SetValue".to_string()),
64 ScalarValue::VNull => Err("Cannot convert Value::VNull to SetValue".to_string()),
65 }
66 }
67}
68
69macro_rules! impl_vint_setvalue_for_int {
70 ($($t:ty),*) => {
71 $(
72 impl From<$t> for SetInput {
73 fn from(value: $t) -> Self {
74 SetInput::Int(value as i64)
75 }
76 }
77 )*
78 };
79}
80
81macro_rules! impl_vint_value_for_int {
82 ($($t:ty),*) => {
83 $(
84 impl From<$t> for ScalarValue {
85 fn from(value: $t) -> Self {
86 ScalarValue::VInt(value as i64)
87 }
88 }
89 )*
90 };
91}
92
93macro_rules! impl_vint_setvalue_for_float {
94 ($($t:ty),*) => {
95 $(
96 impl From<$t> for SetInput {
97 fn from(value: $t) -> Self {
98 SetInput::Float(value as f64)
99 }
100 }
101 )*
102 };
103}
104
105macro_rules! impl_vint_value_for_float {
106 ($($t:ty),*) => {
107 $(
108 impl From<$t> for ScalarValue {
109 fn from(value: $t) -> Self {
110 ScalarValue::VFloat(value as f64)
111 }
112 }
113 )*
114 };
115}
116
117impl_vint_setvalue_for_int!(i64, i32, i16, i8, u64, u32, u16, u8);
118impl_vint_value_for_int!(i64, i32, i16, i8, u64, u32, u16, u8);
119impl_vint_setvalue_for_float!(f64, f32);
120impl_vint_value_for_float!(f64, f32);
121
122impl Into<ScalarValue> for &str {
123 fn into(self) -> ScalarValue {
124 ScalarValue::VStr(self.to_string())
125 }
126}
127
128impl Into<SetInput> for &str {
129 fn into(self) -> SetInput {
130 SetInput::Str(self.to_string())
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, PartialOrd)]
136pub enum ScalarValue {
137 VStr(String),
139 VInt(i64),
141 VFloat(f64),
143 VBool(bool),
145 VNull,
147}
148
149impl Display for ScalarValue {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 ScalarValue::VStr(s) => write!(f, "{}", s),
153 ScalarValue::VInt(i) => write!(f, "{}", i),
154 ScalarValue::VFloat(fl) => write!(f, "{}", fl),
155 ScalarValue::VBool(b) => write!(f, "{}", b),
156 ScalarValue::VNull => write!(f, "null"),
157 }
158 }
159}
160
161impl AsArg for ScalarValue {
162 fn as_arg(&self) -> String {
163 match self {
164 ScalarValue::VStr(s) => s.clone(),
165 ScalarValue::VInt(i) => i.to_string(),
166 ScalarValue::VFloat(f) => f.to_string(),
167 ScalarValue::VBool(b) => b.to_string(),
168 ScalarValue::VNull => "".to_string(),
169 }
170 }
171}
172
173impl Into<ScalarValue> for wire::response::Value {
174 fn into(self) -> ScalarValue {
175 match self {
176 wire::response::Value::VNil(_) => ScalarValue::VNull,
177 wire::response::Value::VInt(i) => ScalarValue::VInt(i),
178 wire::response::Value::VStr(s) => ScalarValue::VStr(s),
179 wire::response::Value::VFloat(f) => ScalarValue::VFloat(f),
180 wire::response::Value::VBytes(b) => {
181 ScalarValue::VStr(String::from_utf8_lossy(&b).to_string())
182 }
183 }
184 }
185}
186
187#[derive(Debug)]
189pub struct WatchValue {
190 pub value: ScalarValue,
192 pub fingerprint: String,
194}
195
196impl Into<ScalarValue> for WatchValue {
197 fn into(self) -> ScalarValue {
198 self.value
199 }
200}
201
202impl WatchValue {
203 pub(crate) fn decode_watchvalue(bytes: &[u8]) -> Result<Self, CommandError> {
204 match wire::Response::decode(bytes) {
205 Ok(v) => {
206 if v.err == "" {
207 let fingerprint = match v
208 .attrs
209 .ok_or(CommandError::WatchValueExpectationError(
210 "Missing attributes from response".to_string(),
211 ))?
212 .fields
213 .get("fingerprint")
214 .ok_or(CommandError::WatchValueExpectationError(
215 "Missing fingerprint from attributes".to_string(),
216 ))?
217 .kind
218 .clone()
219 .ok_or(CommandError::WatchValueExpectationError(
220 "Missing kind from fingerprint attribute".to_string(),
221 ))? {
222 prost_types::value::Kind::StringValue(s) => s,
223 _ => {
224 return Err(CommandError::WatchValueExpectationError(
225 "Fingerprint is not a string".to_string(),
226 ))
227 }
228 };
229 let value = v
230 .value
231 .ok_or(CommandError::WatchValueExpectationError(
232 "Missing value from response".to_string(),
233 ))?
234 .into();
235
236 Ok(WatchValue { value, fingerprint })
237 } else {
238 Err(CommandError::ServerError(v.err))
239 }
240 }
241 Err(e) => Err(CommandError::DecodeError(e)),
242 }
243 }
244}
245
246#[derive(Debug, Clone, PartialEq)]
248pub struct HSetValue {
249 pub fields: HashMap<String, String>,
251}
252
253impl Into<HashMap<String, String>> for HSetValue {
254 fn into(self) -> HashMap<String, String> {
255 self.fields
256 }
257}
258
259impl HSetValue {
260 pub(crate) fn decode(bytes: &[u8]) -> Result<Self, CommandError> {
261 match wire::Response::decode(bytes) {
262 Ok(v) => {
263 if v.err == "" {
264 let fields = v.v_ss_map;
265 Ok(HSetValue { fields })
266 } else {
267 Err(CommandError::ServerError(v.err))
268 }
269 }
270 Err(e) => Err(CommandError::DecodeError(e)),
271 }
272 }
273}
274
275impl ScalarValue {
276 pub(crate) fn decode(bytes: &[u8]) -> Result<Self, CommandError> {
277 let decoded = match wire::Response::decode(bytes) {
278 Ok(v) => {
279 if v.err == "" {
280 match v.value {
281 Some(value) => Ok(value.into()),
282 None => Ok(ScalarValue::VNull),
283 }
284 } else {
285 Err(CommandError::ServerError(v.err))
286 }
287 }
288 Err(e) => Err(CommandError::DecodeError(e)),
289 };
290 eprintln!("Decoded value: {:?}", decoded);
291
292 decoded
293 }
294}
295
296trait AsArg {
297 fn as_arg(&self) -> String;
298}
299
300trait AsArgs {
301 fn as_args(&self) -> Vec<String>;
302}
303
304pub(crate) trait CommandExecutor {
305 fn execute_scalar_command(&mut self, command: Command) -> Result<ScalarValue, StreamError>;
306 fn execute_hset_command(&mut self, command: Command) -> Result<HSetValue, StreamError>;
307}
308
309#[derive(Debug, Clone, Copy)]
311pub enum ExpireOption {
312 NX,
314 XX,
316 None,
318}
319
320impl AsArg for ExpireOption {
321 fn as_arg(&self) -> String {
322 match self {
323 ExpireOption::NX => "NX".to_string(),
324 ExpireOption::XX => "XX".to_string(),
325 ExpireOption::None => "".to_string(),
326 }
327 }
328}
329
330#[derive(Debug, Clone, Copy)]
332pub enum ExpireAtOption {
333 NX,
335 XX,
337 GT,
339 LT,
341 None,
343}
344
345impl AsArg for ExpireAtOption {
346 fn as_arg(&self) -> String {
347 match self {
348 ExpireAtOption::NX => "NX".to_string(),
349 ExpireAtOption::XX => "XX".to_string(),
350 ExpireAtOption::GT => "GT".to_string(),
351 ExpireAtOption::LT => "LT".to_string(),
352 ExpireAtOption::None => "".to_string(),
353 }
354 }
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
359pub enum GetexOption {
360 EX(u64),
362 PX(u64),
364 EXAT(u64),
366 PXAT(u64),
368 PERSIST,
370}
371
372impl AsArgs for GetexOption {
373 fn as_args(&self) -> Vec<String> {
374 match self {
375 GetexOption::EX(seconds) => vec!["EX".to_string(), seconds.to_string()],
376 GetexOption::PX(milliseconds) => vec!["PX".to_string(), milliseconds.to_string()],
377 GetexOption::EXAT(timestamp) => vec!["EXAT".to_string(), timestamp.to_string()],
378 GetexOption::PXAT(timestamp) => vec!["PXAT".to_string(), timestamp.to_string()],
379 GetexOption::PERSIST => vec!["PERSIST".to_string()],
380 }
381 }
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
386pub enum SetOption {
387 EX(u64),
389 PX(u64),
391 EXAT(u64),
393 PXAT(u64),
395 XX,
397 NX,
399 KEEPTTL,
401 None,
403}
404
405impl AsArgs for SetOption {
406 fn as_args(&self) -> Vec<String> {
407 match self {
408 SetOption::EX(seconds) => vec!["EX".to_string(), seconds.to_string()],
409 SetOption::PX(milliseconds) => vec!["PX".to_string(), milliseconds.to_string()],
410 SetOption::EXAT(timestamp) => vec!["EXAT".to_string(), timestamp.to_string()],
411 SetOption::PXAT(timestamp) => vec!["PXAT".to_string(), timestamp.to_string()],
412 SetOption::XX => vec!["XX".to_string()],
413 SetOption::NX => vec!["NX".to_string()],
414 SetOption::KEEPTTL => vec!["KEEPTTL".to_string()],
415 SetOption::None => vec![],
416 }
417 }
418}
419
420impl AsArg for SetInput {
421 fn as_arg(&self) -> String {
422 match self {
423 SetInput::Str(s) => s.clone(),
424 SetInput::Int(i) => i.to_string(),
425 SetInput::Float(f) => f.to_string(),
426 }
427 }
428}
429
430impl AsArg for String {
431 fn as_arg(&self) -> String {
432 self.clone()
433 }
434}
435
436impl AsArgs for Vec<(String, SetInput)> {
437 fn as_args(&self) -> Vec<String> {
438 let mut args = vec![];
439 for (field, value) in self {
440 args.push(field.clone());
441 args.push(value.as_arg());
442 }
443 args
444 }
445}
446
447#[derive(Debug)]
448pub(crate) enum ExecutionMode {
449 Command,
450 Watch,
451}
452
453impl AsArg for ExecutionMode {
454 fn as_arg(&self) -> String {
455 match self {
456 ExecutionMode::Command => "command".to_string(),
457 ExecutionMode::Watch => "watch".to_string(),
458 }
459 }
460}
461
462#[derive(Debug)]
463pub(crate) enum Command {
464 DECR {
465 key: String,
466 },
467 DECRBY {
468 key: String,
469 delta: i64,
470 },
471 DEL {
472 keys: Vec<String>,
473 },
474 ECHO {
475 message: String,
476 },
477 EXISTS {
478 key: String,
479 additional_keys: Vec<String>,
480 },
481 EXPIRE {
482 key: String,
483 seconds: i64,
484 option: ExpireOption,
485 },
486 EXPIREAT {
487 key: String,
488 timestamp: i64,
489 option: ExpireAtOption,
490 },
491 EXPIRETIME {
492 key: String,
493 },
494 FLUSHDB,
495 GET {
496 key: String,
497 },
498 GETDEL {
499 key: String,
500 },
501 GETEX {
502 key: String,
503 ex: GetexOption,
504 },
505 HSET {
506 key: String,
507 fields: Vec<(String, String)>,
508 },
509 HGET {
510 key: String,
511 field: String,
512 },
513 HGETALL {
514 key: String,
515 },
516 GETWATCH {
517 key: String,
518 },
519 HANDSHAKE {
520 client_id: String,
521 execution_mode: ExecutionMode,
522 },
523 INCR {
524 key: String,
525 },
526 INCRBY {
527 key: String,
528 delta: i64,
529 },
530 PING,
531 SET {
532 key: String,
533 value: SetInput,
534 option: SetOption,
535 get: bool,
536 },
537 TTL {
538 key: String,
539 },
540 TYPE {
541 key: String,
542 },
543 UNWATCH {
544 key: String,
545 },
546}
547
548impl Into<wire::Command> for Command {
549 fn into(self) -> wire::Command {
550 match self {
551 Command::DECR { key } => wire::Command {
552 cmd: "DECR".to_string(),
553 args: vec![key],
554 },
555 Command::DECRBY { key, delta } => wire::Command {
556 cmd: "DECRBY".to_string(),
557 args: vec![key, delta.to_string()],
558 },
559 Command::DEL { keys } => wire::Command {
560 cmd: "DEL".to_string(),
561 args: keys,
562 },
563 Command::ECHO { message } => wire::Command {
564 cmd: "ECHO".to_string(),
565 args: vec![message],
566 },
567 Command::EXISTS {
568 key,
569 additional_keys: keys,
570 } => {
571 let mut args = vec![key];
572 args.extend(keys);
573 wire::Command {
574 cmd: "EXISTS".to_string(),
575 args,
576 }
577 }
578 Command::EXPIRE {
579 key,
580 seconds,
581 option,
582 } => {
583 let mut args = vec![key, seconds.to_string()];
584 match option {
585 ExpireOption::NX => args.push("NX".to_string()),
586 ExpireOption::XX => args.push("XX".to_string()),
587 ExpireOption::None => {}
588 }
589 wire::Command {
590 cmd: "EXPIRE".to_string(),
591 args,
592 }
593 }
594 Command::EXPIREAT {
595 key,
596 timestamp,
597 option,
598 } => {
599 let mut args = vec![key, timestamp.to_string()];
600 match option {
601 ExpireAtOption::None => {}
602 option => args.push(option.as_arg()),
603 }
604 wire::Command {
605 cmd: "EXPIREAT".to_string(),
606 args,
607 }
608 }
609 Command::EXPIRETIME { key } => wire::Command {
610 cmd: "EXPIRETIME".to_string(),
611 args: vec![key],
612 },
613 Command::FLUSHDB => wire::Command {
614 cmd: "FLUSHDB".to_string(),
615 args: vec![],
616 },
617 Command::GET { key } => wire::Command {
618 cmd: "GET".to_string(),
619 args: vec![key],
620 },
621 Command::GETDEL { key } => wire::Command {
622 cmd: "GETDEL".to_string(),
623 args: vec![key],
624 },
625 Command::GETEX { key, ex } => {
626 let mut args = vec![key];
627 args.extend(ex.as_args());
628 wire::Command {
629 cmd: "GETEX".to_string(),
630 args,
631 }
632 }
633 Command::HSET { key, fields } => {
634 let mut args = vec![key];
635 for (field, value) in fields {
636 args.push(field);
637 args.push(value.as_arg());
638 }
639 wire::Command {
640 cmd: "HSET".to_string(),
641 args,
642 }
643 }
644 Command::HGET { key, field } => wire::Command {
645 cmd: "HGET".to_string(),
646 args: vec![key, field],
647 },
648 Command::HGETALL { key } => wire::Command {
649 cmd: "HGETALL".to_string(),
650 args: vec![key],
651 },
652 Command::GETWATCH { key } => wire::Command {
653 cmd: "GET.WATCH".to_string(),
654 args: vec![key],
655 },
656 Command::HANDSHAKE {
657 client_id,
658 execution_mode,
659 } => wire::Command {
660 cmd: "HANDSHAKE".to_string(),
661 args: vec![client_id, execution_mode.as_arg()],
662 },
663 Command::INCR { key } => wire::Command {
664 cmd: "INCR".to_string(),
665 args: vec![key],
666 },
667 Command::INCRBY { key, delta } => wire::Command {
668 cmd: "INCRBY".to_string(),
669 args: vec![key, delta.to_string()],
670 },
671 Command::PING => wire::Command {
672 cmd: "PING".to_string(),
673 args: vec![],
674 },
675 Command::SET {
676 key,
677 value,
678 option,
679 get,
680 } => {
681 let value: ScalarValue = value.into();
682 let mut args = vec![key, value.as_arg()];
683 args.extend(option.as_args());
684 match get {
685 true => args.push("GET".to_string()),
686 false => {}
687 }
688 wire::Command {
689 cmd: "SET".to_string(),
690 args,
691 }
692 }
693 Command::TTL { key } => wire::Command {
694 cmd: "TTL".to_string(),
695 args: vec![key],
696 },
697 Command::TYPE { key } => wire::Command {
698 cmd: "TYPE".to_string(),
699 args: vec![key],
700 },
701 Command::UNWATCH { key } => wire::Command {
702 cmd: "UNWATCH".to_string(),
703 args: vec![key],
704 },
705 }
706 }
707}
708
709impl Command {
710 pub(crate) fn encode(self) -> Vec<u8> {
711 let command: wire::Command = self.into();
712 eprintln!("Sending command: {:?}", command);
713 command.encode_to_vec()
714 }
715}
716
717#[cfg(test)]
718mod tests {
719
720 use super::*;
721
722 #[test]
723 fn test_try_into() {
724 let v: ScalarValue = ScalarValue::VInt(42);
725 let v_setval: SetInput = v.try_into().unwrap();
726 assert_eq!(v_setval, SetInput::Int(42));
727 let v: ScalarValue = ScalarValue::VStr("42".to_string());
728 let v_setval: SetInput = v.try_into().unwrap();
729 assert_eq!(v_setval, SetInput::Str("42".to_string()));
730 let v: ScalarValue = ScalarValue::VFloat(42.0);
731 let v_setval: SetInput = v.try_into().unwrap();
732 assert_eq!(v_setval, SetInput::Float(42.0));
733 let v: ScalarValue = ScalarValue::VBool(true);
734 let v_setval: Result<SetInput, String> = v.try_into();
735 assert_eq!(
736 v_setval,
737 Err("Cannot convert Value::VBool to SetValue".to_string())
738 );
739 let v: ScalarValue = ScalarValue::VNull;
740 let v_setval: Result<SetInput, String> = v.try_into();
741 assert_eq!(
742 v_setval,
743 Err("Cannot convert Value::VNull to SetValue".to_string())
744 );
745 }
746
747 #[test]
748 fn test_value_can_convert() {
749 let v: i64 = 42;
750 let v_setval: SetInput = v.into();
751 let v_value: ScalarValue = v.into();
752 assert_eq!(v_setval, SetInput::Int(42));
753 assert_eq!(v_value, ScalarValue::VInt(42));
754
755 let v_f: f64 = 42.0;
756 let v_setval: SetInput = v_f.into();
757 let v_value: ScalarValue = v_f.into();
758 assert_eq!(v_setval, SetInput::Float(42.0));
759 assert_eq!(v_value, ScalarValue::VFloat(42.0));
760
761 let v_str: &str = "42";
762 let v_setval: SetInput = v_str.into();
763 let v_value: ScalarValue = v_str.into();
764 assert_eq!(v_setval, SetInput::Str("42".to_string()));
765 assert_eq!(v_value, ScalarValue::VStr("42".to_string()));
766 }
767
768 #[test]
769 fn test_display_for_value() {
770 let value = ScalarValue::VInt(1);
771 assert_eq!(format!("{}", value), "1");
772 let value = ScalarValue::VStr("test".to_string());
773 assert_eq!(format!("{}", value), "test");
774 let value = ScalarValue::VNull;
775 assert_eq!(format!("{}", value), "null");
776 let value = ScalarValue::VFloat(1.2);
777 assert_eq!(format!("{}", value), "1.2");
778 let value = ScalarValue::VBool(true);
779 assert_eq!(format!("{}", value), "true");
780 }
781}