redust_resp/de/
deserializer.rs

1use std::{borrow::Cow, str::FromStr};
2
3use serde::de::{self, Unexpected};
4
5use crate::parser::{self, parse_array, parse_bytes, parse_err, parse_int_loose, parse_str_loose};
6
7use super::{Enum, Error, WithLen};
8
9/// RESP deserializer.
10pub struct Deserializer<'de> {
11	pub input: &'de [u8],
12}
13
14impl<'de> Deserializer<'de> {
15	fn parse_error(&mut self) -> Result<&'de str, Error<'de>> {
16		let (rem, str) = parse_err(self.input)?;
17		self.input = rem;
18
19		Ok(str)
20	}
21
22	fn parse_str(&mut self) -> Result<&'de str, Error<'de>> {
23		self.check_error()?;
24
25		let (rem, str) = parse_str_loose(self.input)?;
26		self.input = rem;
27
28		Ok(str)
29	}
30
31	fn parse_str_into<T>(&mut self) -> Result<T, Error<'de>>
32	where
33		T: FromStr,
34		<T as FromStr>::Err: std::fmt::Display,
35	{
36		self.parse_str()?
37			.parse()
38			.map_err::<Error, _>(de::Error::custom)
39	}
40
41	fn parse_int(&mut self) -> Result<i64, Error<'de>> {
42		self.check_error()?;
43
44		let (rem, int) = parse_int_loose(self.input)?;
45		self.input = rem;
46
47		Ok(int)
48	}
49
50	fn parse_int_into<T>(&mut self) -> Result<T, Error<'de>>
51	where
52		T: TryFrom<i64>,
53		<T as TryFrom<i64>>::Error: std::fmt::Display,
54	{
55		self.parse_int()?
56			.try_into()
57			.map_err::<Error, _>(de::Error::custom)
58	}
59
60	fn parse_bytes(&mut self) -> Result<Option<&'de [u8]>, Error<'de>> {
61		self.check_error()?;
62
63		let (rem, bytes) = parse_bytes(self.input)?;
64		self.input = rem;
65
66		Ok(bytes)
67	}
68
69	fn parse_array(&mut self) -> Result<i64, Error<'de>> {
70		self.check_error()?;
71
72		let (rem, len) = parse_array(self.input)?;
73		self.input = rem;
74
75		Ok(len)
76	}
77
78	fn parse_array_len(
79		&mut self,
80		exp: usize,
81		visitor: &impl de::Visitor<'de>,
82	) -> Result<i64, Error<'de>> {
83		let len = self.parse_array()?;
84		let maybe_exp_signed: Result<i64, _> = exp.try_into();
85
86		match maybe_exp_signed {
87			Ok(exp_signed) if exp_signed == len => Ok(len),
88			_ => Err(de::Error::invalid_length(len as usize, visitor)),
89		}
90	}
91
92	fn check_error(&mut self) -> Result<(), Error<'de>> {
93		if self.input.get(0).copied() == Some(b'-') {
94			Err(Error::Redis(Cow::Borrowed(self.parse_error()?)))
95		} else {
96			Ok(())
97		}
98	}
99}
100
101impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
102	type Error = Error<'de>;
103
104	fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
105	where
106		V: de::Visitor<'de>,
107	{
108		match self.input.get(0) {
109			Some(b'+') => self.deserialize_str(visitor),
110			Some(b'-') => Err(Error::Redis(Cow::Borrowed(self.parse_error()?))),
111			Some(b':') => self.deserialize_i64(visitor),
112			Some(b'$') => self.deserialize_bytes(visitor),
113			Some(b'*') => self.deserialize_seq(visitor),
114			Some(b) => Err(de::Error::invalid_value(
115				Unexpected::Unsigned(*b as u64),
116				&visitor,
117			)),
118			None => Err(Error::Parse(parser::Error::Incomplete(
119				nom::Needed::Unknown,
120			))),
121		}
122	}
123
124	fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
125	where
126		V: de::Visitor<'de>,
127	{
128		visitor.visit_bool(self.parse_str_into()?)
129	}
130
131	fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
132	where
133		V: de::Visitor<'de>,
134	{
135		visitor.visit_i8(self.parse_int_into()?)
136	}
137
138	fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
139	where
140		V: de::Visitor<'de>,
141	{
142		visitor.visit_i16(self.parse_int_into()?)
143	}
144
145	fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146	where
147		V: de::Visitor<'de>,
148	{
149		visitor.visit_i32(self.parse_int_into()?)
150	}
151
152	fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
153	where
154		V: de::Visitor<'de>,
155	{
156		visitor.visit_i64(self.parse_int()?)
157	}
158
159	fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
160	where
161		V: de::Visitor<'de>,
162	{
163		visitor.visit_u8(self.parse_int_into()?)
164	}
165
166	fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
167	where
168		V: de::Visitor<'de>,
169	{
170		visitor.visit_u16(self.parse_int_into()?)
171	}
172
173	fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
174	where
175		V: de::Visitor<'de>,
176	{
177		visitor.visit_u32(self.parse_int_into()?)
178	}
179
180	fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
181	where
182		V: de::Visitor<'de>,
183	{
184		visitor.visit_u64(self.parse_int_into()?)
185	}
186
187	fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188	where
189		V: de::Visitor<'de>,
190	{
191		visitor.visit_f32(self.parse_str_into()?)
192	}
193
194	fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
195	where
196		V: de::Visitor<'de>,
197	{
198		visitor.visit_f64(self.parse_str_into()?)
199	}
200
201	fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202	where
203		V: de::Visitor<'de>,
204	{
205		visitor.visit_char(self.parse_str_into()?)
206	}
207
208	fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209	where
210		V: de::Visitor<'de>,
211	{
212		visitor.visit_borrowed_str(self.parse_str()?)
213	}
214
215	fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
216	where
217		V: de::Visitor<'de>,
218	{
219		self.deserialize_str(visitor)
220	}
221
222	fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
223	where
224		V: de::Visitor<'de>,
225	{
226		match self.parse_bytes()? {
227			Some(d) => visitor.visit_borrowed_bytes(d),
228			None => visitor.visit_none(),
229		}
230	}
231
232	fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
233	where
234		V: de::Visitor<'de>,
235	{
236		self.deserialize_bytes(visitor)
237	}
238
239	fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
240	where
241		V: de::Visitor<'de>,
242	{
243		match self.input.get(0..5) {
244			Some(b"*-1\r\n") | Some(b"$-1\r\n") => {
245				self.input = &self.input[5..];
246				visitor.visit_none()
247			}
248			_ => visitor.visit_some(self),
249		}
250	}
251
252	fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
253	where
254		V: de::Visitor<'de>,
255	{
256		self.check_error()?;
257		visitor.visit_none()
258	}
259
260	fn deserialize_unit_struct<V>(
261		self,
262		_name: &'static str,
263		visitor: V,
264	) -> Result<V::Value, Self::Error>
265	where
266		V: de::Visitor<'de>,
267	{
268		self.deserialize_unit(visitor)
269	}
270
271	fn deserialize_newtype_struct<V>(
272		self,
273		_name: &'static str,
274		visitor: V,
275	) -> Result<V::Value, Self::Error>
276	where
277		V: de::Visitor<'de>,
278	{
279		self.check_error()?;
280		visitor.visit_newtype_struct(self)
281	}
282
283	fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
284	where
285		V: de::Visitor<'de>,
286	{
287		let len = self.parse_array()?;
288
289		if len < 0 {
290			visitor.visit_none()
291		} else {
292			visitor.visit_seq(WithLen {
293				de: self,
294				cur: 0,
295				len,
296			})
297		}
298	}
299
300	fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
301	where
302		V: de::Visitor<'de>,
303	{
304		let len = self.parse_array_len(len, &visitor)?;
305
306		visitor.visit_seq(WithLen {
307			de: self,
308			cur: 0,
309			len,
310		})
311	}
312
313	fn deserialize_tuple_struct<V>(
314		self,
315		_name: &'static str,
316		len: usize,
317		visitor: V,
318	) -> Result<V::Value, Self::Error>
319	where
320		V: de::Visitor<'de>,
321	{
322		let len = self.parse_array_len(len, &visitor)?;
323
324		visitor.visit_seq(WithLen {
325			de: self,
326			cur: 0,
327			len,
328		})
329	}
330
331	fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
332	where
333		V: de::Visitor<'de>,
334	{
335		let len = self.parse_array()?;
336
337		if len < 0 {
338			visitor.visit_none()
339		} else {
340			visitor.visit_map(WithLen {
341				de: self,
342				cur: 0,
343				len: len / 2,
344			})
345		}
346	}
347
348	fn deserialize_struct<V>(
349		self,
350		_name: &'static str,
351		_fields: &'static [&'static str],
352		visitor: V,
353	) -> Result<V::Value, Self::Error>
354	where
355		V: de::Visitor<'de>,
356	{
357		self.deserialize_map(visitor)
358	}
359
360	fn deserialize_enum<V>(
361		self,
362		_name: &'static str,
363		_variants: &'static [&'static str],
364		visitor: V,
365	) -> Result<V::Value, Self::Error>
366	where
367		V: de::Visitor<'de>,
368	{
369		visitor.visit_enum(Enum { de: self })
370	}
371
372	fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
373	where
374		V: de::Visitor<'de>,
375	{
376		self.deserialize_str(visitor)
377	}
378
379	fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
380	where
381		V: de::Visitor<'de>,
382	{
383		self.deserialize_any(visitor)
384	}
385}