memsecurity/
random.rs

1use crate::MemSecurityResult;
2use borsh::{BorshDeserialize, BorshSerialize};
3use rand_chacha::ChaCha20Rng;
4use rand_core::{RngCore, SeedableRng};
5use std::ops::{Add, Sub};
6use zeroize::Zeroize;
7
8/// Generate Cryptographically secure random bytes of array size 8, 16, 24, 32 or 64
9pub struct CsprngArraySimple;
10
11impl CsprngArraySimple {
12    /// Generate an array of random bytes with maximum array size of 8
13    ///
14    /// #### Usage
15    /// ```rs
16    /// let bytes = Csprng::gen_u8_byte();
17    /// ```
18    pub fn gen_u8_byte() -> u8 {
19        CsprngArray::<1>::gen().0[0]
20    }
21    /// Generate an array of random bytes with maximum array size of 8
22    ///
23    /// #### Usage
24    /// ```rs
25    /// let bytes = Csprng::gen_u8();
26    /// assert_eq!(bytes.len(), 8);
27    /// ```
28    pub fn gen_u8_array() -> CsprngArray<8> {
29        CsprngArray::<8>::gen()
30    }
31
32    /// Generate an array of random bytes with maximum array size of 16
33    ///
34    /// #### Usage
35    /// ```rs
36    /// let bytes = CsprngArray::gen_16();
37    /// assert_eq!(bytes.len(), 16);
38    /// ```
39    pub fn gen_u16_array() -> CsprngArray<16> {
40        CsprngArray::<16>::gen()
41    }
42
43    /// Generate an array of random bytes with maximum array size of 24
44    ///
45    /// #### Usage
46    /// ```rs
47    /// let bytes = CsprngArray::gen_24();
48    /// assert_eq!(bytes.len(), 24);
49    /// ```
50    pub fn gen_u24_array() -> CsprngArray<24> {
51        CsprngArray::<24>::gen()
52    }
53
54    /// Generate an array of random bytes with maximum array size of 32
55    ///
56    /// #### Usage
57    /// ```rs
58    /// let bytes = CsprngArray::gen_32();
59    /// assert_eq!(bytes.len(), 32);
60    /// ```
61    pub fn gen_u32_array() -> CsprngArray<32> {
62        CsprngArray::<32>::gen()
63    }
64
65    /// Generate an array of random bytes with maximum array size of 64
66    ///
67    /// #### Usage
68    /// ```rs
69    /// let bytes = CsprngArray::gen_64();
70    /// assert_eq!(bytes.len(), 64);
71    /// ```
72    pub fn gen_u64_array() -> CsprngArray<64> {
73        CsprngArray::<64>::gen()
74    }
75}
76
77/// Generate Cryptographically secure random bytes of different sizes based on generic usize `N`
78/// #### Structure
79/// ```rs
80/// pub struct CsprngArray<const N: usize>([u8; N]);
81/// ```
82///
83/// #### Example
84/// ```rs
85/// let bytes = CsprngArray::<32>::gen(); // Generates 32 random bytes
86/// assert_eq!(bytes.len(), 32);
87/// ```
88#[derive(BorshSerialize, BorshDeserialize)]
89pub struct CsprngArray<const N: usize>([u8; N]);
90
91impl<const N: usize> AsRef<[u8]> for CsprngArray<N> {
92    fn as_ref(&self) -> &[u8] {
93        self.expose_borrowed()
94    }
95}
96
97impl<const N: usize> CsprngArray<N> {
98    /// Method to generate random cryptographically secure random bytes
99    /// #### Example
100    /// ```rs
101    /// let bytes = CsprngArray::<64>::gen(); // Generates 64 random bytes
102    /// assert_eq!(bytes.len(), 64);
103    /// ```
104    pub fn gen() -> Self {
105        let mut rng = ChaCha20Rng::from_entropy();
106        let mut buffer = [0u8; N];
107        rng.fill_bytes(&mut buffer);
108
109        let outcome = CsprngArray(buffer);
110
111        buffer.fill(0);
112
113        outcome
114    }
115
116    /// Copies the contents of the buffer
117    pub fn take(mut self, buffer: &mut [u8; N]) -> MemSecurityResult<()> {
118        // FIXME implement
119        let found = buffer.len();
120
121        if found != N {
122            Err(crate::MemSecurityErr::InvalidArrayLength { expected: N, found })
123        } else {
124            buffer[0..N].copy_from_slice(&self.0);
125
126            self.zeroize();
127
128            Ok(())
129        }
130    }
131
132    /// Copies the contents of the buffer
133    pub fn take_zeroize_on_error(mut self, buffer: &mut [u8; N]) -> MemSecurityResult<()> {
134        let found = buffer.len();
135
136        if found != N {
137            self.zeroize();
138
139            Err(crate::MemSecurityErr::InvalidArrayLength { expected: N, found })
140        } else {
141            buffer[0..N].copy_from_slice(&self.0);
142
143            self.zeroize();
144
145            Ok(())
146        }
147    }
148
149    /// Clone the data. Be careful with this as it retains the secret in memory.
150    /// It is recommended to call `Csprng::zeroize()` after consuming this in order to zeroize the memory
151    pub fn expose(&self) -> [u8; N] {
152        self.0
153    }
154
155    /// Clone the data. Be careful with this as it retains the secret in memory.
156    /// It is recommended to call `Csprng::zeroize()` after consuming this in order to zeroize the memory
157    pub fn expose_borrowed(&self) -> &[u8] {
158        self.0.as_ref()
159    }
160
161    /// Get the inner value of the struct. This is only available in a debug build and
162    /// is enforced by the flag `#[cfg(debug_assertions)]`
163    #[cfg(debug_assertions)]
164    pub fn dangerous_debug(&self) -> &[u8; N] {
165        &self.0
166    }
167}
168
169impl<const N: usize> Zeroize for CsprngArray<N> {
170    fn zeroize(&mut self) {
171        self.0.fill(0);
172
173        assert_eq!(self.0, [0u8; N]); //Must panic if memory cannot be zeroized
174    }
175}
176
177impl<const N: usize> core::fmt::Debug for CsprngArray<N> {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("CsprngArray(REDACTED)").finish()
180    }
181}
182
183impl<const N: usize> core::fmt::Display for CsprngArray<N> {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("CsprngArray(REDACTED)").finish()
186    }
187}
188
189impl<const N: usize> Drop for CsprngArray<N> {
190    fn drop(&mut self) {
191        self.zeroize()
192    }
193}
194
195/// Define maximum number a generic `T` can hold.
196/// This is implemented for all integer and float primitive types
197/// #### Example
198/// ```rs
199/// // The function with the enforced constraint
200/// fn foo<T: MustBeInRange>(bar: T) {
201///     // Use T as needed
202///     println!("Input: {:?}", bar);
203/// }
204///
205/// foo(42u64);
206/// foo(0u64);
207/// foo(std::u64::MAX);
208/// ```
209pub trait MinMaxNum: PartialOrd + Add + Sub + Copy {
210    /// The minimum value that can be defined
211    const MIN_VALUE: Self;
212    /// The maximum value that can be defined
213    const MAX_VALUE: Self;
214}
215
216impl MinMaxNum for u8 {
217    const MIN_VALUE: u8 = core::u8::MIN;
218    const MAX_VALUE: u8 = core::u8::MAX;
219}
220
221impl MinMaxNum for u16 {
222    const MIN_VALUE: u16 = core::u16::MIN;
223    const MAX_VALUE: u16 = core::u16::MAX;
224}
225
226impl MinMaxNum for u32 {
227    const MIN_VALUE: u32 = core::u32::MIN;
228    const MAX_VALUE: u32 = core::u32::MAX;
229}
230
231impl MinMaxNum for u64 {
232    const MIN_VALUE: u64 = core::u64::MIN;
233    const MAX_VALUE: u64 = core::u64::MAX;
234}
235
236impl MinMaxNum for u128 {
237    const MIN_VALUE: u128 = core::u128::MIN;
238    const MAX_VALUE: u128 = core::u128::MAX;
239}
240
241impl MinMaxNum for f32 {
242    const MIN_VALUE: f32 = core::f32::MIN;
243    const MAX_VALUE: f32 = core::f32::MAX;
244}
245
246impl MinMaxNum for f64 {
247    const MIN_VALUE: f64 = core::f64::MIN;
248    const MAX_VALUE: f64 = core::f64::MAX;
249}
250
251impl MinMaxNum for i8 {
252    const MIN_VALUE: i8 = core::i8::MIN;
253    const MAX_VALUE: i8 = core::i8::MAX;
254}
255
256impl MinMaxNum for i16 {
257    const MIN_VALUE: i16 = core::i16::MIN;
258    const MAX_VALUE: i16 = core::i16::MAX;
259}
260
261impl MinMaxNum for i32 {
262    const MIN_VALUE: i32 = core::i32::MIN;
263    const MAX_VALUE: i32 = core::i32::MAX;
264}
265
266impl MinMaxNum for i64 {
267    const MIN_VALUE: i64 = core::i64::MIN;
268    const MAX_VALUE: i64 = core::i64::MAX;
269}
270
271impl MinMaxNum for i128 {
272    const MIN_VALUE: i128 = core::i128::MIN;
273    const MAX_VALUE: i128 = core::i128::MAX;
274}