Skip to main content

async_nats/
header.rs

1// Copyright 2020-2023 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14// NOTE(thomastaylor312): This clippy lint is coming from serialize and deserialize and is likely a
15// false positive due to the bytes crate, see
16// https://rust-lang.github.io/rust-clippy/master/index.html#/mutable_key_type for more details.
17// Sorry to make this global to this module, rather than on the `HeaderMap` struct, but because it
18// is coming from the derive, it didn't work to set it on the struct.
19#![allow(clippy::mutable_key_type)]
20
21//! NATS [Message][crate::Message] headers, modeled loosely after the `http::header` crate.
22
23use std::{collections::HashMap, fmt, slice::Iter, str::FromStr};
24
25use bytes::Bytes;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27
28/// A struct for handling NATS headers.
29/// Has a similar API to `http::header`, but properly serializes and deserializes
30/// according to NATS requirements.
31///
32/// # Examples
33///
34/// ```
35/// # #[tokio::main]
36/// # async fn main() -> Result<(), async_nats::Error> {
37/// let client = async_nats::connect("demo.nats.io").await?;
38/// let mut headers = async_nats::HeaderMap::new();
39/// headers.insert("Key", "Value");
40/// client
41///     .publish_with_headers("subject", headers, "payload".into())
42///     .await?;
43/// # Ok(())
44/// # }
45/// ```
46
47#[derive(Clone, PartialEq, Eq, Debug, Default)]
48pub struct HeaderMap {
49    inner: HashMap<HeaderName, Vec<HeaderValue>>,
50}
51
52/// Helper enum for backward-compatible deserialization using serde's untagged feature
53/// This is required because of the bug #1470 where the client incorrectly serialized
54/// headers with an "inner" wrapper.
55#[derive(Deserialize)]
56#[serde(untagged)]
57enum HeaderMapHelper {
58    // Legacy format with "inner" wrapper
59    Legacy {
60        inner: HashMap<HeaderName, Vec<HeaderValue>>,
61    },
62    // Proper format - direct HashMap
63    Current(HashMap<HeaderName, Vec<HeaderValue>>),
64}
65
66impl<'de> Deserialize<'de> for HeaderMap {
67    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
68    where
69        D: Deserializer<'de>,
70    {
71        // Use the untagged enum to automatically try both formats
72        let helper = HeaderMapHelper::deserialize(deserializer)?;
73
74        Ok(match helper {
75            HeaderMapHelper::Legacy { inner } => HeaderMap { inner },
76            HeaderMapHelper::Current(inner) => HeaderMap { inner },
77        })
78    }
79}
80
81impl Serialize for HeaderMap {
82    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: Serializer,
85    {
86        // Serialize as the new format (direct HashMap without "inner" wrapper)
87        self.inner.serialize(serializer)
88    }
89}
90
91impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
92    fn from_iter<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: T) -> Self {
93        let mut header_map = HeaderMap::new();
94        for (key, value) in iter {
95            header_map.insert(key, value);
96        }
97        header_map
98    }
99}
100
101impl HeaderMap {
102    pub fn iter(&self) -> std::collections::hash_map::Iter<'_, HeaderName, Vec<HeaderValue>> {
103        self.inner.iter()
104    }
105}
106
107pub struct GetAll<'a, T> {
108    inner: Iter<'a, T>,
109}
110
111impl<'a, T> Iterator for GetAll<'a, T> {
112    type Item = &'a T;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        self.inner.next()
116    }
117}
118
119impl HeaderMap {
120    /// Create an empty `HeaderMap`.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// # use async_nats::HeaderMap;
126    /// let map = HeaderMap::new();
127    ///
128    /// assert!(map.is_empty());
129    /// ```
130    pub fn new() -> Self {
131        HeaderMap::default()
132    }
133
134    /// Returns true if the map contains no elements.
135    ///
136    /// # Examples
137    ///
138    /// ```
139    /// # use async_nats::HeaderMap;
140    /// # use async_nats::header::NATS_SUBJECT;
141    /// let mut map = HeaderMap::new();
142    ///
143    /// assert!(map.is_empty());
144    ///
145    /// map.insert(NATS_SUBJECT, "FOO.BAR");
146    ///
147    /// assert!(!map.is_empty());
148    /// ```
149    pub fn is_empty(&self) -> bool {
150        self.inner.is_empty()
151    }
152
153    pub fn len(&self) -> usize {
154        self.inner.len()
155    }
156
157    /// Serialized byte length of this header block (as [`HeaderMap::to_bytes`]
158    /// produces), computed without allocating.
159    pub(crate) fn wire_len(&self) -> usize {
160        // Mirrors `to_bytes`: "NATS/1.0\r\n" + each "<key>: <value>\r\n" + "\r\n".
161        let mut len = b"NATS/1.0\r\n".len() + b"\r\n".len();
162        for (k, vs) in &self.inner {
163            for v in vs.iter() {
164                len += k.as_str().len() + b": ".len() + v.inner.len() + b"\r\n".len();
165            }
166        }
167        len
168    }
169}
170
171impl HeaderMap {
172    /// Inserts a new value to a [HeaderMap].
173    ///
174    /// # Examples
175    ///
176    /// ```
177    /// use async_nats::HeaderMap;
178    ///
179    /// let mut headers = HeaderMap::new();
180    /// headers.insert("Key", "Value");
181    /// ```
182    pub fn insert<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
183        self.inner
184            .insert(name.into_header_name(), vec![value.into_header_value()]);
185    }
186
187    /// Appends a new value to the list of values to a given key.
188    /// If the key did not exist, it will be inserted with provided value.
189    ///
190    /// # Examples
191    ///
192    /// ```
193    /// use async_nats::HeaderMap;
194    ///
195    /// let mut headers = HeaderMap::new();
196    /// headers.append("Key", "Value");
197    /// headers.append("Key", "Another");
198    pub fn append<K: IntoHeaderName, V: IntoHeaderValue>(&mut self, name: K, value: V) {
199        let key = name.into_header_name();
200        let v = self.inner.get_mut(&key);
201        match v {
202            Some(v) => {
203                v.push(value.into_header_value());
204            }
205            None => {
206                self.insert(key, value.into_header_value());
207            }
208        }
209    }
210
211    /// Gets a value for a given key. If key is not found, [Option::None] is returned.
212    ///
213    /// # Examples
214    ///
215    /// ```
216    /// # use async_nats::HeaderMap;
217    ///
218    /// let mut headers = HeaderMap::new();
219    /// headers.append("Key", "Value");
220    /// let values = headers.get("Key").unwrap();
221    /// ```
222    pub fn get<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
223        self.inner
224            .get(&key.into_header_name())
225            .and_then(|x| x.first())
226    }
227
228    /// Gets a last value for a given key. If key is not found, [Option::None] is returned.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// # use async_nats::HeaderMap;
234    ///
235    /// let mut headers = HeaderMap::new();
236    /// headers.append("Key", "Value");
237    /// let values = headers.get_last("Key").unwrap();
238    /// ```
239    pub fn get_last<K: IntoHeaderName>(&self, key: K) -> Option<&HeaderValue> {
240        self.inner
241            .get(&key.into_header_name())
242            .and_then(|x| x.last())
243    }
244
245    /// Gets an iterator to the values for a given key.
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// # use async_nats::HeaderMap;
251    ///
252    /// let mut headers = HeaderMap::new();
253    /// headers.append("Key", "Value1");
254    /// headers.append("Key", "Value2");
255    /// let mut values = headers.get_all("Key");
256    /// let value1 = values.next();
257    /// let value2 = values.next();
258    /// ```
259    pub fn get_all<K: IntoHeaderName>(&self, key: K) -> GetAll<'_, HeaderValue> {
260        let inner = self
261            .inner
262            .get(&key.into_header_name())
263            .map(|x| x.iter())
264            .unwrap_or([].iter());
265
266        GetAll { inner }
267    }
268
269    pub(crate) fn to_bytes(&self) -> Vec<u8> {
270        let mut buf = vec![];
271        buf.extend_from_slice(b"NATS/1.0\r\n");
272        for (k, vs) in &self.inner {
273            for v in vs.iter() {
274                buf.extend_from_slice(k.as_str().as_bytes());
275                buf.extend_from_slice(b": ");
276                buf.extend_from_slice(v.inner.as_bytes());
277                buf.extend_from_slice(b"\r\n");
278            }
279        }
280        buf.extend_from_slice(b"\r\n");
281        buf
282    }
283}
284
285/// Represents NATS header field value.
286///
287/// # Examples
288///
289/// ```
290/// # use async_nats::HeaderMap;
291///
292/// let mut headers = HeaderMap::new();
293/// headers.insert("Key", "Value");
294/// headers.insert("Another", "AnotherValue");
295/// ```
296#[derive(Clone, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
297#[serde(transparent)]
298pub struct HeaderValue {
299    inner: String,
300}
301
302impl fmt::Display for HeaderValue {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        fmt::Display::fmt(&self.as_str(), f)
305    }
306}
307
308impl AsRef<[u8]> for HeaderValue {
309    fn as_ref(&self) -> &[u8] {
310        self.inner.as_ref()
311    }
312}
313
314impl AsRef<str> for HeaderValue {
315    fn as_ref(&self) -> &str {
316        self.as_str()
317    }
318}
319
320impl From<i16> for HeaderValue {
321    fn from(v: i16) -> Self {
322        Self {
323            inner: v.to_string(),
324        }
325    }
326}
327
328impl From<i32> for HeaderValue {
329    fn from(v: i32) -> Self {
330        Self {
331            inner: v.to_string(),
332        }
333    }
334}
335
336impl From<i64> for HeaderValue {
337    fn from(v: i64) -> Self {
338        Self {
339            inner: v.to_string(),
340        }
341    }
342}
343
344impl From<isize> for HeaderValue {
345    fn from(v: isize) -> Self {
346        Self {
347            inner: v.to_string(),
348        }
349    }
350}
351
352impl From<u16> for HeaderValue {
353    fn from(v: u16) -> Self {
354        Self {
355            inner: v.to_string(),
356        }
357    }
358}
359
360impl From<u32> for HeaderValue {
361    fn from(v: u32) -> Self {
362        Self {
363            inner: v.to_string(),
364        }
365    }
366}
367
368impl From<u64> for HeaderValue {
369    fn from(v: u64) -> Self {
370        Self {
371            inner: v.to_string(),
372        }
373    }
374}
375
376impl From<usize> for HeaderValue {
377    fn from(v: usize) -> Self {
378        Self {
379            inner: v.to_string(),
380        }
381    }
382}
383
384impl FromStr for HeaderValue {
385    type Err = ParseHeaderValueError;
386
387    fn from_str(s: &str) -> Result<Self, Self::Err> {
388        validate_header_value(s)?;
389        Ok(HeaderValue {
390            inner: s.to_string(),
391        })
392    }
393}
394
395impl From<&str> for HeaderValue {
396    fn from(v: &str) -> Self {
397        assert!(
398            validate_header_value(v).is_ok(),
399            "invalid header value: cannot contain '\\r' or '\\n'"
400        );
401        Self {
402            inner: v.to_string(),
403        }
404    }
405}
406
407impl From<String> for HeaderValue {
408    fn from(inner: String) -> Self {
409        assert!(
410            validate_header_value(&inner).is_ok(),
411            "invalid header value: cannot contain '\\r' or '\\n'"
412        );
413        Self { inner }
414    }
415}
416
417impl HeaderValue {
418    pub fn new() -> Self {
419        HeaderValue::default()
420    }
421
422    pub fn as_str(&self) -> &str {
423        self.inner.as_str()
424    }
425}
426
427#[derive(Debug, Clone)]
428pub struct ParseHeaderValueError;
429
430impl fmt::Display for ParseHeaderValueError {
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        write!(
433            f,
434            r#"invalid character found in header value (value cannot contain '\r' or '\n')"#
435        )
436    }
437}
438
439impl std::error::Error for ParseHeaderValueError {}
440
441fn validate_header_value(s: &str) -> Result<(), ParseHeaderValueError> {
442    if s.contains(['\r', '\n']) {
443        Err(ParseHeaderValueError)
444    } else {
445        Ok(())
446    }
447}
448
449pub trait IntoHeaderName {
450    fn into_header_name(self) -> HeaderName;
451}
452
453impl IntoHeaderName for &str {
454    fn into_header_name(self) -> HeaderName {
455        assert!(
456            validate_header_name(self).is_ok(),
457            "invalid header name: cannot contain control characters, non-ASCII, or ':'"
458        );
459        match StandardHeader::from_bytes(self.as_bytes()) {
460            Some(v) => HeaderName {
461                inner: HeaderRepr::Standard(v),
462            },
463            None => HeaderName {
464                inner: HeaderRepr::Custom(self.into()),
465            },
466        }
467    }
468}
469
470impl IntoHeaderName for String {
471    fn into_header_name(self) -> HeaderName {
472        assert!(
473            validate_header_name(&self).is_ok(),
474            "invalid header name: cannot contain control characters, non-ASCII, or ':'"
475        );
476        match StandardHeader::from_bytes(self.as_bytes()) {
477            Some(v) => HeaderName {
478                inner: HeaderRepr::Standard(v),
479            },
480            None => HeaderName {
481                inner: HeaderRepr::Custom(self.into()),
482            },
483        }
484    }
485}
486
487impl IntoHeaderName for HeaderName {
488    fn into_header_name(self) -> HeaderName {
489        self
490    }
491}
492
493pub trait IntoHeaderValue {
494    fn into_header_value(self) -> HeaderValue;
495}
496
497impl IntoHeaderValue for &str {
498    fn into_header_value(self) -> HeaderValue {
499        HeaderValue::from(self)
500    }
501}
502
503impl IntoHeaderValue for String {
504    fn into_header_value(self) -> HeaderValue {
505        HeaderValue::from(self)
506    }
507}
508
509impl IntoHeaderValue for HeaderValue {
510    fn into_header_value(self) -> HeaderValue {
511        self
512    }
513}
514
515macro_rules! standard_headers {
516    (
517        $(
518            $(#[$docs:meta])*
519            ($variant:ident, $constant:ident, $bytes:literal);
520        )+
521    ) => {
522        #[allow(clippy::enum_variant_names)]
523        #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
524        enum StandardHeader {
525            $(
526                $variant,
527            )+
528        }
529
530        $(
531            $(#[$docs])*
532            pub const $constant: HeaderName = HeaderName {
533                inner: HeaderRepr::Standard(StandardHeader::$variant),
534            };
535        )+
536
537        impl StandardHeader {
538            #[inline]
539            fn as_str(&self) -> &'static str {
540                match *self {
541                    $(
542                    StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) },
543                    )+
544                }
545            }
546
547            const fn from_bytes(bytes: &[u8]) -> Option<StandardHeader> {
548                match bytes {
549                    $(
550                        $bytes => Some(StandardHeader::$variant),
551                    )+
552                    _ => None,
553                }
554            }
555        }
556
557        #[cfg(test)]
558        mod standard_header_tests {
559            use super::HeaderName;
560            use std::str::{self, FromStr};
561
562            const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[
563                $(
564                (&super::$constant, $bytes),
565                )+
566            ];
567
568            #[test]
569            fn from_str() {
570                for &(header, bytes) in TEST_HEADERS {
571                    let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8");
572                    assert_eq!(HeaderName::from_str(utf8).unwrap(), *header);
573                }
574            }
575        }
576    }
577}
578
579// Generate constants for all standard NATS headers.
580standard_headers! {
581    /// The name of the stream the message belongs to.
582    (NatsStream, NATS_STREAM, b"Nats-Stream");
583    /// The sequence number of the message within the stream.
584    (NatsSequence, NATS_SEQUENCE, b"Nats-Sequence");
585    /// The timestamp of when the message was sent.
586    (NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp");
587    /// The subject of the message, used for routing and filtering messages.
588    (NatsSubject, NATS_SUBJECT, b"Nats-Subject");
589    /// A unique identifier for the message.
590    (NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id");
591    /// The last known stream the message was part of.
592    (NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream");
593    /// The last known consumer that processed the message.
594    (NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer");
595    /// The last known sequence number of the message.
596    (NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence");
597    /// The expected last sequence number of the subject.
598    (NatsExpectedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence");
599    /// The expected last message ID within the stream.
600    (NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id");
601    /// The expected last sequence number within the stream.
602    (NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence");
603    /// The expected stream the message should be part of.
604    (NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream");
605    /// Sets the TTL for a single message.
606    (NatsMessageTtl, NATS_MESSAGE_TTL, b"Nats-TTL");
607    /// Reason why the delete marked on a stream with enabled markers was put.
608    (NatsMarkerReason, NATS_MARKER_REASON, b"Nats-Marker-Reason");
609    /// Initiates a rollup of the given subject(s); valid values are `sub` and `all`.
610    (NatsRollup, NATS_ROLLUP, b"Nats-Rollup");
611    /// Schedule expression for a JetStream message scheduler entry. ADR-51.
612    (NatsSchedule, NATS_SCHEDULE, b"Nats-Schedule");
613    /// Target subject the schedule publishes to.
614    (NatsScheduleTarget, NATS_SCHEDULE_TARGET, b"Nats-Schedule-Target");
615    /// TTL applied to messages produced by the schedule.
616    (NatsScheduleTtl, NATS_SCHEDULE_TTL, b"Nats-Schedule-TTL");
617    /// Source subject sampled into the schedule output.
618    (NatsScheduleSource, NATS_SCHEDULE_SOURCE, b"Nats-Schedule-Source");
619    /// Time zone for cron schedules. Accepts IANA names like `America/New_York`.
620    (NatsScheduleTimeZone, NATS_SCHEDULE_TIME_ZONE, b"Nats-Schedule-Time-Zone");
621    /// Auto-applies a rollup on the schedule target. Currently only `sub` is valid.
622    (NatsScheduleRollup, NATS_SCHEDULE_ROLLUP, b"Nats-Schedule-Rollup");
623    /// On schedule-produced messages: the subject of the originating schedule.
624    (NatsScheduler, NATS_SCHEDULER, b"Nats-Scheduler");
625    /// On schedule-produced messages: timestamp of next firing or `purge` for delayed schedules.
626    (NatsScheduleNext, NATS_SCHEDULE_NEXT, b"Nats-Schedule-Next");
627    /// Atomic batch publish: batch id (max 64 chars).
628    (NatsBatchId, NATS_BATCH_ID, b"Nats-Batch-Id");
629    /// Atomic batch publish: per-message sequence within the batch.
630    (NatsBatchSequence, NATS_BATCH_SEQUENCE, b"Nats-Batch-Sequence");
631    /// Atomic batch publish: commit marker. `1` to commit and store the final
632    /// message; `eob` to commit without storing it (end-of-batch).
633    (NatsBatchCommit, NATS_BATCH_COMMIT, b"Nats-Batch-Commit");
634    /// Minimum JetStream API level the publishing client requires; the server
635    /// will reject the message (and the enclosing batch, if any) when its own
636    /// level is below the value set here.
637    (NatsRequiredApiLevel, NATS_REQUIRED_API_LEVEL, b"Nats-Required-Api-Level");
638}
639
640/// Value constant for [`NATS_BATCH_COMMIT`]: commit and store the final message.
641pub const NATS_BATCH_COMMIT_FINAL: &str = "1";
642/// Value constant for [`NATS_BATCH_COMMIT`]: commit without storing the final
643/// message (end-of-batch). Server is case-sensitive on this string.
644pub const NATS_BATCH_COMMIT_EOB: &str = "eob";
645/// Value constant for [`NATS_SCHEDULE_ROLLUP`]: rollup the schedule's target
646/// subject. Currently the only legal value.
647pub const NATS_SCHEDULE_ROLLUP_SUB: &str = "sub";
648
649/// Predefined [`NATS_SCHEDULE`] expression: run once a year at midnight Jan 1.
650pub const NATS_SCHEDULE_YEARLY: &str = "@yearly";
651/// Predefined [`NATS_SCHEDULE`] expression: run once a month at midnight on the 1st.
652pub const NATS_SCHEDULE_MONTHLY: &str = "@monthly";
653/// Predefined [`NATS_SCHEDULE`] expression: run once a week at midnight Sat→Sun.
654pub const NATS_SCHEDULE_WEEKLY: &str = "@weekly";
655/// Predefined [`NATS_SCHEDULE`] expression: run once a day at midnight.
656pub const NATS_SCHEDULE_DAILY: &str = "@daily";
657/// Predefined [`NATS_SCHEDULE`] expression: run once an hour at the top of the hour.
658pub const NATS_SCHEDULE_HOURLY: &str = "@hourly";
659
660#[derive(Debug, Hash, PartialEq, Eq, Clone)]
661struct CustomHeader {
662    bytes: Bytes,
663}
664
665impl CustomHeader {
666    #[inline]
667    pub(crate) const fn from_static(value: &'static str) -> CustomHeader {
668        CustomHeader {
669            bytes: Bytes::from_static(value.as_bytes()),
670        }
671    }
672
673    #[inline]
674    pub(crate) fn as_str(&self) -> &str {
675        unsafe { std::str::from_utf8_unchecked(self.bytes.as_ref()) }
676    }
677}
678
679impl From<String> for CustomHeader {
680    #[inline]
681    fn from(value: String) -> CustomHeader {
682        CustomHeader {
683            bytes: Bytes::from(value),
684        }
685    }
686}
687
688impl<'a> From<&'a str> for CustomHeader {
689    #[inline]
690    fn from(value: &'a str) -> CustomHeader {
691        CustomHeader {
692            bytes: Bytes::copy_from_slice(value.as_bytes()),
693        }
694    }
695}
696
697#[derive(Debug, Hash, PartialEq, Eq, Clone)]
698enum HeaderRepr {
699    Standard(StandardHeader),
700    Custom(CustomHeader),
701}
702
703/// Defines a NATS header field name
704///
705/// Header field names identify the header. Header sets may include multiple
706/// headers with the same name.
707///
708/// # Representation
709///
710/// `HeaderName` represents standard header names using an `enum`, as such they
711/// will not require an allocation for storage.
712#[derive(Clone, PartialEq, Eq, Hash, Debug)]
713pub struct HeaderName {
714    inner: HeaderRepr,
715}
716
717impl HeaderName {
718    /// Converts a static string to a NATS header name.
719    #[inline]
720    pub const fn from_static(value: &'static str) -> HeaderName {
721        if let Some(standard) = StandardHeader::from_bytes(value.as_bytes()) {
722            return HeaderName {
723                inner: HeaderRepr::Standard(standard),
724            };
725        }
726
727        HeaderName {
728            inner: HeaderRepr::Custom(CustomHeader::from_static(value)),
729        }
730    }
731
732    /// Returns a `str` representation of the header.
733    #[inline]
734    fn as_str(&self) -> &str {
735        match self.inner {
736            HeaderRepr::Standard(v) => v.as_str(),
737            HeaderRepr::Custom(ref v) => v.as_str(),
738        }
739    }
740}
741
742impl FromStr for HeaderName {
743    type Err = ParseHeaderNameError;
744
745    fn from_str(s: &str) -> Result<Self, Self::Err> {
746        validate_header_name(s)?;
747        Ok(match StandardHeader::from_bytes(s.as_bytes()) {
748            Some(v) => HeaderName {
749                inner: HeaderRepr::Standard(v),
750            },
751            None => HeaderName {
752                inner: HeaderRepr::Custom(CustomHeader::from(s)),
753            },
754        })
755    }
756}
757
758/// Fallible conversion from a borrowed string slice; validates the name without panicking.
759impl TryFrom<&str> for HeaderName {
760    type Error = ParseHeaderNameError;
761
762    fn try_from(value: &str) -> Result<Self, Self::Error> {
763        value.parse()
764    }
765}
766
767/// Fallible conversion from an owned `String`; validates without panicking and
768/// avoids re-allocating the bytes when the result is a custom (non-standard) name.
769impl TryFrom<String> for HeaderName {
770    type Error = ParseHeaderNameError;
771
772    fn try_from(value: String) -> Result<Self, Self::Error> {
773        validate_header_name(&value)?;
774        Ok(match StandardHeader::from_bytes(value.as_bytes()) {
775            Some(v) => HeaderName {
776                inner: HeaderRepr::Standard(v),
777            },
778            None => HeaderName {
779                inner: HeaderRepr::Custom(CustomHeader::from(value)),
780            },
781        })
782    }
783}
784
785impl fmt::Display for HeaderName {
786    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
787        fmt::Display::fmt(&self.as_str(), f)
788    }
789}
790
791impl AsRef<[u8]> for HeaderName {
792    fn as_ref(&self) -> &[u8] {
793        self.as_str().as_bytes()
794    }
795}
796
797impl AsRef<str> for HeaderName {
798    fn as_ref(&self) -> &str {
799        self.as_str()
800    }
801}
802
803impl Serialize for HeaderName {
804    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
805    where
806        S: serde::Serializer,
807    {
808        serializer.serialize_str(self.as_str())
809    }
810}
811
812impl<'de> Deserialize<'de> for HeaderName {
813    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
814    where
815        D: serde::Deserializer<'de>,
816    {
817        String::deserialize(deserializer)?
818            .parse()
819            .map_err(serde::de::Error::custom)
820    }
821}
822
823#[derive(Debug, Clone)]
824pub struct ParseHeaderNameError;
825
826impl std::fmt::Display for ParseHeaderNameError {
827    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
828        write!(f, "invalid header name (name cannot contain non-ascii alphanumeric characters other than '-')")
829    }
830}
831
832impl std::error::Error for ParseHeaderNameError {}
833
834fn validate_header_name(s: &str) -> Result<(), ParseHeaderNameError> {
835    if s.contains(|c: char| !c.is_ascii_graphic() || c == ':') {
836        Err(ParseHeaderNameError)
837    } else {
838        Ok(())
839    }
840}
841
842#[cfg(test)]
843mod tests {
844    use super::{HeaderMap, HeaderName, HeaderValue, IntoHeaderName, IntoHeaderValue};
845    use std::str::{from_utf8, FromStr};
846
847    #[test]
848    fn try_from() {
849        let mut headers = HeaderMap::new();
850        headers.insert("name", "something".parse::<HeaderValue>().unwrap());
851        headers.insert("name", "something2");
852    }
853
854    #[test]
855    fn append() {
856        let mut headers = HeaderMap::new();
857        headers.append("Key", "value");
858        headers.append("Key", "second_value");
859
860        let mut result = headers.get_all("Key");
861
862        assert_eq!(
863            result.next().unwrap(),
864            &HeaderValue::from_str("value").unwrap()
865        );
866
867        assert_eq!(
868            result.next().unwrap(),
869            &HeaderValue::from_str("second_value").unwrap()
870        );
871
872        assert_eq!(result.next(), None);
873    }
874
875    #[test]
876    fn get_string() {
877        let mut headers = HeaderMap::new();
878        headers.append("Key", "value");
879        headers.append("Key", "other");
880
881        assert_eq!(headers.get("Key").unwrap().to_string(), "value");
882
883        let key: String = headers.get("Key").unwrap().as_str().into();
884        assert_eq!(key, "value".to_string());
885
886        let key: String = headers.get("Key").unwrap().as_str().to_owned();
887        assert_eq!(key, "value".to_string());
888
889        assert_eq!(headers.get("Key").unwrap().as_str(), "value");
890
891        let key: String = headers.get_last("Key").unwrap().as_str().into();
892        assert_eq!(key, "other".to_string());
893    }
894
895    #[test]
896    fn insert() {
897        let mut headers = HeaderMap::new();
898        headers.insert("Key", "Value");
899
900        let mut result = headers.get_all("Key");
901
902        assert_eq!(
903            result.next().unwrap(),
904            &HeaderValue::from_str("Value").unwrap()
905        );
906        assert_eq!(result.next(), None);
907    }
908
909    #[test]
910    fn serialize() {
911        let mut headers = HeaderMap::new();
912        headers.append("Key", "value");
913        headers.append("Key", "second_value");
914        headers.insert("Second", "SecondValue");
915
916        let bytes = headers.to_bytes();
917
918        println!("bytes: {:?}", from_utf8(&bytes));
919    }
920
921    #[test]
922    fn wire_len_matches_serialized() {
923        let mut headers = HeaderMap::new();
924        headers.append("Key", "value");
925        headers.append("Key", "second_value");
926        headers.insert("Second", "SecondValue");
927        assert_eq!(headers.wire_len(), headers.to_bytes().len());
928
929        // An empty block still serializes the framing bytes.
930        let empty = HeaderMap::new();
931        assert_eq!(empty.wire_len(), empty.to_bytes().len());
932    }
933
934    #[test]
935    fn is_empty() {
936        let mut headers = HeaderMap::new();
937        assert!(headers.is_empty());
938
939        headers.append("Key", "value");
940        headers.append("Key", "second_value");
941        headers.insert("Second", "SecondValue");
942        assert!(!headers.is_empty());
943    }
944
945    #[test]
946    fn parse_value() {
947        assert!("Foo\r".parse::<HeaderValue>().is_err());
948        assert!("Foo\n".parse::<HeaderValue>().is_err());
949        assert!("Foo\r\n".parse::<HeaderValue>().is_err());
950    }
951
952    #[test]
953    fn valid_header_name() {
954        let valid_header_name = "X-Custom-Header";
955        let parsed_header = HeaderName::from_str(valid_header_name);
956
957        assert!(
958            parsed_header.is_ok(),
959            "Expected Ok(HeaderName), but got an error: {:?}",
960            parsed_header.err()
961        );
962    }
963
964    #[test]
965    fn dollar_header_name() {
966        let valid_header_name = "$X_Custom_Header";
967        let parsed_header = HeaderName::from_str(valid_header_name);
968
969        assert!(
970            parsed_header.is_ok(),
971            "Expected Ok(HeaderName), but got an error: {:?}",
972            parsed_header.err()
973        );
974    }
975
976    #[test]
977    fn invalid_header_name_with_space() {
978        let invalid_header_name = "X Custom Header";
979        let parsed_header = HeaderName::from_str(invalid_header_name);
980
981        assert!(
982            parsed_header.is_err(),
983            "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
984            parsed_header.ok()
985        );
986    }
987
988    #[test]
989    fn invalid_header_name_with_special_chars() {
990        let invalid_header_name = "X-Header:";
991        let parsed_header = HeaderName::from_str(invalid_header_name);
992
993        assert!(
994            parsed_header.is_err(),
995            "Expected Err(InvalidHeaderNameError), but got Ok: {:?}",
996            parsed_header.ok()
997        );
998    }
999
1000    #[test]
1001    fn from_static_eq() {
1002        let a = HeaderName::from_static("NATS-Stream");
1003        let b = HeaderName::from_static("NATS-Stream");
1004
1005        assert_eq!(a, b);
1006    }
1007
1008    #[test]
1009    fn header_name_serde() {
1010        let raw = "Nats-Stream";
1011        let raw_json = "\"Nats-Stream\"";
1012        let header = HeaderName::from_static(raw);
1013
1014        // ser/de of HeaderName should be the same as raw string
1015        assert_eq!(serde_json::to_string(&header).unwrap(), raw_json);
1016        assert_eq!(
1017            serde_json::from_str::<HeaderName>(raw_json).unwrap(),
1018            header
1019        );
1020    }
1021
1022    #[test]
1023    fn header_name_from_string() {
1024        let string = "NATS-Stream".to_string();
1025        let name = string.into_header_name();
1026
1027        assert_eq!("NATS-Stream", name.as_str());
1028    }
1029
1030    #[test]
1031    fn header_value_from_string_with_trait() {
1032        let string = "some value".to_string();
1033
1034        let value = string.into_header_value();
1035
1036        assert_eq!("some value", value.as_str());
1037    }
1038
1039    #[test]
1040    fn header_value_from_string() {
1041        let string = "some value".to_string();
1042
1043        let value: HeaderValue = string.into();
1044
1045        assert_eq!("some value", value.as_str());
1046    }
1047
1048    #[test]
1049    fn header_map_backward_compatible_deserialization() {
1050        // Test new format (direct HashMap) - this is how it should serialize now
1051        let new_format_json =
1052            r#"{"Content-Type": ["application/json"], "Authorization": ["Bearer token"]}"#;
1053        let header_map: HeaderMap = serde_json::from_str(new_format_json).unwrap();
1054
1055        assert_eq!(
1056            header_map.get("Content-Type").unwrap().as_str(),
1057            "application/json"
1058        );
1059        assert_eq!(
1060            header_map.get("Authorization").unwrap().as_str(),
1061            "Bearer token"
1062        );
1063
1064        // Test legacy format (with "inner" wrapper) - this is the old format that should still work
1065        let legacy_format_json = r#"{"inner": {"Content-Type": ["application/json"], "Authorization": ["Bearer token"]}}"#;
1066        let header_map_legacy: HeaderMap = serde_json::from_str(legacy_format_json).unwrap();
1067
1068        assert_eq!(
1069            header_map_legacy.get("Content-Type").unwrap().as_str(),
1070            "application/json"
1071        );
1072        assert_eq!(
1073            header_map_legacy.get("Authorization").unwrap().as_str(),
1074            "Bearer token"
1075        );
1076
1077        // Both should be equal after deserialization
1078        assert_eq!(header_map, header_map_legacy);
1079    }
1080
1081    #[test]
1082    fn header_map_serialization_new_format() {
1083        // Test that serialization uses the new format (no "inner" wrapper)
1084        let mut headers = HeaderMap::new();
1085        headers.insert("Content-Type", "application/json");
1086        headers.insert("Authorization", "Bearer token");
1087
1088        let serialized = serde_json::to_string(&headers).unwrap();
1089
1090        // Should not contain "inner" key
1091        assert!(!serialized.contains("inner"));
1092
1093        // Should be able to deserialize back
1094        let deserialized: HeaderMap = serde_json::from_str(&serialized).unwrap();
1095        assert_eq!(headers, deserialized);
1096    }
1097
1098    #[test]
1099    fn header_map_roundtrip_compatibility() {
1100        // Test that we can roundtrip both formats
1101        let mut original = HeaderMap::new();
1102        original.insert("X-Custom-Header", "custom-value");
1103        original.append("Multi-Value", "value1");
1104        original.append("Multi-Value", "value2");
1105
1106        // Serialize using new format
1107        let new_serialized = serde_json::to_string(&original).unwrap();
1108        let new_deserialized: HeaderMap = serde_json::from_str(&new_serialized).unwrap();
1109        assert_eq!(original, new_deserialized);
1110
1111        // Manually create legacy format JSON
1112        let legacy_json = format!(r#"{{"inner": {}}}"#, new_serialized);
1113        let legacy_deserialized: HeaderMap = serde_json::from_str(&legacy_json).unwrap();
1114        assert_eq!(original, legacy_deserialized);
1115    }
1116
1117    #[test]
1118    fn header_map_invalid_format_error() {
1119        // Test that invalid JSON returns proper error
1120        let invalid_json = r#"{"not_inner_or_direct": {"Content-Type": ["application/json"]}}"#;
1121        let result = serde_json::from_str::<HeaderMap>(invalid_json);
1122        assert!(result.is_err());
1123
1124        let error_message = result.unwrap_err().to_string();
1125        // With untagged enum, serde will report that data doesn't match any variant
1126        assert!(error_message.contains("did not match any variant"));
1127    }
1128
1129    #[test]
1130    fn header_map_empty_cases() {
1131        // Test empty HeaderMap serialization/deserialization
1132        let empty = HeaderMap::new();
1133        let serialized = serde_json::to_string(&empty).unwrap();
1134        assert_eq!(serialized, "{}");
1135
1136        let deserialized: HeaderMap = serde_json::from_str("{}").unwrap();
1137        assert!(deserialized.is_empty());
1138
1139        // Test legacy empty format
1140        let legacy_empty = r#"{"inner": {}}"#;
1141        let legacy_deserialized: HeaderMap = serde_json::from_str(legacy_empty).unwrap();
1142        assert!(legacy_deserialized.is_empty());
1143    }
1144
1145    #[test]
1146    fn header_map_mixed_legacy_detection() {
1147        // Test that an "inner" header name doesn't confuse the deserializer
1148        // This would be a new format where "inner" is a legitimate header name
1149        let json_with_inner_header =
1150            r#"{"inner": ["some-value"], "Other-Header": ["other-value"]}"#;
1151        let header_map: HeaderMap = serde_json::from_str(json_with_inner_header).unwrap();
1152
1153        // Should have two headers
1154        assert_eq!(header_map.len(), 2);
1155        assert_eq!(header_map.get("inner").unwrap().as_str(), "some-value");
1156        assert_eq!(
1157            header_map.get("Other-Header").unwrap().as_str(),
1158            "other-value"
1159        );
1160    }
1161
1162    #[test]
1163    #[should_panic(expected = "invalid header value")]
1164    fn header_value_from_str_rejects_cr() {
1165        let _: HeaderValue = "value\rwith\rcr".into();
1166    }
1167
1168    #[test]
1169    #[should_panic(expected = "invalid header value")]
1170    fn header_value_from_str_rejects_lf() {
1171        let _: HeaderValue = "value\nwith\nlf".into();
1172    }
1173
1174    #[test]
1175    #[should_panic(expected = "invalid header value")]
1176    fn header_value_from_string_rejects_crlf() {
1177        let _: HeaderValue = "injected\r\nPUB attack 0\r\n\r\n".to_string().into();
1178    }
1179
1180    #[test]
1181    #[should_panic(expected = "invalid header value")]
1182    fn header_value_into_trait_rejects_crlf() {
1183        let mut headers = HeaderMap::new();
1184        headers.insert("Key", "value\r\nPUB attack 0\r\n\r\n");
1185    }
1186
1187    #[test]
1188    #[should_panic(expected = "invalid header name")]
1189    fn header_name_into_trait_rejects_cr() {
1190        let mut headers = HeaderMap::new();
1191        headers.insert("Bad\rName", "value");
1192    }
1193
1194    #[test]
1195    #[should_panic(expected = "invalid header name")]
1196    fn header_name_into_trait_rejects_lf() {
1197        let mut headers = HeaderMap::new();
1198        headers.insert("Bad\nName", "value");
1199    }
1200
1201    #[test]
1202    #[should_panic(expected = "invalid header name")]
1203    fn header_name_into_trait_rejects_space() {
1204        let mut headers = HeaderMap::new();
1205        headers.insert("Bad Name", "value");
1206    }
1207
1208    #[test]
1209    #[should_panic(expected = "invalid header name")]
1210    fn header_name_into_trait_rejects_colon() {
1211        let mut headers = HeaderMap::new();
1212        headers.insert("Bad:Name", "value");
1213    }
1214
1215    #[test]
1216    #[should_panic(expected = "invalid header name")]
1217    fn header_name_from_string_rejects_control_chars() {
1218        let name = "Bad\x00Name".to_string();
1219        name.into_header_name();
1220    }
1221
1222    #[test]
1223    fn valid_header_values_still_work() {
1224        let _: HeaderValue = "normal value".into();
1225        let _: HeaderValue = "value with special chars !@#$%^&*()".into();
1226        let _: HeaderValue = "".into();
1227        let _: HeaderValue = String::from("string value").into();
1228    }
1229
1230    #[test]
1231    fn header_name_try_from_str() {
1232        assert_eq!(
1233            HeaderName::try_from("X-Custom").unwrap().as_str(),
1234            "X-Custom"
1235        );
1236        assert!(HeaderName::try_from("Bad Name").is_err());
1237        assert!(HeaderName::try_from("Bad:Name").is_err());
1238    }
1239
1240    #[test]
1241    fn header_name_try_from_string() {
1242        assert_eq!(
1243            HeaderName::try_from("Nats-Stream".to_string())
1244                .unwrap()
1245                .as_str(),
1246            "Nats-Stream"
1247        );
1248        assert!(HeaderName::try_from("Bad\nName".to_string()).is_err());
1249    }
1250
1251    #[test]
1252    fn valid_header_names_still_work() {
1253        let mut headers = HeaderMap::new();
1254        headers.insert("X-Custom-Header", "value");
1255        headers.insert("Another-Header", "value");
1256        headers.insert("$dollar", "value");
1257        headers.insert("Nats-Stream", "value");
1258        assert_eq!(headers.get("Nats-Stream").unwrap().as_str(), "value");
1259    }
1260
1261    #[test]
1262    fn header_map_large_headers() {
1263        // Test with many headers
1264        let mut headers = HeaderMap::new();
1265        for i in 0..100 {
1266            headers.insert(format!("Header-{}", i), format!("Value-{}", i));
1267        }
1268
1269        let serialized = serde_json::to_string(&headers).unwrap();
1270        let deserialized: HeaderMap = serde_json::from_str(&serialized).unwrap();
1271
1272        assert_eq!(headers.len(), deserialized.len());
1273        for i in 0..100 {
1274            assert_eq!(
1275                deserialized
1276                    .get(format!("Header-{}", i).as_str())
1277                    .unwrap()
1278                    .as_str(),
1279                format!("Value-{}", i)
1280            );
1281        }
1282    }
1283}