Skip to main content

apollo_router/plugin/
serde.rs

1//! serde support for commonly used data structures.
2
3use std::fmt::Formatter;
4use std::str::FromStr;
5
6use http::HeaderValue;
7use http::header::HeaderName;
8use regex::Regex;
9use serde::Deserializer;
10use serde::de;
11use serde::de::Error;
12use serde::de::SeqAccess;
13use serde::de::Visitor;
14
15/// De-serialize an optional [`HeaderName`].
16pub fn deserialize_option_header_name<'de, D>(
17    deserializer: D,
18) -> Result<Option<HeaderName>, D::Error>
19where
20    D: Deserializer<'de>,
21{
22    struct OptionHeaderNameVisitor;
23
24    impl<'de> Visitor<'de> for OptionHeaderNameVisitor {
25        type Value = Option<HeaderName>;
26
27        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
28            formatter.write_str("struct HeaderName")
29        }
30
31        fn visit_none<E>(self) -> Result<Self::Value, E>
32        where
33            E: de::Error,
34        {
35            Ok(None)
36        }
37
38        fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
39        where
40            D: de::Deserializer<'de>,
41        {
42            Ok(Some(deserializer.deserialize_str(HeaderNameVisitor)?))
43        }
44    }
45    deserializer.deserialize_option(OptionHeaderNameVisitor)
46}
47
48/// De-serialize a vector of [`HeaderName`].
49pub fn deserialize_vec_header_name<'de, D>(deserializer: D) -> Result<Vec<HeaderName>, D::Error>
50where
51    D: Deserializer<'de>,
52{
53    struct VecHeaderNameVisitor;
54
55    impl<'de> Visitor<'de> for VecHeaderNameVisitor {
56        type Value = Vec<HeaderName>;
57
58        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
59            formatter.write_str("struct HeaderName")
60        }
61
62        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
63        where
64            A: SeqAccess<'de>,
65        {
66            let mut result = Vec::new();
67            while let Some(element) = seq.next_element::<String>()? {
68                let header_name = HeaderNameVisitor.visit_string(element)?;
69                result.push(header_name);
70            }
71            Ok(result)
72        }
73    }
74    deserializer.deserialize_seq(VecHeaderNameVisitor)
75}
76
77/// De-serialize an optional [`HeaderValue`].
78pub fn deserialize_option_header_value<'de, D>(
79    deserializer: D,
80) -> Result<Option<HeaderValue>, D::Error>
81where
82    D: Deserializer<'de>,
83{
84    struct OptionHeaderValueVisitor;
85
86    impl<'de> Visitor<'de> for OptionHeaderValueVisitor {
87        type Value = Option<HeaderValue>;
88
89        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
90            formatter.write_str("struct HeaderValue")
91        }
92
93        fn visit_none<E>(self) -> Result<Self::Value, E>
94        where
95            E: de::Error,
96        {
97            Ok(None)
98        }
99
100        fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
101        where
102            D: de::Deserializer<'de>,
103        {
104            Ok(Some(deserializer.deserialize_str(HeaderValueVisitor)?))
105        }
106    }
107
108    deserializer.deserialize_option(OptionHeaderValueVisitor)
109}
110
111#[derive(Default)]
112struct HeaderNameVisitor;
113
114impl Visitor<'_> for HeaderNameVisitor {
115    type Value = HeaderName;
116
117    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
118        formatter.write_str("struct HeaderName")
119    }
120
121    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
122    where
123        E: Error,
124    {
125        HeaderName::try_from(v).map_err(|e| de::Error::custom(format!("Invalid header name {e}")))
126    }
127}
128
129/// De-serialize a [`HeaderName`].
130pub fn deserialize_header_name<'de, D>(deserializer: D) -> Result<HeaderName, D::Error>
131where
132    D: Deserializer<'de>,
133{
134    deserializer.deserialize_str(HeaderNameVisitor)
135}
136
137struct HeaderValueVisitor;
138
139impl Visitor<'_> for HeaderValueVisitor {
140    type Value = HeaderValue;
141
142    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
143        formatter.write_str("struct HeaderValue")
144    }
145
146    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
147    where
148        E: Error,
149    {
150        HeaderValue::try_from(v).map_err(|e| de::Error::custom(format!("Invalid header value {e}")))
151    }
152}
153
154/// De-serialize a [`HeaderValue`].
155pub fn deserialize_header_value<'de, D>(deserializer: D) -> Result<HeaderValue, D::Error>
156where
157    D: Deserializer<'de>,
158{
159    deserializer.deserialize_str(HeaderValueVisitor)
160}
161
162/// De-serialize a [`Regex`].
163pub fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
164where
165    D: Deserializer<'de>,
166{
167    struct RegexVisitor;
168
169    impl Visitor<'_> for RegexVisitor {
170        type Value = Regex;
171
172        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
173            formatter.write_str("struct Regex")
174        }
175
176        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
177        where
178            E: Error,
179        {
180            Regex::from_str(v).map_err(|e| de::Error::custom(format!("{e}")))
181        }
182    }
183    deserializer.deserialize_str(RegexVisitor)
184}
185
186pub(crate) fn deserialize_jsonpath<'de, D>(
187    deserializer: D,
188) -> Result<serde_json_bytes::path::JsonPathInst, D::Error>
189where
190    D: serde::Deserializer<'de>,
191{
192    deserializer.deserialize_str(JSONPathVisitor)
193}
194
195struct JSONPathVisitor;
196
197impl serde::de::Visitor<'_> for JSONPathVisitor {
198    type Value = serde_json_bytes::path::JsonPathInst;
199
200    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
201        write!(formatter, "a JSON path")
202    }
203
204    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
205    where
206        E: serde::de::Error,
207    {
208        serde_json_bytes::path::JsonPathInst::from_str(s).map_err(serde::de::Error::custom)
209    }
210}