use crate::error::{Error, Result};
use crate::random::OsRng;
use rand::RngCore;
use std::ops::{Deref, DerefMut};
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, ZeroizeOnDrop)]
pub struct SecretBuffer<const N: usize> {
data: [u8; N],
}
impl<const N: usize> SecretBuffer<N> {
pub fn new() -> Self {
Self { data: [0u8; N] }
}
pub fn random() -> Self {
let mut data = [0u8; N];
OsRng.fill_bytes(&mut data);
Self { data }
}
pub fn from_array(data: [u8; N]) -> Self {
Self { data }
}
pub fn from_slice(slice: &[u8]) -> Result<Self> {
if slice.len() != N {
return Err(Error::InvalidParameter(format!(
"expected {} bytes, got {}",
N,
slice.len()
)));
}
let mut data = [0u8; N];
data.copy_from_slice(slice);
Ok(Self { data })
}
pub const fn len() -> usize {
N
}
pub fn as_array(&self) -> &[u8; N] {
&self.data
}
pub fn as_array_mut(&mut self) -> &mut [u8; N] {
&mut self.data
}
pub fn fill(&mut self, value: u8) {
self.data.fill(value);
}
pub fn copy_from(&mut self, other: &Self) {
self.data.copy_from_slice(&other.data);
}
}
impl<const N: usize> Default for SecretBuffer<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> Deref for SecretBuffer<N> {
type Target = [u8; N];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<const N: usize> DerefMut for SecretBuffer<N> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl<const N: usize> AsRef<[u8]> for SecretBuffer<N> {
fn as_ref(&self) -> &[u8] {
&self.data
}
}
impl<const N: usize> AsMut<[u8]> for SecretBuffer<N> {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.data
}
}
impl<const N: usize> std::fmt::Debug for SecretBuffer<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SecretBuffer<{}>[REDACTED]", N)
}
}
#[derive(Clone)]
pub struct SecureVec {
data: Vec<u8>,
}
impl SecureVec {
pub fn new() -> Self {
Self { data: Vec::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
}
}
pub fn zeroed(len: usize) -> Self {
Self {
data: vec![0u8; len],
}
}
pub fn random(len: usize) -> Self {
let mut data = vec![0u8; len];
OsRng.fill_bytes(&mut data);
Self { data }
}
pub fn from_vec(data: Vec<u8>) -> Self {
Self { data }
}
pub fn from_slice(slice: &[u8]) -> Self {
Self {
data: slice.to_vec(),
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn reserve(&mut self, additional: usize) {
self.data.reserve(additional);
}
pub fn push(&mut self, byte: u8) {
self.data.push(byte);
}
pub fn extend_from_slice(&mut self, slice: &[u8]) {
self.data.extend_from_slice(slice);
}
pub fn clear(&mut self) {
self.data.zeroize();
self.data.clear();
}
pub fn resize(&mut self, new_len: usize, value: u8) {
if new_len < self.data.len() {
self.data[new_len..].zeroize();
}
self.data.resize(new_len, value);
}
pub fn truncate(&mut self, len: usize) {
if len < self.data.len() {
self.data[len..].zeroize();
}
self.data.truncate(len);
}
pub fn into_vec(mut self) -> Vec<u8> {
std::mem::take(&mut self.data)
}
pub fn split_off(&mut self, at: usize) -> Self {
Self {
data: self.data.split_off(at),
}
}
pub fn expose(&self) -> &[u8] {
&self.data
}
pub fn expose_mut(&mut self) -> &mut [u8] {
&mut self.data
}
}
pub type SecretBytes = SecureVec;
impl Default for SecureVec {
fn default() -> Self {
Self::new()
}
}
impl Drop for SecureVec {
fn drop(&mut self) {
self.data.zeroize();
}
}
impl Deref for SecureVec {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl DerefMut for SecureVec {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl AsRef<[u8]> for SecureVec {
fn as_ref(&self) -> &[u8] {
&self.data
}
}
impl AsMut<[u8]> for SecureVec {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.data
}
}
impl std::fmt::Debug for SecureVec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SecureVec[{} bytes, REDACTED]", self.data.len())
}
}
impl From<Vec<u8>> for SecureVec {
fn from(data: Vec<u8>) -> Self {
Self::from_vec(data)
}
}
impl From<&[u8]> for SecureVec {
fn from(slice: &[u8]) -> Self {
Self::from_slice(slice)
}
}
impl<const N: usize> From<[u8; N]> for SecureVec {
fn from(array: [u8; N]) -> Self {
Self::from_slice(&array)
}
}
impl FromIterator<u8> for SecureVec {
fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
Self {
data: iter.into_iter().collect(),
}
}
}
#[cfg(feature = "std")]
pub struct GuardedBuffer {
data: SecureVec,
}
#[cfg(feature = "std")]
impl GuardedBuffer {
pub fn new(size: usize) -> Self {
Self {
data: SecureVec::zeroed(size),
}
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secret_buffer() {
let mut buf = SecretBuffer::<32>::new();
buf.fill(0xAB);
assert_eq!(buf.as_array(), &[0xAB; 32]);
}
#[test]
fn test_secret_buffer_from_slice() {
let data = [1u8; 16];
let buf = SecretBuffer::<16>::from_slice(&data).unwrap();
assert_eq!(buf.as_array(), &data);
assert!(SecretBuffer::<32>::from_slice(&data).is_err());
}
#[test]
fn test_secure_vec() {
let mut vec = SecureVec::new();
vec.extend_from_slice(&[1, 2, 3, 4]);
assert_eq!(vec.len(), 4);
assert_eq!(&*vec, &[1, 2, 3, 4]);
vec.clear();
assert!(vec.is_empty());
}
#[test]
fn test_secure_vec_truncate() {
let mut vec = SecureVec::from_slice(&[1, 2, 3, 4, 5]);
vec.truncate(3);
assert_eq!(vec.len(), 3);
assert_eq!(&*vec, &[1, 2, 3]);
}
#[test]
fn test_secret_buffer_random() {
let buf1 = SecretBuffer::<32>::random();
let buf2 = SecretBuffer::<32>::random();
assert_ne!(buf1.as_array(), buf2.as_array());
}
#[test]
fn test_secure_vec_random() {
let vec1 = SecureVec::random(32);
let vec2 = SecureVec::random(32);
assert_ne!(vec1.expose(), vec2.expose());
assert_eq!(vec1.len(), 32);
}
#[test]
fn test_secret_bytes_alias() {
let key: SecretBytes = SecretBytes::random(32);
assert_eq!(key.len(), 32);
assert_eq!(key.expose().len(), 32);
}
}