use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
const CHUNK_SIZE: usize = 256;
const FIXED_ITERATIONS: usize = 8;
const FIXED_WORK_SIZE: usize = CHUNK_SIZE * FIXED_ITERATIONS;
pub fn ash_timing_safe_equal(a: &[u8], b: &[u8]) -> bool {
let len_a = a.len() as u64;
let len_b = b.len() as u64;
let lengths_equal: Choice = len_a.ct_eq(&len_b);
let min_len = std::cmp::min(a.len(), b.len());
let capped_min_len = std::cmp::min(min_len, FIXED_WORK_SIZE);
let capped_a_len = std::cmp::min(a.len(), FIXED_WORK_SIZE);
let capped_b_len = std::cmp::min(b.len(), FIXED_WORK_SIZE);
let mut result = Choice::from(1u8);
for i in 0..FIXED_ITERATIONS {
let pos = i * CHUNK_SIZE;
let mut padded_a = [0u8; CHUNK_SIZE];
let mut padded_b = [0u8; CHUNK_SIZE];
if pos < capped_a_len {
let a_end = std::cmp::min(pos + CHUNK_SIZE, capped_a_len);
let copy_len = a_end - pos;
padded_a[..copy_len].copy_from_slice(&a[pos..a_end]);
}
if pos < capped_b_len {
let b_end = std::cmp::min(pos + CHUNK_SIZE, capped_b_len);
let copy_len = b_end - pos;
padded_b[..copy_len].copy_from_slice(&b[pos..b_end]);
}
let chunk_cmp = padded_a.ct_eq(&padded_b);
let in_range = Choice::from((pos < capped_min_len) as u8);
let combined = result & chunk_cmp;
result = Choice::conditional_select(&result, &combined, in_range);
}
(lengths_equal & result).into()
}
#[inline]
#[allow(dead_code)]
pub fn ash_timing_safe_equal_fixed_length(a: &[u8], b: &[u8]) -> bool {
debug_assert_eq!(
a.len(),
b.len(),
"timing_safe_equal_fixed_length called with different lengths"
);
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
pub fn ash_timing_safe_compare(a: &str, b: &str) -> bool {
ash_timing_safe_equal(a.as_bytes(), b.as_bytes())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timing_safe_equal_same() {
let a = b"hello world";
let b = b"hello world";
assert!(ash_timing_safe_equal(a, b));
}
#[test]
fn test_timing_safe_equal_different() {
let a = b"hello world";
let b = b"hello worle";
assert!(!ash_timing_safe_equal(a, b));
}
#[test]
fn test_timing_safe_equal_different_length() {
let a = b"hello";
let b = b"hello world";
assert!(!ash_timing_safe_equal(a, b));
}
#[test]
fn test_timing_safe_equal_empty() {
let a = b"";
let b = b"";
assert!(ash_timing_safe_equal(a, b));
}
#[test]
fn test_ash_timing_safe_compare() {
assert!(ash_timing_safe_compare("test", "test"));
assert!(!ash_timing_safe_compare("test", "Test"));
}
#[test]
fn test_timing_safe_equal_fixed_length() {
let a = b"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let b = b"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let c = b"d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592";
assert!(ash_timing_safe_equal_fixed_length(a, b));
assert!(!ash_timing_safe_equal_fixed_length(a, c));
}
#[test]
fn test_empty_vs_nonempty() {
let a = b"";
let b = b"x";
assert!(!ash_timing_safe_equal(a, b));
assert!(!ash_timing_safe_equal(b, a));
}
#[test]
fn test_single_byte_difference() {
let a = b"aaaaaaaaaa";
let b = b"aaaaaaaaab";
assert!(!ash_timing_safe_equal(a, b));
}
}