use subtle::{Choice, ConstantTimeEq};
#[inline]
pub fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
#[inline]
pub fn ct_eq_32(a: &[u8; 32], b: &[u8; 32]) -> bool {
a.ct_eq(b).into()
}
#[inline]
pub fn ct_eq_64(a: &[u8; 64], b: &[u8; 64]) -> bool {
a.ct_eq(b).into()
}
#[inline]
pub fn ct_eq_slice_32(slice: &[u8], array: &[u8; 32]) -> bool {
if slice.len() != 32 {
return false;
}
slice.ct_eq(array.as_slice()).into()
}
#[inline]
pub fn ct_select<T: Copy + Default>(condition: bool, a: T, b: T) -> T {
let choice = Choice::from(condition as u8);
if choice.unwrap_u8() == 1 { a } else { b }
}
#[derive(Clone)]
pub struct SecretBytes(Vec<u8>);
impl SecretBytes {
pub fn new(bytes: Vec<u8>) -> Self {
Self(bytes)
}
pub fn from_slice(slice: &[u8]) -> Self {
Self(slice.to_vec())
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl ConstantTimeEq for SecretBytes {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
impl PartialEq for SecretBytes {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl Eq for SecretBytes {}
impl Drop for SecretBytes {
fn drop(&mut self) {
for byte in &mut self.0 {
unsafe {
std::ptr::write_volatile(byte, 0);
}
}
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
}
impl std::fmt::Debug for SecretBytes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SecretBytes([REDACTED; {} bytes])", self.0.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ct_eq_equal() {
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4, 5];
assert!(ct_eq(&a, &b));
}
#[test]
fn test_ct_eq_not_equal() {
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4, 6];
assert!(!ct_eq(&a, &b));
}
#[test]
fn test_ct_eq_different_length() {
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4];
assert!(!ct_eq(&a, &b));
}
#[test]
fn test_ct_eq_32() {
let a = [42u8; 32];
let b = [42u8; 32];
let c = [43u8; 32];
assert!(ct_eq_32(&a, &b));
assert!(!ct_eq_32(&a, &c));
}
#[test]
fn test_secret_bytes_equality() {
let a = SecretBytes::new(vec![1, 2, 3, 4]);
let b = SecretBytes::new(vec![1, 2, 3, 4]);
let c = SecretBytes::new(vec![1, 2, 3, 5]);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_secret_bytes_debug() {
let s = SecretBytes::new(vec![1, 2, 3, 4]);
let debug = format!("{:?}", s);
assert!(debug.contains("REDACTED"));
assert!(!debug.contains("1, 2, 3, 4"));
}
}