use core::{
fmt::{Debug, Display, Formatter},
mem::ManuallyDrop,
};
use ctutils::CtEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[inline]
unsafe fn zeroize_ptr<T>(ptr: *mut T) {
let slice = core::slice::from_raw_parts_mut(ptr as *mut u8, core::mem::size_of::<T>());
slice.zeroize();
}
pub struct Secret<T>(ManuallyDrop<T>);
impl<T> Secret<T> {
#[inline]
pub const fn new(value: T) -> Self {
Self(ManuallyDrop::new(value))
}
#[inline]
pub fn expose<R>(&self, f: impl for<'a> FnOnce(&'a T) -> R) -> R {
f(&self.0)
}
#[inline]
pub fn expose_unwrap(mut self) -> T {
let ptr = &raw mut *self.0;
let value = unsafe { ManuallyDrop::take(&mut self.0) };
core::mem::forget(self);
unsafe { zeroize_ptr(ptr) };
value
}
}
impl<T> Drop for Secret<T> {
fn drop(&mut self) {
let ptr = &raw mut *self.0;
unsafe {
ManuallyDrop::drop(&mut self.0);
zeroize_ptr(ptr);
}
}
}
impl<T> Debug for Secret<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_str("Secret([REDACTED])")
}
}
impl<T> Display for Secret<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_str("[REDACTED]")
}
}
impl<T> ZeroizeOnDrop for Secret<T> {}
impl<T: Clone> Clone for Secret<T> {
fn clone(&self) -> Self {
self.expose(|v| Self::new(v.clone()))
}
}
impl<T: CtEq> PartialEq for Secret<T> {
fn eq(&self, other: &Self) -> bool {
self.expose(|a| other.expose(|b| a.ct_eq(b).into()))
}
}
impl<T: CtEq> Eq for Secret<T> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_debug_redacted() {
let secret = Secret::new([1u8, 2, 3, 4]);
assert_eq!(format!("{:?}", secret), "Secret([REDACTED])");
}
#[test]
fn test_display_redacted() {
let secret = Secret::new([1u8, 2, 3, 4]);
assert_eq!(format!("{}", secret), "[REDACTED]");
}
#[test]
fn test_expose() {
let secret = Secret::new([1u8, 2, 3, 4]);
secret.expose(|v| {
assert_eq!(v, &[1u8, 2, 3, 4]);
});
}
#[test]
fn test_expose_unwrap() {
let secret = Secret::new([1u8, 2, 3, 4]);
let value = secret.expose_unwrap();
assert_eq!(value, [1u8, 2, 3, 4]);
}
#[test]
fn test_clone() {
let secret = Secret::new([1u8, 2, 3, 4]);
let cloned = secret.clone();
secret.expose(|a| {
cloned.expose(|b| {
assert_eq!(a, b);
});
});
}
#[test]
fn test_equality() {
let s1 = Secret::new([1u8, 2, 3, 4]);
let s2 = Secret::new([1u8, 2, 3, 4]);
let s3 = Secret::new([5u8, 6, 7, 8]);
assert_eq!(s1, s2);
assert_ne!(s1, s3);
}
#[test]
fn test_multiple_expose() {
let secret = Secret::new([42u8; 32]);
secret.expose(|v| {
assert_eq!(v[0], 42);
});
secret.expose(|v| {
assert_eq!(v[31], 42);
});
}
}