Skip to main content

async_nats/
header.rs

1// Copyright 2020-2023 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14// NOTE(thomastaylor312): This clippy lint is coming from serialize and deserialize and is likely a
15// false positive due to the bytes crate, see
16// https://rust-lang.github.io/rust-clippy/master/index.html#/mutable_key_type for more details.
17// Sorry to make this global to this module, rather than on the `HeaderMap` struct, but because it
18// is coming from the derive, it didn't work to set it on the struct.
19#![allow(clippy::mutable_key_type)]
20
21//! NATS [Message][crate::Message] headers, modeled loosely after the `http::header` crate.
22
23use std::{collections::HashMap, fmt, slice::Iter, str::FromStr};
24
25use bytes::Bytes;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27
28/// A struct for handling NATS headers.
29/// Has a similar API to `http::header`, but properly serializes and deserializes
30/// according to NATS requirements.
31///
32/// # Examples
33///
34/// ```
35/// # #[tokio::main]
36/// # async fn main() -> Result<(), async_nats::Error> {
37/// let client = async_nats::connect("demo.nats.io").await?;
38/// let mut headers = async_nats::HeaderMap::new();
39/// headers.insert("Key", "Value");
40/// client
41///     .publish_with_headers("subject", headers, "payload".into())
42///     .await?;
43/// # Ok(())
44/// # }
45/// ```
46
47#[derive(Clone, PartialEq, Eq, Debug, Default)]
48pub struct HeaderMap {
49    inner: HashMap<HeaderName, Vec<HeaderValue>>,
50}
51
52/// Helper enum for backward-compatible deserialization using serde's untagged feature
53/// This is required because of the bug #1470 where the client incorrectly serialized
54/// headers with an "inner" wrapper.
55#[derive(Deserialize)]
56#[serde(untagged)]
57enum HeaderMapHelper {
58    // Legacy format with "inner" wrapper
59    Legacy {
60        inner: HashMap<HeaderName, Vec<HeaderValue>>,
61    },
62    // Proper format - direct HashMap
63    Current(HashMap<HeaderName, Vec<HeaderValue>>),
64}
65
66impl<'de> Deserialize<'de> for HeaderMap {
67    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
68    where
69        D: Deserializer<'de>,
70    {
71        // Use the untagged enum to automatically try both formats
72        let helper = HeaderMapHelper::deserialize(deserializer)?;
73
74        Ok(match helper {
75            HeaderMapHelper::Legacy { inner } => HeaderMap { inner },
76            HeaderMapHelper::Current(inner) => HeaderMap { inner },
77        })
78    }
79}
80
81impl Serialize for HeaderMap {
82    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: Serializer,
85    {
86        // Serialize as the new format (direct HashMap without "inner" wrapper)
87        self.inner.serialize(serializer)
88    }
89}
90
91impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
92    fn from_iter<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: T) -> Self {
93        let mut header_map = HeaderMap::new();
94        for (key, value) in iter {
95            header_map.insert(key, value);
96        }
97        header_map
98    }
99}
100
101impl HeaderMap {
102    pub fn iter(&self) -> std::collections::hash_map::Iter<'_, HeaderName, Vec<HeaderValue>> {
103        self.inner.iter()
104    }
105}
106
107pub struct GetAll<'a, T> {
108    inner: Iter<'a, T>,
109}
110
111impl<'a, T> Iterator for GetAll<'a, T> {
112    type Item = &'a T;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        self.inner.next()
116    }
117}
118
119impl HeaderMap {
120    /// Create an empty `HeaderMap`.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// # use async_nats::HeaderMap;
126    /// let map = HeaderMap::new();
127    ///
128    /// assert!(map.is_empty());
129    /// ```
130    pub fn new() -> Self {
131        HeaderMap::default()
132    }
133
134    /// Returns true if the map contains no elements.
135    ///
136    /// # Examples
137    ///
138    /// ```
139    /// # use async_nats::HeaderMap;
140    /// # use async_nats::header::NATS_SUBJECT;
141    /// let mut map = HeaderMap::new();
142    ///
143    /// assert!(map.is_empty());
144    ///
145    /// map.insert(NATS_SUBJECT, "FOO.BAR");
146    ///
147    /// assert!(!map.is_empty());
148    /// ```
149    pub fn is_empty(&self) -> bool {
150        self.inner.is_empty()
151    }
152
153    pub fn len(&self) -> usize {
154        self.inner.len()
155    }
156}
157
158impl HeaderMap {
159    /// Inserts a new value to a [HeaderMap].
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// use async_nats::HeaderMap;
165    ///
166    /// let mut headers = HeaderMap::new();
167    /// headers.insert("Key", "Value");
168    /// ```
169    pub fn insert<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
170        self.inner
171            .insert(name.into_header_name(), vec![value.into_header_value()]);
172    }
173
174    /// Appends a new value to the list of values to a given key.
175    /// If the key did not exist, it will be inserted with provided value.
176    ///
177    /// # Examples
178    ///
179    /// ```
180    /// use async_nats::HeaderMap;
181    ///
182    /// let mut headers = HeaderMap::new();
183    /// headers.append("Key", "Value");
184    /// headers.append("Key", "Another");
185    pub fn append<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
186        let key = name.into_header_name();
187        let v = self.inner.get_mut(&key);
188        match v {
189            Some(v) => {
190                v.push(value.into_header_value());
191            }
192            None => {
193                self.insert(key, value.into_header_value());
194            }
195        }
196    }
197
198    /// Gets a value for a given key. If key is not found, [Option::None] is returned.
199    ///
200    /// # Examples
201    ///
202    /// ```
203    /// # use async_nats::HeaderMap;
204    ///
205    /// let mut headers = HeaderMap::new();
206    /// headers.append("Key", "Value");
207    /// let values = headers.get("Key").unwrap();
208    /// ```
209    pub fn get<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
210        self.inner
211            .get(&key.into_header_name())
212            .and_then(|x| x.first())
213    }
214
215    /// Gets a last value for a given key. If key is not found, [Option::None] is returned.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// # use async_nats::HeaderMap;
221    ///
222    /// let mut headers = HeaderMap::new();
223    /// headers.append("Key", "Value");
224    /// let values = headers.get_last("Key").unwrap();
225    /// ```
226    pub fn get_last<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
227        self.inner
228            .get(&key.into_header_name())
229            .and_then(|x| x.last())
230    }
231
232    /// Gets an iterator to the values for a given key.
233    ///
234    /// # Examples
235    ///
236    /// ```
237    /// # use async_nats::HeaderMap;
238    ///
239    /// let mut headers = HeaderMap::new();
240    /// headers.append("Key", "Value1");
241    /// headers.append("Key", "Value2");
242    /// let mut values = headers.get_all("Key");
243    /// let value1 = values.next();
244    /// let value2 = values.next();
245    /// ```
246    pub fn get_all<K: IntoHeaderName>(&self, key: K) -> GetAll<'_, HeaderValue> {
247        let inner = self
248            .inner
249            .get(&key.into_header_name())
250            .map(|x| x.iter())
251            .unwrap_or([].iter());
252
253        GetAll { inner }
254    }
255
256    pub(crate) fn to_bytes(&self) -> Vec<u8> {
257        let mut buf = vec![];
258        buf.extend_from_slice(b"NATS/1.0\r\n");
259        for (k, vs) in &self.inner {
260            for v in vs.iter() {
261                buf.extend_from_slice(k.as_str().as_bytes());
262                buf.extend_from_slice(b": ");
263                buf.extend_from_slice(v.inner.as_bytes());
264                buf.extend_from_slice(b"\r\n");
265            }
266        }
267        buf.extend_from_slice(b"\r\n");
268        buf
269    }
270}
271
272/// Represents NATS header field value.
273///
274/// # Examples
275///
276/// ```
277/// # use async_nats::HeaderMap;
278///
279/// let mut headers = HeaderMap::new();
280/// headers.insert("Key", "Value");
281/// headers.insert("Another", "AnotherValue");
282/// ```
283#[derive(Clone, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
284#[serde(transparent)]
285pub struct HeaderValue {
286    inner: String,
287}
288
289impl fmt::Display for HeaderValue {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        fmt::Display::fmt(&self.as_str(), f)
292    }
293}
294
295impl AsRef<[u8]> for HeaderValue {
296    fn as_ref(&self) -> &[u8] {
297        self.inner.as_ref()
298    }
299}
300
301impl AsRef<str> for HeaderValue {
302    fn as_ref(&self) -> &str {
303        self.as_str()
304    }
305}
306
307impl From<i16> for HeaderValue {
308    fn from(v: i16) -> Self {
309        Self {
310            inner: v.to_string(),
311        }
312    }
313}
314
315impl From<i32> for HeaderValue {
316    fn from(v: i32) -> Self {
317        Self {
318            inner: v.to_string(),
319        }
320    }
321}
322
323impl From<i64> for HeaderValue {
324    fn from(v: i64) -> Self {
325        Self {
326            inner: v.to_string(),
327        }
328    }
329}
330
331impl From<isize> for HeaderValue {
332    fn from(v: isize) -> Self {
333        Self {
334            inner: v.to_string(),
335        }
336    }
337}
338
339impl From<u16> for HeaderValue {
340    fn from(v: u16) -> Self {
341        Self {
342            inner: v.to_string(),
343        }
344    }
345}
346
347impl From<u32> for HeaderValue {
348    fn from(v: u32) -> Self {
349        Self {
350            inner: v.to_string(),
351        }
352    }
353}
354
355impl From<u64> for HeaderValue {
356    fn from(v: u64) -> Self {
357        Self {
358            inner: v.to_string(),
359        }
360    }
361}
362
363impl From<usize> for HeaderValue {
364    fn from(v: usize) -> Self {
365        Self {
366            inner: v.to_string(),
367        }
368    }
369}
370
371impl FromStr for HeaderValue {
372    type Err = ParseHeaderValueError;
373
374    fn from_str(s: &str) -> Result<Self, Self::Err> {
375        if s.contains(['\r', '\n']) {
376            return Err(ParseHeaderValueError);
377        }
378
379        Ok(HeaderValue {
380            inner: s.to_string(),
381        })
382    }
383}
384
385impl From<&str> for HeaderValue {
386    fn from(v: &str) -> Self {
387        assert!(
388            !v.contains(['\r', '\n']),
389            "invalid header value: cannot contain '\\r' or '\\n'"
390        );
391        Self {
392            inner: v.to_string(),
393        }
394    }
395}
396
397impl From<String> for HeaderValue {
398    fn from(inner: String) -> Self {
399        assert!(
400            !inner.contains(['\r', '\n']),
401            "invalid header value: cannot contain '\\r' or '\\n'"
402        );
403        Self { inner }
404    }
405}
406
407impl HeaderValue {
408    pub fn new() -> Self {
409        HeaderValue::default()
410    }
411
412    pub fn as_str(&self) -> &str {
413        self.inner.as_str()
414    }
415}
416
417#[derive(Debug, Clone)]
418pub struct ParseHeaderValueError;
419
420impl fmt::Display for ParseHeaderValueError {
421    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422        write!(
423            f,
424            r#"invalid character found in header value (value cannot contain '\r' or '\n')"#
425        )
426    }
427}
428
429impl std::error::Error for ParseHeaderValueError {}
430
431pub trait IntoHeaderName {
432    fn into_header_name(self) -> HeaderName;
433}
434
435impl IntoHeaderName for &str {
436    fn into_header_name(self) -> HeaderName {
437        assert!(
438            !self.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126),
439            "invalid header name: cannot contain control characters, non-ASCII, or ':'"
440        );
441        match StandardHeader::from_bytes(self.as_bytes()) {
442            Some(v) => HeaderName {
443                inner: HeaderRepr::Standard(v),
444            },
445            None => HeaderName {
446                inner: HeaderRepr::Custom(self.into()),
447            },
448        }
449    }
450}
451
452impl IntoHeaderName for String {
453    fn into_header_name(self) -> HeaderName {
454        assert!(
455            !self.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126),
456            "invalid header name: cannot contain control characters, non-ASCII, or ':'"
457        );
458        match StandardHeader::from_bytes(self.as_bytes()) {
459            Some(v) => HeaderName {
460                inner: HeaderRepr::Standard(v),
461            },
462            None => HeaderName {
463                inner: HeaderRepr::Custom(self.into()),
464            },
465        }
466    }
467}
468
469impl IntoHeaderName for HeaderName {
470    fn into_header_name(self) -> HeaderName {
471        self
472    }
473}
474
475pub trait IntoHeaderValue {
476    fn into_header_value(self) -> HeaderValue;
477}
478
479impl IntoHeaderValue for &str {
480    fn into_header_value(self) -> HeaderValue {
481        HeaderValue::from(self)
482    }
483}
484
485impl IntoHeaderValue for String {
486    fn into_header_value(self) -> HeaderValue {
487        HeaderValue::from(self)
488    }
489}
490
491impl IntoHeaderValue for HeaderValue {
492    fn into_header_value(self) -> HeaderValue {
493        self
494    }
495}
496
497macro_rules! standard_headers {
498    (
499        $(
500            $(#[$docs:meta])*
501            ($variant:ident, $constant:ident, $bytes:literal);
502        )+
503    ) => {
504        #[allow(clippy::enum_variant_names)]
505        #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
506        enum StandardHeader {
507            $(
508                $variant,
509            )+
510        }
511
512        $(
513            $(#[$docs])*
514            pub const $constant: HeaderName = HeaderName {
515                inner: HeaderRepr::Standard(StandardHeader::$variant),
516            };
517        )+
518
519        impl StandardHeader {
520            #[inline]
521            fn as_str(&self) -> &'static str {
522                match *self {
523                    $(
524                    StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) },
525                    )+
526                }
527            }
528
529            const fn from_bytes(bytes: &[u8]) -> Option<StandardHeader> {
530                match bytes {
531                    $(
532                        $bytes => Some(StandardHeader::$variant),
533                    )+
534                    _ => None,
535                }
536            }
537        }
538
539        #[cfg(test)]
540        mod standard_header_tests {
541            use super::HeaderName;
542            use std::str::{self, FromStr};
543
544            const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[
545                $(
546                (&super::$constant, $bytes),
547                )+
548            ];
549
550            #[test]
551            fn from_str() {
552                for &(header, bytes) in TEST_HEADERS {
553                    let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8");
554                    assert_eq!(HeaderName::from_str(utf8).unwrap(), *header);
555                }
556            }
557        }
558    }
559}
560
561// Generate constants for all standard NATS headers.
562standard_headers! {
563    /// The name of the stream the message belongs to.
564    (NatsStream, NATS_STREAM, b"Nats-Stream");
565    /// The sequence number of the message within the stream.
566    (NatsSequence, NATS_SEQUENCE, b"Nats-Sequence");
567    /// The timestamp of when the message was sent.
568    (NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp");
569    /// The subject of the message, used for routing and filtering messages.
570    (NatsSubject, NATS_SUBJECT, b"Nats-Subject");
571    /// A unique identifier for the message.
572    (NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id");
573    /// The last known stream the message was part of.
574    (NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream");
575    /// The last known consumer that processed the message.
576    (NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer");
577    /// The last known sequence number of the message.
578    (NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence");
579    /// The expected last sequence number of the subject.
580    (NatsExpectedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence");
581    /// The expected last message ID within the stream.
582    (NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id");
583    /// The expected last sequence number within the stream.
584    (NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence");
585    /// The expected stream the message should be part of.
586    (NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream");
587    /// Sets the TTL for a single message.
588    (NatsMessageTtl, NATS_MESSAGE_TTL, b"Nats-TTL");
589    /// Reason why the delete marked on a stream with enabled markers was put.
590    (NatsMarkerReason, NATS_MARKER_REASON, b"Nats-Marker-Reason");
591    /// Initiates a rollup of the given subject(s); valid values are `sub` and `all`.
592    (NatsRollup, NATS_ROLLUP, b"Nats-Rollup");
593    /// Schedule expression for a JetStream message scheduler entry. ADR-51.
594    (NatsSchedule, NATS_SCHEDULE, b"Nats-Schedule");
595    /// Target subject the schedule publishes to.
596    (NatsScheduleTarget, NATS_SCHEDULE_TARGET, b"Nats-Schedule-Target");
597    /// TTL applied to messages produced by the schedule.
598    (NatsScheduleTtl, NATS_SCHEDULE_TTL, b"Nats-Schedule-TTL");
599    /// Source subject sampled into the schedule output.
600    (NatsScheduleSource, NATS_SCHEDULE_SOURCE, b"Nats-Schedule-Source");
601    /// Time zone for cron schedules. Accepts IANA names like `America/New_York`.
602    (NatsScheduleTimeZone, NATS_SCHEDULE_TIME_ZONE, b"Nats-Schedule-Time-Zone");
603    /// Auto-applies a rollup on the schedule target. Currently only `sub` is valid.
604    (NatsScheduleRollup, NATS_SCHEDULE_ROLLUP, b"Nats-Schedule-Rollup");
605    /// On schedule-produced messages: the subject of the originating schedule.
606    (NatsScheduler, NATS_SCHEDULER, b"Nats-Scheduler");
607    /// On schedule-produced messages: timestamp of next firing or `purge` for delayed schedules.
608    (NatsScheduleNext, NATS_SCHEDULE_NEXT, b"Nats-Schedule-Next");
609    /// Atomic batch publish: batch id (max 64 chars).
610    (NatsBatchId, NATS_BATCH_ID, b"Nats-Batch-Id");
611    /// Atomic batch publish: per-message sequence within the batch.
612    (NatsBatchSequence, NATS_BATCH_SEQUENCE, b"Nats-Batch-Sequence");
613    /// Atomic batch publish: commit marker. `1` to commit and store the final
614    /// message; `eob` to commit without storing it (end-of-batch).
615    (NatsBatchCommit, NATS_BATCH_COMMIT, b"Nats-Batch-Commit");
616    /// Minimum JetStream API level the publishing client requires; the server
617    /// will reject the message (and the enclosing batch, if any) when its own
618    /// level is below the value set here.
619    (NatsRequiredApiLevel, NATS_REQUIRED_API_LEVEL, b"Nats-Required-Api-Level");
620}
621
622/// Value constant for [`NATS_BATCH_COMMIT`]: commit and store the final message.
623pub const NATS_BATCH_COMMIT_FINAL: &str = "1";
624/// Value constant for [`NATS_BATCH_COMMIT`]: commit without storing the final
625/// message (end-of-batch). Server is case-sensitive on this string.
626pub const NATS_BATCH_COMMIT_EOB: &str = "eob";
627/// Value constant for [`NATS_SCHEDULE_ROLLUP`]: rollup the schedule's target
628/// subject. Currently the only legal value.
629pub const NATS_SCHEDULE_ROLLUP_SUB: &str = "sub";
630
631/// Predefined [`NATS_SCHEDULE`] expression: run once a year at midnight Jan 1.
632pub const NATS_SCHEDULE_YEARLY: &str = "@yearly";
633/// Predefined [`NATS_SCHEDULE`] expression: run once a month at midnight on the 1st.
634pub const NATS_SCHEDULE_MONTHLY: &str = "@monthly";
635/// Predefined [`NATS_SCHEDULE`] expression: run once a week at midnight Sat→Sun.
636pub const NATS_SCHEDULE_WEEKLY: &str = "@weekly";
637/// Predefined [`NATS_SCHEDULE`] expression: run once a day at midnight.
638pub const NATS_SCHEDULE_DAILY: &str = "@daily";
639/// Predefined [`NATS_SCHEDULE`] expression: run once an hour at the top of the hour.
640pub const NATS_SCHEDULE_HOURLY: &str = "@hourly";
641
642#[derive(Debug, Hash, PartialEq, Eq, Clone)]
643struct CustomHeader {
644    bytes: Bytes,
645}
646
647impl CustomHeader {
648    #[inline]
649    pub(crate) const fn from_static(value: &'static str) -> CustomHeader {
650        CustomHeader {
651            bytes: Bytes::from_static(value.as_bytes()),
652        }
653    }
654
655    #[inline]
656    pub(crate) fn as_str(&self) -> &str {
657        unsafe { std::str::from_utf8_unchecked(self.bytes.as_ref()) }
658    }
659}
660
661impl From<String> for CustomHeader {
662    #[inline]
663    fn from(value: String) -> CustomHeader {
664        CustomHeader {
665            bytes: Bytes::from(value),
666        }
667    }
668}
669
670impl<'a> From<&'a str> for CustomHeader {
671    #[inline]
672    fn from(value: &'a str) -> CustomHeader {
673        CustomHeader {
674            bytes: Bytes::copy_from_slice(value.as_bytes()),
675        }
676    }
677}
678
679#[derive(Debug, Hash, PartialEq, Eq, Clone)]
680enum HeaderRepr {
681    Standard(StandardHeader),
682    Custom(CustomHeader),
683}
684
685/// Defines a NATS header field name
686///
687/// Header field names identify the header. Header sets may include multiple
688/// headers with the same name.
689///
690/// # Representation
691///
692/// `HeaderName` represents standard header names using an `enum`, as such they
693/// will not require an allocation for storage.
694#[derive(Clone, PartialEq, Eq, Hash, Debug)]
695pub struct HeaderName {
696    inner: HeaderRepr,
697}
698
699impl HeaderName {
700    /// Converts a static string to a NATS header name.
701    #[inline]
702    pub const fn from_static(value: &'static str) -> HeaderName {
703        if let Some(standard) = StandardHeader::from_bytes(value.as_bytes()) {
704            return HeaderName {
705                inner: HeaderRepr::Standard(standard),
706            };
707        }
708
709        HeaderName {
710            inner: HeaderRepr::Custom(CustomHeader::from_static(value)),
711        }
712    }
713
714    /// Returns a `str` representation of the header.
715    #[inline]
716    fn as_str(&self) -> &str {
717        match self.inner {
718            HeaderRepr::Standard(v) => v.as_str(),
719            HeaderRepr::Custom(ref v) => v.as_str(),
720        }
721    }
722}
723
724impl FromStr for HeaderName {
725    type Err = ParseHeaderNameError;
726
727    fn from_str(s: &str) -> Result<Self, Self::Err> {
728        if s.contains(|c: char| c == ':' || (c as u8) < 33 || (c as u8) > 126) {
729            return Err(ParseHeaderNameError);
730        }
731
732        match StandardHeader::from_bytes(s.as_ref()) {
733            Some(v) => Ok(HeaderName {
734                inner: HeaderRepr::Standard(v),
735            }),
736            None => Ok(HeaderName {
737                inner: HeaderRepr::Custom(CustomHeader::from(s)),
738            }),
739        }
740    }
741}
742
743impl fmt::Display for HeaderName {
744    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
745        fmt::Display::fmt(&self.as_str(), f)
746    }
747}
748
749impl AsRef<[u8]> for HeaderName {
750    fn as_ref(&self) -> &[u8] {
751        self.as_str().as_bytes()
752    }
753}
754
755impl AsRef<str> for HeaderName {
756    fn as_ref(&self) -> &str {
757        self.as_str()
758    }
759}
760
761impl Serialize for HeaderName {
762    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
763    where
764        S: serde::Serializer,
765    {
766        serializer.serialize_str(self.as_str())
767    }
768}
769
770impl<'de> Deserialize<'de> for HeaderName {
771    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
772    where
773        D: serde::Deserializer<'de>,
774    {
775        String::deserialize(deserializer)?
776            .parse()
777            .map_err(serde::de::Error::custom)
778    }
779}
780
781#[derive(Debug, Clone)]
782pub struct ParseHeaderNameError;
783
784impl std::fmt::Display for ParseHeaderNameError {
785    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786        write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')")
787    }
788}
789
790impl std::error::Error for ParseHeaderNameError {}
791
792#[cfg(test)]
793mod tests {
794    use super::{HeaderMap, HeaderName, HeaderValue, IntoHeaderName, IntoHeaderValue};
795    use std::str::{from_utf8, FromStr};
796
797    #[test]
798    fn try_from() {
799        let mut headers = HeaderMap::new();
800        headers.insert("name", "something".parse::<HeaderValue>().unwrap());
801        headers.insert("name", "something2");
802    }
803
804    #[test]
805    fn append() {
806        let mut headers = HeaderMap::new();
807        headers.append("Key", "value");
808        headers.append("Key", "second_value");
809
810        let mut result = headers.get_all("Key");
811
812        assert_eq!(
813            result.next().unwrap(),
814            &HeaderValue::from_str("value").unwrap()
815        );
816
817        assert_eq!(
818            result.next().unwrap(),
819            &HeaderValue::from_str("second_value").unwrap()
820        );
821
822        assert_eq!(result.next(), None);
823    }
824
825    #[test]
826    fn get_string() {
827        let mut headers = HeaderMap::new();
828        headers.append("Key", "value");
829        headers.append("Key", "other");
830
831        assert_eq!(headers.get("Key").unwrap().to_string(), "value");
832
833        let key: String = headers.get("Key").unwrap().as_str().into();
834        assert_eq!(key, "value".to_string());
835
836        let key: String = headers.get("Key").unwrap().as_str().to_owned();
837        assert_eq!(key, "value".to_string());
838
839        assert_eq!(headers.get("Key").unwrap().as_str(), "value");
840
841        let key: String = headers.get_last("Key").unwrap().as_str().into();
842        assert_eq!(key, "other".to_string());
843    }
844
845    #[test]
846    fn insert() {
847        let mut headers = HeaderMap::new();
848        headers.insert("Key", "Value");
849
850        let mut result = headers.get_all("Key");
851
852        assert_eq!(
853            result.next().unwrap(),
854            &HeaderValue::from_str("Value").unwrap()
855        );
856        assert_eq!(result.next(), None);
857    }
858
859    #[test]
860    fn serialize() {
861        let mut headers = HeaderMap::new();
862        headers.append("Key", "value");
863        headers.append("Key", "second_value");
864        headers.insert("Second", "SecondValue");
865
866        let bytes = headers.to_bytes();
867
868        println!("bytes: {:?}", from_utf8(&bytes));
869    }
870
871    #[test]
872    fn is_empty() {
873        let mut headers = HeaderMap::new();
874        assert!(headers.is_empty());
875
876        headers.append("Key", "value");
877        headers.append("Key", "second_value");
878        headers.insert("Second", "SecondValue");
879        assert!(!headers.is_empty());
880    }
881
882    #[test]
883    fn parse_value() {
884        assert!("Foo\r".parse::<HeaderValue>().is_err());
885        assert!("Foo\n".parse::<HeaderValue>().is_err());
886        assert!("Foo\r\n".parse::<HeaderValue>().is_err());
887    }
888
889    #[test]
890    fn valid_header_name() {
891        let valid_header_name = "X-Custom-Header";
892        let parsed_header = HeaderName::from_str(valid_header_name);
893
894        assert!(
895            parsed_header.is_ok(),
896            "Expected Ok(HeaderName), but got an error: {:?}",
897            parsed_header.err()
898        );
899    }
900
901    #[test]
902    fn dollar_header_name() {
903        let valid_header_name = "$X_Custom_Header";
904        let parsed_header = HeaderName::from_str(valid_header_name);
905
906        assert!(
907            parsed_header.is_ok(),
908            "Expected Ok(HeaderName), but got an error: {:?}",
909            parsed_header.err()
910        );
911    }
912
913    #[test]
914    fn invalid_header_name_with_space() {
915        let invalid_header_name = "X Custom Header";
916        let parsed_header = HeaderName::from_str(invalid_header_name);
917
918        assert!(
919            parsed_header.is_err(),
920            "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
921            parsed_header.ok()
922        );
923    }
924
925    #[test]
926    fn invalid_header_name_with_special_chars() {
927        let invalid_header_name = "X-Header:";
928        let parsed_header = HeaderName::from_str(invalid_header_name);
929
930        assert!(
931            parsed_header.is_err(),
932            "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
933            parsed_header.ok()
934        );
935    }
936
937    #[test]
938    fn from_static_eq() {
939        let a = HeaderName::from_static("NATS-Stream");
940        let b = HeaderName::from_static("NATS-Stream");
941
942        assert_eq!(a, b);
943    }
944
945    #[test]
946    fn header_name_serde() {
947        let raw = "Nats-Stream";
948        let raw_json = "\"Nats-Stream\"";
949        let header = HeaderName::from_static(raw);
950
951        // ser/de of HeaderName should be the same as raw string
952        assert_eq!(serde_json::to_string(&header).unwrap(), raw_json);
953        assert_eq!(
954            serde_json::from_str::<HeaderName>(raw_json).unwrap(),
955            header
956        );
957    }
958
959    #[test]
960    fn header_name_from_string() {
961        let string = "NATS-Stream".to_string();
962        let name = string.into_header_name();
963
964        assert_eq!("NATS-Stream", name.as_str());
965    }
966
967    #[test]
968    fn header_value_from_string_with_trait() {
969        let string = "some value".to_string();
970
971        let value = string.into_header_value();
972
973        assert_eq!("some value", value.as_str());
974    }
975
976    #[test]
977    fn header_value_from_string() {
978        let string = "some value".to_string();
979
980        let value: HeaderValue = string.into();
981
982        assert_eq!("some value", value.as_str());
983    }
984
985    #[test]
986    fn header_map_backward_compatible_deserialization() {
987        // Test new format (direct HashMap) - this is how it should serialize now
988        let new_format_json =
989            r#"{"Content-Type": ["application/json"], "Authorization": ["Bearer token"]}"#;
990        let header_map: HeaderMap = serde_json::from_str(new_format_json).unwrap();
991
992        assert_eq!(
993            header_map.get("Content-Type").unwrap().as_str(),
994            "application/json"
995        );
996        assert_eq!(
997            header_map.get("Authorization").unwrap().as_str(),
998            "Bearer token"
999        );
1000
1001        // Test legacy format (with "inner" wrapper) - this is the old format that should still work
1002        let legacy_format_json = r#"{"inner": {"Content-Type": ["application/json"], "Authorization": ["Bearer token"]}}"#;
1003        let header_map_legacy: HeaderMap = serde_json::from_str(legacy_format_json).unwrap();
1004
1005        assert_eq!(
1006            header_map_legacy.get("Content-Type").unwrap().as_str(),
1007            "application/json"
1008        );
1009        assert_eq!(
1010            header_map_legacy.get("Authorization").unwrap().as_str(),
1011            "Bearer token"
1012        );
1013
1014        // Both should be equal after deserialization
1015        assert_eq!(header_map, header_map_legacy);
1016    }
1017
1018    #[test]
1019    fn header_map_serialization_new_format() {
1020        // Test that serialization uses the new format (no "inner" wrapper)
1021        let mut headers = HeaderMap::new();
1022        headers.insert("Content-Type", "application/json");
1023        headers.insert("Authorization", "Bearer token");
1024
1025        let serialized = serde_json::to_string(&headers).unwrap();
1026
1027        // Should not contain "inner" key
1028        assert!(!serialized.contains("inner"));
1029
1030        // Should be able to deserialize back
1031        let deserialized: HeaderMap = serde_json::from_str(&serialized).unwrap();
1032        assert_eq!(headers, deserialized);
1033    }
1034
1035    #[test]
1036    fn header_map_roundtrip_compatibility() {
1037        // Test that we can roundtrip both formats
1038        let mut original = HeaderMap::new();
1039        original.insert("X-Custom-Header", "custom-value");
1040        original.append("Multi-Value", "value1");
1041        original.append("Multi-Value", "value2");
1042
1043        // Serialize using new format
1044        let new_serialized = serde_json::to_string(&original).unwrap();
1045        let new_deserialized: HeaderMap = serde_json::from_str(&new_serialized).unwrap();
1046        assert_eq!(original, new_deserialized);
1047
1048        // Manually create legacy format JSON
1049        let legacy_json = format!(r#"{{"inner": {}}}"#, new_serialized);
1050        let legacy_deserialized: HeaderMap = serde_json::from_str(&legacy_json).unwrap();
1051        assert_eq!(original, legacy_deserialized);
1052    }
1053
1054    #[test]
1055    fn header_map_invalid_format_error() {
1056        // Test that invalid JSON returns proper error
1057        let invalid_json = r#"{"not_inner_or_direct": {"Content-Type": ["application/json"]}}"#;
1058        let result = serde_json::from_str::<HeaderMap>(invalid_json);
1059        assert!(result.is_err());
1060
1061        let error_message = result.unwrap_err().to_string();
1062        // With untagged enum, serde will report that data doesn't match any variant
1063        assert!(error_message.contains("did not match any variant"));
1064    }
1065
1066    #[test]
1067    fn header_map_empty_cases() {
1068        // Test empty HeaderMap serialization/deserialization
1069        let empty = HeaderMap::new();
1070        let serialized = serde_json::to_string(&empty).unwrap();
1071        assert_eq!(serialized, "{}");
1072
1073        let deserialized: HeaderMap = serde_json::from_str("{}").unwrap();
1074        assert!(deserialized.is_empty());
1075
1076        // Test legacy empty format
1077        let legacy_empty = r#"{"inner": {}}"#;
1078        let legacy_deserialized: HeaderMap = serde_json::from_str(legacy_empty).unwrap();
1079        assert!(legacy_deserialized.is_empty());
1080    }
1081
1082    #[test]
1083    fn header_map_mixed_legacy_detection() {
1084        // Test that an "inner" header name doesn't confuse the deserializer
1085        // This would be a new format where "inner" is a legitimate header name
1086        let json_with_inner_header =
1087            r#"{"inner": ["some-value"], "Other-Header": ["other-value"]}"#;
1088        let header_map: HeaderMap = serde_json::from_str(json_with_inner_header).unwrap();
1089
1090        // Should have two headers
1091        assert_eq!(header_map.len(), 2);
1092        assert_eq!(header_map.get("inner").unwrap().as_str(), "some-value");
1093        assert_eq!(
1094            header_map.get("Other-Header").unwrap().as_str(),
1095            "other-value"
1096        );
1097    }
1098
1099    #[test]
1100    #[should_panic(expected = "invalid header value")]
1101    fn header_value_from_str_rejects_cr() {
1102        let _: HeaderValue = "value\rwith\rcr".into();
1103    }
1104
1105    #[test]
1106    #[should_panic(expected = "invalid header value")]
1107    fn header_value_from_str_rejects_lf() {
1108        let _: HeaderValue = "value\nwith\nlf".into();
1109    }
1110
1111    #[test]
1112    #[should_panic(expected = "invalid header value")]
1113    fn header_value_from_string_rejects_crlf() {
1114        let _: HeaderValue = "injected\r\nPUB attack 0\r\n\r\n".to_string().into();
1115    }
1116
1117    #[test]
1118    #[should_panic(expected = "invalid header value")]
1119    fn header_value_into_trait_rejects_crlf() {
1120        let mut headers = HeaderMap::new();
1121        headers.insert("Key", "value\r\nPUB attack 0\r\n\r\n");
1122    }
1123
1124    #[test]
1125    #[should_panic(expected = "invalid header name")]
1126    fn header_name_into_trait_rejects_cr() {
1127        let mut headers = HeaderMap::new();
1128        headers.insert("Bad\rName", "value");
1129    }
1130
1131    #[test]
1132    #[should_panic(expected = "invalid header name")]
1133    fn header_name_into_trait_rejects_lf() {
1134        let mut headers = HeaderMap::new();
1135        headers.insert("Bad\nName", "value");
1136    }
1137
1138    #[test]
1139    #[should_panic(expected = "invalid header name")]
1140    fn header_name_into_trait_rejects_space() {
1141        let mut headers = HeaderMap::new();
1142        headers.insert("Bad Name", "value");
1143    }
1144
1145    #[test]
1146    #[should_panic(expected = "invalid header name")]
1147    fn header_name_into_trait_rejects_colon() {
1148        let mut headers = HeaderMap::new();
1149        headers.insert("Bad:Name", "value");
1150    }
1151
1152    #[test]
1153    #[should_panic(expected = "invalid header name")]
1154    fn header_name_from_string_rejects_control_chars() {
1155        let name = "Bad\x00Name".to_string();
1156        name.into_header_name();
1157    }
1158
1159    #[test]
1160    fn valid_header_values_still_work() {
1161        let _: HeaderValue = "normal value".into();
1162        let _: HeaderValue = "value with special chars !@#$%^&*()".into();
1163        let _: HeaderValue = "".into();
1164        let _: HeaderValue = String::from("string value").into();
1165    }
1166
1167    #[test]
1168    fn valid_header_names_still_work() {
1169        let mut headers = HeaderMap::new();
1170        headers.insert("X-Custom-Header", "value");
1171        headers.insert("Another-Header", "value");
1172        headers.insert("$dollar", "value");
1173        headers.insert("Nats-Stream", "value");
1174        assert_eq!(headers.get("Nats-Stream").unwrap().as_str(), "value");
1175    }
1176
1177    #[test]
1178    fn header_map_large_headers() {
1179        // Test with many headers
1180        let mut headers = HeaderMap::new();
1181        for i in 0..100 {
1182            headers.insert(format!("Header-{}", i), format!("Value-{}", i));
1183        }
1184
1185        let serialized = serde_json::to_string(&headers).unwrap();
1186        let deserialized: HeaderMap = serde_json::from_str(&serialized).unwrap();
1187
1188        assert_eq!(headers.len(), deserialized.len());
1189        for i in 0..100 {
1190            assert_eq!(
1191                deserialized
1192                    .get(format!("Header-{}", i).as_str())
1193                    .unwrap()
1194                    .as_str(),
1195                format!("Value-{}", i)
1196            );
1197        }
1198    }
1199}