Skip to main content

apollo_router/plugin/
serde.rs

1//! serde support for commonly used data structures.
2
3use std::fmt::Formatter;
4use std::str::FromStr;
5
6use access_json::JSONQuery;
7use http::HeaderValue;
8use http::header::HeaderName;
9use regex::Regex;
10use serde::Deserializer;
11use serde::de;
12use serde::de::Error;
13use serde::de::SeqAccess;
14use serde::de::Visitor;
15
16/// De-serialize an optional [`HeaderName`].
17pub fn deserialize_option_header_name<'de, D>(
18    deserializer: D,
19) -> Result<Option<HeaderName>, D::Error>
20where
21    D: Deserializer<'de>,
22{
23    struct OptionHeaderNameVisitor;
24
25    impl<'de> Visitor<'de> for OptionHeaderNameVisitor {
26        type Value = Option<HeaderName>;
27
28        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
29            formatter.write_str("struct HeaderName")
30        }
31
32        fn visit_none<E>(self) -> Result<Self::Value, E>
33        where
34            E: de::Error,
35        {
36            Ok(None)
37        }
38
39        fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
40        where
41            D: de::Deserializer<'de>,
42        {
43            Ok(Some(deserializer.deserialize_str(HeaderNameVisitor)?))
44        }
45    }
46    deserializer.deserialize_option(OptionHeaderNameVisitor)
47}
48
49/// De-serialize a vector of [`HeaderName`].
50pub fn deserialize_vec_header_name<'de, D>(deserializer: D) -> Result<Vec<HeaderName>, D::Error>
51where
52    D: Deserializer<'de>,
53{
54    struct VecHeaderNameVisitor;
55
56    impl<'de> Visitor<'de> for VecHeaderNameVisitor {
57        type Value = Vec<HeaderName>;
58
59        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
60            formatter.write_str("struct HeaderName")
61        }
62
63        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
64        where
65            A: SeqAccess<'de>,
66        {
67            let mut result = Vec::new();
68            while let Some(element) = seq.next_element::<String>()? {
69                let header_name = HeaderNameVisitor.visit_string(element)?;
70                result.push(header_name);
71            }
72            Ok(result)
73        }
74    }
75    deserializer.deserialize_seq(VecHeaderNameVisitor)
76}
77
78/// De-serialize an optional [`HeaderValue`].
79pub fn deserialize_option_header_value<'de, D>(
80    deserializer: D,
81) -> Result<Option<HeaderValue>, D::Error>
82where
83    D: Deserializer<'de>,
84{
85    struct OptionHeaderValueVisitor;
86
87    impl<'de> Visitor<'de> for OptionHeaderValueVisitor {
88        type Value = Option<HeaderValue>;
89
90        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
91            formatter.write_str("struct HeaderValue")
92        }
93
94        fn visit_none<E>(self) -> Result<Self::Value, E>
95        where
96            E: de::Error,
97        {
98            Ok(None)
99        }
100
101        fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
102        where
103            D: de::Deserializer<'de>,
104        {
105            Ok(Some(deserializer.deserialize_str(HeaderValueVisitor)?))
106        }
107    }
108
109    deserializer.deserialize_option(OptionHeaderValueVisitor)
110}
111
112#[derive(Default)]
113struct HeaderNameVisitor;
114
115impl Visitor<'_> for HeaderNameVisitor {
116    type Value = HeaderName;
117
118    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
119        formatter.write_str("struct HeaderName")
120    }
121
122    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
123    where
124        E: Error,
125    {
126        HeaderName::try_from(v).map_err(|e| de::Error::custom(format!("Invalid header name {e}")))
127    }
128}
129
130/// De-serialize a [`HeaderName`].
131pub fn deserialize_header_name<'de, D>(deserializer: D) -> Result<HeaderName, D::Error>
132where
133    D: Deserializer<'de>,
134{
135    deserializer.deserialize_str(HeaderNameVisitor)
136}
137
138struct JSONQueryVisitor;
139
140impl Visitor<'_> for JSONQueryVisitor {
141    type Value = JSONQuery;
142
143    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
144        formatter.write_str("struct JSONQuery")
145    }
146
147    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
148    where
149        E: Error,
150    {
151        JSONQuery::parse(v)
152            .map_err(|e| de::Error::custom(format!("Invalid JSON query path for '{v}' {e}")))
153    }
154}
155
156/// De-serialize a [`JSONQuery`].
157pub fn deserialize_json_query<'de, D>(deserializer: D) -> Result<JSONQuery, D::Error>
158where
159    D: Deserializer<'de>,
160{
161    deserializer.deserialize_str(JSONQueryVisitor)
162}
163
164struct HeaderValueVisitor;
165
166impl Visitor<'_> for HeaderValueVisitor {
167    type Value = HeaderValue;
168
169    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
170        formatter.write_str("struct HeaderValue")
171    }
172
173    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
174    where
175        E: Error,
176    {
177        HeaderValue::try_from(v).map_err(|e| de::Error::custom(format!("Invalid header value {e}")))
178    }
179}
180
181/// De-serialize a [`HeaderValue`].
182pub fn deserialize_header_value<'de, D>(deserializer: D) -> Result<HeaderValue, D::Error>
183where
184    D: Deserializer<'de>,
185{
186    deserializer.deserialize_str(HeaderValueVisitor)
187}
188
189/// De-serialize a [`Regex`].
190pub fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
191where
192    D: Deserializer<'de>,
193{
194    struct RegexVisitor;
195
196    impl Visitor<'_> for RegexVisitor {
197        type Value = Regex;
198
199        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
200            formatter.write_str("struct Regex")
201        }
202
203        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
204        where
205            E: Error,
206        {
207            Regex::from_str(v).map_err(|e| de::Error::custom(format!("{e}")))
208        }
209    }
210    deserializer.deserialize_str(RegexVisitor)
211}
212
213pub(crate) fn deserialize_jsonpath<'de, D>(
214    deserializer: D,
215) -> Result<serde_json_bytes::path::JsonPathInst, D::Error>
216where
217    D: serde::Deserializer<'de>,
218{
219    deserializer.deserialize_str(JSONPathVisitor)
220}
221
222struct JSONPathVisitor;
223
224impl serde::de::Visitor<'_> for JSONPathVisitor {
225    type Value = serde_json_bytes::path::JsonPathInst;
226
227    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
228        write!(formatter, "a JSON path")
229    }
230
231    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
232    where
233        E: serde::de::Error,
234    {
235        serde_json_bytes::path::JsonPathInst::from_str(s).map_err(serde::de::Error::custom)
236    }
237}