use core::{fmt, result::Result, str};
use subtle::{Choice, ConditionallySelectable};
#[derive(Copy, Clone)]
pub struct Hex<T>(T);
impl<T> Hex<T> {
pub const fn new(value: T) -> Self {
Self(value)
}
}
impl<T> fmt::Display for Hex<T>
where
T: AsRef<[u8]>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::LowerHex::fmt(self, f)
}
}
impl<T> fmt::Debug for Hex<T>
where
T: AsRef<[u8]>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::LowerHex::fmt(self, f)
}
}
impl<T> fmt::LowerHex for Hex<T>
where
T: AsRef<[u8]>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ct_write_lower(f, self.0.as_ref())
}
}
impl<T> fmt::UpperHex for Hex<T>
where
T: AsRef<[u8]>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ct_write_upper(f, self.0.as_ref())
}
}
pub trait ToHex {
type Output: AsRef<[u8]>;
fn to_hex(self) -> Hex<Self::Output>;
}
impl<T> ToHex for T
where
T: AsRef<[u8]>,
{
type Output = T;
fn to_hex(self) -> Hex<Self::Output> {
Hex::new(self)
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("invalid length")]
pub struct InvalidLength(());
pub fn ct_encode(dst: &mut [u8], src: &[u8]) -> Result<(), InvalidLength> {
if dst.len() / 2 < src.len() {
return Err(InvalidLength(()));
}
for (v, chunk) in src.iter().zip(dst.chunks_mut(2)) {
chunk[0] = enc_nibble_lower(v >> 4);
chunk[1] = enc_nibble_lower(v & 0x0f);
}
Ok(())
}
pub fn ct_write_lower<W>(dst: &mut W, src: &[u8]) -> Result<(), fmt::Error>
where
W: fmt::Write,
{
for v in src {
dst.write_char(enc_nibble_lower(v >> 4) as char)?;
dst.write_char(enc_nibble_lower(v & 0x0f) as char)?;
}
Ok(())
}
pub fn ct_write_upper<W>(dst: &mut W, src: &[u8]) -> Result<(), fmt::Error>
where
W: fmt::Write,
{
for v in src {
dst.write_char(enc_nibble_upper(v >> 4) as char)?;
dst.write_char(enc_nibble_upper(v & 0x0f) as char)?;
}
Ok(())
}
#[inline(always)]
const fn enc_nibble_lower(c: u8) -> u8 {
let c = c as u16;
c.wrapping_add(87)
.wrapping_add((c.wrapping_sub(10) >> 8) & !38) as u8
}
#[inline(always)]
const fn enc_nibble_upper(c: u8) -> u8 {
let c = enc_nibble_lower(c);
c ^ ((c & 0x40) >> 1)
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("invalid hexadecimal encoding: {0}")]
pub struct InvalidEncoding(&'static str);
pub fn ct_decode(dst: &mut [u8], src: &[u8]) -> Result<usize, InvalidEncoding> {
if src.len() % 2 != 0 {
return Err(InvalidEncoding("`src` length not a multiple of two"));
}
if src.len() / 2 > dst.len() {
return Err(InvalidEncoding(
"`dst` length not at least half as long as `src`",
));
}
let mut valid = Choice::from(1u8);
for (src, dst) in src.chunks_exact(2).zip(dst.iter_mut()) {
let (hi, hi_ok) = dec_nibble(src[0]);
let (lo, lo_ok) = dec_nibble(src[1]);
valid &= hi_ok & lo_ok;
let val = (hi << 4) | (lo & 0x0f);
*dst = u8::conditional_select(dst, &val, valid);
}
if bool::from(valid) {
Ok(src.len() / 2)
} else {
Err(InvalidEncoding(
"`src` contains invalid hexadecimal characters",
))
}
}
#[inline(always)]
fn dec_nibble(c: u8) -> (u8, Choice) {
let c = u16::from(c);
let num = c ^ u16::from(b'0');
let num_ok = num.wrapping_sub(10) >> 8;
let alpha = (c & !32).wrapping_sub(55);
let alpha_ok = (alpha.wrapping_sub(10) ^ alpha.wrapping_sub(16)) >> 8;
let ok = Choice::from(((num_ok ^ alpha_ok) & 1) as u8);
let result = ((num_ok & num) | (alpha_ok & alpha)) & 0xf;
(result as u8, ok)
}
#[cfg(test)]
mod tests {
use super::*;
fn from_hex_char(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => Some(c.wrapping_sub(b'0')),
b'a'..=b'f' => Some(c.wrapping_sub(b'a').wrapping_add(10)),
b'A'..=b'F' => Some(c.wrapping_sub(b'A').wrapping_add(10)),
_ => None,
}
}
fn valid_hex_char(c: u8) -> bool {
from_hex_char(c).is_some()
}
fn must_from_hex_char(c: u8) -> u8 {
from_hex_char(c).expect("should be a valid hex char")
}
#[test]
fn test_encode_lower_exhaustive() {
for i in 0..256 {
const TABLE: &[u8] = b"0123456789abcdef";
let want = [TABLE[i >> 4], TABLE[i & 0x0f]];
let got = [
enc_nibble_lower((i as u8) >> 4),
enc_nibble_lower((i as u8) & 0x0f),
];
assert_eq!(want, got, "#{i}");
}
}
#[test]
fn test_encode_upper_exhaustive() {
for i in 0..256 {
const TABLE: &[u8] = b"0123456789ABCDEF";
let want = [TABLE[i >> 4], TABLE[i & 0x0f]];
let got = [
enc_nibble_upper((i as u8) >> 4),
enc_nibble_upper((i as u8) & 0x0f),
];
assert_eq!(want, got, "#{i}");
}
}
#[test]
fn test_decode_exhaustive() {
for i in u16::MIN..=u16::MAX {
let ci = i as u8;
let cj = (i >> 8) as u8;
let mut dst = [0u8; 1];
let src = &[ci, cj];
let res = ct_decode(&mut dst, src);
if valid_hex_char(ci) && valid_hex_char(cj) {
#[allow(clippy::panic)]
let n = res.unwrap_or_else(|_| {
panic!("#{i}: should be able to decode pair '{ci:x}{cj:x}'")
});
assert_eq!(n, 1, "#{i}: {ci:x}{cj:x}");
let want = (must_from_hex_char(ci) << 4) | must_from_hex_char(cj);
assert_eq!(&dst, &[want], "#{i}: {ci:x}{cj:x}");
} else {
res.expect_err(&format!("#{i}: should not have decoded pair '{src:?}'"));
assert_eq!(&dst, &[0], "#{i}: {src:?}");
}
}
}
}