Skip to main content

byte_wrapper/
hex_array.rs

1// Copyright (c) The byte-wrapper Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! The [`HexArray`] newtype wrapper.
5
6use core::{
7    array::TryFromSliceError,
8    error,
9    fmt::{self, Write},
10    str::FromStr,
11};
12
13/// A byte array that displays and parses as hex.
14///
15/// `HexArray<N>` wraps `[u8; N]`, providing [`Display`](fmt::Display),
16/// [`FromStr`], [`LowerHex`](fmt::LowerHex), and
17/// [`UpperHex`](fmt::UpperHex) implementations that use hexadecimal
18/// encoding.
19///
20/// With the **`serde`** feature enabled, it also implements
21/// `Serialize` and `Deserialize` (hex strings in human-readable
22/// formats, raw bytes in binary formats), and can be used with
23/// `#[serde(with = "HexArray::<N>")]` on `[u8; N]` fields.
24///
25/// # Examples
26///
27/// ```
28/// use byte_wrapper::HexArray;
29///
30/// let h = HexArray::new([0x01, 0x02, 0xab, 0xff]);
31/// assert_eq!(h.to_string(), "0102abff");
32///
33/// let parsed: HexArray<4> = "0102abff".parse().unwrap();
34/// assert_eq!(parsed, h);
35/// ```
36#[must_use]
37#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
38pub struct HexArray<const N: usize>(pub [u8; N]);
39
40impl<const N: usize> Default for HexArray<N> {
41    fn default() -> Self {
42        Self([0u8; N])
43    }
44}
45
46impl<const N: usize> HexArray<N> {
47    /// Creates a new `HexArray` from a byte array.
48    #[inline]
49    pub const fn new(bytes: [u8; N]) -> Self {
50        Self(bytes)
51    }
52
53    /// Returns the inner byte array.
54    #[inline]
55    #[must_use]
56    pub const fn into_inner(self) -> [u8; N] {
57        self.0
58    }
59}
60
61/// Formats a byte slice as lower-case hex.
62///
63/// This is used both for serialization (via `Display`) and as the
64/// inner value in `HexArray`'s `Debug` output (via `Debug`).
65struct HexDisplay<'a>(&'a [u8]);
66
67impl fmt::Display for HexDisplay<'_> {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        for byte in self.0 {
70            write!(f, "{byte:02x}")?;
71        }
72        Ok(())
73    }
74}
75
76impl fmt::Debug for HexDisplay<'_> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        fmt::Display::fmt(self, f)
79    }
80}
81
82/// Error returned by [`HexArray::from_str`].
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum ParseHexError {
85    /// The input string had the wrong length.
86    InvalidLength {
87        /// Expected number of hex characters.
88        expected: usize,
89        /// Actual number of characters in the input.
90        actual: usize,
91    },
92    /// The input contained an invalid hex character.
93    InvalidHexCharacter {
94        /// The invalid character.
95        c: char,
96        /// Byte index of the invalid character in the input.
97        index: usize,
98    },
99}
100
101impl fmt::Display for ParseHexError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self {
104            ParseHexError::InvalidLength { expected, actual } => {
105                write!(
106                    f,
107                    "expected {} hex characters, got {}",
108                    expected, actual,
109                )
110            }
111            ParseHexError::InvalidHexCharacter { c, index } => {
112                write!(f, "invalid hex character '{}' at index {}", c, index,)
113            }
114        }
115    }
116}
117
118impl error::Error for ParseHexError {}
119
120impl<const N: usize> FromStr for HexArray<N> {
121    type Err = ParseHexError;
122
123    fn from_str(s: &str) -> Result<Self, Self::Err> {
124        let expected = N * 2;
125        if s.len() != expected {
126            return Err(ParseHexError::InvalidLength {
127                expected,
128                actual: s.len(),
129            });
130        }
131        let mut out = [0u8; N];
132        hex::decode_to_slice(s, &mut out).map_err(|e| {
133            match e {
134                hex::FromHexError::InvalidHexCharacter { c, index } => {
135                    ParseHexError::InvalidHexCharacter { c, index }
136                }
137                // The length is already validated above, so this
138                // branch is unreachable in practice.
139                hex::FromHexError::OddLength
140                | hex::FromHexError::InvalidStringLength => {
141                    ParseHexError::InvalidLength { expected, actual: s.len() }
142                }
143            }
144        })?;
145        Ok(Self(out))
146    }
147}
148
149impl<const N: usize> fmt::Debug for HexArray<N> {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        f.debug_tuple("HexArray").field(&HexDisplay(&self.0)).finish()
152    }
153}
154
155/// Writes hex bytes with padding/alignment support.
156///
157/// `write_byte` is called for each byte in the array to allow
158/// callers to choose between lowercase and uppercase hex.
159fn fmt_hex_padded<const N: usize>(
160    bytes: &[u8; N],
161    f: &mut fmt::Formatter<'_>,
162    write_byte: fn(&mut fmt::Formatter<'_>, u8) -> fmt::Result,
163) -> fmt::Result {
164    let content_len = N * 2;
165
166    match f.width() {
167        Some(width) if width > content_len => {
168            let padding = width - content_len;
169            let fill = f.fill();
170            let (pre, post) = match f.align() {
171                Some(fmt::Alignment::Left) => (0, padding),
172                Some(fmt::Alignment::Right) | None => (padding, 0),
173                Some(fmt::Alignment::Center) => {
174                    (padding / 2, padding - padding / 2)
175                }
176            };
177            for _ in 0..pre {
178                f.write_char(fill)?;
179            }
180            for &byte in bytes {
181                write_byte(f, byte)?;
182            }
183            for _ in 0..post {
184                f.write_char(fill)?;
185            }
186            Ok(())
187        }
188        Some(_) | None => {
189            for &byte in bytes {
190                write_byte(f, byte)?;
191            }
192            Ok(())
193        }
194    }
195}
196
197impl<const N: usize> fmt::Display for HexArray<N> {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        fmt_hex_padded(&self.0, f, |f, b| write!(f, "{b:02x}"))
200    }
201}
202
203impl<const N: usize> fmt::LowerHex for HexArray<N> {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        fmt_hex_padded(&self.0, f, |f, b| write!(f, "{b:02x}"))
206    }
207}
208
209impl<const N: usize> fmt::UpperHex for HexArray<N> {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        fmt_hex_padded(&self.0, f, |f, b| write!(f, "{b:02X}"))
212    }
213}
214
215impl<const N: usize> core::ops::Deref for HexArray<N> {
216    type Target = [u8; N];
217
218    #[inline]
219    fn deref(&self) -> &Self::Target {
220        &self.0
221    }
222}
223
224impl<const N: usize> core::ops::DerefMut for HexArray<N> {
225    #[inline]
226    fn deref_mut(&mut self) -> &mut Self::Target {
227        &mut self.0
228    }
229}
230
231impl<const N: usize> AsRef<[u8]> for HexArray<N> {
232    #[inline]
233    fn as_ref(&self) -> &[u8] {
234        &self.0
235    }
236}
237
238impl<const N: usize> AsMut<[u8]> for HexArray<N> {
239    #[inline]
240    fn as_mut(&mut self) -> &mut [u8] {
241        &mut self.0
242    }
243}
244
245impl<const N: usize> From<[u8; N]> for HexArray<N> {
246    #[inline]
247    fn from(bytes: [u8; N]) -> Self {
248        Self(bytes)
249    }
250}
251
252impl<const N: usize> From<HexArray<N>> for [u8; N] {
253    #[inline]
254    fn from(hex_array: HexArray<N>) -> Self {
255        hex_array.0
256    }
257}
258
259impl<const N: usize> TryFrom<&[u8]> for HexArray<N> {
260    type Error = TryFromSliceError;
261
262    #[inline]
263    fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
264        <[u8; N]>::try_from(slice).map(Self)
265    }
266}
267
268#[cfg(feature = "serde")]
269mod serde_impls {
270    use super::{HexArray, HexDisplay};
271    use core::fmt;
272    use serde_core::{
273        Deserializer,
274        de::{Expected, SeqAccess, Visitor},
275    };
276
277    /// Serializes a byte slice as lower-case hex if human-readable,
278    /// or as raw bytes if not.
279    fn serialize_lower<S>(
280        bytes: &[u8],
281        serializer: S,
282    ) -> Result<S::Ok, S::Error>
283    where
284        S: serde_core::Serializer,
285    {
286        if serializer.is_human_readable() {
287            serializer.collect_str(&HexDisplay(bytes))
288        } else {
289            serializer.serialize_bytes(bytes)
290        }
291    }
292
293    struct HexExpected<const N: usize>;
294
295    impl<const N: usize> Expected for HexExpected<N> {
296        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
297            write!(f, "a byte array [u8; {}]", N)
298        }
299    }
300
301    struct HexStrExpected<const N: usize>;
302
303    impl<const N: usize> Expected for HexStrExpected<N> {
304        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
305            write!(f, "a hex string {} hex digits long", N * 2)
306        }
307    }
308
309    /// Deserializes hex strings (if human-readable) or byte arrays
310    /// (if not) to `[u8; N]`.
311    fn deserialize<'de, D, const N: usize>(
312        deserializer: D,
313    ) -> Result<[u8; N], D::Error>
314    where
315        D: Deserializer<'de>,
316    {
317        use serde_core::de::Error;
318
319        if deserializer.is_human_readable() {
320            // hex::FromHex doesn't have an implementation for
321            // const-generic N, so do our own thing.
322            struct HexVisitor<const N: usize>;
323
324            impl<'de2, const N: usize> Visitor<'de2> for HexVisitor<N> {
325                type Value = [u8; N];
326
327                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
328                    write!(f, "a hex string {} hex digits long", N * 2)
329                }
330
331                fn visit_str<E>(self, data: &str) -> Result<Self::Value, E>
332                where
333                    E: Error,
334                {
335                    let expected_len = N * 2;
336                    if data.len() != expected_len {
337                        return Err(E::invalid_length(
338                            data.len(),
339                            &HexStrExpected::<N>,
340                        ));
341                    }
342                    let mut out = [0u8; N];
343                    hex::decode_to_slice(data, &mut out)
344                        .map_err(Error::custom)?;
345                    Ok(out)
346                }
347            }
348
349            deserializer.deserialize_str(HexVisitor)
350        } else {
351            struct BytesVisitor<const N: usize>;
352
353            impl<'de2, const N: usize> Visitor<'de2> for BytesVisitor<N> {
354                type Value = [u8; N];
355
356                fn expecting(
357                    &self,
358                    formatter: &mut fmt::Formatter,
359                ) -> fmt::Result {
360                    write!(formatter, "a byte array [u8; {}]", N)
361                }
362
363                fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
364                where
365                    E: Error,
366                {
367                    v.try_into().map_err(|_| {
368                        E::invalid_length(v.len(), &HexExpected::<N>)
369                    })
370                }
371
372                fn visit_seq<A>(
373                    self,
374                    mut seq: A,
375                ) -> Result<Self::Value, A::Error>
376                where
377                    A: SeqAccess<'de2>,
378                {
379                    // Reject early if the sequence reports a
380                    // wrong length.
381                    if let Some(len) = seq.size_hint()
382                        && len != N
383                    {
384                        return Err(Error::invalid_length(
385                            len,
386                            &HexExpected::<N>,
387                        ));
388                    }
389                    let mut out = [0u8; N];
390                    for (i, byte) in out.iter_mut().enumerate() {
391                        *byte = seq.next_element()?.ok_or_else(|| {
392                            Error::invalid_length(i, &HexExpected::<N>)
393                        })?;
394                    }
395                    // Reject trailing elements rather than
396                    // silently discarding them.
397                    if seq.next_element::<u8>()?.is_some() {
398                        // We don't know the actual length, but
399                        // we know it's more than N.
400                        return Err(Error::invalid_length(
401                            N + 1,
402                            &HexExpected::<N>,
403                        ));
404                    }
405                    Ok(out)
406                }
407            }
408
409            deserializer.deserialize_bytes(BytesVisitor)
410        }
411    }
412
413    impl<const N: usize> HexArray<N> {
414        /// Serializes a byte array as hex in human-readable formats,
415        /// or as raw bytes otherwise.
416        ///
417        /// Intended for use with
418        /// `#[serde(with = "HexArray::<N>")]`.
419        ///
420        /// # Examples
421        ///
422        /// ```
423        /// use byte_wrapper::HexArray;
424        /// use serde::{Deserialize, Serialize};
425        ///
426        /// #[derive(Serialize, Deserialize)]
427        /// struct Record {
428        ///     #[serde(with = "HexArray::<4>")]
429        ///     id: [u8; 4],
430        /// }
431        ///
432        /// let r = Record { id: [0x01, 0x02, 0x03, 0x04] };
433        /// let json = serde_json::to_string(&r).unwrap();
434        /// assert_eq!(json, r#"{"id":"01020304"}"#);
435        /// ```
436        #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
437        pub fn serialize<S>(
438            bytes: &[u8; N],
439            serializer: S,
440        ) -> Result<S::Ok, S::Error>
441        where
442            S: serde_core::Serializer,
443        {
444            serialize_lower(bytes, serializer)
445        }
446
447        /// Deserializes a byte array from hex if the format is
448        /// human-readable, or as raw bytes otherwise.
449        ///
450        /// Intended for use with
451        /// `#[serde(with = "HexArray::<N>")]`.
452        ///
453        /// # Examples
454        ///
455        /// ```
456        /// use byte_wrapper::HexArray;
457        /// use serde::{Deserialize, Serialize};
458        ///
459        /// #[derive(Serialize, Deserialize)]
460        /// struct Record {
461        ///     #[serde(with = "HexArray::<4>")]
462        ///     id: [u8; 4],
463        /// }
464        ///
465        /// let r: Record = serde_json::from_str(r#"{"id":"01020304"}"#).unwrap();
466        /// assert_eq!(r.id, [0x01, 0x02, 0x03, 0x04]);
467        /// ```
468        #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
469        pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; N], D::Error>
470        where
471            D: Deserializer<'de>,
472        {
473            deserialize(deserializer)
474        }
475    }
476
477    #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
478    impl<const N: usize> serde_core::Serialize for HexArray<N> {
479        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
480        where
481            S: serde_core::Serializer,
482        {
483            serialize_lower(&self.0, serializer)
484        }
485    }
486
487    #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
488    impl<'de, const N: usize> serde_core::Deserialize<'de> for HexArray<N> {
489        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
490        where
491            D: Deserializer<'de>,
492        {
493            deserialize(deserializer).map(Self)
494        }
495    }
496}
497
498#[cfg(feature = "schemars08")]
499mod schemars_impls {
500    use super::HexArray;
501    use crate::schemars_util::x_rust_type_extension;
502    use alloc::{boxed::Box, format, string::String};
503    use schemars08::{
504        JsonSchema,
505        r#gen::SchemaGenerator,
506        schema::{InstanceType, Schema, SchemaObject, StringValidation},
507    };
508
509    impl<const N: usize> JsonSchema for HexArray<N> {
510        fn schema_name() -> String {
511            format!("HexArray_{N}")
512        }
513
514        fn is_referenceable() -> bool {
515            false
516        }
517
518        fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
519            let hex_len = N * 2;
520            Schema::Object(SchemaObject {
521                instance_type: Some(InstanceType::String.into()),
522                string: Some(Box::new(StringValidation {
523                    min_length: Some(hex_len as u32),
524                    max_length: Some(hex_len as u32),
525                    pattern: Some(format!("^[0-9a-fA-F]{{{hex_len}}}$")),
526                })),
527                extensions: x_rust_type_extension(&format!("HexArray::<{N}>")),
528                ..Default::default()
529            })
530        }
531    }
532}