use crate::core::app_errors::{GenerationError, Result};
use rand::{CryptoRng, Rng};
use subtle::{Choice, ConstantTimeEq};
pub struct TimingSafeOps;
impl TimingSafeOps {
const HYBRID_ATTEMPTS: usize = 8;
pub fn constant_time_select(chars: &[char], index: usize) -> Option<char> {
if chars.is_empty() {
return None;
}
let len = chars.len();
let safe_index = index % len;
let mut result = chars[0];
for (i, &ch) in chars.iter().enumerate() {
let is_target = Choice::from((i == safe_index) as u8);
result = Self::conditional_select_char(result, ch, is_target);
}
Some(result)
}
fn conditional_select_char(a: char, b: char, choice: Choice) -> char {
let a_val = a as u32;
let b_val = b as u32;
let mask = u32::from(choice.unwrap_u8()).wrapping_neg();
let result = (a_val & !mask) | (b_val & mask);
char::from_u32(result).unwrap_or(a)
}
pub fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return Self::constant_time_compare_bytes(a.as_bytes(), b.as_bytes());
}
a.as_bytes().ct_eq(b.as_bytes()).unwrap_u8() == 1
}
fn constant_time_compare_bytes(a: &[u8], b: &[u8]) -> bool {
let len = a.len().max(b.len());
let mut result = Choice::from(1u8);
result &= a.len().ct_eq(&b.len());
for i in 0..len {
let a_byte = if i < a.len() { a[i] } else { 0 };
let b_byte = if i < b.len() { b[i] } else { 0 };
result &= a_byte.ct_eq(&b_byte);
}
result.unwrap_u8() == 1
}
pub fn add_timing_noise() {
let mut dummy = 0u64;
for _ in 0..4 {
dummy = dummy
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
std::hint::black_box(dummy);
}
}
pub fn secure_random_index<R>(rng: &mut R, max: usize) -> Result<usize>
where
R: Rng + CryptoRng + ?Sized,
{
if max == 0 {
return Ok(0);
}
let mask = match max.checked_next_power_of_two() {
Some(power) => power - 1,
None => usize::MAX,
};
let mut selected = 0usize;
let mut selected_flag = 0u8;
for _ in 0..Self::HYBRID_ATTEMPTS {
let random_value = Self::next_random_usize(rng)?;
let candidate = random_value & mask;
let valid = (candidate < max) as u8;
let take = valid & (selected_flag ^ 1);
let choose_mask = (take as usize).wrapping_neg();
selected = (selected & !choose_mask) | (candidate & choose_mask);
selected_flag |= valid;
}
if selected_flag == 1 {
return Ok(selected);
}
loop {
let random_value = Self::next_random_usize(rng)?;
let candidate = random_value & mask;
if candidate < max {
return Ok(candidate);
}
}
}
fn next_random_usize<R>(rng: &mut R) -> Result<usize>
where
R: Rng + CryptoRng + ?Sized,
{
const USIZE_BYTES: usize = std::mem::size_of::<usize>();
let mut bytes = [0u8; USIZE_BYTES];
rng.try_fill_bytes(&mut bytes)
.map_err(|err| GenerationError::RngFailure(err.to_string()))?;
Ok(usize::from_le_bytes(bytes))
}
}
impl TimingSafeOps {
pub fn constant_time_concat(s1: &str, s2: &str, max_len: usize) -> String {
let mut result = String::with_capacity(max_len);
let mut iter = s1.chars().chain(s2.chars());
for _ in 0..max_len {
if let Some(ch) = iter.next() {
result.push(ch);
} else {
let _ = result.capacity();
}
}
result
}
pub fn secure_shuffle<T, R>(items: &mut [T], rng: &mut R) -> Result<()>
where
R: Rng + CryptoRng + ?Sized,
{
let len = items.len();
for i in (1..len).rev() {
let j = TimingSafeOps::secure_random_index(rng, i + 1)?;
items.swap(i, j);
}
Ok(())
}
}