use crate::bigint::BigUint;
use core::mem::ManuallyDrop;
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{compiler_fence, Ordering};
pub trait Zeroize {
fn zeroize(&mut self);
}
#[inline]
fn vol_write_byte(b: &mut u8) {
unsafe {
core::ptr::write_volatile(core::ptr::from_mut::<u8>(b), 0u8);
}
}
impl Zeroize for u8 {
fn zeroize(&mut self) {
vol_write_byte(self);
compiler_fence(Ordering::SeqCst);
}
}
impl<const N: usize> Zeroize for [u8; N] {
fn zeroize(&mut self) {
for b in self.iter_mut() {
vol_write_byte(b);
}
compiler_fence(Ordering::SeqCst);
}
}
impl Zeroize for [u8] {
fn zeroize(&mut self) {
for b in self.iter_mut() {
vol_write_byte(b);
}
compiler_fence(Ordering::SeqCst);
}
}
impl Zeroize for Vec<u8> {
fn zeroize(&mut self) {
let cap = self.capacity();
if cap > 0 {
let raw = self.as_mut_ptr();
for i in 0..cap {
unsafe { core::ptr::write_volatile(raw.add(i), 0u8) };
}
compiler_fence(Ordering::SeqCst);
}
self.clear();
}
}
impl Zeroize for BigUint {
fn zeroize(&mut self) {
let me = core::mem::replace(self, BigUint::zero());
drop(me);
}
}
impl Zeroize for Vec<BigUint> {
fn zeroize(&mut self) {
for v in self.iter_mut() {
v.zeroize();
}
self.clear();
let cap_bytes = self.capacity() * core::mem::size_of::<BigUint>();
if cap_bytes > 0 {
let raw = self.as_mut_ptr() as *mut u8;
for i in 0..cap_bytes {
unsafe { core::ptr::write_volatile(raw.add(i), 0u8) };
}
compiler_fence(Ordering::SeqCst);
}
}
}
impl Zeroize for Vec<Vec<BigUint>> {
fn zeroize(&mut self) {
for inner in self.iter_mut() {
inner.zeroize();
}
self.clear();
let cap_bytes = self.capacity() * core::mem::size_of::<Vec<BigUint>>();
if cap_bytes > 0 {
let raw = self.as_mut_ptr() as *mut u8;
for i in 0..cap_bytes {
unsafe { core::ptr::write_volatile(raw.add(i), 0u8) };
}
compiler_fence(Ordering::SeqCst);
}
}
}
pub struct Zeroizing<T: Zeroize> {
inner: ManuallyDrop<T>,
}
impl<T: Zeroize> Zeroizing<T> {
pub fn new(inner: T) -> Self {
Self {
inner: ManuallyDrop::new(inner),
}
}
}
impl<T: Zeroize> Drop for Zeroizing<T> {
fn drop(&mut self) {
self.inner.zeroize();
unsafe {
ManuallyDrop::drop(&mut self.inner);
}
}
}
impl<T: Zeroize> Deref for Zeroizing<T> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}
impl<T: Zeroize> DerefMut for Zeroizing<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T: Zeroize + core::fmt::Debug> core::fmt::Debug for Zeroizing<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("Zeroizing(<elided>)")
}
}
#[must_use]
pub fn ct_eq_biguint(a: &BigUint, b: &BigUint) -> bool {
let mut a_be = a.to_be_bytes();
let mut b_be = b.to_be_bytes();
let max_len = a_be.len().max(b_be.len());
let mut a_buf = vec![0u8; max_len];
let mut b_buf = vec![0u8; max_len];
a_buf[max_len - a_be.len()..].copy_from_slice(&a_be);
b_buf[max_len - b_be.len()..].copy_from_slice(&b_be);
a_be.zeroize();
b_be.zeroize();
let mut acc: u8 = 0;
for i in 0..max_len {
acc |= a_buf[i] ^ b_buf[i];
}
let result = acc == 0;
a_buf.zeroize();
b_buf.zeroize();
result
}
#[must_use]
pub fn ct_eq_biguint_padded(a: &BigUint, b: &BigUint, byte_width: usize) -> bool {
let mut a_buf = a.to_be_bytes_padded(byte_width);
let mut b_buf = b.to_be_bytes_padded(byte_width);
let mut acc: u8 = 0;
for i in 0..byte_width {
acc |= a_buf[i] ^ b_buf[i];
}
let result = acc == 0;
a_buf.zeroize();
b_buf.zeroize();
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeroize_byte_array() {
let mut a = [0xFFu8; 16];
a.zeroize();
assert_eq!(a, [0u8; 16]);
}
#[test]
fn zeroize_vec_u8_clears_and_truncates() {
let mut v = vec![0xABu8; 32];
v.zeroize();
assert!(v.is_empty(), "Zeroize<Vec<u8>> should clear the logical length");
}
#[test]
fn zeroize_vec_biguint_replaces_each_with_zero_then_clears() {
let mut v: Vec<BigUint> = (1..=4).map(BigUint::from_u64).collect();
v.zeroize();
assert!(v.is_empty());
}
#[test]
fn zeroizing_drops_with_scrub() {
{
let mut z = Zeroizing::new(vec![0xCDu8; 64]);
z[0] = 0xEF; assert_eq!(z[0], 0xEF);
}
}
#[test]
fn ct_eq_matches_eq_for_equal_values() {
for n in [0u64, 1, 0xFF, 0xFFFF, 0xCAFE_BABE_DEAD_BEEF] {
let a = BigUint::from_u64(n);
let b = BigUint::from_u64(n);
assert!(ct_eq_biguint(&a, &b), "equal values must compare true");
}
}
#[test]
fn ct_eq_distinguishes_unequal_values_of_same_bit_length() {
let a = BigUint::from_u64(0x8000_0000);
let b = BigUint::from_u64(0x8000_0001);
assert!(!ct_eq_biguint(&a, &b));
}
#[test]
fn ct_eq_handles_different_byte_lengths() {
let small = BigUint::from_u64(7);
let mut big = BigUint::one();
big.shl_bits(80);
big = big.add_ref(&small);
assert!(!ct_eq_biguint(&small, &big));
let small2 = BigUint::from_u64(7);
assert!(ct_eq_biguint(&small, &small2));
}
#[test]
fn ct_eq_padded_constant_byte_width() {
let zero = BigUint::zero();
let one = BigUint::one();
assert!(ct_eq_biguint_padded(&zero, &zero, 16));
assert!(!ct_eq_biguint_padded(&zero, &one, 16));
}
#[test]
#[should_panic(expected = "value does not fit in 8 bytes")]
fn ct_eq_padded_panics_on_oversize_operand() {
let big = {
let mut v = BigUint::one();
v.shl_bits(80);
v
};
let _ = ct_eq_biguint_padded(&big, &big, 8);
}
#[test]
fn ct_eq_handles_zero_and_one() {
assert!(ct_eq_biguint(&BigUint::zero(), &BigUint::zero()));
assert!(ct_eq_biguint(&BigUint::one(), &BigUint::one()));
assert!(!ct_eq_biguint(&BigUint::zero(), &BigUint::one()));
}
#[test]
fn end_to_end_secret_round_trip_with_security_layer() {
use crate::field::{mersenne127, PrimeField};
use crate::{cgma_vss, csprng::ChaCha20Rng, proactive, shamir, vss};
let f = PrimeField::new(mersenne127());
let mut rng = ChaCha20Rng::from_seed(&[0x99u8; 32]);
let secret = BigUint::from_u64(0xCAFE_BABE);
let shares = shamir::split(&f, &mut rng, &secret, 3, 5);
let recovered = shamir::reconstruct(&f, &shares, 3).expect("shamir round-trip");
assert!(
ct_eq_biguint(&recovered, &secret),
"shamir secret survives Zeroizing+ct_eq layer"
);
let vss_shares = vss::deal(&f, &mut rng, &secret, 3, 5);
assert!(vss::verify_consistent(&f, &vss_shares));
let vss_recovered =
vss::reconstruct(&f, &vss_shares[..3], 3).expect("vss round-trip");
assert!(ct_eq_biguint(&vss_recovered, &secret));
let group = cgma_vss::small_test_group();
let small_secret = BigUint::from_u64(7); let (cgma_shares, commits) = cgma_vss::deal(&group, &mut rng, &small_secret, 3, 5);
for s in &cgma_shares {
assert!(cgma_vss::verify_share(&group, &commits, s));
}
let cgma_recovered =
cgma_vss::reconstruct(&group, &cgma_shares[..3], 3).expect("cgma round-trip");
assert!(ct_eq_biguint(&cgma_recovered, &small_secret));
let refreshed = proactive::refresh(&f, &mut rng, &shares, 3);
let refreshed_recovered =
shamir::reconstruct(&f, &refreshed[..3], 3).expect("post-refresh round-trip");
assert!(
ct_eq_biguint(&refreshed_recovered, &secret),
"secret preserved across refresh + reconstruct"
);
}
#[test]
fn debug_does_not_leak_inner_value() {
let z = Zeroizing::new(vec![0xDEu8; 8]);
let printed = format!("{:?}", z);
assert!(printed.contains("elided"));
assert!(!printed.contains("0xDE"));
assert!(!printed.contains("222"));
}
}