use crate::MemSecurityResult;
use borsh::{BorshDeserialize, BorshSerialize};
use rand_chacha::ChaCha20Rng;
use rand_core::{RngCore, SeedableRng};
use std::ops::{Add, Sub};
use zeroize::Zeroize;
pub struct CsprngArraySimple;
impl CsprngArraySimple {
pub fn gen_u8_byte() -> u8 {
CsprngArray::<1>::gen().0[0]
}
pub fn gen_u8_array() -> CsprngArray<8> {
CsprngArray::<8>::gen()
}
pub fn gen_u16_array() -> CsprngArray<16> {
CsprngArray::<16>::gen()
}
pub fn gen_u24_array() -> CsprngArray<24> {
CsprngArray::<24>::gen()
}
pub fn gen_u32_array() -> CsprngArray<32> {
CsprngArray::<32>::gen()
}
pub fn gen_u64_array() -> CsprngArray<64> {
CsprngArray::<64>::gen()
}
}
#[derive(BorshSerialize, BorshDeserialize)]
pub struct CsprngArray<const N: usize>([u8; N]);
impl<const N: usize> AsRef<[u8]> for CsprngArray<N> {
fn as_ref(&self) -> &[u8] {
self.expose_borrowed()
}
}
impl<const N: usize> CsprngArray<N> {
pub fn gen() -> Self {
let mut rng = ChaCha20Rng::from_entropy();
let mut buffer = [0u8; N];
rng.fill_bytes(&mut buffer);
let outcome = CsprngArray(buffer);
buffer.fill(0);
outcome
}
pub fn take(mut self, buffer: &mut [u8; N]) -> MemSecurityResult<()> {
let found = buffer.len();
if found != N {
Err(crate::MemSecurityErr::InvalidArrayLength { expected: N, found })
} else {
buffer[0..N].copy_from_slice(&self.0);
self.zeroize();
Ok(())
}
}
pub fn take_zeroize_on_error(mut self, buffer: &mut [u8; N]) -> MemSecurityResult<()> {
let found = buffer.len();
if found != N {
self.zeroize();
Err(crate::MemSecurityErr::InvalidArrayLength { expected: N, found })
} else {
buffer[0..N].copy_from_slice(&self.0);
self.zeroize();
Ok(())
}
}
pub fn expose(&self) -> [u8; N] {
self.0
}
pub fn expose_borrowed(&self) -> &[u8] {
self.0.as_ref()
}
#[cfg(debug_assertions)]
pub fn dangerous_debug(&self) -> &[u8; N] {
&self.0
}
}
impl<const N: usize> Zeroize for CsprngArray<N> {
fn zeroize(&mut self) {
self.0.fill(0);
assert_eq!(self.0, [0u8; N]); }
}
impl<const N: usize> core::fmt::Debug for CsprngArray<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CsprngArray(REDACTED)").finish()
}
}
impl<const N: usize> core::fmt::Display for CsprngArray<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CsprngArray(REDACTED)").finish()
}
}
impl<const N: usize> Drop for CsprngArray<N> {
fn drop(&mut self) {
self.zeroize()
}
}
pub trait MinMaxNum: PartialOrd + Add + Sub + Copy {
const MIN_VALUE: Self;
const MAX_VALUE: Self;
}
impl MinMaxNum for u8 {
const MIN_VALUE: u8 = core::u8::MIN;
const MAX_VALUE: u8 = core::u8::MAX;
}
impl MinMaxNum for u16 {
const MIN_VALUE: u16 = core::u16::MIN;
const MAX_VALUE: u16 = core::u16::MAX;
}
impl MinMaxNum for u32 {
const MIN_VALUE: u32 = core::u32::MIN;
const MAX_VALUE: u32 = core::u32::MAX;
}
impl MinMaxNum for u64 {
const MIN_VALUE: u64 = core::u64::MIN;
const MAX_VALUE: u64 = core::u64::MAX;
}
impl MinMaxNum for u128 {
const MIN_VALUE: u128 = core::u128::MIN;
const MAX_VALUE: u128 = core::u128::MAX;
}
impl MinMaxNum for f32 {
const MIN_VALUE: f32 = core::f32::MIN;
const MAX_VALUE: f32 = core::f32::MAX;
}
impl MinMaxNum for f64 {
const MIN_VALUE: f64 = core::f64::MIN;
const MAX_VALUE: f64 = core::f64::MAX;
}
impl MinMaxNum for i8 {
const MIN_VALUE: i8 = core::i8::MIN;
const MAX_VALUE: i8 = core::i8::MAX;
}
impl MinMaxNum for i16 {
const MIN_VALUE: i16 = core::i16::MIN;
const MAX_VALUE: i16 = core::i16::MAX;
}
impl MinMaxNum for i32 {
const MIN_VALUE: i32 = core::i32::MIN;
const MAX_VALUE: i32 = core::i32::MAX;
}
impl MinMaxNum for i64 {
const MIN_VALUE: i64 = core::i64::MIN;
const MAX_VALUE: i64 = core::i64::MAX;
}
impl MinMaxNum for i128 {
const MIN_VALUE: i128 = core::i128::MIN;
const MAX_VALUE: i128 = core::i128::MAX;
}