watermelon_proto/headers/
map.rs1use alloc::{
2 collections::{BTreeMap, btree_map::Entry},
3 vec,
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Debug},
8 mem,
9};
10
11use super::{HeaderName, HeaderValue};
12
13static EMPTY_HEADERS: OneOrMany = OneOrMany::Many(Vec::new());
14
15#[derive(Clone, PartialEq, Eq)]
19pub struct HeaderMap {
20 headers: BTreeMap<HeaderName, OneOrMany>,
21 len: usize,
22}
23
24#[derive(Clone, PartialEq, Eq)]
25enum OneOrMany {
26 One(HeaderValue),
27 Many(Vec<HeaderValue>),
28}
29
30impl HeaderMap {
31 #[must_use]
38 pub const fn new() -> Self {
39 Self {
40 headers: BTreeMap::new(),
41 len: 0,
42 }
43 }
44
45 pub fn get(&self, name: &HeaderName) -> Option<&HeaderValue> {
46 self.get_all(name).next()
47 }
48
49 pub fn get_all<'a>(
50 &'a self,
51 name: &HeaderName,
52 ) -> impl DoubleEndedIterator<Item = &'a HeaderValue> + use<'a> {
53 self.headers.get(name).unwrap_or(&EMPTY_HEADERS).iter()
54 }
55
56 pub fn insert(&mut self, name: HeaderName, value: HeaderValue) {
57 if let Some(prev) = self.headers.insert(name, OneOrMany::One(value)) {
58 self.len -= prev.len();
59 }
60 self.len += 1;
61 }
62
63 pub fn append(&mut self, name: HeaderName, value: HeaderValue) {
64 match self.headers.entry(name) {
65 Entry::Vacant(vacant) => {
66 vacant.insert(OneOrMany::One(value));
67 }
68 Entry::Occupied(mut occupied) => {
69 occupied.get_mut().push(value);
70 }
71 }
72 self.len += 1;
73 }
74
75 pub fn remove(&mut self, name: &HeaderName) {
76 if let Some(prev) = self.headers.remove(name) {
77 self.len -= prev.len();
78 }
79 }
80
81 #[must_use]
85 pub fn keys_len(&self) -> usize {
86 self.headers.len()
87 }
88
89 #[must_use]
94 pub fn len(&self) -> usize {
95 self.len
96 }
97
98 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.headers.is_empty()
102 }
103
104 pub fn clear(&mut self) {
106 self.headers.clear();
107 self.len = 0;
108 }
109
110 #[cfg(test)]
111 fn keys(&self) -> impl Iterator<Item = &'_ HeaderName> {
112 self.headers.keys()
113 }
114
115 pub(crate) fn iter(
116 &self,
117 ) -> impl DoubleEndedIterator<Item = (&'_ HeaderName, impl Iterator<Item = &'_ HeaderValue>)>
118 {
119 self.headers
120 .iter()
121 .map(|(name, value)| (name, value.iter()))
122 }
123}
124
125impl Debug for HeaderMap {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_tuple("HeaderMap")
128 .field(&self.headers)
129 .finish()
131 }
132}
133
134impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
135 fn from_iter<I: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: I) -> Self {
136 let mut this = Self::new();
137 this.extend(iter);
138 this
139 }
140}
141
142impl Extend<(HeaderName, HeaderValue)> for HeaderMap {
143 fn extend<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(&mut self, iter: T) {
144 iter.into_iter().for_each(|(name, value)| {
145 self.append(name, value);
146 });
147 }
148}
149
150impl Default for HeaderMap {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156impl Debug for OneOrMany {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 f.debug_set().entries(self.iter()).finish()
159 }
160}
161
162impl OneOrMany {
163 fn len(&self) -> usize {
164 match self {
165 Self::One(_) => 1,
166 Self::Many(vec) => vec.len(),
167 }
168 }
169
170 fn push(&mut self, item: HeaderValue) {
171 match self {
172 Self::One(current_item) => {
173 let current_item =
174 mem::replace(current_item, HeaderValue::from_static("replacing"));
175 *self = Self::Many(vec![current_item, item]);
176 }
177 Self::Many(vec) => {
178 debug_assert!(!vec.is_empty(), "OneOrMany can't be empty");
179 vec.push(item);
180 }
181 }
182 }
183
184 fn iter(&self) -> impl DoubleEndedIterator<Item = &'_ HeaderValue> {
185 match self {
188 Self::One(one) => Iterator::chain(Some(one).into_iter(), &[]),
189 Self::Many(many) => Iterator::chain(None.into_iter(), many),
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use alloc::{vec, vec::Vec};
197
198 use crate::headers::{HeaderName, HeaderValue};
199
200 use super::HeaderMap;
201
202 #[test]
203 fn manual() {
204 let mut headers = HeaderMap::new();
205 headers.append(
206 HeaderName::from_static("Nats-Message-Id"),
207 HeaderValue::from_static("abcd"),
208 );
209 headers.append(
210 HeaderName::from_static("Nats-Sequence"),
211 HeaderValue::from_static("1"),
212 );
213 headers.append(
214 HeaderName::from_static("Nats-Message-Id"),
215 HeaderValue::from_static("1234"),
216 );
217 headers.append(
218 HeaderName::from_static("Nats-Time-Stamp"),
219 HeaderValue::from_static("0"),
220 );
221 headers.remove(&HeaderName::from_static("Nats-Time-Stamp"));
222
223 verify_header_map(&headers);
224 }
225
226 #[test]
227 fn collect() {
228 let headers = [
229 (
230 HeaderName::from_static("Nats-Message-Id"),
231 HeaderValue::from_static("abcd"),
232 ),
233 (
234 HeaderName::from_static("Nats-Sequence"),
235 HeaderValue::from_static("1"),
236 ),
237 (
238 HeaderName::from_static("Nats-Message-Id"),
239 HeaderValue::from_static("1234"),
240 ),
241 ]
242 .into_iter()
243 .collect::<HeaderMap>();
244
245 verify_header_map(&headers);
246 }
247
248 fn verify_header_map(headers: &HeaderMap) {
249 assert_eq!(
250 [
251 HeaderName::from_static("Nats-Message-Id"),
252 HeaderName::from_static("Nats-Sequence")
253 ]
254 .as_slice(),
255 headers.keys().cloned().collect::<Vec<_>>().as_slice()
256 );
257
258 let raw_headers = headers
259 .iter()
260 .map(|(name, values)| (name.clone(), values.cloned().collect::<Vec<_>>()))
261 .collect::<Vec<_>>();
262 assert_eq!(
263 [
264 (
265 HeaderName::from_static("Nats-Message-Id"),
266 vec![
267 HeaderValue::from_static("abcd"),
268 HeaderValue::from_static("1234")
269 ]
270 ),
271 (
272 HeaderName::from_static("Nats-Sequence"),
273 vec![HeaderValue::from_static("1")]
274 ),
275 ]
276 .as_slice(),
277 raw_headers.as_slice(),
278 );
279 }
280}