#[cfg(llvm_ir_check)]
extern crate core;
use core::hint::black_box;
#[cfg(not(llvm_ir_check))]
use crate::opaque_res::Res;
#[cfg(llvm_ir_check)]
pub struct Res(pub bool);
#[cfg(not(llvm_ir_check))]
use crate::buf::InvalidSize;
#[cfg(llvm_ir_check)]
pub struct InvalidSize;
macro_rules! smear {
($b:ident) => {{
$b |= $b >> 1;
$b |= $b >> 2;
$b |= $b >> 4;
$b |= $b >> 8;
$b |= $b >> 16;
}};
}
#[cfg_attr(llvm_ir_check, no_mangle)]
pub const fn gt(left: u32, right: u32) -> u32 {
let gtb = left & !right;
let mut ltb = !left & right;
smear!(ltb);
let mut bit = gtb & !ltb;
smear!(bit);
bit & 1
}
#[inline(always)]
const fn create_mask(overflow: u32) -> u32 {
!overflow.wrapping_neg()
}
#[inline(always)]
const fn mask_add(left: u32, right: u32, mask: u32) -> u32 {
left.wrapping_add(right & mask)
}
#[cfg_attr(not(llvm_ir_check), inline)]
#[cfg_attr(llvm_ir_check, no_mangle)]
#[cfg_attr(not(feature = "allow-non-fips"), allow(dead_code))]
pub fn add_no_wrap(a: u32, b: u32) -> (u32, Res) {
let overflow = gt(b, u32::MAX.wrapping_sub(a));
let sum = mask_add(a, b, black_box(create_mask(overflow)));
(sum, Res(overflow as u8 == 0))
}
#[inline(always)]
fn volatile(byte: u8) -> u8 {
unsafe { core::ptr::read_volatile(&byte) }
}
#[inline(always)]
fn eq_hsb(xor: u8) -> u8 {
volatile(xor | volatile(xor.wrapping_neg())) >> 7
}
#[cfg_attr(not(llvm_ir_check), inline(always))]
#[cfg_attr(llvm_ir_check, no_mangle)]
pub fn byte_eq(a: u8, b: u8) -> u8 {
eq_hsb(b ^ a) ^ volatile(1)
}
macro_rules! unroll_ct_cmp {
(g2 $start:expr, $result:ident, $left:ident, $right:ident) => {{
$result &= byte_eq(*$left.get_unchecked($start), *$right.get_unchecked($start));
$result &= byte_eq(
*$left.get_unchecked($start.wrapping_add(1)),
*$right.get_unchecked($start.wrapping_add(1))
);
}};
(g4 $start:expr, $result:ident, $left:ident, $right:ident) => {
unroll_ct_cmp!(g2 $start, $result, $left, $right);
unroll_ct_cmp!(g2 $start.wrapping_add(2) as usize, $result, $left, $right);
};
(g8 $start:expr, $result:ident, $left:ident, $right:ident) => {{
unroll_ct_cmp!(g4 $start, $result, $left, $right);
unroll_ct_cmp!(g4 $start + 4, $result, $left, $right);
}};
(g16 $start:expr, $result:ident, $left:ident, $right:ident) => {{
unroll_ct_cmp!(g8 $start, $result, $left, $right);
unroll_ct_cmp!(g8 $start + 8, $result, $left, $right);
}}
}
#[cfg_attr(llvm_ir_check, no_mangle)]
pub unsafe fn cmp_bytes_4_unchecked(mut res: u8, a: &[u8], b: &[u8]) -> u8 {
unroll_ct_cmp!(g4 0usize, res, a, b);
res
}
#[cfg_attr(llvm_ir_check, no_mangle)]
#[must_use]
pub fn cmp_slice(a: &[u8], b: &[u8]) -> u8 {
if a.len() != b.len() { return 0 }
let mut rem = a.len();
let mut res = volatile(1u8);
while rem >= 4 {
debug_assert!(rem <= a.len());
let next = rem.wrapping_sub(4);
res &= unsafe {
cmp_bytes_4_unchecked(
res,
a.get_unchecked(next..rem),
b.get_unchecked(next..rem)
)
};
rem = next;
}
match rem {
3 => unsafe {
res &= byte_eq(*a.get_unchecked(0), *b.get_unchecked(0));
unroll_ct_cmp!(g2 rem.wrapping_sub(2), res, a, b);
}
2 => unsafe { unroll_ct_cmp!(g2 rem.wrapping_sub(2), res, a, b) },
1 => unsafe { res &= byte_eq(*a.get_unchecked(0), *b.get_unchecked(0)) },
_ => {}
}
res
}
#[must_use]
pub fn ct_eq<A: AsRef<[u8]>, B: AsRef<[u8]>>(a: A, b: B) -> bool {
cmp_slice(a.as_ref(), b.as_ref()) != 0
}
#[must_use]
#[inline]
pub const fn hex_encode_len(len: usize) -> usize {
len << 1
}
#[inline]
fn encode_byte(byte: u8, output: &mut [u8]) {
let lower = (byte & 0xf) as u32;
let upper = (byte >> 4) as u32;
let h =
87u32.wrapping_add(lower)
.wrapping_add(lower.wrapping_sub(10u32).wrapping_shr(8) & !38u32)
.wrapping_shl(8)
|
87u32.wrapping_add(upper)
.wrapping_add(upper.wrapping_sub(10u32).wrapping_shr(8) & !38u32);
output[0] = h as u8;
output[1] = h.wrapping_shr(8) as u8;
}
#[cfg_attr(llvm_ir_check, no_mangle)]
pub fn hex_encode(input: &[u8], output: &mut [u8]) -> Result<usize, InvalidSize> {
let hex_len = hex_encode_len(input.len());
#[cfg(any(check, kani, test))] {
ensure!((hex_len == 0) <==> (input.len() == 0));
ensure!((hex_len != 0) <==> (input.len() != 0));
}
if output.len() < hex_len { return Err(InvalidSize) }
#[cfg(any(check, kani, test))]
let mut post_len = 0usize;
for (pos, byte) in input.iter().enumerate() {
let o_pos = pos.wrapping_shl(1);
#[cfg(any(check, kani, test))] {
ensure!((pos != 0) ==> (is_valid_hex_2(output[o_pos - 2], output[o_pos - 1])));
post_len = o_pos + 1;
}
encode_byte(*byte, &mut output[o_pos..o_pos + 2]);
}
#[cfg(any(check, kani, test))] {
ensure!((hex_len != 0) ==> (is_valid_hex_2(output[post_len - 1], output[post_len])));
ensure!((hex_len != 0) ==> (post_len + 1 == hex_len));
}
Ok(hex_len)
}
#[inline]
pub fn hex_encode_str<'o>(input: &[u8], output: &'o mut [u8]) -> Result<&'o str, InvalidSize> {
hex_encode(input, output)
.map(move |len| unsafe { core::str::from_utf8_unchecked(&output[..len]) })
}
#[cfg(not(llvm_ir_check))]
alloc! {
#[allow(clippy::missing_panics_doc)] pub fn hex_encode_alloc(input: &[u8]) -> alloc::string::String {
let mut output = vec![0u8; hex_encode_len(input.len())];
hex_encode(input, output.as_mut_slice()).unwrap();
unsafe {
alloc::string::String::from_utf8_unchecked(output)
}
}
}
#[inline]
const fn decode_nibble(first: u8) -> u16 {
let byte = first as i16;
let mut ret: i16 = -1;
ret = ret.wrapping_add(
(0x2fi16.wrapping_sub(byte) & byte.wrapping_sub(0x3a)).wrapping_shr(8)
& byte.wrapping_sub(47)
);
ret = ret.wrapping_add(
(0x60i16.wrapping_sub(byte) & byte.wrapping_sub(0x67)).wrapping_shr(8)
& byte.wrapping_sub(86)
);
ret as u16
}
#[must_use]
const fn decode_predicate(inp_len: usize, out_len: usize) -> (bool, usize) {
let dec_len = inp_len >> 1;
(inp_len & 1 == 0 && out_len >= dec_len, dec_len)
}
pub enum HexError {
Encoding,
Size
}
impl From<InvalidSize> for HexError {
fn from(_value: InvalidSize) -> Self {
Self::Size
}
}
#[cfg(not(llvm_ir_check))]
impl From<HexError> for crate::Unspecified {
fn from(_value: HexError) -> Self {
Self
}
}
#[cfg(not(llvm_ir_check))]
impl core::fmt::Display for HexError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Size => f.write_str("HexError::Size"),
Self::Encoding => f.write_str("HexError::Encoding")
}
}
}
#[cfg(not(llvm_ir_check))]
impl core::fmt::Debug for HexError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<Self as core::fmt::Display>::fmt(self, f)
}
}
#[cfg(not(llvm_ir_check))]
std! { impl std::error::Error for HexError {} }
pub fn hex_decode(input: &[u8], output: &mut [u8]) -> Result<usize, HexError> {
let (valid_len, dec_len) = decode_predicate(input.len(), output.len());
if !valid_len { return Err(HexError::Size) }
let mut err: u16 = 0;
for (pos, o_byte) in output.iter_mut().enumerate().take(dec_len) {
let src_pos = pos << 1;
let byte = decode_nibble(input[src_pos]).wrapping_shl(4)
| decode_nibble(input[src_pos + 1]);
err |= byte >> 8;
*o_byte = byte as u8;
}
if err == 0 {
Ok(dec_len)
} else {
Err(HexError::Encoding)
}
}
#[cfg(not(llvm_ir_check))]
alloc! {
pub fn hex_decode_alloc(input: &[u8]) -> Result<Vec<u8>, HexError> {
let mut output = vec![0u8; input.len() >> 1];
hex_decode(input, output.as_mut_slice()).map(move |_| output)
}
}
#[cfg(any(test, kani, check))]
const fn is_valid_hex(byte: u8) -> bool {
matches!(
byte,
b'A'..=b'Z'
| b'a'..=b'z'
| b'0'..=b'9'
)
}
#[cfg(any(test, kani, check))]
const fn is_valid_hex_2(first: u8, second: u8) -> bool {
is_valid_hex(first) && is_valid_hex(second)
}
#[cfg(any(kani, check))]
const fn is_valid_hex_4(a: u8, b: u8, c: u8, d: u8) -> bool {
is_valid_hex_2(a, b) && is_valid_hex_2(c, d)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cmp_slice_eq_smoke() {
let a = [3u8; 17];
let b = [3u8; 17];
assert_eq!(cmp_slice(a.as_slice(), b.as_slice()), 1);
}
macro_rules! str {
($expr:expr) => {
core::str::from_utf8($expr).unwrap()
};
}
#[test]
fn hex_encode_works() {
let mut out = [0u8; 22];
let len = hex_encode(b"hello world", &mut out).unwrap();
assert_eq!(len, 22);
assert_eq!(str!(&out), "68656c6c6f20776f726c64");
}
#[test]
fn hex_encode_to_decode() {
let mut out = [0u8; 22];
let _len = hex_encode(b"hello world", &mut out).unwrap();
let mut dec = [0u8; 11];
let read = hex_decode(&out, &mut dec).unwrap();
assert_eq!(read, 11);
assert_eq!(&dec, b"hello world");
}
#[test]
fn invalid_hex() {
let mut out = [0; 69];
assert!(hex_decode(b"hello world I am not valid hex !!!", &mut out).is_err());
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
use crate::aes::test_utils::BoundList;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100_000))]
#[test]
fn enusre_ct_add_no_wrap(a in any::<u32>(), b in any::<u32>()) {
let (out, res) = add_no_wrap(a, b);
ensure!(( res.is_err() ) <==> ( a.checked_add(b).is_none() ));
ensure!(( out == a ) <==> ( res.is_err() || b == 0 ));
ensure!(( res.is_ok() ) <==> ( out != a || b == 0 ));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10_000))]
#[test]
fn ensure_cmp_slice(a in any::<BoundList<1024>>(), b in any::<BoundList<1024>>()) {
let ct_res = cmp_slice(a.as_slice(), b.as_slice()) == 1;
let res = a.as_slice() == b.as_slice();
ensure!((ct_res) <==> (res));
ensure!((!ct_res) <==> (!res));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50_000))]
#[test]
fn hex_encode_is_hex_crate(
bin in any::<BoundList<1024>>()
) {
let output_len = hex_encode_len(bin.len());
let mut output = BoundList::<2048>::new_zeroes(output_len);
let res = hex::encode(bin.as_slice());
let len = hex_encode(bin.as_slice(), output.as_mut_slice()).unwrap();
prop_assert_eq!(res.len(), len);
prop_assert_eq!(res.as_bytes(), output.as_slice());
}
#[test]
fn hex_is_bijective(
bin in any::<BoundList<1024>>()
) {
let output_len = hex_encode_len(bin.len());
let mut output = BoundList::<2048>::new_zeroes(output_len);
let len = hex_encode(bin.as_slice(), output.as_mut_slice()).unwrap();
prop_assert_eq!(len, output.len());
let mut decoded = bin.create_self();
let len = hex_decode(output.as_slice(), decoded.as_mut_slice()).unwrap();
prop_assert_eq!(len, bin.len());
prop_assert_eq!(decoded.as_slice(), bin.as_slice());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1_000_000))]
#[test]
fn hex_encode_is_valid_utf8(
bin in any::<BoundList<256>>()
) {
let output_len = hex_encode_len(bin.len());
let mut output = BoundList::<512>::new_zeroes(output_len);
hex_encode(bin.as_slice(), output.as_mut_slice()).unwrap();
prop_assert!(core::str::from_utf8(output.as_slice()).is_ok());
}
}
}
#[cfg(kani)]
mod verify {
use super::*;
use kani::proof;
#[proof]
fn check_ct_add_no_wrap() {
let a = kani::any();
let b = kani::any();
let (out, res) = add_no_wrap(a, b);
ensure!(( res.is_err() ) <==> ( a.checked_add(b).is_none() ));
ensure!(( out == a ) <==> ( res.is_err() || b == 0 ));
ensure!(( res.is_ok() ) <==> ( out != a || b == 0 ));
}
#[proof]
fn check_ct_gt() {
let a = kani::any();
let b = kani::any();
let is_gt = gt(a, b) == 1;
ensure!((is_gt) <==> (a > b));
ensure!((!is_gt) <==> (a <= b));
}
#[proof]
#[kani::unwind(12)]
fn check_slice_cmp() {
let a = kani::vec::any_vec::<u8, 7>();
let b = kani::vec::any_vec::<u8, 7>();
let res = a == b;
let ct_res = cmp_slice(a.as_slice(), b.as_slice()) == 1;
ensure!((res) <==> (ct_res));
ensure!((!res) <==> (!ct_res));
}
#[proof]
fn encode_byte_is_always_utf8() {
let byte: u8 = kani::any();
let mut out = [0u8; 2];
encode_byte(byte, &mut out);
kani::assert(
is_valid_hex_2(out[0], out[1]),
"For all bytes, the output must always be valid UTF8"
);
}
#[proof]
fn verify_encode_n_plus_1_bytes_symbolic() {
let byte: u8 = kani::any();
let mut out = [0u8; 4];
encode_byte(byte, &mut out);
let byte: u8 = kani::any();
encode_byte(byte, &mut out[2..]);
kani::assert(
is_valid_hex_2(out[0], out[1]) && is_valid_hex_2(out[2], out[3]),
"Encoding an additional byte preserves UTF-8 validity"
);
}
#[proof]
fn verify_hex_encode_inductive_step_symbolic() {
let input: [u8; 6] = kani::any();
let mut output = [0u8; 12];
hex_encode(&input, &mut output).unwrap();
kani::assert(
is_valid_hex_4(output[0], output[1], output[2], output[3])
&& is_valid_hex_4(output[4], output[5], output[6], output[7])
&& is_valid_hex_4(output[8], output[9], output[10], output[11]),
"Encoding an additional byte preserves UTF-8 validity"
)
}
}