1#![allow(clippy::mutable_key_type)]
20
21use std::{collections::HashMap, fmt, slice::Iter, str::FromStr};
24
25use bytes::Bytes;
26use serde::{Deserialize, Serialize};
27
28#[derive(Clone, PartialEq, Eq, Debug, Default, Deserialize, Serialize)]
48pub struct HeaderMap {
49 inner: HashMap<HeaderName, Vec<HeaderValue>>,
50}
51
52impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
53 fn from_iter<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: T) -> Self {
54 let mut header_map = HeaderMap::new();
55 for (key, value) in iter {
56 header_map.insert(key, value);
57 }
58 header_map
59 }
60}
61
62impl HeaderMap {
63 pub fn iter(&self) -> std::collections::hash_map::Iter<'_, HeaderName, Vec<HeaderValue>> {
64 self.inner.iter()
65 }
66}
67
68pub struct GetAll<'a, T> {
69 inner: Iter<'a, T>,
70}
71
72impl<'a, T> Iterator for GetAll<'a, T> {
73 type Item = &'a T;
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.inner.next()
77 }
78}
79
80impl HeaderMap {
81 pub fn new() -> Self {
92 HeaderMap::default()
93 }
94
95 pub fn is_empty(&self) -> bool {
111 self.inner.is_empty()
112 }
113}
114
115impl HeaderMap {
116 pub fn insert<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
127 self.inner
128 .insert(name.into_header_name(), vec![value.into_header_value()]);
129 }
130
131 pub fn append<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
143 let key = name.into_header_name();
144 let v = self.inner.get_mut(&key);
145 match v {
146 Some(v) => {
147 v.push(value.into_header_value());
148 }
149 None => {
150 self.insert(key, value.into_header_value());
151 }
152 }
153 }
154
155 pub fn get<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
167 self.inner
168 .get(&key.into_header_name())
169 .and_then(|x| x.first())
170 }
171
172 pub fn get_last<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
184 self.inner
185 .get(&key.into_header_name())
186 .and_then(|x| x.last())
187 }
188
189 pub fn get_all<K: IntoHeaderName>(&self, key: K) -> GetAll<HeaderValue> {
204 let inner = self
205 .inner
206 .get(&key.into_header_name())
207 .map(|x| x.iter())
208 .unwrap_or([].iter());
209
210 GetAll { inner }
211 }
212
213 pub(crate) fn to_bytes(&self) -> Vec<u8> {
214 let mut buf = vec![];
215 buf.extend_from_slice(b"NATS/1.0\r\n");
216 for (k, vs) in &self.inner {
217 for v in vs.iter() {
218 buf.extend_from_slice(k.as_str().as_bytes());
219 buf.extend_from_slice(b": ");
220 buf.extend_from_slice(v.inner.as_bytes());
221 buf.extend_from_slice(b"\r\n");
222 }
223 }
224 buf.extend_from_slice(b"\r\n");
225 buf
226 }
227}
228
229#[derive(Clone, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
241pub struct HeaderValue {
242 inner: String,
243}
244
245impl fmt::Display for HeaderValue {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 fmt::Display::fmt(&self.as_str(), f)
248 }
249}
250
251impl AsRef<[u8]> for HeaderValue {
252 fn as_ref(&self) -> &[u8] {
253 self.inner.as_ref()
254 }
255}
256
257impl AsRef<str> for HeaderValue {
258 fn as_ref(&self) -> &str {
259 self.as_str()
260 }
261}
262
263impl From<i16> for HeaderValue {
264 fn from(v: i16) -> Self {
265 Self {
266 inner: v.to_string(),
267 }
268 }
269}
270
271impl From<i32> for HeaderValue {
272 fn from(v: i32) -> Self {
273 Self {
274 inner: v.to_string(),
275 }
276 }
277}
278
279impl From<i64> for HeaderValue {
280 fn from(v: i64) -> Self {
281 Self {
282 inner: v.to_string(),
283 }
284 }
285}
286
287impl From<isize> for HeaderValue {
288 fn from(v: isize) -> Self {
289 Self {
290 inner: v.to_string(),
291 }
292 }
293}
294
295impl From<u16> for HeaderValue {
296 fn from(v: u16) -> Self {
297 Self {
298 inner: v.to_string(),
299 }
300 }
301}
302
303impl From<u32> for HeaderValue {
304 fn from(v: u32) -> Self {
305 Self {
306 inner: v.to_string(),
307 }
308 }
309}
310
311impl From<u64> for HeaderValue {
312 fn from(v: u64) -> Self {
313 Self {
314 inner: v.to_string(),
315 }
316 }
317}
318
319impl From<usize> for HeaderValue {
320 fn from(v: usize) -> Self {
321 Self {
322 inner: v.to_string(),
323 }
324 }
325}
326
327impl FromStr for HeaderValue {
328 type Err = ParseHeaderValueError;
329
330 fn from_str(s: &str) -> Result<Self, Self::Err> {
331 if s.contains(['\r', '\n']) {
332 return Err(ParseHeaderValueError);
333 }
334
335 Ok(HeaderValue {
336 inner: s.to_string(),
337 })
338 }
339}
340
341impl From<&str> for HeaderValue {
342 fn from(v: &str) -> Self {
343 Self {
344 inner: v.to_string(),
345 }
346 }
347}
348
349impl HeaderValue {
350 pub fn new() -> Self {
351 HeaderValue::default()
352 }
353
354 pub fn as_str(&self) -> &str {
355 self.inner.as_str()
356 }
357}
358
359#[derive(Debug, Clone)]
360pub struct ParseHeaderValueError;
361
362impl fmt::Display for ParseHeaderValueError {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 write!(
365 f,
366 r#"invalid character found in header value (value cannot contain '\r' or '\n')"#
367 )
368 }
369}
370
371impl std::error::Error for ParseHeaderValueError {}
372
373pub trait IntoHeaderName {
374 fn into_header_name(self) -> HeaderName;
375}
376
377impl IntoHeaderName for &str {
378 fn into_header_name(self) -> HeaderName {
379 HeaderName {
380 inner: HeaderRepr::Custom(self.into()),
381 }
382 }
383}
384
385impl IntoHeaderName for HeaderName {
386 fn into_header_name(self) -> HeaderName {
387 self
388 }
389}
390
391pub trait IntoHeaderValue {
392 fn into_header_value(self) -> HeaderValue;
393}
394
395impl IntoHeaderValue for &str {
396 fn into_header_value(self) -> HeaderValue {
397 HeaderValue {
398 inner: self.to_string(),
399 }
400 }
401}
402
403impl IntoHeaderValue for HeaderValue {
404 fn into_header_value(self) -> HeaderValue {
405 self
406 }
407}
408
409macro_rules! standard_headers {
410 (
411 $(
412 $(#[$docs:meta])*
413 ($variant:ident, $constant:ident, $bytes:literal);
414 )+
415 ) => {
416 #[allow(clippy::enum_variant_names)]
417 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
418 enum StandardHeader {
419 $(
420 $variant,
421 )+
422 }
423
424 $(
425 $(#[$docs])*
426 pub const $constant: HeaderName = HeaderName {
427 inner: HeaderRepr::Standard(StandardHeader::$variant),
428 };
429 )+
430
431 impl StandardHeader {
432 #[inline]
433 fn as_str(&self) -> &'static str {
434 match *self {
435 $(
436 StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) },
437 )+
438 }
439 }
440
441 const fn from_bytes(bytes: &[u8]) -> Option<StandardHeader> {
442 match bytes {
443 $(
444 $bytes => Some(StandardHeader::$variant),
445 )+
446 _ => None,
447 }
448 }
449 }
450
451 #[cfg(test)]
452 mod standard_header_tests {
453 use super::HeaderName;
454 use std::str::{self, FromStr};
455
456 const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[
457 $(
458 (&super::$constant, $bytes),
459 )+
460 ];
461
462 #[test]
463 fn from_str() {
464 for &(header, bytes) in TEST_HEADERS {
465 let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8");
466 assert_eq!(HeaderName::from_str(utf8).unwrap(), *header);
467 }
468 }
469 }
470 }
471}
472
473standard_headers! {
475 (NatsStream, NATS_STREAM, b"Nats-Stream");
477 (NatsSequence, NATS_SEQUENCE, b"Nats-Sequence");
479 (NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp");
481 (NatsSubject, NATS_SUBJECT, b"Nats-Subject");
483 (NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id");
485 (NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream");
487 (NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer");
489 (NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence");
491 (NatsExpectedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence");
493 (NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id");
495 (NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence");
497 (NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream");
499}
500
501#[derive(Debug, Hash, PartialEq, Eq, Clone)]
502struct CustomHeader {
503 bytes: Bytes,
504}
505
506impl CustomHeader {
507 #[inline]
508 pub(crate) const fn from_static(value: &'static str) -> CustomHeader {
509 CustomHeader {
510 bytes: Bytes::from_static(value.as_bytes()),
511 }
512 }
513
514 #[inline]
515 pub(crate) fn as_str(&self) -> &str {
516 unsafe { std::str::from_utf8_unchecked(self.bytes.as_ref()) }
517 }
518}
519
520impl From<String> for CustomHeader {
521 #[inline]
522 fn from(value: String) -> CustomHeader {
523 CustomHeader {
524 bytes: Bytes::from(value),
525 }
526 }
527}
528
529impl<'a> From<&'a str> for CustomHeader {
530 #[inline]
531 fn from(value: &'a str) -> CustomHeader {
532 CustomHeader {
533 bytes: Bytes::copy_from_slice(value.as_bytes()),
534 }
535 }
536}
537
538#[derive(Debug, Hash, PartialEq, Eq, Clone)]
539enum HeaderRepr {
540 Standard(StandardHeader),
541 Custom(CustomHeader),
542}
543
544#[derive(Clone, PartialEq, Eq, Hash, Debug)]
554pub struct HeaderName {
555 inner: HeaderRepr,
556}
557
558impl HeaderName {
559 #[inline]
561 pub const fn from_static(value: &'static str) -> HeaderName {
562 if let Some(standard) = StandardHeader::from_bytes(value.as_bytes()) {
563 return HeaderName {
564 inner: HeaderRepr::Standard(standard),
565 };
566 }
567
568 HeaderName {
569 inner: HeaderRepr::Custom(CustomHeader::from_static(value)),
570 }
571 }
572
573 #[inline]
575 fn as_str(&self) -> &str {
576 match self.inner {
577 HeaderRepr::Standard(v) => v.as_str(),
578 HeaderRepr::Custom(ref v) => v.as_str(),
579 }
580 }
581}
582
583impl FromStr for HeaderName {
584 type Err = ParseHeaderNameError;
585
586 fn from_str(s: &str) -> Result<Self, Self::Err> {
587 if s.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126) {
588 return Err(ParseHeaderNameError);
589 }
590
591 match StandardHeader::from_bytes(s.as_ref()) {
592 Some(v) => Ok(HeaderName {
593 inner: HeaderRepr::Standard(v),
594 }),
595 None => Ok(HeaderName {
596 inner: HeaderRepr::Custom(CustomHeader::from(s)),
597 }),
598 }
599 }
600}
601
602impl fmt::Display for HeaderName {
603 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604 fmt::Display::fmt(&self.as_str(), f)
605 }
606}
607
608impl AsRef<[u8]> for HeaderName {
609 fn as_ref(&self) -> &[u8] {
610 self.as_str().as_bytes()
611 }
612}
613
614impl AsRef<str> for HeaderName {
615 fn as_ref(&self) -> &str {
616 self.as_str()
617 }
618}
619
620impl Serialize for HeaderName {
621 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
622 where
623 S: serde::Serializer,
624 {
625 serializer.serialize_str(self.as_str())
626 }
627}
628
629impl<'de> Deserialize<'de> for HeaderName {
630 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
631 where
632 D: serde::Deserializer<'de>,
633 {
634 String::deserialize(deserializer)?
635 .parse()
636 .map_err(serde::de::Error::custom)
637 }
638}
639
640#[derive(Debug, Clone)]
641pub struct ParseHeaderNameError;
642
643impl std::fmt::Display for ParseHeaderNameError {
644 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
645 write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')")
646 }
647}
648
649impl std::error::Error for ParseHeaderNameError {}
650
651#[cfg(test)]
652mod tests {
653 use super::{HeaderMap, HeaderName, HeaderValue};
654 use std::str::{from_utf8, FromStr};
655
656 #[test]
657 fn try_from() {
658 let mut headers = HeaderMap::new();
659 headers.insert("name", "something".parse::<HeaderValue>().unwrap());
660 headers.insert("name", "something2");
661 }
662
663 #[test]
664 fn append() {
665 let mut headers = HeaderMap::new();
666 headers.append("Key", "value");
667 headers.append("Key", "second_value");
668
669 let mut result = headers.get_all("Key");
670
671 assert_eq!(
672 result.next().unwrap(),
673 &HeaderValue::from_str("value").unwrap()
674 );
675
676 assert_eq!(
677 result.next().unwrap(),
678 &HeaderValue::from_str("second_value").unwrap()
679 );
680
681 assert_eq!(result.next(), None);
682 }
683
684 #[test]
685 fn get_string() {
686 let mut headers = HeaderMap::new();
687 headers.append("Key", "value");
688 headers.append("Key", "other");
689
690 assert_eq!(headers.get("Key").unwrap().to_string(), "value");
691
692 let key: String = headers.get("Key").unwrap().as_str().into();
693 assert_eq!(key, "value".to_string());
694
695 let key: String = headers.get("Key").unwrap().as_str().to_owned();
696 assert_eq!(key, "value".to_string());
697
698 assert_eq!(headers.get("Key").unwrap().as_str(), "value");
699
700 let key: String = headers.get_last("Key").unwrap().as_str().into();
701 assert_eq!(key, "other".to_string());
702 }
703
704 #[test]
705 fn insert() {
706 let mut headers = HeaderMap::new();
707 headers.insert("Key", "Value");
708
709 let mut result = headers.get_all("Key");
710
711 assert_eq!(
712 result.next().unwrap(),
713 &HeaderValue::from_str("Value").unwrap()
714 );
715 assert_eq!(result.next(), None);
716 }
717
718 #[test]
719 fn serialize() {
720 let mut headers = HeaderMap::new();
721 headers.append("Key", "value");
722 headers.append("Key", "second_value");
723 headers.insert("Second", "SecondValue");
724
725 let bytes = headers.to_bytes();
726
727 println!("bytes: {:?}", from_utf8(&bytes));
728 }
729
730 #[test]
731 fn is_empty() {
732 let mut headers = HeaderMap::new();
733 assert!(headers.is_empty());
734
735 headers.append("Key", "value");
736 headers.append("Key", "second_value");
737 headers.insert("Second", "SecondValue");
738 assert!(!headers.is_empty());
739 }
740
741 #[test]
742 fn parse_value() {
743 assert!("Foo\r".parse::<HeaderValue>().is_err());
744 assert!("Foo\n".parse::<HeaderValue>().is_err());
745 assert!("Foo\r\n".parse::<HeaderValue>().is_err());
746 }
747
748 #[test]
749 fn valid_header_name() {
750 let valid_header_name = "X-Custom-Header";
751 let parsed_header = HeaderName::from_str(valid_header_name);
752
753 assert!(
754 parsed_header.is_ok(),
755 "Expected Ok(HeaderName), but got an error: {:?}",
756 parsed_header.err()
757 );
758 }
759
760 #[test]
761 fn dollar_header_name() {
762 let valid_header_name = "$X_Custom_Header";
763 let parsed_header = HeaderName::from_str(valid_header_name);
764
765 assert!(
766 parsed_header.is_ok(),
767 "Expected Ok(HeaderName), but got an error: {:?}",
768 parsed_header.err()
769 );
770 }
771
772 #[test]
773 fn invalid_header_name_with_space() {
774 let invalid_header_name = "X Custom Header";
775 let parsed_header = HeaderName::from_str(invalid_header_name);
776
777 assert!(
778 parsed_header.is_err(),
779 "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
780 parsed_header.ok()
781 );
782 }
783
784 #[test]
785 fn invalid_header_name_with_special_chars() {
786 let invalid_header_name = "X-Header:";
787 let parsed_header = HeaderName::from_str(invalid_header_name);
788
789 assert!(
790 parsed_header.is_err(),
791 "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
792 parsed_header.ok()
793 );
794 }
795
796 #[test]
797 fn from_static_eq() {
798 let a = HeaderName::from_static("NATS-Stream");
799 let b = HeaderName::from_static("NATS-Stream");
800
801 assert_eq!(a, b);
802 }
803
804 #[test]
805 fn header_name_serde() {
806 let raw = "Nats-Stream";
807 let raw_json = "\"Nats-Stream\"";
808 let header = HeaderName::from_static(raw);
809
810 assert_eq!(serde_json::to_string(&header).unwrap(), raw_json);
812 assert_eq!(
813 serde_json::from_str::<HeaderName>(raw_json).unwrap(),
814 header
815 );
816 }
817}