Skip to main content

rustis/resp/
util.rs

1use serde::{
2    Deserializer, Serialize, Serializer,
3    de::{self, DeserializeOwned, DeserializeSeed, Visitor},
4};
5use smallvec::SmallVec;
6use std::{fmt, marker::PhantomData};
7
8/// Deserialize a Vec of pairs from a sequence
9pub fn deserialize_vec_of_pairs<'de, D, T1, T2>(
10    deserializer: D,
11) -> std::result::Result<Vec<(T1, T2)>, D::Error>
12where
13    D: Deserializer<'de>,
14    T1: DeserializeOwned,
15    T2: DeserializeOwned,
16{
17    struct VecOfPairsVisitor<T1, T2>
18    where
19        T1: DeserializeOwned,
20        T2: DeserializeOwned,
21    {
22        phantom: PhantomData<(T1, T2)>,
23    }
24
25    impl<'de, T1, T2> Visitor<'de> for VecOfPairsVisitor<T1, T2>
26    where
27        T1: DeserializeOwned,
28        T2: DeserializeOwned,
29    {
30        type Value = Vec<(T1, T2)>;
31
32        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
33            formatter.write_str("Vec<(T1, T2)>")
34        }
35
36        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
37        where
38            A: serde::de::SeqAccess<'de>,
39        {
40            let mut v = if let Some(size) = seq.size_hint() {
41                Vec::with_capacity(size / 2)
42            } else {
43                Vec::new()
44            };
45
46            while let Some(first) = seq.next_element()? {
47                let Some(second) = seq.next_element()? else {
48                    return Err(de::Error::custom("invalid length"));
49                };
50
51                v.push((first, second));
52            }
53
54            Ok(v)
55        }
56    }
57
58    deserializer.deserialize_seq(VecOfPairsVisitor {
59        phantom: PhantomData,
60    })
61}
62
63/// Deserialize a Vec of triplets from a sequence
64pub fn deserialize_vec_of_triplets<'de, D, T1, T2, T3>(
65    deserializer: D,
66) -> std::result::Result<Vec<(T1, T2, T3)>, D::Error>
67where
68    D: Deserializer<'de>,
69    T1: DeserializeOwned,
70    T2: DeserializeOwned,
71    T3: DeserializeOwned,
72{
73    struct VecOfTripletVisitor<T1, T2, T3>
74    where
75        T1: DeserializeOwned,
76        T2: DeserializeOwned,
77        T3: DeserializeOwned,
78    {
79        phantom: PhantomData<(T1, T2, T3)>,
80    }
81
82    impl<'de, T1, T2, T3> Visitor<'de> for VecOfTripletVisitor<T1, T2, T3>
83    where
84        T1: DeserializeOwned,
85        T2: DeserializeOwned,
86        T3: DeserializeOwned,
87    {
88        type Value = Vec<(T1, T2, T3)>;
89
90        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
91            formatter.write_str("Vec<(T1, T2, T3)>")
92        }
93
94        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
95        where
96            A: serde::de::SeqAccess<'de>,
97        {
98            let mut v = if let Some(size) = seq.size_hint() {
99                Vec::with_capacity(size / 3)
100            } else {
101                Vec::new()
102            };
103
104            while let Some(first) = seq.next_element()? {
105                let Some(second) = seq.next_element()? else {
106                    return Err(de::Error::custom("invalid length"));
107                };
108
109                let Some(third) = seq.next_element()? else {
110                    return Err(de::Error::custom("invalid length"));
111                };
112
113                v.push((first, second, third));
114            }
115
116            Ok(v)
117        }
118    }
119
120    deserializer.deserialize_seq(VecOfTripletVisitor {
121        phantom: PhantomData,
122    })
123}
124
125/// Deserialize a byte buffer (Vec\<u8\>)
126pub fn deserialize_byte_buf<'de, D>(deserializer: D) -> std::result::Result<Vec<u8>, D::Error>
127where
128    D: Deserializer<'de>,
129{
130    struct ByteBufVisitor;
131
132    impl Visitor<'_> for ByteBufVisitor {
133        type Value = Vec<u8>;
134
135        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
136            formatter.write_str("Vec<u8>")
137        }
138
139        fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
140        where
141            E: serde::de::Error,
142        {
143            Ok(v)
144        }
145    }
146
147    deserializer.deserialize_byte_buf(ByteBufVisitor)
148}
149
150/// Serialize a byte buffer (&\[u8\])
151pub fn serialize_byte_buf<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
152where
153    S: Serializer,
154{
155    serializer.serialize_bytes(bytes)
156}
157
158/// Serialize a byte buffer (&\[u8\]) option
159pub fn serialize_byte_buf_option<S>(bytes: &Option<&[u8]>, serializer: S) -> Result<S::Ok, S::Error>
160where
161    S: Serializer,
162{
163    if let Some(bytes) = bytes {
164        serializer.serialize_bytes(bytes)
165    } else {
166        serializer.serialize_none()
167    }
168}
169
170pub(crate) struct ByteBufSeed;
171
172impl<'de> DeserializeSeed<'de> for ByteBufSeed {
173    type Value = Vec<u8>;
174
175    fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
176    where
177        D: Deserializer<'de>,
178    {
179        deserialize_byte_buf(deserializer)
180    }
181}
182
183/// Deserialize a byte slice (&\[u8\])
184pub fn deserialize_bytes<'de, D>(deserializer: D) -> std::result::Result<&'de [u8], D::Error>
185where
186    D: Deserializer<'de>,
187{
188    struct ByteBufVisitor;
189
190    impl<'de> Visitor<'de> for ByteBufVisitor {
191        type Value = &'de [u8];
192
193        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194            formatter.write_str("&'de [u8]")
195        }
196
197        fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
198        where
199            E: de::Error,
200        {
201            Ok(v)
202        }
203    }
204
205    deserializer.deserialize_bytes(ByteBufVisitor)
206}
207
208#[derive(Default)]
209pub(crate) struct VecOfPairsSeed<T1, T2>
210where
211    T1: DeserializeOwned,
212    T2: DeserializeOwned,
213{
214    phatom: PhantomData<(T1, T2)>,
215}
216
217impl<T1, T2> VecOfPairsSeed<T1, T2>
218where
219    T1: DeserializeOwned,
220    T2: DeserializeOwned,
221{
222    #[allow(dead_code)]
223    pub fn new() -> Self {
224        Self {
225            phatom: PhantomData,
226        }
227    }
228}
229
230impl<'de, T1, T2> DeserializeSeed<'de> for VecOfPairsSeed<T1, T2>
231where
232    T1: DeserializeOwned,
233    T2: DeserializeOwned,
234{
235    type Value = Vec<(T1, T2)>;
236
237    #[inline]
238    fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
239    where
240        D: Deserializer<'de>,
241    {
242        deserialize_vec_of_pairs(deserializer)
243    }
244}
245
246/// Serialize field name only and skip the boolean value
247pub(crate) fn serialize_flag<S: serde::Serializer>(
248    _: &bool,
249    serializer: S,
250) -> std::result::Result<S::Ok, S::Error> {
251    serializer.serialize_unit()
252}
253
254/// Serializes a slice prefixed by its length.
255/// Use with #[serde(serialize_with = "serialize_slice_with_len")]
256pub(crate) fn serialize_slice_with_len<S, T>(slice: &[T], serializer: S) -> Result<S::Ok, S::Error>
257where
258    S: Serializer,
259    T: Serialize,
260{
261    // Astuce : Le tuple (usize, &[T]) est sérialisé séquentiellement
262    (slice.len(), slice).serialize(serializer)
263}
264
265pub struct SmallVecWithCounter<T, const N: usize>(usize, SmallVec<[T; N]>);
266
267impl<T, const N: usize> SmallVecWithCounter<T, N> {
268    pub fn push(&mut self, value: T) {
269        self.0 += 1;
270        self.1.push(value);
271    }
272}