hex_conservative/
iter.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Iterator that converts hex to bytes.
4
5use core::convert::TryInto;
6use core::iter::FusedIterator;
7use core::str;
8#[cfg(feature = "std")]
9use std::io;
10
11#[cfg(feature = "alloc")]
12use crate::alloc::vec::Vec;
13use crate::error::{InvalidCharError, OddLengthStringError};
14
15/// Iterator yielding bytes decoded from an iterator of pairs of hex digits.
16#[derive(Debug)]
17pub struct HexToBytesIter<I>
18where
19    I: Iterator<Item = [u8; 2]>,
20{
21    iter: I,
22    original_len: usize,
23}
24
25impl<'a> HexToBytesIter<HexDigitsIter<'a>> {
26    /// Constructs a new `HexToBytesIter` from a string slice.
27    ///
28    /// # Errors
29    ///
30    /// If the input string is of odd length.
31    #[inline]
32    #[allow(dead_code)] // Remove this when making HexToBytesIter public.
33    pub(crate) fn new(s: &'a str) -> Result<Self, OddLengthStringError> {
34        if s.len() % 2 != 0 {
35            Err(OddLengthStringError { len: s.len() })
36        } else {
37            Ok(Self::new_unchecked(s))
38        }
39    }
40
41    #[inline]
42    pub(crate) fn new_unchecked(s: &'a str) -> Self {
43        Self::from_pairs(HexDigitsIter::new_unchecked(s.as_bytes()))
44    }
45
46    /// Writes all the bytes yielded by this `HexToBytesIter` to the provided slice.
47    ///
48    /// Stops writing if this `HexToBytesIter` yields an `InvalidCharError`.
49    ///
50    /// # Panics
51    ///
52    /// Panics if the length of this `HexToBytesIter` is not equal to the length of the provided
53    /// slice.
54    pub(crate) fn drain_to_slice(self, buf: &mut [u8]) -> Result<(), InvalidCharError> {
55        assert_eq!(self.len(), buf.len());
56        let mut ptr = buf.as_mut_ptr();
57        for byte in self {
58            // SAFETY: for loop iterates `len` times, and `buf` has length `len`
59            unsafe {
60                core::ptr::write(ptr, byte?);
61                ptr = ptr.add(1);
62            }
63        }
64        Ok(())
65    }
66
67    /// Writes all the bytes yielded by this `HexToBytesIter` to a `Vec<u8>`.
68    ///
69    /// This is equivalent to the combinator chain `iter().map().collect()` but was found by
70    /// benchmarking to be faster.
71    #[cfg(feature = "alloc")]
72    pub(crate) fn drain_to_vec(self) -> Result<Vec<u8>, InvalidCharError> {
73        let len = self.len();
74        let mut ret = Vec::with_capacity(len);
75        let mut ptr = ret.as_mut_ptr();
76        for byte in self {
77            // SAFETY: for loop iterates `len` times, and `ret` has a capacity of at least `len`
78            unsafe {
79                // docs: "`core::ptr::write` is appropriate for initializing uninitialized memory"
80                core::ptr::write(ptr, byte?);
81                ptr = ptr.add(1);
82            }
83        }
84        // SAFETY: `len` elements have been initialized, and `ret` has a capacity of at least `len`
85        unsafe {
86            ret.set_len(len);
87        }
88        Ok(ret)
89    }
90}
91
92impl<I> HexToBytesIter<I>
93where
94    I: Iterator<Item = [u8; 2]> + ExactSizeIterator,
95{
96    /// Constructs a custom hex decoding iterator from another iterator.
97    #[inline]
98    pub fn from_pairs(iter: I) -> Self { Self { original_len: iter.len(), iter } }
99}
100
101impl<I> Iterator for HexToBytesIter<I>
102where
103    I: Iterator<Item = [u8; 2]> + ExactSizeIterator,
104{
105    type Item = Result<u8, InvalidCharError>;
106
107    #[inline]
108    fn next(&mut self) -> Option<Self::Item> {
109        let [hi, lo] = self.iter.next()?;
110        Some(hex_chars_to_byte(hi, lo).map_err(|(c, is_high)| InvalidCharError {
111            invalid: c,
112            pos: if is_high {
113                (self.original_len - self.iter.len() - 1) * 2
114            } else {
115                (self.original_len - self.iter.len() - 1) * 2 + 1
116            },
117        }))
118    }
119
120    #[inline]
121    fn size_hint(&self) -> (usize, Option<usize>) { self.iter.size_hint() }
122
123    #[inline]
124    fn nth(&mut self, n: usize) -> Option<Self::Item> {
125        let [hi, lo] = self.iter.nth(n)?;
126        Some(hex_chars_to_byte(hi, lo).map_err(|(c, is_high)| InvalidCharError {
127            invalid: c,
128            pos: if is_high {
129                (self.original_len - self.iter.len() - 1) * 2
130            } else {
131                (self.original_len - self.iter.len() - 1) * 2 + 1
132            },
133        }))
134    }
135}
136
137impl<I> DoubleEndedIterator for HexToBytesIter<I>
138where
139    I: Iterator<Item = [u8; 2]> + DoubleEndedIterator + ExactSizeIterator,
140{
141    #[inline]
142    fn next_back(&mut self) -> Option<Self::Item> {
143        let [hi, lo] = self.iter.next_back()?;
144        Some(hex_chars_to_byte(hi, lo).map_err(|(c, is_high)| InvalidCharError {
145            invalid: c,
146            pos: if is_high { self.iter.len() * 2 } else { self.iter.len() * 2 + 1 },
147        }))
148    }
149
150    #[inline]
151    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
152        let [hi, lo] = self.iter.nth_back(n)?;
153        Some(hex_chars_to_byte(hi, lo).map_err(|(c, is_high)| InvalidCharError {
154            invalid: c,
155            pos: if is_high { self.iter.len() * 2 } else { self.iter.len() * 2 + 1 },
156        }))
157    }
158}
159
160impl<I> ExactSizeIterator for HexToBytesIter<I> where I: Iterator<Item = [u8; 2]> + ExactSizeIterator
161{}
162
163impl<I> FusedIterator for HexToBytesIter<I> where
164    I: Iterator<Item = [u8; 2]> + ExactSizeIterator + FusedIterator
165{
166}
167
168#[cfg(feature = "std")]
169impl<I> io::Read for HexToBytesIter<I>
170where
171    I: Iterator<Item = [u8; 2]> + ExactSizeIterator + FusedIterator,
172{
173    #[inline]
174    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175        let mut bytes_read = 0usize;
176        for dst in buf {
177            match self.next() {
178                Some(Ok(src)) => {
179                    *dst = src;
180                    bytes_read += 1;
181                }
182                Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::InvalidData, e)),
183                None => break,
184            }
185        }
186        Ok(bytes_read)
187    }
188}
189
190/// An internal iterator returning hex digits from a string.
191///
192/// Generally you shouldn't need to refer to this or bother with it and just use
193/// [`HexToBytesIter::new`] consuming the returned value and use `HexSliceToBytesIter` if you need
194/// to refer to the iterator in your types.
195#[derive(Debug)]
196pub struct HexDigitsIter<'a> {
197    // Invariant: the length of the chunks is 2.
198    // Technically, this is `iter::Map` but we can't use it because fn is anonymous.
199    // We can swap this for actual `ArrayChunks` once it's stable.
200    iter: core::slice::ChunksExact<'a, u8>,
201}
202
203impl<'a> HexDigitsIter<'a> {
204    #[inline]
205    fn new_unchecked(digits: &'a [u8]) -> Self { Self { iter: digits.chunks_exact(2) } }
206}
207
208impl Iterator for HexDigitsIter<'_> {
209    type Item = [u8; 2];
210
211    #[inline]
212    fn next(&mut self) -> Option<Self::Item> {
213        self.iter.next().map(|digits| digits.try_into().expect("HexDigitsIter invariant"))
214    }
215
216    #[inline]
217    fn size_hint(&self) -> (usize, Option<usize>) { self.iter.size_hint() }
218
219    #[inline]
220    fn nth(&mut self, n: usize) -> Option<Self::Item> {
221        self.iter.nth(n).map(|digits| digits.try_into().expect("HexDigitsIter invariant"))
222    }
223}
224
225impl DoubleEndedIterator for HexDigitsIter<'_> {
226    #[inline]
227    fn next_back(&mut self) -> Option<Self::Item> {
228        self.iter.next_back().map(|digits| digits.try_into().expect("HexDigitsIter invariant"))
229    }
230
231    #[inline]
232    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
233        self.iter.nth_back(n).map(|digits| digits.try_into().expect("HexDigitsIter invariant"))
234    }
235}
236
237impl ExactSizeIterator for HexDigitsIter<'_> {}
238
239impl core::iter::FusedIterator for HexDigitsIter<'_> {}
240
241/// `hi` and `lo` are bytes representing hex characters.
242///
243/// Returns the valid byte or the invalid input byte and a bool indicating error for `hi` or `lo`.
244fn hex_chars_to_byte(hi: u8, lo: u8) -> Result<u8, (u8, bool)> {
245    let hih = (hi as char).to_digit(16).ok_or((hi, true))?;
246    let loh = (lo as char).to_digit(16).ok_or((lo, false))?;
247
248    let ret = (hih << 4) + loh;
249    Ok(ret as u8)
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn decode_iter_forward() {
258        let hex = "deadbeef";
259        let bytes = [0xde, 0xad, 0xbe, 0xef];
260
261        for (i, b) in HexToBytesIter::new(hex).unwrap().enumerate() {
262            assert_eq!(b.unwrap(), bytes[i]);
263        }
264
265        let mut iter = HexToBytesIter::new(hex).unwrap();
266        for i in (0..=bytes.len()).rev() {
267            assert_eq!(iter.len(), i);
268            let _ = iter.next();
269        }
270    }
271
272    #[test]
273    fn decode_iter_backward() {
274        let hex = "deadbeef";
275        let bytes = [0xef, 0xbe, 0xad, 0xde];
276
277        for (i, b) in HexToBytesIter::new(hex).unwrap().rev().enumerate() {
278            assert_eq!(b.unwrap(), bytes[i]);
279        }
280
281        let mut iter = HexToBytesIter::new(hex).unwrap().rev();
282        for i in (0..=bytes.len()).rev() {
283            assert_eq!(iter.len(), i);
284            let _ = iter.next();
285        }
286    }
287
288    #[test]
289    fn hex_to_digits_size_hint() {
290        let hex = "deadbeef";
291        let iter = HexDigitsIter::new_unchecked(hex.as_bytes());
292        // HexDigitsIter yields two digits at a time `[u8; 2]`.
293        assert_eq!(iter.size_hint(), (4, Some(4)));
294    }
295
296    #[test]
297    fn hex_to_bytes_size_hint() {
298        let hex = "deadbeef";
299        let iter = HexToBytesIter::new_unchecked(hex);
300        assert_eq!(iter.size_hint(), (4, Some(4)));
301    }
302
303    #[test]
304    fn hex_to_bytes_slice_drain() {
305        let hex = "deadbeef";
306        let want = [0xde, 0xad, 0xbe, 0xef];
307        let iter = HexToBytesIter::new_unchecked(hex);
308        let mut got = [0u8; 4];
309        iter.drain_to_slice(&mut got).unwrap();
310        assert_eq!(got, want);
311
312        let hex = "";
313        let want = [];
314        let iter = HexToBytesIter::new_unchecked(hex);
315        let mut got = [];
316        iter.drain_to_slice(&mut got).unwrap();
317        assert_eq!(got, want);
318    }
319
320    #[test]
321    #[should_panic]
322    // Don't test panic message because it is from `debug_assert`.
323    #[allow(clippy::should_panic_without_expect)]
324    fn hex_to_bytes_slice_drain_panic_empty() {
325        let hex = "deadbeef";
326        let iter = HexToBytesIter::new_unchecked(hex);
327        let mut got = [];
328        iter.drain_to_slice(&mut got).unwrap();
329    }
330
331    #[test]
332    #[should_panic]
333    // Don't test panic message because it is from `debug_assert`.
334    #[allow(clippy::should_panic_without_expect)]
335    fn hex_to_bytes_slice_drain_panic_too_small() {
336        let hex = "deadbeef";
337        let iter = HexToBytesIter::new_unchecked(hex);
338        let mut got = [0u8; 3];
339        iter.drain_to_slice(&mut got).unwrap();
340    }
341
342    #[test]
343    #[should_panic]
344    // Don't test panic message because it is from `debug_assert`.
345    #[allow(clippy::should_panic_without_expect)]
346    fn hex_to_bytes_slice_drain_panic_too_big() {
347        let hex = "deadbeef";
348        let iter = HexToBytesIter::new_unchecked(hex);
349        let mut got = [0u8; 5];
350        iter.drain_to_slice(&mut got).unwrap();
351    }
352
353    #[test]
354    fn hex_to_bytes_slice_drain_first_char_error() {
355        let hex = "geadbeef";
356        let iter = HexToBytesIter::new_unchecked(hex);
357        let mut got = [0u8; 4];
358        assert_eq!(
359            iter.drain_to_slice(&mut got).unwrap_err(),
360            InvalidCharError { invalid: b'g', pos: 0 }
361        );
362    }
363
364    #[test]
365    fn hex_to_bytes_slice_drain_middle_char_error() {
366        let hex = "deadgeef";
367        let iter = HexToBytesIter::new_unchecked(hex);
368        let mut got = [0u8; 4];
369        assert_eq!(
370            iter.drain_to_slice(&mut got).unwrap_err(),
371            InvalidCharError { invalid: b'g', pos: 4 }
372        );
373    }
374
375    #[test]
376    fn hex_to_bytes_slice_drain_end_char_error() {
377        let hex = "deadbeeg";
378        let iter = HexToBytesIter::new_unchecked(hex);
379        let mut got = [0u8; 4];
380        assert_eq!(
381            iter.drain_to_slice(&mut got).unwrap_err(),
382            InvalidCharError { invalid: b'g', pos: 7 }
383        );
384    }
385
386    #[cfg(feature = "alloc")]
387    #[test]
388    fn hex_to_bytes_vec_drain() {
389        let hex = "deadbeef";
390        let want = [0xde, 0xad, 0xbe, 0xef];
391        let iter = HexToBytesIter::new_unchecked(hex);
392        let got = iter.drain_to_vec().unwrap();
393        assert_eq!(got, want);
394
395        let hex = "";
396        let iter = HexToBytesIter::new_unchecked(hex);
397        let got = iter.drain_to_vec().unwrap();
398        assert!(got.is_empty());
399    }
400
401    #[cfg(feature = "alloc")]
402    #[test]
403    fn hex_to_bytes_vec_drain_first_char_error() {
404        let hex = "geadbeef";
405        let iter = HexToBytesIter::new_unchecked(hex);
406        assert_eq!(iter.drain_to_vec().unwrap_err(), InvalidCharError { invalid: b'g', pos: 0 });
407    }
408
409    #[cfg(feature = "alloc")]
410    #[test]
411    fn hex_to_bytes_vec_drain_middle_char_error() {
412        let hex = "deadgeef";
413        let iter = HexToBytesIter::new_unchecked(hex);
414        assert_eq!(iter.drain_to_vec().unwrap_err(), InvalidCharError { invalid: b'g', pos: 4 });
415    }
416
417    #[cfg(feature = "alloc")]
418    #[test]
419    fn hex_to_bytes_vec_drain_end_char_error() {
420        let hex = "deadbeeg";
421        let iter = HexToBytesIter::new_unchecked(hex);
422        assert_eq!(iter.drain_to_vec().unwrap_err(), InvalidCharError { invalid: b'g', pos: 7 });
423    }
424
425    #[test]
426    #[cfg(feature = "std")]
427    fn hex_to_bytes_iter_read() {
428        use std::io::Read;
429
430        let hex = "deadbeef";
431        let mut iter = HexToBytesIter::new(hex).unwrap();
432        let mut buf = [0u8; 4];
433        let bytes_read = iter.read(&mut buf).unwrap();
434        assert_eq!(bytes_read, 4);
435        assert_eq!(buf, [0xde, 0xad, 0xbe, 0xef]);
436
437        let hex = "deadbeef";
438        let mut iter = HexToBytesIter::new(hex).unwrap();
439        let mut buf = [0u8; 2];
440        let bytes_read = iter.read(&mut buf).unwrap();
441        assert_eq!(bytes_read, 2);
442        assert_eq!(buf, [0xde, 0xad]);
443
444        let hex = "deadbeef";
445        let mut iter = HexToBytesIter::new(hex).unwrap();
446        let mut buf = [0u8; 6];
447        let bytes_read = iter.read(&mut buf).unwrap();
448        assert_eq!(bytes_read, 4);
449        assert_eq!(buf[..4], [0xde, 0xad, 0xbe, 0xef]);
450
451        let hex = "deadbeefXX";
452        let mut iter = HexToBytesIter::new(hex).unwrap();
453        let mut buf = [0u8; 6];
454        let err = iter.read(&mut buf).unwrap_err();
455        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
456    }
457}