dicedb_rs/
commands.rs

1//! # Commands Module
2//! Contains structures and options related to interact with the server.
3//! It contains structures for all the commands, value types and options.
4
5use prost::Message;
6use std::{collections::HashMap, fmt::Display};
7
8use crate::errors::{CommandError, StreamError};
9
10mod wire {
11    tonic::include_proto!("wire");
12}
13
14/// A special input type for the DEL oeration.
15#[derive(Debug, Clone, PartialEq)]
16pub enum DelInput<'a> {
17    /// A single key to delete.
18    Single(&'a str),
19    /// Multiple keys to delete.
20    Multiple(Vec<&'a str>),
21}
22
23/// A special input type for the HSET operation.
24/// The type is a convenience type that allows users to specify either a single key-value pair or
25/// multiple key-value pairs.
26#[derive(Debug, Clone, PartialEq)]
27pub enum HSetInput<'a> {
28    /// A single key-value pair.
29    Single(&'a str, &'a str),
30    /// Multiple key-value pairs.
31    Multiple(Vec<(&'a str, &'a str)>),
32}
33
34/// Valid values that can be used with the SET operation.
35#[derive(Debug, Clone, PartialEq)]
36pub enum SetInput {
37    /// A string value.
38    Str(String),
39    /// An integer value.
40    Int(i64),
41    /// A floating point value.
42    Float(f64),
43}
44
45impl Into<ScalarValue> for SetInput {
46    fn into(self) -> ScalarValue {
47        match self {
48            SetInput::Str(s) => ScalarValue::VStr(s),
49            SetInput::Int(i) => ScalarValue::VInt(i),
50            SetInput::Float(f) => ScalarValue::VFloat(f),
51        }
52    }
53}
54
55impl TryInto<SetInput> for ScalarValue {
56    type Error = String;
57
58    fn try_into(self) -> Result<SetInput, Self::Error> {
59        match self {
60            ScalarValue::VStr(s) => Ok(SetInput::Str(s)),
61            ScalarValue::VInt(i) => Ok(SetInput::Int(i)),
62            ScalarValue::VFloat(f) => Ok(SetInput::Float(f)),
63            ScalarValue::VBool(_) => Err("Cannot convert Value::VBool to SetValue".to_string()),
64            ScalarValue::VNull => Err("Cannot convert Value::VNull to SetValue".to_string()),
65        }
66    }
67}
68
69macro_rules! impl_vint_setvalue_for_int {
70    ($($t:ty),*) => {
71        $(
72            impl From<$t> for SetInput {
73                fn from(value: $t) -> Self {
74                    SetInput::Int(value as i64)
75                }
76            }
77        )*
78    };
79}
80
81macro_rules! impl_vint_value_for_int {
82    ($($t:ty),*) => {
83        $(
84            impl From<$t> for ScalarValue {
85                fn from(value: $t) -> Self {
86                    ScalarValue::VInt(value as i64)
87                }
88            }
89        )*
90    };
91}
92
93macro_rules! impl_vint_setvalue_for_float {
94    ($($t:ty),*) => {
95        $(
96            impl From<$t> for SetInput {
97                fn from(value: $t) -> Self {
98                    SetInput::Float(value as f64)
99                }
100            }
101        )*
102    };
103}
104
105macro_rules! impl_vint_value_for_float {
106    ($($t:ty),*) => {
107        $(
108            impl From<$t> for ScalarValue {
109                fn from(value: $t) -> Self {
110                    ScalarValue::VFloat(value as f64)
111                }
112            }
113        )*
114    };
115}
116
117impl_vint_setvalue_for_int!(i64, i32, i16, i8, u64, u32, u16, u8);
118impl_vint_value_for_int!(i64, i32, i16, i8, u64, u32, u16, u8);
119impl_vint_setvalue_for_float!(f64, f32);
120impl_vint_value_for_float!(f64, f32);
121
122impl Into<ScalarValue> for &str {
123    fn into(self) -> ScalarValue {
124        ScalarValue::VStr(self.to_string())
125    }
126}
127
128impl Into<SetInput> for &str {
129    fn into(self) -> SetInput {
130        SetInput::Str(self.to_string())
131    }
132}
133
134/// A value received from the server.
135#[derive(Debug, Clone, PartialEq, PartialOrd)]
136pub enum ScalarValue {
137    /// A string value.
138    VStr(String),
139    /// An integer value.
140    VInt(i64),
141    /// A floating point value.
142    VFloat(f64),
143    /// A boolean value.
144    VBool(bool),
145    /// A null value. A null value is not indicative of failure, but just the absence of a value.
146    VNull,
147}
148
149impl Display for ScalarValue {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            ScalarValue::VStr(s) => write!(f, "{}", s),
153            ScalarValue::VInt(i) => write!(f, "{}", i),
154            ScalarValue::VFloat(fl) => write!(f, "{}", fl),
155            ScalarValue::VBool(b) => write!(f, "{}", b),
156            ScalarValue::VNull => write!(f, "null"),
157        }
158    }
159}
160
161impl AsArg for ScalarValue {
162    fn as_arg(&self) -> String {
163        match self {
164            ScalarValue::VStr(s) => s.clone(),
165            ScalarValue::VInt(i) => i.to_string(),
166            ScalarValue::VFloat(f) => f.to_string(),
167            ScalarValue::VBool(b) => b.to_string(),
168            ScalarValue::VNull => "".to_string(),
169        }
170    }
171}
172
173impl Into<ScalarValue> for wire::response::Value {
174    fn into(self) -> ScalarValue {
175        match self {
176            wire::response::Value::VNil(_) => ScalarValue::VNull,
177            wire::response::Value::VInt(i) => ScalarValue::VInt(i),
178            wire::response::Value::VStr(s) => ScalarValue::VStr(s),
179            wire::response::Value::VFloat(f) => ScalarValue::VFloat(f),
180            wire::response::Value::VBytes(b) => {
181                ScalarValue::VStr(String::from_utf8_lossy(&b).to_string())
182            }
183        }
184    }
185}
186
187/// A watch value is a value that originates from a GET.WATCH command.
188#[derive(Debug)]
189pub struct WatchValue {
190    /// The value from the watch session, it indicates a change in a watched key.
191    pub value: ScalarValue,
192    /// The fingerprint of the value, which is a unique identifier for the value.
193    pub fingerprint: String,
194}
195
196impl Into<ScalarValue> for WatchValue {
197    fn into(self) -> ScalarValue {
198        self.value
199    }
200}
201
202impl WatchValue {
203    pub(crate) fn decode_watchvalue(bytes: &[u8]) -> Result<Self, CommandError> {
204        match wire::Response::decode(bytes) {
205            Ok(v) => {
206                if v.err == "" {
207                    let fingerprint = match v
208                        .attrs
209                        .ok_or(CommandError::WatchValueExpectationError(
210                            "Missing attributes from response".to_string(),
211                        ))?
212                        .fields
213                        .get("fingerprint")
214                        .ok_or(CommandError::WatchValueExpectationError(
215                            "Missing fingerprint from attributes".to_string(),
216                        ))?
217                        .kind
218                        .clone()
219                        .ok_or(CommandError::WatchValueExpectationError(
220                            "Missing kind from fingerprint attribute".to_string(),
221                        ))? {
222                        prost_types::value::Kind::StringValue(s) => s,
223                        _ => {
224                            return Err(CommandError::WatchValueExpectationError(
225                                "Fingerprint is not a string".to_string(),
226                            ))
227                        }
228                    };
229                    let value = v
230                        .value
231                        .ok_or(CommandError::WatchValueExpectationError(
232                            "Missing value from response".to_string(),
233                        ))?
234                        .into();
235
236                    Ok(WatchValue { value, fingerprint })
237                } else {
238                    Err(CommandError::ServerError(v.err))
239                }
240            }
241            Err(e) => Err(CommandError::DecodeError(e)),
242        }
243    }
244}
245
246/// HSetValue is a value that originates from a HGETALL command.
247#[derive(Debug, Clone, PartialEq)]
248pub struct HSetValue {
249    /// The fields of the hash set.
250    pub fields: HashMap<String, String>,
251}
252
253impl Into<HashMap<String, String>> for HSetValue {
254    fn into(self) -> HashMap<String, String> {
255        self.fields
256    }
257}
258
259impl HSetValue {
260    pub(crate) fn decode(bytes: &[u8]) -> Result<Self, CommandError> {
261        match wire::Response::decode(bytes) {
262            Ok(v) => {
263                if v.err == "" {
264                    let fields = v.v_ss_map;
265                    Ok(HSetValue { fields })
266                } else {
267                    Err(CommandError::ServerError(v.err))
268                }
269            }
270            Err(e) => Err(CommandError::DecodeError(e)),
271        }
272    }
273}
274
275impl ScalarValue {
276    pub(crate) fn decode(bytes: &[u8]) -> Result<Self, CommandError> {
277        let decoded = match wire::Response::decode(bytes) {
278            Ok(v) => {
279                if v.err == "" {
280                    match v.value {
281                        Some(value) => Ok(value.into()),
282                        None => Ok(ScalarValue::VNull),
283                    }
284                } else {
285                    Err(CommandError::ServerError(v.err))
286                }
287            }
288            Err(e) => Err(CommandError::DecodeError(e)),
289        };
290        eprintln!("Decoded value: {:?}", decoded);
291
292        decoded
293    }
294}
295
296trait AsArg {
297    fn as_arg(&self) -> String;
298}
299
300trait AsArgs {
301    fn as_args(&self) -> Vec<String>;
302}
303
304pub(crate) trait CommandExecutor {
305    fn execute_scalar_command(&mut self, command: Command) -> Result<ScalarValue, StreamError>;
306    fn execute_hset_command(&mut self, command: Command) -> Result<HSetValue, StreamError>;
307}
308
309/// Expire options for the EXPIRE command
310#[derive(Debug, Clone, Copy)]
311pub enum ExpireOption {
312    /// Don't overwrite existing expiration time
313    NX,
314    /// Only set the expiration time if it already exists
315    XX,
316    /// Always set the expiration time
317    None,
318}
319
320impl AsArg for ExpireOption {
321    fn as_arg(&self) -> String {
322        match self {
323            ExpireOption::NX => "NX".to_string(),
324            ExpireOption::XX => "XX".to_string(),
325            ExpireOption::None => "".to_string(),
326        }
327    }
328}
329
330/// Expire options for the EXPIREAT command
331#[derive(Debug, Clone, Copy)]
332pub enum ExpireAtOption {
333    /// Don't overwrite existing expiration time
334    NX,
335    /// Only set the expiration time if it already exists
336    XX,
337    /// Set the expiration time only if it's greater than the existing expiration time
338    GT,
339    /// Set the expiration time only if it's less than the existing expiration time
340    LT,
341    /// Always set the expiration time
342    None,
343}
344
345impl AsArg for ExpireAtOption {
346    fn as_arg(&self) -> String {
347        match self {
348            ExpireAtOption::NX => "NX".to_string(),
349            ExpireAtOption::XX => "XX".to_string(),
350            ExpireAtOption::GT => "GT".to_string(),
351            ExpireAtOption::LT => "LT".to_string(),
352            ExpireAtOption::None => "".to_string(),
353        }
354    }
355}
356
357/// Options for the GETEX command
358#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
359pub enum GetexOption {
360    /// Set the expiration to seconds from now.
361    EX(u64),
362    /// Set the expiration to milliseconds from now.
363    PX(u64),
364    /// Set the expiration to a Unix timestamp.
365    EXAT(u64),
366    /// Set the expiration to a Unix timestamp in milliseconds.
367    PXAT(u64),
368    /// Remove the expiration from the key.
369    PERSIST,
370}
371
372impl AsArgs for GetexOption {
373    fn as_args(&self) -> Vec<String> {
374        match self {
375            GetexOption::EX(seconds) => vec!["EX".to_string(), seconds.to_string()],
376            GetexOption::PX(milliseconds) => vec!["PX".to_string(), milliseconds.to_string()],
377            GetexOption::EXAT(timestamp) => vec!["EXAT".to_string(), timestamp.to_string()],
378            GetexOption::PXAT(timestamp) => vec!["PXAT".to_string(), timestamp.to_string()],
379            GetexOption::PERSIST => vec!["PERSIST".to_string()],
380        }
381    }
382}
383
384/// Options for the SET command
385#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
386pub enum SetOption {
387    /// Set the expiration time in seconds
388    EX(u64),
389    /// Set the expiration time in milliseconds
390    PX(u64),
391    /// Set the expiration time in seconds since epoch
392    EXAT(u64),
393    /// Set the expiration time in milliseconds since epoch
394    PXAT(u64),
395    /// Only set the key if it already exists
396    XX,
397    /// Only set the key if it does not already exist
398    NX,
399    /// Keep the existing TTL of the key
400    KEEPTTL,
401    /// No special option, default
402    None,
403}
404
405impl AsArgs for SetOption {
406    fn as_args(&self) -> Vec<String> {
407        match self {
408            SetOption::EX(seconds) => vec!["EX".to_string(), seconds.to_string()],
409            SetOption::PX(milliseconds) => vec!["PX".to_string(), milliseconds.to_string()],
410            SetOption::EXAT(timestamp) => vec!["EXAT".to_string(), timestamp.to_string()],
411            SetOption::PXAT(timestamp) => vec!["PXAT".to_string(), timestamp.to_string()],
412            SetOption::XX => vec!["XX".to_string()],
413            SetOption::NX => vec!["NX".to_string()],
414            SetOption::KEEPTTL => vec!["KEEPTTL".to_string()],
415            SetOption::None => vec![],
416        }
417    }
418}
419
420impl AsArg for SetInput {
421    fn as_arg(&self) -> String {
422        match self {
423            SetInput::Str(s) => s.clone(),
424            SetInput::Int(i) => i.to_string(),
425            SetInput::Float(f) => f.to_string(),
426        }
427    }
428}
429
430impl AsArg for String {
431    fn as_arg(&self) -> String {
432        self.clone()
433    }
434}
435
436impl AsArgs for Vec<(String, SetInput)> {
437    fn as_args(&self) -> Vec<String> {
438        let mut args = vec![];
439        for (field, value) in self {
440            args.push(field.clone());
441            args.push(value.as_arg());
442        }
443        args
444    }
445}
446
447#[derive(Debug)]
448pub(crate) enum ExecutionMode {
449    Command,
450    Watch,
451}
452
453impl AsArg for ExecutionMode {
454    fn as_arg(&self) -> String {
455        match self {
456            ExecutionMode::Command => "command".to_string(),
457            ExecutionMode::Watch => "watch".to_string(),
458        }
459    }
460}
461
462#[derive(Debug)]
463pub(crate) enum Command {
464    DECR {
465        key: String,
466    },
467    DECRBY {
468        key: String,
469        delta: i64,
470    },
471    DEL {
472        keys: Vec<String>,
473    },
474    ECHO {
475        message: String,
476    },
477    EXISTS {
478        key: String,
479        additional_keys: Vec<String>,
480    },
481    EXPIRE {
482        key: String,
483        seconds: i64,
484        option: ExpireOption,
485    },
486    EXPIREAT {
487        key: String,
488        timestamp: i64,
489        option: ExpireAtOption,
490    },
491    EXPIRETIME {
492        key: String,
493    },
494    FLUSHDB,
495    GET {
496        key: String,
497    },
498    GETDEL {
499        key: String,
500    },
501    GETEX {
502        key: String,
503        ex: GetexOption,
504    },
505    HSET {
506        key: String,
507        fields: Vec<(String, String)>,
508    },
509    HGET {
510        key: String,
511        field: String,
512    },
513    HGETALL {
514        key: String,
515    },
516    GETWATCH {
517        key: String,
518    },
519    HANDSHAKE {
520        client_id: String,
521        execution_mode: ExecutionMode,
522    },
523    INCR {
524        key: String,
525    },
526    INCRBY {
527        key: String,
528        delta: i64,
529    },
530    PING,
531    SET {
532        key: String,
533        value: SetInput,
534        option: SetOption,
535        get: bool,
536    },
537    TTL {
538        key: String,
539    },
540    TYPE {
541        key: String,
542    },
543    UNWATCH {
544        key: String,
545    },
546}
547
548impl Into<wire::Command> for Command {
549    fn into(self) -> wire::Command {
550        match self {
551            Command::DECR { key } => wire::Command {
552                cmd: "DECR".to_string(),
553                args: vec![key],
554            },
555            Command::DECRBY { key, delta } => wire::Command {
556                cmd: "DECRBY".to_string(),
557                args: vec![key, delta.to_string()],
558            },
559            Command::DEL { keys } => wire::Command {
560                cmd: "DEL".to_string(),
561                args: keys,
562            },
563            Command::ECHO { message } => wire::Command {
564                cmd: "ECHO".to_string(),
565                args: vec![message],
566            },
567            Command::EXISTS {
568                key,
569                additional_keys: keys,
570            } => {
571                let mut args = vec![key];
572                args.extend(keys);
573                wire::Command {
574                    cmd: "EXISTS".to_string(),
575                    args,
576                }
577            }
578            Command::EXPIRE {
579                key,
580                seconds,
581                option,
582            } => {
583                let mut args = vec![key, seconds.to_string()];
584                match option {
585                    ExpireOption::NX => args.push("NX".to_string()),
586                    ExpireOption::XX => args.push("XX".to_string()),
587                    ExpireOption::None => {}
588                }
589                wire::Command {
590                    cmd: "EXPIRE".to_string(),
591                    args,
592                }
593            }
594            Command::EXPIREAT {
595                key,
596                timestamp,
597                option,
598            } => {
599                let mut args = vec![key, timestamp.to_string()];
600                match option {
601                    ExpireAtOption::None => {}
602                    option => args.push(option.as_arg()),
603                }
604                wire::Command {
605                    cmd: "EXPIREAT".to_string(),
606                    args,
607                }
608            }
609            Command::EXPIRETIME { key } => wire::Command {
610                cmd: "EXPIRETIME".to_string(),
611                args: vec![key],
612            },
613            Command::FLUSHDB => wire::Command {
614                cmd: "FLUSHDB".to_string(),
615                args: vec![],
616            },
617            Command::GET { key } => wire::Command {
618                cmd: "GET".to_string(),
619                args: vec![key],
620            },
621            Command::GETDEL { key } => wire::Command {
622                cmd: "GETDEL".to_string(),
623                args: vec![key],
624            },
625            Command::GETEX { key, ex } => {
626                let mut args = vec![key];
627                args.extend(ex.as_args());
628                wire::Command {
629                    cmd: "GETEX".to_string(),
630                    args,
631                }
632            }
633            Command::HSET { key, fields } => {
634                let mut args = vec![key];
635                for (field, value) in fields {
636                    args.push(field);
637                    args.push(value.as_arg());
638                }
639                wire::Command {
640                    cmd: "HSET".to_string(),
641                    args,
642                }
643            }
644            Command::HGET { key, field } => wire::Command {
645                cmd: "HGET".to_string(),
646                args: vec![key, field],
647            },
648            Command::HGETALL { key } => wire::Command {
649                cmd: "HGETALL".to_string(),
650                args: vec![key],
651            },
652            Command::GETWATCH { key } => wire::Command {
653                cmd: "GET.WATCH".to_string(),
654                args: vec![key],
655            },
656            Command::HANDSHAKE {
657                client_id,
658                execution_mode,
659            } => wire::Command {
660                cmd: "HANDSHAKE".to_string(),
661                args: vec![client_id, execution_mode.as_arg()],
662            },
663            Command::INCR { key } => wire::Command {
664                cmd: "INCR".to_string(),
665                args: vec![key],
666            },
667            Command::INCRBY { key, delta } => wire::Command {
668                cmd: "INCRBY".to_string(),
669                args: vec![key, delta.to_string()],
670            },
671            Command::PING => wire::Command {
672                cmd: "PING".to_string(),
673                args: vec![],
674            },
675            Command::SET {
676                key,
677                value,
678                option,
679                get,
680            } => {
681                let value: ScalarValue = value.into();
682                let mut args = vec![key, value.as_arg()];
683                args.extend(option.as_args());
684                match get {
685                    true => args.push("GET".to_string()),
686                    false => {}
687                }
688                wire::Command {
689                    cmd: "SET".to_string(),
690                    args,
691                }
692            }
693            Command::TTL { key } => wire::Command {
694                cmd: "TTL".to_string(),
695                args: vec![key],
696            },
697            Command::TYPE { key } => wire::Command {
698                cmd: "TYPE".to_string(),
699                args: vec![key],
700            },
701            Command::UNWATCH { key } => wire::Command {
702                cmd: "UNWATCH".to_string(),
703                args: vec![key],
704            },
705        }
706    }
707}
708
709impl Command {
710    pub(crate) fn encode(self) -> Vec<u8> {
711        let command: wire::Command = self.into();
712        eprintln!("Sending command: {:?}", command);
713        command.encode_to_vec()
714    }
715}
716
717#[cfg(test)]
718mod tests {
719
720    use super::*;
721
722    #[test]
723    fn test_try_into() {
724        let v: ScalarValue = ScalarValue::VInt(42);
725        let v_setval: SetInput = v.try_into().unwrap();
726        assert_eq!(v_setval, SetInput::Int(42));
727        let v: ScalarValue = ScalarValue::VStr("42".to_string());
728        let v_setval: SetInput = v.try_into().unwrap();
729        assert_eq!(v_setval, SetInput::Str("42".to_string()));
730        let v: ScalarValue = ScalarValue::VFloat(42.0);
731        let v_setval: SetInput = v.try_into().unwrap();
732        assert_eq!(v_setval, SetInput::Float(42.0));
733        let v: ScalarValue = ScalarValue::VBool(true);
734        let v_setval: Result<SetInput, String> = v.try_into();
735        assert_eq!(
736            v_setval,
737            Err("Cannot convert Value::VBool to SetValue".to_string())
738        );
739        let v: ScalarValue = ScalarValue::VNull;
740        let v_setval: Result<SetInput, String> = v.try_into();
741        assert_eq!(
742            v_setval,
743            Err("Cannot convert Value::VNull to SetValue".to_string())
744        );
745    }
746
747    #[test]
748    fn test_value_can_convert() {
749        let v: i64 = 42;
750        let v_setval: SetInput = v.into();
751        let v_value: ScalarValue = v.into();
752        assert_eq!(v_setval, SetInput::Int(42));
753        assert_eq!(v_value, ScalarValue::VInt(42));
754
755        let v_f: f64 = 42.0;
756        let v_setval: SetInput = v_f.into();
757        let v_value: ScalarValue = v_f.into();
758        assert_eq!(v_setval, SetInput::Float(42.0));
759        assert_eq!(v_value, ScalarValue::VFloat(42.0));
760
761        let v_str: &str = "42";
762        let v_setval: SetInput = v_str.into();
763        let v_value: ScalarValue = v_str.into();
764        assert_eq!(v_setval, SetInput::Str("42".to_string()));
765        assert_eq!(v_value, ScalarValue::VStr("42".to_string()));
766    }
767
768    #[test]
769    fn test_display_for_value() {
770        let value = ScalarValue::VInt(1);
771        assert_eq!(format!("{}", value), "1");
772        let value = ScalarValue::VStr("test".to_string());
773        assert_eq!(format!("{}", value), "test");
774        let value = ScalarValue::VNull;
775        assert_eq!(format!("{}", value), "null");
776        let value = ScalarValue::VFloat(1.2);
777        assert_eq!(format!("{}", value), "1.2");
778        let value = ScalarValue::VBool(true);
779        assert_eq!(format!("{}", value), "true");
780    }
781}