use derive_deftly::{Deftly, define_derive_deftly};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use zeroize::Zeroize;
#[cfg(feature = "memquota-memcost")]
use tor_memquota_cost::derive_deftly_template_HasMemoryCost;
define_derive_deftly! {
export ConstantTimeEq for struct:
impl<$tgens> ConstantTimeEq for $ttype
where $twheres
$( $ftype : ConstantTimeEq , )
{
fn ct_eq(&self, other: &Self) -> subtle::Choice {
match (self, other) {
$(
(${vpat fprefix=self_}, ${vpat fprefix=other_}) => {
$(
$<self_ $fname>.ct_eq($<other_ $fname>) &
)
subtle::Choice::from(1)
},
)
}
}
}
}
define_derive_deftly! {
export PartialEqFromCtEq:
impl<$tgens> PartialEq for $ttype
where $twheres
$ttype : ConstantTimeEq
{
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
}
pub(crate) use {derive_deftly_template_ConstantTimeEq, derive_deftly_template_PartialEqFromCtEq};
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Clone, Copy, Debug, Hash, Zeroize)]
#[cfg_attr(
feature = "memquota-memcost",
derive(Deftly),
derive_deftly(HasMemoryCost)
)]
pub struct CtByteArray<const N: usize>([u8; N]);
impl<const N: usize> ConstantTimeEq for CtByteArray<N> {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
impl<const N: usize> PartialEq for CtByteArray<N> {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl<const N: usize> Eq for CtByteArray<N> {}
impl<const N: usize> From<[u8; N]> for CtByteArray<N> {
fn from(value: [u8; N]) -> Self {
Self(value)
}
}
impl<const N: usize> From<CtByteArray<N>> for [u8; N] {
fn from(value: CtByteArray<N>) -> Self {
value.0
}
}
impl<const N: usize> Ord for CtByteArray<N> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let mut first_nonzero_difference = 0_i16;
for (a, b) in self.0.iter().zip(other.0.iter()) {
let difference = i16::from(*a) - i16::from(*b);
first_nonzero_difference
.conditional_assign(&difference, first_nonzero_difference.ct_eq(&0));
}
first_nonzero_difference.cmp(&0)
}
}
impl<const N: usize> PartialOrd for CtByteArray<N> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<const N: usize> AsRef<[u8; N]> for CtByteArray<N> {
fn as_ref(&self) -> &[u8; N] {
&self.0
}
}
impl<const N: usize> AsMut<[u8; N]> for CtByteArray<N> {
fn as_mut(&mut self) -> &mut [u8; N] {
&mut self.0
}
}
pub fn ct_lookup<T, F>(array: &[T], matches: F) -> Option<&T>
where
F: Fn(&T) -> Choice,
{
let mut idx: u64 = 0;
let mut found: Choice = 0.into();
for (i, x) in array.iter().enumerate() {
let equal = matches(x);
idx.conditional_assign(&(i as u64), equal);
found.conditional_assign(&equal, equal);
}
if found.into() {
Some(&array[idx as usize])
} else {
None
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use rand::Rng;
use tor_basic_utils::test_rng;
#[allow(clippy::nonminimal_bool)]
#[test]
fn test_comparisons() {
let num = 200;
let mut rng = test_rng::testing_rng();
let mut array: Vec<CtByteArray<32>> =
(0..num).map(|_| rng.random::<[u8; 32]>().into()).collect();
array.sort();
for i in 0..num {
assert_eq!(array[i], array[i]);
assert!(!(array[i] < array[i]));
assert!(!(array[i] > array[i]));
for j in (i + 1)..num {
assert!(array[i] < array[j]);
assert_ne!(array[i], array[j]);
assert!(array[j] > array[i]);
assert_eq!(
array[i].cmp(&array[j]),
array[j].as_ref().cmp(array[i].as_ref()).reverse()
);
}
}
}
#[test]
fn test_lookup() {
use super::ct_lookup as lookup;
use subtle::ConstantTimeEq;
let items = vec![
"One".to_string(),
"word".to_string(),
"of".to_string(),
"every".to_string(),
"length".to_string(),
];
let of_word = lookup(&items[..], |i| i.len().ct_eq(&2));
let every_word = lookup(&items[..], |i| i.len().ct_eq(&5));
let no_word = lookup(&items[..], |i| i.len().ct_eq(&99));
assert_eq!(of_word.unwrap(), "of");
assert_eq!(every_word.unwrap(), "every");
assert_eq!(no_word, None);
}
}