1#![allow(clippy::mutable_key_type)]
20
21use std::{collections::HashMap, fmt, slice::Iter, str::FromStr};
24
25use bytes::Bytes;
26use serde::{Deserialize, Serialize};
27
28#[derive(Clone, PartialEq, Eq, Debug, Default, Deserialize, Serialize)]
48pub struct HeaderMap {
49 inner: HashMap<HeaderName, Vec<HeaderValue>>,
50}
51
52impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
53 fn from_iter<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: T) -> Self {
54 let mut header_map = HeaderMap::new();
55 for (key, value) in iter {
56 header_map.insert(key, value);
57 }
58 header_map
59 }
60}
61
62impl HeaderMap {
63 pub fn iter(&self) -> std::collections::hash_map::Iter<'_, HeaderName, Vec<HeaderValue>> {
64 self.inner.iter()
65 }
66}
67
68pub struct GetAll<'a, T> {
69 inner: Iter<'a, T>,
70}
71
72impl<'a, T> Iterator for GetAll<'a, T> {
73 type Item = &'a T;
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.inner.next()
77 }
78}
79
80impl HeaderMap {
81 pub fn new() -> Self {
92 HeaderMap::default()
93 }
94
95 pub fn is_empty(&self) -> bool {
111 self.inner.is_empty()
112 }
113
114 pub fn len(&self) -> usize {
115 self.inner.len()
116 }
117}
118
119impl HeaderMap {
120 pub fn insert<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
131 self.inner
132 .insert(name.into_header_name(), vec![value.into_header_value()]);
133 }
134
135 pub fn append<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
147 let key = name.into_header_name();
148 let v = self.inner.get_mut(&key);
149 match v {
150 Some(v) => {
151 v.push(value.into_header_value());
152 }
153 None => {
154 self.insert(key, value.into_header_value());
155 }
156 }
157 }
158
159 pub fn get<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
171 self.inner
172 .get(&key.into_header_name())
173 .and_then(|x| x.first())
174 }
175
176 pub fn get_last<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
188 self.inner
189 .get(&key.into_header_name())
190 .and_then(|x| x.last())
191 }
192
193 pub fn get_all<K: IntoHeaderName>(&self, key: K) -> GetAll<HeaderValue> {
208 let inner = self
209 .inner
210 .get(&key.into_header_name())
211 .map(|x| x.iter())
212 .unwrap_or([].iter());
213
214 GetAll { inner }
215 }
216
217 pub(crate) fn to_bytes(&self) -> Vec<u8> {
218 let mut buf = vec![];
219 buf.extend_from_slice(b"NATS/1.0\r\n");
220 for (k, vs) in &self.inner {
221 for v in vs.iter() {
222 buf.extend_from_slice(k.as_str().as_bytes());
223 buf.extend_from_slice(b": ");
224 buf.extend_from_slice(v.inner.as_bytes());
225 buf.extend_from_slice(b"\r\n");
226 }
227 }
228 buf.extend_from_slice(b"\r\n");
229 buf
230 }
231}
232
233#[derive(Clone, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
245pub struct HeaderValue {
246 inner: String,
247}
248
249impl fmt::Display for HeaderValue {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 fmt::Display::fmt(&self.as_str(), f)
252 }
253}
254
255impl AsRef<[u8]> for HeaderValue {
256 fn as_ref(&self) -> &[u8] {
257 self.inner.as_ref()
258 }
259}
260
261impl AsRef<str> for HeaderValue {
262 fn as_ref(&self) -> &str {
263 self.as_str()
264 }
265}
266
267impl From<i16> for HeaderValue {
268 fn from(v: i16) -> Self {
269 Self {
270 inner: v.to_string(),
271 }
272 }
273}
274
275impl From<i32> for HeaderValue {
276 fn from(v: i32) -> Self {
277 Self {
278 inner: v.to_string(),
279 }
280 }
281}
282
283impl From<i64> for HeaderValue {
284 fn from(v: i64) -> Self {
285 Self {
286 inner: v.to_string(),
287 }
288 }
289}
290
291impl From<isize> for HeaderValue {
292 fn from(v: isize) -> Self {
293 Self {
294 inner: v.to_string(),
295 }
296 }
297}
298
299impl From<u16> for HeaderValue {
300 fn from(v: u16) -> Self {
301 Self {
302 inner: v.to_string(),
303 }
304 }
305}
306
307impl From<u32> for HeaderValue {
308 fn from(v: u32) -> Self {
309 Self {
310 inner: v.to_string(),
311 }
312 }
313}
314
315impl From<u64> for HeaderValue {
316 fn from(v: u64) -> Self {
317 Self {
318 inner: v.to_string(),
319 }
320 }
321}
322
323impl From<usize> for HeaderValue {
324 fn from(v: usize) -> Self {
325 Self {
326 inner: v.to_string(),
327 }
328 }
329}
330
331impl FromStr for HeaderValue {
332 type Err = ParseHeaderValueError;
333
334 fn from_str(s: &str) -> Result<Self, Self::Err> {
335 if s.contains(['\r', '\n']) {
336 return Err(ParseHeaderValueError);
337 }
338
339 Ok(HeaderValue {
340 inner: s.to_string(),
341 })
342 }
343}
344
345impl From<&str> for HeaderValue {
346 fn from(v: &str) -> Self {
347 Self {
348 inner: v.to_string(),
349 }
350 }
351}
352
353impl From<String> for HeaderValue {
354 fn from(inner: String) -> Self {
355 Self { inner }
356 }
357}
358
359impl HeaderValue {
360 pub fn new() -> Self {
361 HeaderValue::default()
362 }
363
364 pub fn as_str(&self) -> &str {
365 self.inner.as_str()
366 }
367}
368
369#[derive(Debug, Clone)]
370pub struct ParseHeaderValueError;
371
372impl fmt::Display for ParseHeaderValueError {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 write!(
375 f,
376 r#"invalid character found in header value (value cannot contain '\r' or '\n')"#
377 )
378 }
379}
380
381impl std::error::Error for ParseHeaderValueError {}
382
383pub trait IntoHeaderName {
384 fn into_header_name(self) -> HeaderName;
385}
386
387impl IntoHeaderName for &str {
388 fn into_header_name(self) -> HeaderName {
389 HeaderName {
390 inner: HeaderRepr::Custom(self.into()),
391 }
392 }
393}
394
395impl IntoHeaderName for String {
396 fn into_header_name(self) -> HeaderName {
397 HeaderName {
398 inner: HeaderRepr::Custom(self.into()),
399 }
400 }
401}
402
403impl IntoHeaderName for HeaderName {
404 fn into_header_name(self) -> HeaderName {
405 self
406 }
407}
408
409pub trait IntoHeaderValue {
410 fn into_header_value(self) -> HeaderValue;
411}
412
413impl IntoHeaderValue for &str {
414 fn into_header_value(self) -> HeaderValue {
415 HeaderValue {
416 inner: self.to_string(),
417 }
418 }
419}
420
421impl IntoHeaderValue for String {
422 fn into_header_value(self) -> HeaderValue {
423 HeaderValue { inner: self }
424 }
425}
426
427impl IntoHeaderValue for HeaderValue {
428 fn into_header_value(self) -> HeaderValue {
429 self
430 }
431}
432
433macro_rules! standard_headers {
434 (
435 $(
436 $(#[$docs:meta])*
437 ($variant:ident, $constant:ident, $bytes:literal);
438 )+
439 ) => {
440 #[allow(clippy::enum_variant_names)]
441 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
442 enum StandardHeader {
443 $(
444 $variant,
445 )+
446 }
447
448 $(
449 $(#[$docs])*
450 pub const $constant: HeaderName = HeaderName {
451 inner: HeaderRepr::Standard(StandardHeader::$variant),
452 };
453 )+
454
455 impl StandardHeader {
456 #[inline]
457 fn as_str(&self) -> &'static str {
458 match *self {
459 $(
460 StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) },
461 )+
462 }
463 }
464
465 const fn from_bytes(bytes: &[u8]) -> Option<StandardHeader> {
466 match bytes {
467 $(
468 $bytes => Some(StandardHeader::$variant),
469 )+
470 _ => None,
471 }
472 }
473 }
474
475 #[cfg(test)]
476 mod standard_header_tests {
477 use super::HeaderName;
478 use std::str::{self, FromStr};
479
480 const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[
481 $(
482 (&super::$constant, $bytes),
483 )+
484 ];
485
486 #[test]
487 fn from_str() {
488 for &(header, bytes) in TEST_HEADERS {
489 let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8");
490 assert_eq!(HeaderName::from_str(utf8).unwrap(), *header);
491 }
492 }
493 }
494 }
495}
496
497standard_headers! {
499 (NatsStream, NATS_STREAM, b"Nats-Stream");
501 (NatsSequence, NATS_SEQUENCE, b"Nats-Sequence");
503 (NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp");
505 (NatsSubject, NATS_SUBJECT, b"Nats-Subject");
507 (NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id");
509 (NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream");
511 (NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer");
513 (NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence");
515 (NatsExpectedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence");
517 (NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id");
519 (NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence");
521 (NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream");
523 (NatsMessageTtl, NATS_MESSAGE_TTL, b"Nats-TTL");
525 (NatsMarkerReason, NATS_MARKER_REASON, b"Nats-Marker-Reason");
527}
528
529#[derive(Debug, Hash, PartialEq, Eq, Clone)]
530struct CustomHeader {
531 bytes: Bytes,
532}
533
534impl CustomHeader {
535 #[inline]
536 pub(crate) const fn from_static(value: &'static str) -> CustomHeader {
537 CustomHeader {
538 bytes: Bytes::from_static(value.as_bytes()),
539 }
540 }
541
542 #[inline]
543 pub(crate) fn as_str(&self) -> &str {
544 unsafe { std::str::from_utf8_unchecked(self.bytes.as_ref()) }
545 }
546}
547
548impl From<String> for CustomHeader {
549 #[inline]
550 fn from(value: String) -> CustomHeader {
551 CustomHeader {
552 bytes: Bytes::from(value),
553 }
554 }
555}
556
557impl<'a> From<&'a str> for CustomHeader {
558 #[inline]
559 fn from(value: &'a str) -> CustomHeader {
560 CustomHeader {
561 bytes: Bytes::copy_from_slice(value.as_bytes()),
562 }
563 }
564}
565
566#[derive(Debug, Hash, PartialEq, Eq, Clone)]
567enum HeaderRepr {
568 Standard(StandardHeader),
569 Custom(CustomHeader),
570}
571
572#[derive(Clone, PartialEq, Eq, Hash, Debug)]
582pub struct HeaderName {
583 inner: HeaderRepr,
584}
585
586impl HeaderName {
587 #[inline]
589 pub const fn from_static(value: &'static str) -> HeaderName {
590 if let Some(standard) = StandardHeader::from_bytes(value.as_bytes()) {
591 return HeaderName {
592 inner: HeaderRepr::Standard(standard),
593 };
594 }
595
596 HeaderName {
597 inner: HeaderRepr::Custom(CustomHeader::from_static(value)),
598 }
599 }
600
601 #[inline]
603 fn as_str(&self) -> &str {
604 match self.inner {
605 HeaderRepr::Standard(v) => v.as_str(),
606 HeaderRepr::Custom(ref v) => v.as_str(),
607 }
608 }
609}
610
611impl FromStr for HeaderName {
612 type Err = ParseHeaderNameError;
613
614 fn from_str(s: &str) -> Result<Self, Self::Err> {
615 if s.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126) {
616 return Err(ParseHeaderNameError);
617 }
618
619 match StandardHeader::from_bytes(s.as_ref()) {
620 Some(v) => Ok(HeaderName {
621 inner: HeaderRepr::Standard(v),
622 }),
623 None => Ok(HeaderName {
624 inner: HeaderRepr::Custom(CustomHeader::from(s)),
625 }),
626 }
627 }
628}
629
630impl fmt::Display for HeaderName {
631 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
632 fmt::Display::fmt(&self.as_str(), f)
633 }
634}
635
636impl AsRef<[u8]> for HeaderName {
637 fn as_ref(&self) -> &[u8] {
638 self.as_str().as_bytes()
639 }
640}
641
642impl AsRef<str> for HeaderName {
643 fn as_ref(&self) -> &str {
644 self.as_str()
645 }
646}
647
648impl Serialize for HeaderName {
649 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
650 where
651 S: serde::Serializer,
652 {
653 serializer.serialize_str(self.as_str())
654 }
655}
656
657impl<'de> Deserialize<'de> for HeaderName {
658 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
659 where
660 D: serde::Deserializer<'de>,
661 {
662 String::deserialize(deserializer)?
663 .parse()
664 .map_err(serde::de::Error::custom)
665 }
666}
667
668#[derive(Debug, Clone)]
669pub struct ParseHeaderNameError;
670
671impl std::fmt::Display for ParseHeaderNameError {
672 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
673 write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')")
674 }
675}
676
677impl std::error::Error for ParseHeaderNameError {}
678
679#[cfg(test)]
680mod tests {
681 use super::{HeaderMap, HeaderName, HeaderValue, IntoHeaderName, IntoHeaderValue};
682 use std::str::{from_utf8, FromStr};
683
684 #[test]
685 fn try_from() {
686 let mut headers = HeaderMap::new();
687 headers.insert("name", "something".parse::<HeaderValue>().unwrap());
688 headers.insert("name", "something2");
689 }
690
691 #[test]
692 fn append() {
693 let mut headers = HeaderMap::new();
694 headers.append("Key", "value");
695 headers.append("Key", "second_value");
696
697 let mut result = headers.get_all("Key");
698
699 assert_eq!(
700 result.next().unwrap(),
701 &HeaderValue::from_str("value").unwrap()
702 );
703
704 assert_eq!(
705 result.next().unwrap(),
706 &HeaderValue::from_str("second_value").unwrap()
707 );
708
709 assert_eq!(result.next(), None);
710 }
711
712 #[test]
713 fn get_string() {
714 let mut headers = HeaderMap::new();
715 headers.append("Key", "value");
716 headers.append("Key", "other");
717
718 assert_eq!(headers.get("Key").unwrap().to_string(), "value");
719
720 let key: String = headers.get("Key").unwrap().as_str().into();
721 assert_eq!(key, "value".to_string());
722
723 let key: String = headers.get("Key").unwrap().as_str().to_owned();
724 assert_eq!(key, "value".to_string());
725
726 assert_eq!(headers.get("Key").unwrap().as_str(), "value");
727
728 let key: String = headers.get_last("Key").unwrap().as_str().into();
729 assert_eq!(key, "other".to_string());
730 }
731
732 #[test]
733 fn insert() {
734 let mut headers = HeaderMap::new();
735 headers.insert("Key", "Value");
736
737 let mut result = headers.get_all("Key");
738
739 assert_eq!(
740 result.next().unwrap(),
741 &HeaderValue::from_str("Value").unwrap()
742 );
743 assert_eq!(result.next(), None);
744 }
745
746 #[test]
747 fn serialize() {
748 let mut headers = HeaderMap::new();
749 headers.append("Key", "value");
750 headers.append("Key", "second_value");
751 headers.insert("Second", "SecondValue");
752
753 let bytes = headers.to_bytes();
754
755 println!("bytes: {:?}", from_utf8(&bytes));
756 }
757
758 #[test]
759 fn is_empty() {
760 let mut headers = HeaderMap::new();
761 assert!(headers.is_empty());
762
763 headers.append("Key", "value");
764 headers.append("Key", "second_value");
765 headers.insert("Second", "SecondValue");
766 assert!(!headers.is_empty());
767 }
768
769 #[test]
770 fn parse_value() {
771 assert!("Foo\r".parse::<HeaderValue>().is_err());
772 assert!("Foo\n".parse::<HeaderValue>().is_err());
773 assert!("Foo\r\n".parse::<HeaderValue>().is_err());
774 }
775
776 #[test]
777 fn valid_header_name() {
778 let valid_header_name = "X-Custom-Header";
779 let parsed_header = HeaderName::from_str(valid_header_name);
780
781 assert!(
782 parsed_header.is_ok(),
783 "Expected Ok(HeaderName), but got an error: {:?}",
784 parsed_header.err()
785 );
786 }
787
788 #[test]
789 fn dollar_header_name() {
790 let valid_header_name = "$X_Custom_Header";
791 let parsed_header = HeaderName::from_str(valid_header_name);
792
793 assert!(
794 parsed_header.is_ok(),
795 "Expected Ok(HeaderName), but got an error: {:?}",
796 parsed_header.err()
797 );
798 }
799
800 #[test]
801 fn invalid_header_name_with_space() {
802 let invalid_header_name = "X Custom Header";
803 let parsed_header = HeaderName::from_str(invalid_header_name);
804
805 assert!(
806 parsed_header.is_err(),
807 "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
808 parsed_header.ok()
809 );
810 }
811
812 #[test]
813 fn invalid_header_name_with_special_chars() {
814 let invalid_header_name = "X-Header:";
815 let parsed_header = HeaderName::from_str(invalid_header_name);
816
817 assert!(
818 parsed_header.is_err(),
819 "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
820 parsed_header.ok()
821 );
822 }
823
824 #[test]
825 fn from_static_eq() {
826 let a = HeaderName::from_static("NATS-Stream");
827 let b = HeaderName::from_static("NATS-Stream");
828
829 assert_eq!(a, b);
830 }
831
832 #[test]
833 fn header_name_serde() {
834 let raw = "Nats-Stream";
835 let raw_json = "\"Nats-Stream\"";
836 let header = HeaderName::from_static(raw);
837
838 assert_eq!(serde_json::to_string(&header).unwrap(), raw_json);
840 assert_eq!(
841 serde_json::from_str::<HeaderName>(raw_json).unwrap(),
842 header
843 );
844 }
845
846 #[test]
847 fn header_name_from_string() {
848 let string = "NATS-Stream".to_string();
849 let name = string.into_header_name();
850
851 assert_eq!("NATS-Stream", name.as_str());
852 }
853
854 #[test]
855 fn header_value_from_string_with_trait() {
856 let string = "some value".to_string();
857
858 let value = string.into_header_value();
859
860 assert_eq!("some value", value.as_str());
861 }
862
863 #[test]
864 fn header_value_from_string() {
865 let string = "some value".to_string();
866
867 let value: HeaderValue = string.into();
868
869 assert_eq!("some value", value.as_str());
870 }
871}