mod parse;
mod string_parser;
use crate::{
Config,
config::DuplicateKeyBehavior,
error::{Error, Result},
};
use parse::{Key, ParsedValue};
use serde::de;
use serde::forward_to_deserialize_any;
use string_parser::StringParsingDeserializer;
use std::borrow::Cow;
pub fn from_bytes<'de, T: de::Deserialize<'de>>(input: &'de [u8]) -> Result<T> {
Config::default().deserialize_bytes(input)
}
pub fn from_str<'de, T: de::Deserialize<'de>>(input: &'de str) -> Result<T> {
from_bytes(input.as_bytes())
}
pub struct QsDeserializer<'a> {
value: ParsedValue<'a>,
config: Config,
}
impl<'a> QsDeserializer<'a> {
pub fn new(input: &'a [u8]) -> Result<Self> {
Self::with_config(Default::default(), input)
}
pub fn with_config(config: Config, input: &'a [u8]) -> Result<Self> {
let parsed = parse::parse(input, config)?;
Ok(Self {
value: parsed,
config,
})
}
fn from_value(value: ParsedValue<'a>, config: Config) -> Self {
Self { value, config }
}
}
struct MapDeserializer<'a, 'qs: 'a> {
parsed: &'a mut parse::ParsedMap<'qs>,
field_order: Option<&'static [&'static str]>,
popped_value: Option<ParsedValue<'qs>>,
config: Config,
}
impl<'a, 'de: 'a> de::MapAccess<'de> for MapDeserializer<'a, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: de::DeserializeSeed<'de>,
{
if let Some(field_order) = &mut self.field_order {
while let Some((field, rem)) = field_order.split_first() {
*field_order = rem;
let field_key = (*field).into();
if let Some(value) = crate::map::remove(self.parsed, &field_key) {
self.popped_value = Some(value);
return seed
.deserialize(StringParsingDeserializer::new_str(field))
.map(Some);
}
}
}
if let Some((key, value)) = crate::map::pop_first(self.parsed) {
self.popped_value = Some(value);
let has_bracket = matches!(key, Key::String(ref s) if s.contains(&b'['));
key.deserialize_seed(seed)
.map(Some)
.map_err(|e| {
if has_bracket {
Error::custom(
format!("{e}\nInvalid field contains an encoded bracket -- consider using form encoding mode\n https://docs.rs/serde_qs/latest/serde_qs/#query-string-vs-form-encoding")
, &self.parsed
)
} else {
e
}
})
} else {
Ok(None)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: de::DeserializeSeed<'de>,
{
if let Some(v) = self.popped_value.take() {
seed.deserialize(QsDeserializer::from_value(v, self.config))
} else {
Err(Error::custom(
"Somehow the map was empty after a non-empty key was returned",
&self.parsed,
))
}
}
fn size_hint(&self) -> Option<usize> {
if let Some(field_order) = self.field_order {
Some(field_order.len())
} else {
Some(self.parsed.len())
}
}
}
impl<'a, 'de: 'a> de::EnumAccess<'de> for MapDeserializer<'a, 'de> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(mut self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: de::DeserializeSeed<'de>,
{
if let Some((key, value)) = crate::map::pop_first(self.parsed) {
self.popped_value = Some(value);
Ok((key.deserialize_seed(seed)?, self))
} else {
Err(Error::custom("No more values", &self.parsed))
}
}
}
impl<'a, 'de: 'a> de::VariantAccess<'de> for MapDeserializer<'a, 'de> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
where
T: de::DeserializeSeed<'de>,
{
if let Some(value) = self.popped_value {
seed.deserialize(QsDeserializer::from_value(value, self.config))
} else {
Err(Error::custom("no value to deserialize", &self.parsed))
}
}
fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
if let Some(value) = self.popped_value {
de::Deserializer::deserialize_seq(
QsDeserializer::from_value(value, self.config),
visitor,
)
} else {
Err(Error::custom("no value to deserialize", &self.parsed))
}
}
fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
if let Some(value) = self.popped_value {
de::Deserializer::deserialize_map(
QsDeserializer::from_value(value, self.config),
visitor,
)
} else {
Err(Error::custom("no value to deserialize", &self.parsed))
}
}
}
struct Seq<'a, I: Iterator<Item = ParsedValue<'a>>> {
iter: I,
config: Config,
}
impl<'de, I: Iterator<Item = ParsedValue<'de>>> de::SeqAccess<'de> for Seq<'de, I> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
if let Some(v) = self.iter.next() {
seed.deserialize(QsDeserializer::from_value(v, self.config))
.map(Some)
} else {
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
match self.iter.size_hint() {
(lower, Some(upper)) if lower == upper => Some(upper),
_ => None,
}
}
}
impl<'a, I> OrderedSeq<'a, I>
where
I: Iterator<Item = (Key<'a>, ParsedValue<'a>)>,
{
pub fn new(iter: I, config: Config) -> Self {
Self {
iter,
counter: 0,
config,
}
}
}
struct OrderedSeq<'a, I: Iterator<Item = (Key<'a>, ParsedValue<'a>)>> {
iter: I,
counter: u32,
config: Config,
}
impl<'de, I: Iterator<Item = (Key<'de>, ParsedValue<'de>)>> de::SeqAccess<'de>
for OrderedSeq<'de, I>
{
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
if let Some((k, v)) = self.iter.next() {
match k {
Key::Int(i) if i == self.counter => {
self.counter = self.counter.checked_add(1).ok_or_else(|| {
Error::custom("cannot deserialize more than u32::MAX elements", &(k, &v))
})?;
seed.deserialize(QsDeserializer::from_value(v, self.config))
.map(Some)
}
Key::Int(i) => Err(Error::custom(
format!("missing index, expected: {} got {i}", self.counter),
&(k, v),
)),
Key::String(ref bytes) => {
let key = std::str::from_utf8(bytes).unwrap_or("<non-utf8>");
Err(Error::custom(
format!("expected an integer index, found a string key `{key}`"),
&(k, v),
))
}
}
} else {
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
None
}
}
fn get_last_string_value<'a>(
seq: &mut Vec<ParsedValue<'a>>,
config: Config,
) -> Result<Option<Cow<'a, [u8]>>> {
if config.duplicate_key_behavior == DuplicateKeyBehavior::Error && seq.len() > 1 {
return Err(Error::custom(
"multiple values provided for non-sequence field",
seq,
));
}
Ok(match seq.last() {
None => None,
Some(ParsedValue::NoValue | ParsedValue::Null) => {
Some(Cow::Borrowed(b""))
}
Some(ParsedValue::String(_)) => {
if let Some(ParsedValue::String(s)) = seq.pop() {
Some(s)
} else {
None
}
}
Some(_) => {
None
}
})
}
macro_rules! forward_to_string_parser {
($($ty:ident => $meth:ident,)*) => {
$(
fn $meth<V>(self, visitor: V) -> Result<V::Value> where V: de::Visitor<'de> {
let s = match self.value {
ParsedValue::String(s) => {
s
}
ParsedValue::Sequence(mut seq) => {
match get_last_string_value(&mut seq, self.config) {
Ok(Some(v)) => v,
Ok(None) => {
return Self::from_value(ParsedValue::Sequence(seq), self.config)
.deserialize_any(visitor);
}
Err(e) => return Err(e),
}
}
_ => {
return self.deserialize_any(visitor);
}
};
let deserializer = StringParsingDeserializer::new(s)?;
return deserializer.$meth(visitor);
}
)*
}
}
impl<'de> de::Deserializer<'de> for QsDeserializer<'de> {
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::Map(mut parsed) => {
if parsed.keys().all(|k| matches!(k, Key::Int(_))) {
#[cfg(feature = "indexmap")]
parsed.sort_unstable_keys();
visitor.visit_seq(OrderedSeq::new(parsed.into_iter(), self.config))
} else {
visitor.visit_map(MapDeserializer {
parsed: &mut parsed,
field_order: None,
popped_value: None,
config: self.config,
})
}
}
ParsedValue::Sequence(seq) => visitor.visit_seq(Seq {
iter: seq.into_iter(),
config: self.config,
}),
ParsedValue::String(x) => StringParsingDeserializer::new(x)?.deserialize_any(visitor),
ParsedValue::Uninitialized => Err(Error::custom(
"internal error: attempted to deserialize unitialised \
value",
&self.value,
)),
ParsedValue::Null => {
StringParsingDeserializer::new(Cow::Borrowed(b""))?.deserialize_any(visitor)
}
ParsedValue::NoValue => visitor.visit_unit(),
}
}
fn deserialize_seq<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::Null | ParsedValue::NoValue => visitor.visit_seq(Seq {
iter: std::iter::empty(),
config: self.config,
}),
ParsedValue::String(s) => visitor.visit_seq(Seq {
iter: std::iter::once(ParsedValue::String(s)),
config: self.config,
}),
_ => self.deserialize_any(visitor),
}
}
fn deserialize_tuple<V>(
self,
_len: usize,
visitor: V,
) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
let mut map = match self.value {
ParsedValue::Map(map) => map,
ParsedValue::Null | ParsedValue::NoValue => {
parse::ParsedMap::default()
}
_ => return self.deserialize_any(visitor),
};
visitor.visit_map(MapDeserializer {
parsed: &mut map,
field_order: Some(fields),
popped_value: None,
config: self.config,
})
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::NoValue => visitor.visit_none(),
ParsedValue::Null => visitor.visit_some(QsDeserializer::from_value(
ParsedValue::NoValue,
self.config,
)),
_ => visitor.visit_some(self),
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
if matches!(self.value, ParsedValue::Null)
|| matches!(self.value, ParsedValue::String(ref s) if s.is_empty())
{
visitor.visit_unit()
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_unit_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::Map(mut parsed) => visitor.visit_enum(MapDeserializer {
parsed: &mut parsed,
field_order: None,
popped_value: None,
config: self.config,
}),
ParsedValue::String(s) => visitor.visit_enum(StringParsingDeserializer::new(s)?),
_ => self.deserialize_any(visitor),
}
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::Map(mut parsed) => visitor.visit_map(MapDeserializer {
parsed: &mut parsed,
field_order: None,
popped_value: None,
config: self.config,
}),
ParsedValue::Null | ParsedValue::NoValue => {
let mut empty_map = parse::ParsedMap::default();
visitor.visit_map(MapDeserializer {
parsed: &mut empty_map,
field_order: None,
popped_value: None,
config: self.config,
})
}
ParsedValue::String(s) => {
let mut parsed = parse::ParsedMap::default();
parsed.insert(Key::String(Cow::Borrowed(b"")), ParsedValue::String(s));
visitor.visit_map(MapDeserializer {
parsed: &mut parsed,
field_order: None,
popped_value: None,
config: self.config,
})
}
_ => self.deserialize_any(visitor),
}
}
fn deserialize_str<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
let s = match self.value {
ParsedValue::String(s) => s,
ParsedValue::Sequence(mut seq) => match get_last_string_value(&mut seq, self.config) {
Ok(Some(v)) => v,
Ok(None) => {
return Self::from_value(ParsedValue::Sequence(seq), self.config)
.deserialize_any(visitor);
}
Err(e) => return Err(e),
},
ParsedValue::Null | ParsedValue::NoValue => {
return visitor.visit_str("");
}
_ => return self.deserialize_any(visitor),
};
match string_parser::decode_utf8(s)? {
Cow::Borrowed(string) => visitor.visit_borrowed_str(string),
Cow::Owned(string) => visitor.visit_string(string),
}
}
fn deserialize_string<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
let s = match self.value {
ParsedValue::String(s) => s,
ParsedValue::Sequence(mut seq) => match get_last_string_value(&mut seq, self.config) {
Ok(Some(v)) => v,
Ok(None) => {
return Self::from_value(ParsedValue::Sequence(seq), self.config)
.deserialize_any(visitor);
}
Err(e) => return Err(e),
},
ParsedValue::Null | ParsedValue::NoValue => {
return visitor.visit_bytes(&[]);
}
_ => return self.deserialize_any(visitor),
};
match s {
Cow::Borrowed(s) => visitor.visit_borrowed_bytes(s),
Cow::Owned(s) => visitor.visit_byte_buf(s),
}
}
fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::String(cow) => match cow {
Cow::Borrowed(s) => visitor.visit_borrowed_bytes(s),
Cow::Owned(s) => visitor.visit_byte_buf(s),
},
_ => self.deserialize_any(visitor),
}
}
fn deserialize_bool<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
ParsedValue::String(s) => {
let deserializer = StringParsingDeserializer::new(s)?;
deserializer.deserialize_bool(visitor)
}
ParsedValue::Sequence(mut seq) => {
match get_last_string_value(&mut seq, self.config) {
Ok(Some(last_value)) => {
StringParsingDeserializer::new(last_value)?.deserialize_bool(visitor)
}
Ok(None) => {
Self::from_value(ParsedValue::Sequence(seq), self.config)
.deserialize_any(visitor)
}
Err(e) => Err(e),
}
}
ParsedValue::Null | ParsedValue::NoValue => visitor.visit_bool(true),
_ => self.deserialize_any(visitor),
}
}
forward_to_deserialize_any! {
char
identifier
}
forward_to_string_parser! {
u8 => deserialize_u8,
u16 => deserialize_u16,
u32 => deserialize_u32,
u64 => deserialize_u64,
i8 => deserialize_i8,
i16 => deserialize_i16,
i32 => deserialize_i32,
i64 => deserialize_i64,
f32 => deserialize_f32,
f64 => deserialize_f64,
}
}