1use std::{borrow::Cow, str::FromStr};
2
3use serde::de::{self, Unexpected};
4
5use crate::parser::{self, parse_array, parse_bytes, parse_err, parse_int_loose, parse_str_loose};
6
7use super::{Enum, Error, WithLen};
8
9pub struct Deserializer<'de> {
11 pub input: &'de [u8],
12}
13
14impl<'de> Deserializer<'de> {
15 fn parse_error(&mut self) -> Result<&'de str, Error<'de>> {
16 let (rem, str) = parse_err(self.input)?;
17 self.input = rem;
18
19 Ok(str)
20 }
21
22 fn parse_str(&mut self) -> Result<&'de str, Error<'de>> {
23 self.check_error()?;
24
25 let (rem, str) = parse_str_loose(self.input)?;
26 self.input = rem;
27
28 Ok(str)
29 }
30
31 fn parse_str_into<T>(&mut self) -> Result<T, Error<'de>>
32 where
33 T: FromStr,
34 <T as FromStr>::Err: std::fmt::Display,
35 {
36 self.parse_str()?
37 .parse()
38 .map_err::<Error, _>(de::Error::custom)
39 }
40
41 fn parse_int(&mut self) -> Result<i64, Error<'de>> {
42 self.check_error()?;
43
44 let (rem, int) = parse_int_loose(self.input)?;
45 self.input = rem;
46
47 Ok(int)
48 }
49
50 fn parse_int_into<T>(&mut self) -> Result<T, Error<'de>>
51 where
52 T: TryFrom<i64>,
53 <T as TryFrom<i64>>::Error: std::fmt::Display,
54 {
55 self.parse_int()?
56 .try_into()
57 .map_err::<Error, _>(de::Error::custom)
58 }
59
60 fn parse_bytes(&mut self) -> Result<Option<&'de [u8]>, Error<'de>> {
61 self.check_error()?;
62
63 let (rem, bytes) = parse_bytes(self.input)?;
64 self.input = rem;
65
66 Ok(bytes)
67 }
68
69 fn parse_array(&mut self) -> Result<i64, Error<'de>> {
70 self.check_error()?;
71
72 let (rem, len) = parse_array(self.input)?;
73 self.input = rem;
74
75 Ok(len)
76 }
77
78 fn parse_array_len(
79 &mut self,
80 exp: usize,
81 visitor: &impl de::Visitor<'de>,
82 ) -> Result<i64, Error<'de>> {
83 let len = self.parse_array()?;
84 let maybe_exp_signed: Result<i64, _> = exp.try_into();
85
86 match maybe_exp_signed {
87 Ok(exp_signed) if exp_signed == len => Ok(len),
88 _ => Err(de::Error::invalid_length(len as usize, visitor)),
89 }
90 }
91
92 fn check_error(&mut self) -> Result<(), Error<'de>> {
93 if self.input.get(0).copied() == Some(b'-') {
94 Err(Error::Redis(Cow::Borrowed(self.parse_error()?)))
95 } else {
96 Ok(())
97 }
98 }
99}
100
101impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
102 type Error = Error<'de>;
103
104 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
105 where
106 V: de::Visitor<'de>,
107 {
108 match self.input.get(0) {
109 Some(b'+') => self.deserialize_str(visitor),
110 Some(b'-') => Err(Error::Redis(Cow::Borrowed(self.parse_error()?))),
111 Some(b':') => self.deserialize_i64(visitor),
112 Some(b'$') => self.deserialize_bytes(visitor),
113 Some(b'*') => self.deserialize_seq(visitor),
114 Some(b) => Err(de::Error::invalid_value(
115 Unexpected::Unsigned(*b as u64),
116 &visitor,
117 )),
118 None => Err(Error::Parse(parser::Error::Incomplete(
119 nom::Needed::Unknown,
120 ))),
121 }
122 }
123
124 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
125 where
126 V: de::Visitor<'de>,
127 {
128 visitor.visit_bool(self.parse_str_into()?)
129 }
130
131 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
132 where
133 V: de::Visitor<'de>,
134 {
135 visitor.visit_i8(self.parse_int_into()?)
136 }
137
138 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
139 where
140 V: de::Visitor<'de>,
141 {
142 visitor.visit_i16(self.parse_int_into()?)
143 }
144
145 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146 where
147 V: de::Visitor<'de>,
148 {
149 visitor.visit_i32(self.parse_int_into()?)
150 }
151
152 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
153 where
154 V: de::Visitor<'de>,
155 {
156 visitor.visit_i64(self.parse_int()?)
157 }
158
159 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
160 where
161 V: de::Visitor<'de>,
162 {
163 visitor.visit_u8(self.parse_int_into()?)
164 }
165
166 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
167 where
168 V: de::Visitor<'de>,
169 {
170 visitor.visit_u16(self.parse_int_into()?)
171 }
172
173 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
174 where
175 V: de::Visitor<'de>,
176 {
177 visitor.visit_u32(self.parse_int_into()?)
178 }
179
180 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
181 where
182 V: de::Visitor<'de>,
183 {
184 visitor.visit_u64(self.parse_int_into()?)
185 }
186
187 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188 where
189 V: de::Visitor<'de>,
190 {
191 visitor.visit_f32(self.parse_str_into()?)
192 }
193
194 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
195 where
196 V: de::Visitor<'de>,
197 {
198 visitor.visit_f64(self.parse_str_into()?)
199 }
200
201 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202 where
203 V: de::Visitor<'de>,
204 {
205 visitor.visit_char(self.parse_str_into()?)
206 }
207
208 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209 where
210 V: de::Visitor<'de>,
211 {
212 visitor.visit_borrowed_str(self.parse_str()?)
213 }
214
215 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
216 where
217 V: de::Visitor<'de>,
218 {
219 self.deserialize_str(visitor)
220 }
221
222 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
223 where
224 V: de::Visitor<'de>,
225 {
226 match self.parse_bytes()? {
227 Some(d) => visitor.visit_borrowed_bytes(d),
228 None => visitor.visit_none(),
229 }
230 }
231
232 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
233 where
234 V: de::Visitor<'de>,
235 {
236 self.deserialize_bytes(visitor)
237 }
238
239 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
240 where
241 V: de::Visitor<'de>,
242 {
243 match self.input.get(0..5) {
244 Some(b"*-1\r\n") | Some(b"$-1\r\n") => {
245 self.input = &self.input[5..];
246 visitor.visit_none()
247 }
248 _ => visitor.visit_some(self),
249 }
250 }
251
252 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
253 where
254 V: de::Visitor<'de>,
255 {
256 self.check_error()?;
257 visitor.visit_none()
258 }
259
260 fn deserialize_unit_struct<V>(
261 self,
262 _name: &'static str,
263 visitor: V,
264 ) -> Result<V::Value, Self::Error>
265 where
266 V: de::Visitor<'de>,
267 {
268 self.deserialize_unit(visitor)
269 }
270
271 fn deserialize_newtype_struct<V>(
272 self,
273 _name: &'static str,
274 visitor: V,
275 ) -> Result<V::Value, Self::Error>
276 where
277 V: de::Visitor<'de>,
278 {
279 self.check_error()?;
280 visitor.visit_newtype_struct(self)
281 }
282
283 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
284 where
285 V: de::Visitor<'de>,
286 {
287 let len = self.parse_array()?;
288
289 if len < 0 {
290 visitor.visit_none()
291 } else {
292 visitor.visit_seq(WithLen {
293 de: self,
294 cur: 0,
295 len,
296 })
297 }
298 }
299
300 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
301 where
302 V: de::Visitor<'de>,
303 {
304 let len = self.parse_array_len(len, &visitor)?;
305
306 visitor.visit_seq(WithLen {
307 de: self,
308 cur: 0,
309 len,
310 })
311 }
312
313 fn deserialize_tuple_struct<V>(
314 self,
315 _name: &'static str,
316 len: usize,
317 visitor: V,
318 ) -> Result<V::Value, Self::Error>
319 where
320 V: de::Visitor<'de>,
321 {
322 let len = self.parse_array_len(len, &visitor)?;
323
324 visitor.visit_seq(WithLen {
325 de: self,
326 cur: 0,
327 len,
328 })
329 }
330
331 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
332 where
333 V: de::Visitor<'de>,
334 {
335 let len = self.parse_array()?;
336
337 if len < 0 {
338 visitor.visit_none()
339 } else {
340 visitor.visit_map(WithLen {
341 de: self,
342 cur: 0,
343 len: len / 2,
344 })
345 }
346 }
347
348 fn deserialize_struct<V>(
349 self,
350 _name: &'static str,
351 _fields: &'static [&'static str],
352 visitor: V,
353 ) -> Result<V::Value, Self::Error>
354 where
355 V: de::Visitor<'de>,
356 {
357 self.deserialize_map(visitor)
358 }
359
360 fn deserialize_enum<V>(
361 self,
362 _name: &'static str,
363 _variants: &'static [&'static str],
364 visitor: V,
365 ) -> Result<V::Value, Self::Error>
366 where
367 V: de::Visitor<'de>,
368 {
369 visitor.visit_enum(Enum { de: self })
370 }
371
372 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
373 where
374 V: de::Visitor<'de>,
375 {
376 self.deserialize_str(visitor)
377 }
378
379 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
380 where
381 V: de::Visitor<'de>,
382 {
383 self.deserialize_any(visitor)
384 }
385}