Skip to main content

scraper_trail/
multi_value.rs

1use std::{borrow::Cow, marker::PhantomData};
2
3#[derive(Copy, Clone, Debug, thiserror::Error)]
4pub enum Error {
5    #[error("Empty values")]
6    Empty,
7}
8
9/// A set of values for a response header.
10///
11/// Typically each header name will map to a single value, but the same name may appear more than
12/// once, so we wish to handle that case, while still making it convenient to work with singleton
13/// values.
14#[derive(Clone, Debug, Eq, PartialEq, bounded_static_derive_more::ToStatic)]
15pub struct MultiValue<'a> {
16    pub first: Cow<'a, str>,
17    rest: Option<Vec<Cow<'a, str>>>,
18}
19
20impl<'a> MultiValue<'a> {
21    pub fn new<S: Into<Cow<'a, str>>>(value: S) -> Self {
22        Self {
23            first: value.into(),
24            rest: None,
25        }
26    }
27
28    pub fn push<S: Into<Cow<'a, str>>>(&mut self, value: S) {
29        match &mut self.rest {
30            None => {
31                self.rest = Some(vec![value.into()]);
32            }
33            Some(rest) => {
34                rest.push(value.into());
35            }
36        }
37    }
38
39    #[must_use]
40    pub fn iter(&'a self) -> Iter<'a> {
41        Iter {
42            first: Some(Cow::Borrowed(self.first.as_ref())),
43            rest: self.rest.as_ref().map(|rest| rest.iter()),
44        }
45    }
46}
47
48impl<'a> AsRef<Cow<'a, str>> for MultiValue<'a> {
49    fn as_ref(&self) -> &Cow<'a, str> {
50        &self.first
51    }
52}
53
54impl<'a, S: Into<Cow<'a, str>>> TryFrom<Vec<S>> for MultiValue<'a> {
55    type Error = Error;
56
57    fn try_from(value: Vec<S>) -> Result<Self, Self::Error> {
58        let mut values = value.into_iter();
59
60        let first = values.next().ok_or(Error::Empty)?.into();
61        let rest: Vec<Cow<'a, str>> = values.map(Into::into).collect();
62
63        Ok(Self {
64            first,
65            rest: if rest.is_empty() { None } else { Some(rest) },
66        })
67    }
68}
69
70impl<'a> IntoIterator for &'a MultiValue<'a> {
71    type Item = std::borrow::Cow<'a, str>;
72    type IntoIter = Iter<'a>;
73
74    fn into_iter(self) -> Self::IntoIter {
75        self.iter()
76    }
77}
78
79pub struct Iter<'a> {
80    first: Option<Cow<'a, str>>,
81    rest: Option<std::slice::Iter<'a, Cow<'a, str>>>,
82}
83
84impl<'a> Iterator for Iter<'a> {
85    type Item = Cow<'a, str>;
86
87    fn next(&mut self) -> Option<Self::Item> {
88        if let Some(first) = self.first.take() {
89            Some(first)
90        } else {
91            self.rest.as_mut()?.next().cloned()
92        }
93    }
94
95    fn size_hint(&self) -> (usize, Option<usize>) {
96        let len = usize::from(self.first.is_some())
97            + self.rest.as_ref().map_or(0, ExactSizeIterator::len);
98
99        (len, Some(len))
100    }
101}
102
103impl ExactSizeIterator for Iter<'_> {}
104
105impl<'a, 'de: 'a> serde::de::Deserialize<'de> for MultiValue<'a> {
106    fn deserialize<D: serde::de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
107        const EXPECTED: &str = "one or more header values";
108
109        struct MultiValueVisitor<'a> {
110            _lifetime: PhantomData<&'a ()>,
111        }
112
113        impl<'a, 'de: 'a> serde::de::Visitor<'de> for MultiValueVisitor<'a> {
114            type Value = MultiValue<'a>;
115
116            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117                formatter.write_str(EXPECTED)
118            }
119
120            fn visit_borrowed_str<E: serde::de::Error>(
121                self,
122                v: &'de str,
123            ) -> Result<Self::Value, E> {
124                Ok(MultiValue::new(v))
125            }
126
127            fn visit_seq<A: serde::de::SeqAccess<'de>>(
128                self,
129                mut seq: A,
130            ) -> Result<Self::Value, A::Error> {
131                let mut result: Option<MultiValue<'a>> = None;
132
133                while let Some(value) = seq.next_element::<Cow<'a, str>>()? {
134                    match result {
135                        Some(ref mut multi_value) => {
136                            multi_value.push(value);
137                        }
138                        None => {
139                            result = Some(MultiValue::new(value));
140                        }
141                    }
142                }
143
144                result.ok_or_else(|| {
145                    serde::de::Error::invalid_value(serde::de::Unexpected::Seq, &EXPECTED)
146                })
147            }
148
149            fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
150                Ok(MultiValue::new(v.to_string()))
151            }
152        }
153
154        deserializer.deserialize_any(MultiValueVisitor {
155            _lifetime: PhantomData,
156        })
157    }
158}
159
160impl serde::ser::Serialize for MultiValue<'_> {
161    fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
162        use serde::ser::SerializeSeq;
163
164        match self.rest.as_ref() {
165            Some(rest) => {
166                let mut seq = serializer.serialize_seq(Some(rest.len() + 1))?;
167                seq.serialize_element(&self.first)?;
168
169                for element in rest {
170                    seq.serialize_element(&element)?;
171                }
172
173                seq.end()
174            }
175            None => self.first.serialize(serializer),
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use crate::multi_value::MultiValue;
183
184    #[derive(Debug, Eq, PartialEq, serde::Deserialize)]
185    struct Test<'a> {
186        #[serde(borrow)]
187        header_values: super::MultiValue<'a>,
188    }
189
190    #[test]
191    fn deserialize_multi_value() -> Result<(), Box<dyn std::error::Error>> {
192        let singleton_example = r#"{ "header_values": "test" }"#;
193        let multi_example = r#"{ "header_values": ["foo", "bar", "baz"] }"#;
194
195        let singleton_example_parsed = serde_json::from_str::<Test<'_>>(singleton_example)?;
196        let multi_example_parsed = serde_json::from_str::<Test<'_>>(multi_example)?;
197
198        let singleton_example_expected = Test {
199            header_values: MultiValue::new("test"),
200        };
201
202        let multi_example_expected = Test {
203            header_values: vec!["foo", "bar", "baz"].try_into()?,
204        };
205
206        assert_eq!(singleton_example_parsed, singleton_example_expected);
207        assert_eq!(multi_example_parsed, multi_example_expected);
208        Ok(())
209    }
210
211    #[test]
212    fn iter() -> Result<(), Box<dyn std::error::Error>> {
213        let singleton_example = MultiValue::new("test");
214
215        let multi_example: MultiValue<'_> = vec!["foo", "bar", "baz"].try_into()?;
216
217        assert_eq!(singleton_example.iter().collect::<Vec<_>>(), vec!["test"]);
218        assert_eq!(
219            multi_example.iter().collect::<Vec<_>>(),
220            vec!["foo", "bar", "baz"]
221        );
222        Ok(())
223    }
224}