spideroak_crypto/
hex.rs

1//! Constant time hexadecimal encoding and decoding.
2
3use core::{fmt, result::Result, str};
4
5use subtle::{Choice, ConditionallySelectable};
6
7/// Encodes `T` as hexadecimal in constant time.
8#[derive(Copy, Clone)]
9pub struct Hex<T>(T);
10
11impl<T> Hex<T> {
12    /// Creates a new `Bytes`.
13    pub const fn new(value: T) -> Self {
14        Self(value)
15    }
16}
17
18impl<T> fmt::Display for Hex<T>
19where
20    T: AsRef<[u8]>,
21{
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        fmt::LowerHex::fmt(self, f)
24    }
25}
26
27impl<T> fmt::Debug for Hex<T>
28where
29    T: AsRef<[u8]>,
30{
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        fmt::LowerHex::fmt(self, f)
33    }
34}
35
36impl<T> fmt::LowerHex for Hex<T>
37where
38    T: AsRef<[u8]>,
39{
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        ct_write_lower(f, self.0.as_ref())
42    }
43}
44
45impl<T> fmt::UpperHex for Hex<T>
46where
47    T: AsRef<[u8]>,
48{
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        ct_write_upper(f, self.0.as_ref())
51    }
52}
53
54/// Implemented by types that can encode themselves as hex in
55/// constant time.
56pub trait ToHex {
57    /// A hexadecimal string.
58    type Output: AsRef<[u8]>;
59
60    /// Encodes itself as a hexadecimal string.
61    fn to_hex(self) -> Hex<Self::Output>;
62}
63
64impl<T> ToHex for T
65where
66    T: AsRef<[u8]>,
67{
68    type Output = T;
69
70    fn to_hex(self) -> Hex<Self::Output> {
71        Hex::new(self)
72    }
73}
74
75/// Returned by [`ct_encode`] when `dst` is not twice as long as
76/// `src`.
77#[derive(Clone, Debug, thiserror::Error)]
78#[error("invalid length")]
79pub struct InvalidLength(());
80
81/// Encodes `src` into `dst` as hexadecimal in constant time and
82/// returns the number of bytes written.
83///
84/// `dst` must be at least twice as long as `src`.
85pub fn ct_encode(dst: &mut [u8], src: &[u8]) -> Result<(), InvalidLength> {
86    // The implementation is taken from
87    // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
88
89    if dst.len() / 2 < src.len() {
90        return Err(InvalidLength(()));
91    }
92    for (v, chunk) in src.iter().zip(dst.chunks_mut(2)) {
93        chunk[0] = enc_nibble_lower(v >> 4);
94        chunk[1] = enc_nibble_lower(v & 0x0f);
95    }
96    Ok(())
97}
98
99/// Encodes `src` to `dst` as lowercase hexadecimal in constant
100/// time and returns the number of bytes written.
101pub fn ct_write_lower<W>(dst: &mut W, src: &[u8]) -> Result<(), fmt::Error>
102where
103    W: fmt::Write,
104{
105    // The implementation is taken from
106    // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
107
108    for v in src {
109        dst.write_char(enc_nibble_lower(v >> 4) as char)?;
110        dst.write_char(enc_nibble_lower(v & 0x0f) as char)?;
111    }
112    Ok(())
113}
114
115/// Encodes `src` to `dst` as uppercase hexadecimal in constant
116/// time and returns the number of bytes written.
117pub fn ct_write_upper<W>(dst: &mut W, src: &[u8]) -> Result<(), fmt::Error>
118where
119    W: fmt::Write,
120{
121    // The implementation is taken from
122    // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
123
124    for v in src {
125        dst.write_char(enc_nibble_upper(v >> 4) as char)?;
126        dst.write_char(enc_nibble_upper(v & 0x0f) as char)?;
127    }
128    Ok(())
129}
130
131/// Encodes a nibble as lowercase hexadecimal.
132#[inline(always)]
133const fn enc_nibble_lower(c: u8) -> u8 {
134    let c = c as u16;
135    c.wrapping_add(87)
136        .wrapping_add((c.wrapping_sub(10) >> 8) & !38) as u8
137}
138
139/// Encodes a nibble as uppercase hexadecimal.
140#[inline(always)]
141const fn enc_nibble_upper(c: u8) -> u8 {
142    let c = enc_nibble_lower(c);
143    c ^ ((c & 0x40) >> 1)
144}
145
146/// Returned by [`ct_decode`] when one of the following occur:
147///
148/// - `src` is not a multiple of two.
149/// - `dst` is not at least half as long as `src`.
150/// - `src` contains invalid hexadecimal characters.
151#[derive(Clone, Debug, thiserror::Error)]
152#[error("invalid hexadecimal encoding: {0}")]
153pub struct InvalidEncoding(&'static str);
154
155/// Decodes `src` into `dst` from hexadecimal in constant time
156/// and returns the number of bytes written.
157///
158/// * The length of `src` must be a multiple of two.
159/// * `dst` must be half as long (or longer) as `src`.
160pub fn ct_decode(dst: &mut [u8], src: &[u8]) -> Result<usize, InvalidEncoding> {
161    // The implementation is taken from
162    // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
163
164    if src.len() % 2 != 0 {
165        return Err(InvalidEncoding("`src` length not a multiple of two"));
166    }
167    if src.len() / 2 > dst.len() {
168        return Err(InvalidEncoding(
169            "`dst` length not at least half as long as `src`",
170        ));
171    }
172
173    let mut valid = Choice::from(1u8);
174    for (src, dst) in src.chunks_exact(2).zip(dst.iter_mut()) {
175        let (hi, hi_ok) = dec_nibble(src[0]);
176        let (lo, lo_ok) = dec_nibble(src[1]);
177
178        valid &= hi_ok & lo_ok;
179
180        let val = (hi << 4) | (lo & 0x0f);
181        // Out of paranoia, do not update `dst` if `valid` is
182        // false.
183        *dst = u8::conditional_select(dst, &val, valid);
184    }
185    if bool::from(valid) {
186        Ok(src.len() / 2)
187    } else {
188        Err(InvalidEncoding(
189            "`src` contains invalid hexadecimal characters",
190        ))
191    }
192}
193
194/// Decode a nibble from a hexadecimal character.
195#[inline(always)]
196fn dec_nibble(c: u8) -> (u8, Choice) {
197    let c = u16::from(c);
198    // Is c in '0' ... '9'?
199    //
200    // This is equivalent to
201    //
202    //    let mut n = c ^ b'0';
203    //    if n < 10 {
204    //        val = n;
205    //    }
206    //
207    // which is correct because
208    //     y^(16*i) < 10 ∀ y ∈ [y, y+10)
209    // and '0' == 48.
210    let num = c ^ u16::from(b'0');
211    // If `num` < 10, subtracting 10 produces the two's
212    // complement which flips the bits in [15:4] (which are all
213    // zero because `num` < 10) to all one. Shifting by 8 then
214    // ensures that bits [7:0] are all set to one, resulting
215    // in 0xff.
216    //
217    // If `num` >= 10, subtracting 10 doesn't set any bits in
218    // [15:8] (which are all zero because `c` < 256) and shifting
219    // by 8 shifts off any set bits, resulting in 0x00.
220    let num_ok = num.wrapping_sub(10) >> 8;
221
222    // Is c in 'a' ... 'f' or 'A' ... 'F'?
223    //
224    // This is equivalent to
225    //
226    //    const MASK: u32 = ^(1<<5); // 0b11011111
227    //    let a = c&MASK;
228    //    if a >= b'A' && a < b'F' {
229    //        val = a-55;
230    //    }
231    //
232    // The only difference between each uppercase and
233    // lowercase ASCII pair ('a'-'A', 'e'-'E', etc.) is 32,
234    // or bit #5. Masking that bit off folds the lowercase
235    // letters into uppercase. The the range check should
236    // then be obvious. Subtracting 55 converts the
237    // hexadecimal character to binary by making 'A' = 10,
238    // 'B' = 11, etc.
239    let alpha = (c & !32).wrapping_sub(55);
240    // If `alpha` is in [10, 15], subtracting 10 results in the
241    // correct binary number, less 10. Notably, the bits in
242    // [15:4] are all zero.
243    //
244    // If `alpha` is in [10, 15], subtracting 16 returns the
245    // two's complement, flipping the bits in [15:4] (which
246    // are all zero because `alpha` <= 15) to one.
247    //
248    // If `alpha` is in [10, 15], `(alpha-10)^(alpha-16)` sets
249    // the bits in [15:4] to one. Otherwise, if `alpha` <= 9 or
250    // `alpha` >= 16, both halves of the XOR have the same bits
251    // in [15:4], so the XOR sets them to zero.
252    //
253    // We shift away the irrelevant bits in [3:0], leaving only
254    // the interesting bits from the XOR.
255    let alpha_ok = (alpha.wrapping_sub(10) ^ alpha.wrapping_sub(16)) >> 8;
256
257    // Bits [3:0] are either 0xf or 0x0.
258    let ok = Choice::from(((num_ok ^ alpha_ok) & 1) as u8);
259
260    // For both `num_ok` and `alpha_ok` the bits in [3:0] are
261    // either 0xf or 0x0. Therefore, the bits in [3:0] are either
262    // `num` or `alpha`. The bits in [7:4] are (as mentioned
263    // above), either 0xf or 0x0.
264    //
265    // Bits [15:4] are irrelevant and should be all zero.
266    let result = ((num_ok & num) | (alpha_ok & alpha)) & 0xf;
267
268    (result as u8, ok)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    fn from_hex_char(c: u8) -> Option<u8> {
276        match c {
277            b'0'..=b'9' => Some(c.wrapping_sub(b'0')),
278            b'a'..=b'f' => Some(c.wrapping_sub(b'a').wrapping_add(10)),
279            b'A'..=b'F' => Some(c.wrapping_sub(b'A').wrapping_add(10)),
280            _ => None,
281        }
282    }
283
284    fn valid_hex_char(c: u8) -> bool {
285        from_hex_char(c).is_some()
286    }
287
288    fn must_from_hex_char(c: u8) -> u8 {
289        from_hex_char(c).expect("should be a valid hex char")
290    }
291
292    /// Test every single byte.
293    #[test]
294    fn test_encode_lower_exhaustive() {
295        for i in 0..256 {
296            const TABLE: &[u8] = b"0123456789abcdef";
297            let want = [TABLE[i >> 4], TABLE[i & 0x0f]];
298            let got = [
299                enc_nibble_lower((i as u8) >> 4),
300                enc_nibble_lower((i as u8) & 0x0f),
301            ];
302            assert_eq!(want, got, "#{i}");
303        }
304    }
305
306    /// Test every single byte.
307    #[test]
308    fn test_encode_upper_exhaustive() {
309        for i in 0..256 {
310            const TABLE: &[u8] = b"0123456789ABCDEF";
311            let want = [TABLE[i >> 4], TABLE[i & 0x0f]];
312            let got = [
313                enc_nibble_upper((i as u8) >> 4),
314                enc_nibble_upper((i as u8) & 0x0f),
315            ];
316            assert_eq!(want, got, "#{i}");
317        }
318    }
319
320    /// Test every single hex character pair (fe, bb, a1, ...).
321    #[test]
322    fn test_decode_exhaustive() {
323        for i in u16::MIN..=u16::MAX {
324            let ci = i as u8;
325            let cj = (i >> 8) as u8;
326            let mut dst = [0u8; 1];
327            let src = &[ci, cj];
328            let res = ct_decode(&mut dst, src);
329            if valid_hex_char(ci) && valid_hex_char(cj) {
330                #[allow(clippy::panic)]
331                let n = res.unwrap_or_else(|_| {
332                    panic!("#{i}: should be able to decode pair '{ci:x}{cj:x}'")
333                });
334                assert_eq!(n, 1, "#{i}: {ci:x}{cj:x}");
335                let want = (must_from_hex_char(ci) << 4) | must_from_hex_char(cj);
336                assert_eq!(&dst, &[want], "#{i}: {ci:x}{cj:x}");
337            } else {
338                res.expect_err(&format!("#{i}: should not have decoded pair '{src:?}'"));
339                assert_eq!(&dst, &[0], "#{i}: {src:?}");
340            }
341        }
342    }
343}