redis_lua/
types.rs

1use redis::{RedisWrite, ToRedisArgs};
2use rmp::encode;
3use serde::{de, ser, Serialize};
4use std::{
5    fmt::{self, Display},
6    io::{self, Write},
7};
8
9trait RedisArgWrite: RedisWrite {
10    fn pack(&mut self);
11}
12
13#[doc(hidden)]
14pub struct ScriptArg {
15    buf: Vec<u8>,
16    pack: bool,
17}
18
19impl ScriptArg {
20    fn new() -> Self {
21        Self {
22            buf: Vec::with_capacity(128),
23            pack: false,
24        }
25    }
26
27    pub fn pack(&self) -> bool {
28        self.pack
29    }
30}
31
32impl RedisWrite for ScriptArg {
33    fn write_arg(&mut self, arg: &[u8]) {
34        self.buf.extend(arg);
35    }
36}
37
38impl RedisArgWrite for ScriptArg {
39    fn pack(&mut self) {
40        self.pack = true;
41    }
42}
43
44pub fn script_arg<T: Serialize + ?Sized>(value: &T) -> ScriptArg {
45    let mut arg = ScriptArg::new();
46    let mut ser = Serializer::new(&mut arg);
47    value.serialize(&mut ser).expect("Couldn't serialize");
48    arg
49}
50
51impl ToRedisArgs for ScriptArg {
52    fn write_redis_args<W: ?Sized>(&self, out: &mut W)
53    where
54        W: RedisWrite,
55    {
56        self.buf.write_redis_args(out);
57    }
58}
59
60type Result<T> = std::result::Result<T, Error>;
61
62#[derive(Debug)]
63struct Error(String);
64
65impl ser::Error for Error {
66    fn custom<T: Display>(msg: T) -> Self {
67        Error(msg.to_string())
68    }
69}
70
71impl de::Error for Error {
72    fn custom<T: Display>(msg: T) -> Self {
73        Error(msg.to_string())
74    }
75}
76
77impl Display for Error {
78    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79        write!(f, "{}", self.0)
80    }
81}
82
83impl std::error::Error for Error {}
84
85impl From<rmp::encode::Error> for Error {
86    fn from(e: rmp::encode::Error) -> Self {
87        Self(e.to_string())
88    }
89}
90
91impl From<rmp::encode::ValueWriteError> for Error {
92    fn from(e: rmp::encode::ValueWriteError) -> Self {
93        Self(e.to_string())
94    }
95}
96
97struct Serializer<'a, W: ?Sized>(&'a mut W);
98
99impl<'a, W> Serializer<'a, W>
100where
101    W: RedisArgWrite + ?Sized,
102{
103    fn new(w: &'a mut W) -> Self {
104        Self(w)
105    }
106
107    fn write_null(&mut self) {
108        Vec::<u8>::new().write_redis_args(self.0);
109    }
110}
111
112impl<'a, 'b, W> ser::Serializer for &'a mut Serializer<'b, W>
113where
114    W: RedisArgWrite + ?Sized,
115{
116    type Ok = ();
117    type Error = Error;
118
119    type SerializeSeq = Compound<Arg<'a, W>, Seq>;
120    type SerializeTuple = Compound<Arg<'a, W>, Seq>;
121    type SerializeTupleStruct = Compound<Arg<'a, W>, Seq>;
122    type SerializeTupleVariant = Compound<Arg<'a, W>, Seq>;
123    type SerializeMap = Compound<Arg<'a, W>, Map>;
124    type SerializeStruct = Compound<Arg<'a, W>, Map>;
125    type SerializeStructVariant = Compound<Arg<'a, W>, Map>;
126
127    fn serialize_bool(self, v: bool) -> Result<()> {
128        Ok((v as usize).write_redis_args(self.0))
129    }
130
131    fn serialize_i8(self, v: i8) -> Result<()> {
132        Ok(v.write_redis_args(self.0))
133    }
134
135    fn serialize_i16(self, v: i16) -> Result<()> {
136        Ok(v.write_redis_args(self.0))
137    }
138
139    fn serialize_i32(self, v: i32) -> Result<()> {
140        Ok(v.write_redis_args(self.0))
141    }
142
143    fn serialize_i64(self, v: i64) -> Result<()> {
144        Ok(v.write_redis_args(self.0))
145    }
146
147    fn serialize_u8(self, v: u8) -> Result<()> {
148        Ok(v.write_redis_args(self.0))
149    }
150
151    fn serialize_u16(self, v: u16) -> Result<()> {
152        Ok(v.write_redis_args(self.0))
153    }
154
155    fn serialize_u32(self, v: u32) -> Result<()> {
156        Ok(v.write_redis_args(self.0))
157    }
158
159    fn serialize_u64(self, v: u64) -> Result<()> {
160        Ok(v.write_redis_args(self.0))
161    }
162
163    fn serialize_f32(self, v: f32) -> Result<()> {
164        Ok(v.write_redis_args(self.0))
165    }
166
167    fn serialize_f64(self, v: f64) -> Result<()> {
168        Ok(v.write_redis_args(self.0))
169    }
170
171    fn serialize_char(self, v: char) -> Result<()> {
172        let mut buf = [0; 4];
173        let len = v.encode_utf8(&mut buf).len();
174        Ok((&buf[..len]).write_redis_args(self.0))
175    }
176
177    fn serialize_str(self, v: &str) -> Result<()> {
178        Ok(v.write_redis_args(self.0))
179    }
180
181    fn serialize_bytes(self, v: &[u8]) -> Result<()> {
182        Ok(v.write_redis_args(self.0))
183    }
184
185    fn serialize_none(self) -> Result<()> {
186        Ok(self.write_null())
187    }
188
189    fn serialize_some<T>(self, value: &T) -> Result<()>
190    where
191        T: ?Sized + Serialize,
192    {
193        value.serialize(self)
194    }
195
196    fn serialize_unit(self) -> Result<()> {
197        Ok(self.write_null())
198    }
199
200    fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
201        Ok(self.write_null())
202    }
203
204    fn serialize_unit_variant(
205        self,
206        _name: &'static str,
207        _variant_index: u32,
208        variant: &'static str,
209    ) -> Result<()> {
210        variant.serialize(self)
211    }
212
213    fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
214    where
215        T: ?Sized + Serialize,
216    {
217        value.serialize(self)
218    }
219
220    fn serialize_newtype_variant<T>(
221        self,
222        _name: &'static str,
223        _variant_index: u32,
224        _variant: &'static str,
225        value: &T,
226    ) -> Result<()>
227    where
228        T: ?Sized + Serialize,
229    {
230        value.serialize(self)
231    }
232
233    fn serialize_seq(self, _: Option<usize>) -> Result<Self::SerializeSeq> {
234        Ok(Compound::new(Arg(self.0), Seq::new()))
235    }
236
237    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
238        self.serialize_seq(Some(len))
239    }
240
241    fn serialize_tuple_struct(
242        self,
243        _name: &'static str,
244        len: usize,
245    ) -> Result<Self::SerializeTupleStruct> {
246        self.serialize_seq(Some(len))
247    }
248
249    fn serialize_tuple_variant(
250        self,
251        _name: &'static str,
252        _variant_index: u32,
253        _variant: &'static str,
254        len: usize,
255    ) -> Result<Self::SerializeTupleVariant> {
256        self.serialize_seq(Some(len))
257    }
258
259    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
260        Ok(Compound::new(Arg(self.0), Map::new()))
261    }
262
263    fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
264        self.serialize_map(Some(len))
265    }
266
267    fn serialize_struct_variant(
268        self,
269        _name: &'static str,
270        _variant_index: u32,
271        _variant: &'static str,
272        len: usize,
273    ) -> Result<Self::SerializeStructVariant> {
274        self.serialize_map(Some(len))
275    }
276}
277
278/// Complex serializer
279struct ComplexSerializer<W>(W, Option<u8>);
280
281impl<W> ComplexSerializer<W>
282where
283    W: io::Write,
284{
285    fn new(w: W) -> Self {
286        Self(w, None)
287    }
288
289    /// Write a binary-safe string
290    fn write_str(&mut self, s: &[u8]) -> Result<()> {
291        encode::write_str_len(&mut self.0, s.len() as u32)?;
292        Ok(self.0.write_all(s)?)
293    }
294
295    fn single_byte(&self) -> Option<u8> {
296        self.1
297    }
298}
299
300impl<'a, W> ser::Serializer for &'a mut ComplexSerializer<W>
301where
302    W: io::Write,
303{
304    type Ok = ();
305    type Error = Error;
306
307    type SerializeSeq = Compound<Buf<'a, W>, Seq>;
308    type SerializeTuple = Compound<Buf<'a, W>, Seq>;
309    type SerializeTupleStruct = Compound<Buf<'a, W>, Seq>;
310    type SerializeTupleVariant = Compound<Buf<'a, W>, Seq>;
311    type SerializeMap = Compound<Buf<'a, W>, Map>;
312    type SerializeStruct = Compound<Buf<'a, W>, Map>;
313    type SerializeStructVariant = Compound<Buf<'a, W>, Map>;
314
315    fn serialize_bool(self, v: bool) -> Result<()> {
316        Ok(encode::write_bool(&mut self.0, v)?)
317    }
318
319    fn serialize_i8(self, v: i8) -> Result<()> {
320        self.serialize_i64(v as i64)
321    }
322
323    fn serialize_i16(self, v: i16) -> Result<()> {
324        self.serialize_i64(v as i64)
325    }
326
327    fn serialize_i32(self, v: i32) -> Result<()> {
328        self.serialize_i64(v as i64)
329    }
330
331    fn serialize_i64(self, v: i64) -> Result<()> {
332        encode::write_sint(&mut self.0, v)?;
333        Ok(())
334    }
335
336    fn serialize_u8(self, v: u8) -> Result<()> {
337        self.1 = Some(v);
338        encode::write_uint(&mut self.0, v as u64)?;
339        Ok(())
340    }
341
342    fn serialize_u16(self, v: u16) -> Result<()> {
343        self.serialize_u64(v as u64)
344    }
345
346    fn serialize_u32(self, v: u32) -> Result<()> {
347        self.serialize_u64(v as u64)
348    }
349
350    fn serialize_u64(self, v: u64) -> Result<()> {
351        encode::write_uint(&mut self.0, v)?;
352        Ok(())
353    }
354
355    fn serialize_f32(self, v: f32) -> Result<()> {
356        encode::write_f32(&mut self.0, v)?;
357        Ok(())
358    }
359
360    fn serialize_f64(self, v: f64) -> Result<()> {
361        encode::write_f64(&mut self.0, v)?;
362        Ok(())
363    }
364
365    fn serialize_char(self, v: char) -> Result<()> {
366        let mut buf = [0; 4];
367        self.serialize_str(v.encode_utf8(&mut buf))
368    }
369
370    fn serialize_str(self, v: &str) -> Result<()> {
371        self.write_str(v.as_bytes())
372    }
373
374    fn serialize_bytes(self, v: &[u8]) -> Result<()> {
375        self.write_str(v)
376    }
377
378    fn serialize_none(self) -> Result<()> {
379        self.serialize_unit()
380    }
381
382    fn serialize_some<T>(self, value: &T) -> Result<()>
383    where
384        T: ?Sized + Serialize,
385    {
386        value.serialize(self)
387    }
388
389    fn serialize_unit(self) -> Result<()> {
390        encode::write_nil(&mut self.0)?;
391        Ok(())
392    }
393
394    fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
395        encode::write_array_len(&mut self.0, 0)?;
396        Ok(())
397    }
398
399    fn serialize_unit_variant(
400        self,
401        _name: &'static str,
402        variant_index: u32,
403        _variant: &'static str,
404    ) -> Result<()> {
405        variant_index.serialize(self)
406    }
407
408    fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
409    where
410        T: ?Sized + Serialize,
411    {
412        value.serialize(self)
413    }
414
415    fn serialize_newtype_variant<T>(
416        self,
417        _name: &'static str,
418        _variant_index: u32,
419        _variant: &'static str,
420        value: &T,
421    ) -> Result<()>
422    where
423        T: ?Sized + Serialize,
424    {
425        value.serialize(self)
426    }
427
428    fn serialize_seq(self, _: Option<usize>) -> Result<Self::SerializeSeq> {
429        Ok(Compound::new(Buf(&mut self.0), Seq::new()))
430    }
431
432    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
433        self.serialize_seq(Some(len))
434    }
435
436    fn serialize_tuple_struct(
437        self,
438        _name: &'static str,
439        len: usize,
440    ) -> Result<Self::SerializeTupleStruct> {
441        self.serialize_seq(Some(len))
442    }
443
444    fn serialize_tuple_variant(
445        self,
446        _name: &'static str,
447        _variant_index: u32,
448        _variant: &'static str,
449        len: usize,
450    ) -> Result<Self::SerializeTupleVariant> {
451        self.serialize_seq(Some(len))
452    }
453
454    fn serialize_map(self, _: Option<usize>) -> Result<Self::SerializeMap> {
455        Ok(Compound::new(Buf(&mut self.0), Map::new()))
456    }
457
458    fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
459        self.serialize_map(Some(len))
460    }
461
462    fn serialize_struct_variant(
463        self,
464        _name: &'static str,
465        _variant_index: u32,
466        _variant: &'static str,
467        len: usize,
468    ) -> Result<Self::SerializeStructVariant> {
469        self.serialize_map(Some(len))
470    }
471}
472
473trait CompoundWrite {
474    fn write_raw(&mut self, buf: &[u8]) -> Result<()>;
475
476    fn write_packed(&mut self, buf: &[u8]) -> Result<()>;
477
478    fn is_arg(&self) -> bool;
479}
480
481struct Arg<'a, W: ?Sized>(&'a mut W);
482
483struct Buf<'a, W>(&'a mut W);
484
485impl<'a, W> CompoundWrite for Arg<'a, W>
486where
487    W: RedisArgWrite + ?Sized,
488{
489    fn write_raw(&mut self, buf: &[u8]) -> Result<()> {
490        buf.write_redis_args(self.0);
491        Ok(())
492    }
493
494    fn write_packed(&mut self, buf: &[u8]) -> Result<()> {
495        self.0.pack();
496        buf.write_redis_args(self.0);
497        Ok(())
498    }
499
500    fn is_arg(&self) -> bool {
501        true
502    }
503}
504
505impl<'a, W> CompoundWrite for Buf<'a, W>
506where
507    W: io::Write,
508{
509    fn write_raw(&mut self, buf: &[u8]) -> Result<()> {
510        self.0.write_all(buf)?;
511        Ok(())
512    }
513
514    fn write_packed(&mut self, buf: &[u8]) -> Result<()> {
515        self.write_raw(buf)
516    }
517
518    fn is_arg(&self) -> bool {
519        false
520    }
521}
522
523trait CompoundType {
524    fn is_map() -> bool;
525
526    fn add_byte(&mut self, _byte: Option<u8>) {}
527
528    fn bytearray(&self) -> Option<&[u8]> {
529        None
530    }
531}
532
533struct Map;
534
535impl Map {
536    fn new() -> Self {
537        Self
538    }
539}
540
541struct Seq {
542    bytearray: Vec<u8>,
543    is_bytearray: bool,
544}
545
546impl Seq {
547    fn new() -> Self {
548        Self {
549            bytearray: vec![],
550            is_bytearray: true,
551        }
552    }
553}
554
555impl CompoundType for Map {
556    fn is_map() -> bool {
557        true
558    }
559}
560
561impl CompoundType for Seq {
562    fn is_map() -> bool {
563        false
564    }
565
566    fn add_byte(&mut self, byte: Option<u8>) {
567        self.is_bytearray &= byte.is_some();
568        if let Some(byte) = byte {
569            self.bytearray.push(byte);
570        }
571    }
572
573    fn bytearray(&self) -> Option<&[u8]> {
574        if self.is_bytearray {
575            Some(&self.bytearray)
576        } else {
577            None
578        }
579    }
580}
581
582struct Compound<W, C> {
583    buf: Vec<u8>,
584    len: usize,
585    wr: W,
586    inner: C,
587}
588
589impl<W, C> Compound<W, C>
590where
591    W: CompoundWrite,
592    C: CompoundType,
593{
594    fn new(wr: W, inner: C) -> Self {
595        Self {
596            buf: vec![],
597            len: 0,
598            wr,
599            inner,
600        }
601    }
602
603    fn add<T: ?Sized>(&mut self, value: &T) -> Result<()>
604    where
605        T: Serialize,
606    {
607        self.len += 1;
608
609        self.inner.add_byte({
610            let mut ser = ComplexSerializer::new(&mut self.buf);
611            value.serialize(&mut ser)?;
612            ser.single_byte()
613        });
614
615        Ok(())
616    }
617
618    fn end(mut self) -> Result<()> {
619        let mut v = Vec::new();
620
621        if C::is_map() {
622            // Here divide the map length by 2
623            // because `add` is called twice per a key/value pair
624            encode::write_map_len(&mut v, self.len as u32 / 2)?;
625            v.write_all(&self.buf)?;
626            self.wr.write_packed(&v)?;
627        } else {
628            if let Some(bytearray) = self.inner.bytearray() {
629                if self.len > 0 {
630                    // Non-empty u8 sequence becomes a string
631                    if self.wr.is_arg() {
632                        self.wr.write_raw(bytearray)?;
633                    } else {
634                        encode::write_str_len(&mut v, self.len as u32)?;
635                        v.write_all(bytearray)?;
636                        self.wr.write_packed(&v)?;
637                    }
638                    return Ok(());
639                }
640            }
641            encode::write_array_len(&mut v, self.len as u32)?;
642            v.write_all(&self.buf)?;
643            self.wr.write_packed(&v)?;
644        }
645
646        Ok(())
647    }
648}
649
650impl<W, C> ser::SerializeSeq for Compound<W, C>
651where
652    W: CompoundWrite,
653    C: CompoundType,
654{
655    type Ok = ();
656    type Error = Error;
657
658    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
659    where
660        T: Serialize,
661    {
662        self.add(value)
663    }
664
665    fn end(self) -> Result<Self::Ok> {
666        self.end()
667    }
668}
669
670impl<W, C> ser::SerializeTuple for Compound<W, C>
671where
672    W: CompoundWrite,
673    C: CompoundType,
674{
675    type Ok = ();
676    type Error = Error;
677
678    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
679    where
680        T: ?Sized + Serialize,
681    {
682        self.add(value)
683    }
684
685    fn end(self) -> Result<()> {
686        self.end()
687    }
688}
689
690impl<W, C> ser::SerializeTupleStruct for Compound<W, C>
691where
692    W: CompoundWrite,
693    C: CompoundType,
694{
695    type Ok = ();
696    type Error = Error;
697
698    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
699    where
700        T: ?Sized + Serialize,
701    {
702        self.add(value)
703    }
704
705    fn end(self) -> Result<()> {
706        self.end()
707    }
708}
709
710impl<W, C> ser::SerializeTupleVariant for Compound<W, C>
711where
712    W: CompoundWrite,
713    C: CompoundType,
714{
715    type Ok = ();
716    type Error = Error;
717
718    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
719    where
720        T: ?Sized + Serialize,
721    {
722        self.add(value)
723    }
724
725    fn end(self) -> Result<()> {
726        self.end()
727    }
728}
729
730impl<W, C> ser::SerializeMap for Compound<W, C>
731where
732    W: CompoundWrite,
733    C: CompoundType,
734{
735    type Ok = ();
736    type Error = Error;
737
738    fn serialize_key<T>(&mut self, key: &T) -> Result<()>
739    where
740        T: ?Sized + Serialize,
741    {
742        self.add(key)
743    }
744
745    fn serialize_value<T>(&mut self, value: &T) -> Result<()>
746    where
747        T: ?Sized + Serialize,
748    {
749        self.add(value)
750    }
751
752    fn end(self) -> Result<()> {
753        self.end()
754    }
755}
756
757impl<W, C> ser::SerializeStruct for Compound<W, C>
758where
759    W: CompoundWrite,
760    C: CompoundType,
761{
762    type Ok = ();
763    type Error = Error;
764
765    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
766    where
767        T: ?Sized + Serialize,
768    {
769        self.add(key)?;
770        self.add(value)
771    }
772
773    fn end(self) -> Result<()> {
774        self.end()
775    }
776}
777
778impl<W, C> ser::SerializeStructVariant for Compound<W, C>
779where
780    W: CompoundWrite,
781    C: CompoundType,
782{
783    type Ok = ();
784    type Error = Error;
785
786    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
787    where
788        T: ?Sized + Serialize,
789    {
790        self.add(key)?;
791        self.add(value)
792    }
793
794    fn end(self) -> Result<()> {
795        self.end()
796    }
797}