use subtle::ConstantTimeEq;
#[inline]
fn u64x4_to_bytes(a: u64, b: u64, c: u64, d: u64) -> [u8; 32] {
let a_bytes = a.to_le_bytes();
let b_bytes = b.to_le_bytes();
let c_bytes = c.to_le_bytes();
let d_bytes = d.to_le_bytes();
[
a_bytes[0], a_bytes[1], a_bytes[2], a_bytes[3], a_bytes[4], a_bytes[5], a_bytes[6],
a_bytes[7], b_bytes[0], b_bytes[1], b_bytes[2], b_bytes[3], b_bytes[4], b_bytes[5],
b_bytes[6], b_bytes[7], c_bytes[0], c_bytes[1], c_bytes[2], c_bytes[3], c_bytes[4],
c_bytes[5], c_bytes[6], c_bytes[7], d_bytes[0], d_bytes[1], d_bytes[2], d_bytes[3],
d_bytes[4], d_bytes[5], d_bytes[6], d_bytes[7],
]
}
pub fn constant_time_checksum_eq(
computed: &crate::integrity::checksum::Checksum,
stored: &crate::integrity::checksum::Checksum,
) -> bool {
let computed_bytes: [u8; 32] = u64x4_to_bytes(
computed.first(),
computed.second(),
computed.third(),
computed.fourth(),
);
let stored_bytes: [u8; 32] = u64x4_to_bytes(
stored.first(),
stored.second(),
stored.third(),
stored.fourth(),
);
computed_bytes.ct_eq(&stored_bytes).into()
}
pub fn constant_time_u64_array_eq(computed: &[u64; 4], stored: &[u64; 4]) -> bool {
let computed_bytes: [u8; 32] =
u64x4_to_bytes(computed[0], computed[1], computed[2], computed[3]);
let stored_bytes: [u8; 32] = u64x4_to_bytes(stored[0], stored[1], stored[2], stored[3]);
computed_bytes.ct_eq(&stored_bytes).into()
}
pub fn constant_time_bytes_eq(a: &[u8], b: &[u8]) -> bool {
assert_eq!(
a.len(),
b.len(),
"constant_time_bytes_eq: slices must have equal length"
);
a.ct_eq(b).into()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_u64_array_eq_matching() {
let a = [
0x1234567890abcdef,
0xfedcba0987654321,
0x1111222233334444,
0x5555666677778888,
];
let b = [
0x1234567890abcdef,
0xfedcba0987654321,
0x1111222233334444,
0x5555666677778888,
];
assert!(constant_time_u64_array_eq(&a, &b));
}
#[test]
fn test_constant_time_u64_array_eq_mismatch_first() {
let a = [
0xFFFFFFFFFFFFFFFF,
0xfedcba0987654321,
0x1111222233334444,
0x5555666677778888,
];
let b = [
0x1234567890abcdef,
0xfedcba0987654321,
0x1111222233334444,
0x5555666677778888,
];
assert!(!constant_time_u64_array_eq(&a, &b));
}
#[test]
fn test_constant_time_u64_array_eq_mismatch_last() {
let a = [
0x1234567890abcdef,
0xfedcba0987654321,
0x1111222233334444,
0x5555666677778888,
];
let b = [
0x1234567890abcdef,
0xfedcba0987654321,
0x1111222233334444,
0xFFFFFFFFFFFFFFFF,
];
assert!(!constant_time_u64_array_eq(&a, &b));
}
#[test]
fn test_constant_time_bytes_eq_matching() {
let a = b"constant_time_comparison_test";
let b = b"constant_time_comparison_test";
assert!(constant_time_bytes_eq(a, b));
}
#[test]
fn test_constant_time_bytes_eq_mismatch() {
let a = b"constant_time_compare_test_1";
let b = b"different_compare_test_data2";
assert!(!constant_time_bytes_eq(a, b));
}
#[test]
#[should_panic(expected = "slices must have equal length")]
fn test_constant_time_bytes_eq_length_mismatch() {
let a = b"short";
let b = b"much_longer_string";
constant_time_bytes_eq(a, b);
}
#[test]
fn test_constant_time_comparison_is_not_short_circuit() {
let a = [0u8; 32];
let mut b = [0u8; 32];
b[0] = 0xFF;
let result: bool = a.ct_eq(&b).into();
assert!(!result);
b[0] = 0;
b[31] = 0xFF;
let result: bool = a.ct_eq(&b).into();
assert!(!result);
}
}