ntex_http/
serde.rs

1use std::fmt;
2
3use serde::de::{self, Deserialize, Deserializer, MapAccess, Unexpected, Visitor};
4use serde::ser::{self, Serialize, SerializeMap, Serializer};
5
6use super::{HeaderMap, HeaderName, HeaderValue, Value};
7
8impl Serialize for HeaderMap {
9    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10    where
11        S: Serializer,
12    {
13        let mut map = serializer.serialize_map(Some(self.len()))?;
14        for (name, value) in &self.inner {
15            map.serialize_entry(name.as_str(), value)?;
16        }
17        map.end()
18    }
19}
20
21impl<'de> Deserialize<'de> for HeaderMap {
22    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
23    where
24        D: Deserializer<'de>,
25    {
26        deserializer.deserialize_map(HeaderMapVisitor)
27    }
28}
29
30struct HeaderMapVisitor;
31
32impl<'de> Visitor<'de> for HeaderMapVisitor {
33    type Value = HeaderMap;
34
35    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
36        formatter.write_str("a header map")
37    }
38
39    fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
40    where
41        M: MapAccess<'de>,
42    {
43        let mut headers = HeaderMap::with_capacity(map.size_hint().unwrap_or(0));
44        while let Some((key, value)) = map.next_entry::<&str, Value>()? {
45            let name = HeaderName::from_bytes(key.as_bytes()).map_err(|_| {
46                de::Error::invalid_value(Unexpected::Str(key), &"a valid header name")
47            })?;
48            headers.inner.insert(name, value);
49        }
50        Ok(headers)
51    }
52}
53
54impl Serialize for Value {
55    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
56    where
57        S: Serializer,
58    {
59        match self {
60            Value::One(val) if serializer.is_human_readable() => val.serialize(serializer),
61            // For non-human-readable formats, always serialize as a sequence
62            Value::One(val) => [val].as_slice().serialize(serializer),
63            Value::Multi(vec) => vec.serialize(serializer),
64        }
65    }
66}
67
68impl<'de> Deserialize<'de> for Value {
69    fn deserialize<D>(deserializer: D) -> Result<Value, D::Error>
70    where
71        D: Deserializer<'de>,
72    {
73        if deserializer.is_human_readable() {
74            return deserializer.deserialize_any(ValueVisitor);
75        }
76        deserializer.deserialize_seq(ValueVisitor)
77    }
78}
79
80struct ValueVisitor;
81
82impl<'de> Visitor<'de> for ValueVisitor {
83    type Value = Value;
84
85    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
86        formatter.write_str("a single header value or sequence of values")
87    }
88
89    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
90    where
91        E: de::Error,
92    {
93        Ok(Value::One(HeaderValueVisitor.visit_str(v)?))
94    }
95
96    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
97    where
98        E: de::Error,
99    {
100        Ok(Value::One(HeaderValueVisitor.visit_string(v)?))
101    }
102
103    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
104    where
105        E: de::Error,
106    {
107        Ok(Value::One(HeaderValueVisitor.visit_bytes(v)?))
108    }
109
110    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
111    where
112        E: de::Error,
113    {
114        Ok(Value::One(HeaderValueVisitor.visit_byte_buf(v)?))
115    }
116
117    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
118    where
119        A: de::SeqAccess<'de>,
120    {
121        let mut value: Option<Value> = None;
122        while let Some(next_val) = seq.next_element()? {
123            match value.as_mut() {
124                Some(value) => value.append(next_val),
125                None => value = Some(Value::One(next_val)),
126            }
127        }
128        value.ok_or_else(|| de::Error::invalid_length(0, &"non-empty value"))
129    }
130}
131
132impl Serialize for HeaderValue {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: Serializer,
136    {
137        if serializer.is_human_readable() {
138            return serializer.serialize_str(
139                self.to_str()
140                    .map_err(|err| ser::Error::custom(err.to_string()))?,
141            );
142        }
143        serializer.serialize_bytes(self.as_bytes())
144    }
145}
146
147impl<'de> Deserialize<'de> for HeaderValue {
148    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
149    where
150        D: Deserializer<'de>,
151    {
152        if deserializer.is_human_readable() {
153            return deserializer.deserialize_string(HeaderValueVisitor);
154        }
155        deserializer.deserialize_byte_buf(HeaderValueVisitor)
156    }
157}
158
159struct HeaderValueVisitor;
160
161impl Visitor<'_> for HeaderValueVisitor {
162    type Value = HeaderValue;
163
164    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
165        formatter.write_str("a header value")
166    }
167
168    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
169    where
170        E: de::Error,
171    {
172        HeaderValue::from_str(v).map_err(|_| {
173            de::Error::invalid_value(Unexpected::Str(v), &"a valid header value")
174        })
175    }
176
177    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
178    where
179        E: de::Error,
180    {
181        HeaderValue::from_shared(v).map_err(|err| de::Error::custom(err.to_string()))
182    }
183
184    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
185    where
186        E: de::Error,
187    {
188        HeaderValue::from_bytes(v).map_err(|_| {
189            de::Error::invalid_value(Unexpected::Bytes(v), &"a valid header value")
190        })
191    }
192
193    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
194    where
195        E: de::Error,
196    {
197        HeaderValue::from_shared(v).map_err(|err| de::Error::custom(err.to_string()))
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::header::*;
205
206    #[test]
207    fn test_serde_json() {
208        let mut map = HeaderMap::new();
209        map.insert(USER_AGENT, HeaderValue::from_static("hello"));
210        map.append(USER_AGENT, HeaderValue::from_static("world"));
211        assert_eq!(
212            serde_json::to_string(&map).unwrap(),
213            r#"{"user-agent":["hello","world"]}"#
214        );
215
216        // Make roundtrip
217        map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
218        map.insert(CONTENT_LENGTH, 0.into());
219        let map_json = serde_json::to_string(&map).unwrap();
220        let map2 = serde_json::from_str::<HeaderMap>(&map_json).unwrap();
221        assert_eq!(map, map2);
222
223        // Try mixed case header names
224        let map_uc = serde_json::from_str::<HeaderMap>(r#"{"X-Foo":"BAR"}"#).unwrap();
225        assert_eq!(map_uc.get("x-foo").unwrap(), "BAR");
226        assert_eq!(
227            serde_json::to_string(&map_uc).unwrap(),
228            r#"{"x-foo":"BAR"}"#
229        );
230
231        // Try decode empty header value
232        let map_empty = serde_json::from_str::<HeaderMap>(r#"{"user-agent":[]}"#);
233        assert!(map_empty.is_err());
234        assert!(map_empty
235            .unwrap_err()
236            .to_string()
237            .contains("invalid length 0, expected non-empty value"));
238    }
239
240    #[test]
241    fn test_serde_bincode() {
242        let mut map = HeaderMap::new();
243        map.insert(USER_AGENT, HeaderValue::from_static("hello"));
244        map.append(USER_AGENT, HeaderValue::from_static("world"));
245        map.insert(HeaderName::from_static("x-foo"), "bar".parse().unwrap());
246        let map_bin = bincode::serialize(&map).unwrap();
247        let map2 = bincode::deserialize::<HeaderMap>(&map_bin).unwrap();
248        assert_eq!(map, map2);
249    }
250}