Skip to main content

fast_hex_lite/
decode.rs

1//! Scalar hex decoder.
2
3use crate::Error;
4
5/// Returns the number of bytes produced from a hex string of `hex_len` bytes.
6///
7/// Returns [`Error::OddLength`] if `hex_len` is odd.
8///
9/// # Examples
10/// ```
11/// assert_eq!(fast_hex_lite::decoded_len(8).unwrap(), 4);
12/// ```
13#[inline]
14pub fn decoded_len(hex_len: usize) -> Result<usize, Error> {
15    if hex_len.is_multiple_of(2) {
16        Ok(hex_len / 2)
17    } else {
18        Err(Error::OddLength)
19    }
20}
21
22/// Decode ASCII-hex bytes `src_hex` into `dst`.
23///
24/// `src_hex` must contain an even number of bytes, all valid hex characters
25/// (`0-9`, `a-f`, `A-F`). `dst` must be at least `src_hex.len() / 2` bytes.
26///
27/// Returns the number of bytes written.
28#[inline]
29pub fn decode_to_slice(src_hex: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
30    let out_len = decoded_len(src_hex.len())?;
31    if dst.len() < out_len {
32        return Err(Error::OutputTooSmall);
33    }
34    #[cfg(feature = "simd")]
35    {
36        crate::simd::decode_to_slice_simd(src_hex, &mut dst[..out_len])
37    }
38    #[cfg(not(feature = "simd"))]
39    {
40        decode_scalar(src_hex, &mut dst[..out_len])
41    }
42}
43
44/// Decode exactly `N` bytes from a hex string of length `2*N`.
45///
46/// Returns [`Error::OutputTooSmall`] if `src_hex.len() / 2 != N`.
47pub fn decode_to_array<const N: usize>(src_hex: &[u8]) -> Result<[u8; N], Error> {
48    let out_len = decoded_len(src_hex.len())?;
49    if out_len != N {
50        return Err(Error::OutputTooSmall);
51    }
52    let mut arr = [0u8; N];
53    decode_to_slice(src_hex, &mut arr)?;
54    Ok(arr)
55}
56
57/// Decode hex bytes in-place: `buf` initially contains ASCII hex; after
58/// decoding, the first `buf.len() / 2` bytes hold the result.
59///
60/// Returns the number of bytes written.
61#[inline]
62pub fn decode_in_place(buf: &mut [u8]) -> Result<usize, Error> {
63    let out_len = decoded_len(buf.len())?;
64
65    // Pass 1: validate without writing, so on error the buffer is unchanged.
66    // Also lets us keep the fast decode loop branch-free.
67    for i in 0..out_len {
68        let hi = buf[2 * i];
69        let lo = buf[2 * i + 1];
70
71        if unhex_byte(hi).is_none() {
72            return Err(Error::InvalidByte {
73                index: 2 * i,
74                byte: hi,
75            });
76        }
77        if unhex_byte(lo).is_none() {
78            return Err(Error::InvalidByte {
79                index: 2 * i + 1,
80                byte: lo,
81            });
82        }
83    }
84
85    // Pass 2: decode. Safe to write now.
86    for i in 0..out_len {
87        let hi = buf[2 * i];
88        let lo = buf[2 * i + 1];
89        // Validation above guarantees `decode_pair` returns 0x00..=0xFF.
90        buf[i] = u8::try_from(decode_pair(hi, lo)).unwrap();
91    }
92
93    Ok(out_len)
94}
95
96// ── Scalar decoder ─────────────────────────────────────────────────────────
97
98#[inline]
99pub(crate) fn decode_scalar(src_hex: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
100    // `src_hex` is already even-length checked by the caller.
101    // `dst` is already sized-checked by the caller.
102    let out_len = src_hex.len() >> 1;
103
104    // Hot loop: single 16-bit table lookup per output byte.
105    // Use a tight index-based loop so LLVM can eliminate bounds checks.
106    let mut j = 0usize;
107    for out in dst.iter_mut().take(out_len) {
108        let hi = src_hex[j];
109        let lo = src_hex[j + 1];
110
111        let v = decode_pair(hi, lo);
112        if (v & 0x0100) != 0 {
113            // Slow-path only on error: identify which byte is invalid so we
114            // can report the correct index/byte.
115            if unhex_byte(hi).is_none() {
116                return Err(Error::InvalidByte { index: j, byte: hi });
117            }
118            return Err(Error::InvalidByte {
119                index: j + 1,
120                byte: lo,
121            });
122        }
123
124        // `decode_pair` returns 0x00..=0xFF for valid pairs.
125        *out = u8::try_from(v).unwrap();
126        j += 2;
127    }
128
129    Ok(out_len)
130}
131
132/// Map a single ASCII hex digit to its nibble value (0..=15).
133/// Returns `None` for non-hex bytes.
134///
135/// Fast table lookup.
136#[inline]
137pub(crate) fn unhex_byte(b: u8) -> Option<u8> {
138    let v = UNHEX_TABLE[b as usize];
139    if v == 0xFF {
140        None
141    } else {
142        Some(v)
143    }
144}
145
146// 256-entry nibble table (0..=15) or 0xFF for invalid.
147const UNHEX_TABLE: [u8; 256] = make_unhex_table();
148
149// 65_536-entry pair table. Each entry encodes either:
150// - valid: 0x0000..=0x00FF (decoded byte)
151// - invalid: 0x0100 (flag set)
152//
153// This lets the scalar decoder process 2 input bytes per iteration with a
154// single table lookup.
155static HEXPAIR_TABLE: [u16; 65536] = make_hexpair_table();
156
157#[inline]
158fn decode_pair(hi: u8, lo: u8) -> u16 {
159    // Index is the two ASCII bytes.
160    let idx = ((hi as usize) << 8) | (lo as usize);
161    HEXPAIR_TABLE[idx]
162}
163
164const fn make_unhex_table() -> [u8; 256] {
165    let mut t = [0xFFu8; 256];
166
167    // Iterate as `u8` to avoid any potentially-truncating casts.
168    let mut b = 0u8;
169    loop {
170        t[b as usize] = if b >= b'0' && b <= b'9' {
171            b - b'0'
172        } else if b >= b'a' && b <= b'f' {
173            b - b'a' + 10
174        } else if b >= b'A' && b <= b'F' {
175            b - b'A' + 10
176        } else {
177            0xFF
178        };
179
180        if b == u8::MAX {
181            break;
182        }
183        b = b.wrapping_add(1);
184    }
185
186    t
187}
188
189#[allow(clippy::large_stack_arrays)]
190const fn make_hexpair_table() -> [u16; 65536] {
191    let mut t = [0x0100u16; 65536];
192    let unhex = make_unhex_table();
193
194    let mut hi = 0u32;
195    while hi < 256 {
196        let mut lo = 0u32;
197        while lo < 256 {
198            let hn = unhex[hi as usize];
199            let ln = unhex[lo as usize];
200            if hn != 0xFF && ln != 0xFF {
201                let out = ((hn as u16) << 4) | (ln as u16);
202                t[((hi as usize) << 8) | (lo as usize)] = out;
203            }
204            lo += 1;
205        }
206        hi += 1;
207    }
208
209    t
210}
211
212#[cfg(test)]
213#[path = "decode/tests.rs"]
214mod tests;