1use redis::{RedisWrite, ToRedisArgs};
2use rmp::encode;
3use serde::{de, ser, Serialize};
4use std::{
5 fmt::{self, Display},
6 io::{self, Write},
7};
8
9trait RedisArgWrite: RedisWrite {
10 fn pack(&mut self);
11}
12
13#[doc(hidden)]
14pub struct ScriptArg {
15 buf: Vec<u8>,
16 pack: bool,
17}
18
19impl ScriptArg {
20 fn new() -> Self {
21 Self {
22 buf: Vec::with_capacity(128),
23 pack: false,
24 }
25 }
26
27 pub fn pack(&self) -> bool {
28 self.pack
29 }
30}
31
32impl RedisWrite for ScriptArg {
33 fn write_arg(&mut self, arg: &[u8]) {
34 self.buf.extend(arg);
35 }
36}
37
38impl RedisArgWrite for ScriptArg {
39 fn pack(&mut self) {
40 self.pack = true;
41 }
42}
43
44pub fn script_arg<T: Serialize + ?Sized>(value: &T) -> ScriptArg {
45 let mut arg = ScriptArg::new();
46 let mut ser = Serializer::new(&mut arg);
47 value.serialize(&mut ser).expect("Couldn't serialize");
48 arg
49}
50
51impl ToRedisArgs for ScriptArg {
52 fn write_redis_args<W: ?Sized>(&self, out: &mut W)
53 where
54 W: RedisWrite,
55 {
56 self.buf.write_redis_args(out);
57 }
58}
59
60type Result<T> = std::result::Result<T, Error>;
61
62#[derive(Debug)]
63struct Error(String);
64
65impl ser::Error for Error {
66 fn custom<T: Display>(msg: T) -> Self {
67 Error(msg.to_string())
68 }
69}
70
71impl de::Error for Error {
72 fn custom<T: Display>(msg: T) -> Self {
73 Error(msg.to_string())
74 }
75}
76
77impl Display for Error {
78 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79 write!(f, "{}", self.0)
80 }
81}
82
83impl std::error::Error for Error {}
84
85impl From<rmp::encode::Error> for Error {
86 fn from(e: rmp::encode::Error) -> Self {
87 Self(e.to_string())
88 }
89}
90
91impl From<rmp::encode::ValueWriteError> for Error {
92 fn from(e: rmp::encode::ValueWriteError) -> Self {
93 Self(e.to_string())
94 }
95}
96
97struct Serializer<'a, W: ?Sized>(&'a mut W);
98
99impl<'a, W> Serializer<'a, W>
100where
101 W: RedisArgWrite + ?Sized,
102{
103 fn new(w: &'a mut W) -> Self {
104 Self(w)
105 }
106
107 fn write_null(&mut self) {
108 Vec::<u8>::new().write_redis_args(self.0);
109 }
110}
111
112impl<'a, 'b, W> ser::Serializer for &'a mut Serializer<'b, W>
113where
114 W: RedisArgWrite + ?Sized,
115{
116 type Ok = ();
117 type Error = Error;
118
119 type SerializeSeq = Compound<Arg<'a, W>, Seq>;
120 type SerializeTuple = Compound<Arg<'a, W>, Seq>;
121 type SerializeTupleStruct = Compound<Arg<'a, W>, Seq>;
122 type SerializeTupleVariant = Compound<Arg<'a, W>, Seq>;
123 type SerializeMap = Compound<Arg<'a, W>, Map>;
124 type SerializeStruct = Compound<Arg<'a, W>, Map>;
125 type SerializeStructVariant = Compound<Arg<'a, W>, Map>;
126
127 fn serialize_bool(self, v: bool) -> Result<()> {
128 Ok((v as usize).write_redis_args(self.0))
129 }
130
131 fn serialize_i8(self, v: i8) -> Result<()> {
132 Ok(v.write_redis_args(self.0))
133 }
134
135 fn serialize_i16(self, v: i16) -> Result<()> {
136 Ok(v.write_redis_args(self.0))
137 }
138
139 fn serialize_i32(self, v: i32) -> Result<()> {
140 Ok(v.write_redis_args(self.0))
141 }
142
143 fn serialize_i64(self, v: i64) -> Result<()> {
144 Ok(v.write_redis_args(self.0))
145 }
146
147 fn serialize_u8(self, v: u8) -> Result<()> {
148 Ok(v.write_redis_args(self.0))
149 }
150
151 fn serialize_u16(self, v: u16) -> Result<()> {
152 Ok(v.write_redis_args(self.0))
153 }
154
155 fn serialize_u32(self, v: u32) -> Result<()> {
156 Ok(v.write_redis_args(self.0))
157 }
158
159 fn serialize_u64(self, v: u64) -> Result<()> {
160 Ok(v.write_redis_args(self.0))
161 }
162
163 fn serialize_f32(self, v: f32) -> Result<()> {
164 Ok(v.write_redis_args(self.0))
165 }
166
167 fn serialize_f64(self, v: f64) -> Result<()> {
168 Ok(v.write_redis_args(self.0))
169 }
170
171 fn serialize_char(self, v: char) -> Result<()> {
172 let mut buf = [0; 4];
173 let len = v.encode_utf8(&mut buf).len();
174 Ok((&buf[..len]).write_redis_args(self.0))
175 }
176
177 fn serialize_str(self, v: &str) -> Result<()> {
178 Ok(v.write_redis_args(self.0))
179 }
180
181 fn serialize_bytes(self, v: &[u8]) -> Result<()> {
182 Ok(v.write_redis_args(self.0))
183 }
184
185 fn serialize_none(self) -> Result<()> {
186 Ok(self.write_null())
187 }
188
189 fn serialize_some<T>(self, value: &T) -> Result<()>
190 where
191 T: ?Sized + Serialize,
192 {
193 value.serialize(self)
194 }
195
196 fn serialize_unit(self) -> Result<()> {
197 Ok(self.write_null())
198 }
199
200 fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
201 Ok(self.write_null())
202 }
203
204 fn serialize_unit_variant(
205 self,
206 _name: &'static str,
207 _variant_index: u32,
208 variant: &'static str,
209 ) -> Result<()> {
210 variant.serialize(self)
211 }
212
213 fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
214 where
215 T: ?Sized + Serialize,
216 {
217 value.serialize(self)
218 }
219
220 fn serialize_newtype_variant<T>(
221 self,
222 _name: &'static str,
223 _variant_index: u32,
224 _variant: &'static str,
225 value: &T,
226 ) -> Result<()>
227 where
228 T: ?Sized + Serialize,
229 {
230 value.serialize(self)
231 }
232
233 fn serialize_seq(self, _: Option<usize>) -> Result<Self::SerializeSeq> {
234 Ok(Compound::new(Arg(self.0), Seq::new()))
235 }
236
237 fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
238 self.serialize_seq(Some(len))
239 }
240
241 fn serialize_tuple_struct(
242 self,
243 _name: &'static str,
244 len: usize,
245 ) -> Result<Self::SerializeTupleStruct> {
246 self.serialize_seq(Some(len))
247 }
248
249 fn serialize_tuple_variant(
250 self,
251 _name: &'static str,
252 _variant_index: u32,
253 _variant: &'static str,
254 len: usize,
255 ) -> Result<Self::SerializeTupleVariant> {
256 self.serialize_seq(Some(len))
257 }
258
259 fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
260 Ok(Compound::new(Arg(self.0), Map::new()))
261 }
262
263 fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
264 self.serialize_map(Some(len))
265 }
266
267 fn serialize_struct_variant(
268 self,
269 _name: &'static str,
270 _variant_index: u32,
271 _variant: &'static str,
272 len: usize,
273 ) -> Result<Self::SerializeStructVariant> {
274 self.serialize_map(Some(len))
275 }
276}
277
278struct ComplexSerializer<W>(W, Option<u8>);
280
281impl<W> ComplexSerializer<W>
282where
283 W: io::Write,
284{
285 fn new(w: W) -> Self {
286 Self(w, None)
287 }
288
289 fn write_str(&mut self, s: &[u8]) -> Result<()> {
291 encode::write_str_len(&mut self.0, s.len() as u32)?;
292 Ok(self.0.write_all(s)?)
293 }
294
295 fn single_byte(&self) -> Option<u8> {
296 self.1
297 }
298}
299
300impl<'a, W> ser::Serializer for &'a mut ComplexSerializer<W>
301where
302 W: io::Write,
303{
304 type Ok = ();
305 type Error = Error;
306
307 type SerializeSeq = Compound<Buf<'a, W>, Seq>;
308 type SerializeTuple = Compound<Buf<'a, W>, Seq>;
309 type SerializeTupleStruct = Compound<Buf<'a, W>, Seq>;
310 type SerializeTupleVariant = Compound<Buf<'a, W>, Seq>;
311 type SerializeMap = Compound<Buf<'a, W>, Map>;
312 type SerializeStruct = Compound<Buf<'a, W>, Map>;
313 type SerializeStructVariant = Compound<Buf<'a, W>, Map>;
314
315 fn serialize_bool(self, v: bool) -> Result<()> {
316 Ok(encode::write_bool(&mut self.0, v)?)
317 }
318
319 fn serialize_i8(self, v: i8) -> Result<()> {
320 self.serialize_i64(v as i64)
321 }
322
323 fn serialize_i16(self, v: i16) -> Result<()> {
324 self.serialize_i64(v as i64)
325 }
326
327 fn serialize_i32(self, v: i32) -> Result<()> {
328 self.serialize_i64(v as i64)
329 }
330
331 fn serialize_i64(self, v: i64) -> Result<()> {
332 encode::write_sint(&mut self.0, v)?;
333 Ok(())
334 }
335
336 fn serialize_u8(self, v: u8) -> Result<()> {
337 self.1 = Some(v);
338 encode::write_uint(&mut self.0, v as u64)?;
339 Ok(())
340 }
341
342 fn serialize_u16(self, v: u16) -> Result<()> {
343 self.serialize_u64(v as u64)
344 }
345
346 fn serialize_u32(self, v: u32) -> Result<()> {
347 self.serialize_u64(v as u64)
348 }
349
350 fn serialize_u64(self, v: u64) -> Result<()> {
351 encode::write_uint(&mut self.0, v)?;
352 Ok(())
353 }
354
355 fn serialize_f32(self, v: f32) -> Result<()> {
356 encode::write_f32(&mut self.0, v)?;
357 Ok(())
358 }
359
360 fn serialize_f64(self, v: f64) -> Result<()> {
361 encode::write_f64(&mut self.0, v)?;
362 Ok(())
363 }
364
365 fn serialize_char(self, v: char) -> Result<()> {
366 let mut buf = [0; 4];
367 self.serialize_str(v.encode_utf8(&mut buf))
368 }
369
370 fn serialize_str(self, v: &str) -> Result<()> {
371 self.write_str(v.as_bytes())
372 }
373
374 fn serialize_bytes(self, v: &[u8]) -> Result<()> {
375 self.write_str(v)
376 }
377
378 fn serialize_none(self) -> Result<()> {
379 self.serialize_unit()
380 }
381
382 fn serialize_some<T>(self, value: &T) -> Result<()>
383 where
384 T: ?Sized + Serialize,
385 {
386 value.serialize(self)
387 }
388
389 fn serialize_unit(self) -> Result<()> {
390 encode::write_nil(&mut self.0)?;
391 Ok(())
392 }
393
394 fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
395 encode::write_array_len(&mut self.0, 0)?;
396 Ok(())
397 }
398
399 fn serialize_unit_variant(
400 self,
401 _name: &'static str,
402 variant_index: u32,
403 _variant: &'static str,
404 ) -> Result<()> {
405 variant_index.serialize(self)
406 }
407
408 fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
409 where
410 T: ?Sized + Serialize,
411 {
412 value.serialize(self)
413 }
414
415 fn serialize_newtype_variant<T>(
416 self,
417 _name: &'static str,
418 _variant_index: u32,
419 _variant: &'static str,
420 value: &T,
421 ) -> Result<()>
422 where
423 T: ?Sized + Serialize,
424 {
425 value.serialize(self)
426 }
427
428 fn serialize_seq(self, _: Option<usize>) -> Result<Self::SerializeSeq> {
429 Ok(Compound::new(Buf(&mut self.0), Seq::new()))
430 }
431
432 fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
433 self.serialize_seq(Some(len))
434 }
435
436 fn serialize_tuple_struct(
437 self,
438 _name: &'static str,
439 len: usize,
440 ) -> Result<Self::SerializeTupleStruct> {
441 self.serialize_seq(Some(len))
442 }
443
444 fn serialize_tuple_variant(
445 self,
446 _name: &'static str,
447 _variant_index: u32,
448 _variant: &'static str,
449 len: usize,
450 ) -> Result<Self::SerializeTupleVariant> {
451 self.serialize_seq(Some(len))
452 }
453
454 fn serialize_map(self, _: Option<usize>) -> Result<Self::SerializeMap> {
455 Ok(Compound::new(Buf(&mut self.0), Map::new()))
456 }
457
458 fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
459 self.serialize_map(Some(len))
460 }
461
462 fn serialize_struct_variant(
463 self,
464 _name: &'static str,
465 _variant_index: u32,
466 _variant: &'static str,
467 len: usize,
468 ) -> Result<Self::SerializeStructVariant> {
469 self.serialize_map(Some(len))
470 }
471}
472
473trait CompoundWrite {
474 fn write_raw(&mut self, buf: &[u8]) -> Result<()>;
475
476 fn write_packed(&mut self, buf: &[u8]) -> Result<()>;
477
478 fn is_arg(&self) -> bool;
479}
480
481struct Arg<'a, W: ?Sized>(&'a mut W);
482
483struct Buf<'a, W>(&'a mut W);
484
485impl<'a, W> CompoundWrite for Arg<'a, W>
486where
487 W: RedisArgWrite + ?Sized,
488{
489 fn write_raw(&mut self, buf: &[u8]) -> Result<()> {
490 buf.write_redis_args(self.0);
491 Ok(())
492 }
493
494 fn write_packed(&mut self, buf: &[u8]) -> Result<()> {
495 self.0.pack();
496 buf.write_redis_args(self.0);
497 Ok(())
498 }
499
500 fn is_arg(&self) -> bool {
501 true
502 }
503}
504
505impl<'a, W> CompoundWrite for Buf<'a, W>
506where
507 W: io::Write,
508{
509 fn write_raw(&mut self, buf: &[u8]) -> Result<()> {
510 self.0.write_all(buf)?;
511 Ok(())
512 }
513
514 fn write_packed(&mut self, buf: &[u8]) -> Result<()> {
515 self.write_raw(buf)
516 }
517
518 fn is_arg(&self) -> bool {
519 false
520 }
521}
522
523trait CompoundType {
524 fn is_map() -> bool;
525
526 fn add_byte(&mut self, _byte: Option<u8>) {}
527
528 fn bytearray(&self) -> Option<&[u8]> {
529 None
530 }
531}
532
533struct Map;
534
535impl Map {
536 fn new() -> Self {
537 Self
538 }
539}
540
541struct Seq {
542 bytearray: Vec<u8>,
543 is_bytearray: bool,
544}
545
546impl Seq {
547 fn new() -> Self {
548 Self {
549 bytearray: vec![],
550 is_bytearray: true,
551 }
552 }
553}
554
555impl CompoundType for Map {
556 fn is_map() -> bool {
557 true
558 }
559}
560
561impl CompoundType for Seq {
562 fn is_map() -> bool {
563 false
564 }
565
566 fn add_byte(&mut self, byte: Option<u8>) {
567 self.is_bytearray &= byte.is_some();
568 if let Some(byte) = byte {
569 self.bytearray.push(byte);
570 }
571 }
572
573 fn bytearray(&self) -> Option<&[u8]> {
574 if self.is_bytearray {
575 Some(&self.bytearray)
576 } else {
577 None
578 }
579 }
580}
581
582struct Compound<W, C> {
583 buf: Vec<u8>,
584 len: usize,
585 wr: W,
586 inner: C,
587}
588
589impl<W, C> Compound<W, C>
590where
591 W: CompoundWrite,
592 C: CompoundType,
593{
594 fn new(wr: W, inner: C) -> Self {
595 Self {
596 buf: vec![],
597 len: 0,
598 wr,
599 inner,
600 }
601 }
602
603 fn add<T: ?Sized>(&mut self, value: &T) -> Result<()>
604 where
605 T: Serialize,
606 {
607 self.len += 1;
608
609 self.inner.add_byte({
610 let mut ser = ComplexSerializer::new(&mut self.buf);
611 value.serialize(&mut ser)?;
612 ser.single_byte()
613 });
614
615 Ok(())
616 }
617
618 fn end(mut self) -> Result<()> {
619 let mut v = Vec::new();
620
621 if C::is_map() {
622 encode::write_map_len(&mut v, self.len as u32 / 2)?;
625 v.write_all(&self.buf)?;
626 self.wr.write_packed(&v)?;
627 } else {
628 if let Some(bytearray) = self.inner.bytearray() {
629 if self.len > 0 {
630 if self.wr.is_arg() {
632 self.wr.write_raw(bytearray)?;
633 } else {
634 encode::write_str_len(&mut v, self.len as u32)?;
635 v.write_all(bytearray)?;
636 self.wr.write_packed(&v)?;
637 }
638 return Ok(());
639 }
640 }
641 encode::write_array_len(&mut v, self.len as u32)?;
642 v.write_all(&self.buf)?;
643 self.wr.write_packed(&v)?;
644 }
645
646 Ok(())
647 }
648}
649
650impl<W, C> ser::SerializeSeq for Compound<W, C>
651where
652 W: CompoundWrite,
653 C: CompoundType,
654{
655 type Ok = ();
656 type Error = Error;
657
658 fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<()>
659 where
660 T: Serialize,
661 {
662 self.add(value)
663 }
664
665 fn end(self) -> Result<Self::Ok> {
666 self.end()
667 }
668}
669
670impl<W, C> ser::SerializeTuple for Compound<W, C>
671where
672 W: CompoundWrite,
673 C: CompoundType,
674{
675 type Ok = ();
676 type Error = Error;
677
678 fn serialize_element<T>(&mut self, value: &T) -> Result<()>
679 where
680 T: ?Sized + Serialize,
681 {
682 self.add(value)
683 }
684
685 fn end(self) -> Result<()> {
686 self.end()
687 }
688}
689
690impl<W, C> ser::SerializeTupleStruct for Compound<W, C>
691where
692 W: CompoundWrite,
693 C: CompoundType,
694{
695 type Ok = ();
696 type Error = Error;
697
698 fn serialize_field<T>(&mut self, value: &T) -> Result<()>
699 where
700 T: ?Sized + Serialize,
701 {
702 self.add(value)
703 }
704
705 fn end(self) -> Result<()> {
706 self.end()
707 }
708}
709
710impl<W, C> ser::SerializeTupleVariant for Compound<W, C>
711where
712 W: CompoundWrite,
713 C: CompoundType,
714{
715 type Ok = ();
716 type Error = Error;
717
718 fn serialize_field<T>(&mut self, value: &T) -> Result<()>
719 where
720 T: ?Sized + Serialize,
721 {
722 self.add(value)
723 }
724
725 fn end(self) -> Result<()> {
726 self.end()
727 }
728}
729
730impl<W, C> ser::SerializeMap for Compound<W, C>
731where
732 W: CompoundWrite,
733 C: CompoundType,
734{
735 type Ok = ();
736 type Error = Error;
737
738 fn serialize_key<T>(&mut self, key: &T) -> Result<()>
739 where
740 T: ?Sized + Serialize,
741 {
742 self.add(key)
743 }
744
745 fn serialize_value<T>(&mut self, value: &T) -> Result<()>
746 where
747 T: ?Sized + Serialize,
748 {
749 self.add(value)
750 }
751
752 fn end(self) -> Result<()> {
753 self.end()
754 }
755}
756
757impl<W, C> ser::SerializeStruct for Compound<W, C>
758where
759 W: CompoundWrite,
760 C: CompoundType,
761{
762 type Ok = ();
763 type Error = Error;
764
765 fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
766 where
767 T: ?Sized + Serialize,
768 {
769 self.add(key)?;
770 self.add(value)
771 }
772
773 fn end(self) -> Result<()> {
774 self.end()
775 }
776}
777
778impl<W, C> ser::SerializeStructVariant for Compound<W, C>
779where
780 W: CompoundWrite,
781 C: CompoundType,
782{
783 type Ok = ();
784 type Error = Error;
785
786 fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
787 where
788 T: ?Sized + Serialize,
789 {
790 self.add(key)?;
791 self.add(value)
792 }
793
794 fn end(self) -> Result<()> {
795 self.end()
796 }
797}