1use std::io::{ErrorKind, Result};
2
3pub struct LockedArray(usize, *mut libc::c_void);
5
6unsafe impl Send for LockedArray {}
8
9impl Drop for LockedArray {
15 fn drop(&mut self) {
16 unsafe {
17 libsodium_sys::sodium_free(self.1);
18 }
19 }
20}
21
22impl LockedArray {
23 pub fn new(size: usize) -> Result<Self> {
25 crate::sodium_init();
26
27 let z = unsafe {
28 let align_size = (size + 7) & !7;
31 let z = libsodium_sys::sodium_malloc(align_size);
32 if z.is_null() {
33 return Err(ErrorKind::OutOfMemory.into());
34 }
35 libsodium_sys::sodium_memzero(z, align_size);
36 libsodium_sys::sodium_mprotect_noaccess(z);
37 z
38 };
39
40 Ok(Self(size, z))
41 }
42
43 pub fn lock(&mut self) -> LockedArrayGuard<'_> {
45 LockedArrayGuard::new(self)
46 }
47}
48
49impl From<Vec<u8>> for LockedArray {
50 fn from(s: Vec<u8>) -> Self {
51 let s = zeroize::Zeroizing::new(s);
52
53 let mut out = Self::new(s.len()).unwrap();
54 out.lock().copy_from_slice(s.as_slice());
55
56 out
57 }
58}
59
60impl From<Box<[u8]>> for LockedArray {
61 fn from(value: Box<[u8]>) -> Self {
62 let value = zeroize::Zeroizing::new(value);
63
64 let mut out = Self::new(value.len()).unwrap();
65 out.lock().copy_from_slice(value.as_ref());
66
67 out
68 }
69}
70
71pub struct LockedArrayGuard<'g>(&'g mut LockedArray);
73
74impl<'g> LockedArrayGuard<'g> {
75 fn new(l: &'g mut LockedArray) -> Self {
76 unsafe {
77 libsodium_sys::sodium_mprotect_readwrite(l.1);
78 }
79 Self(l)
80 }
81}
82
83impl Drop for LockedArrayGuard<'_> {
84 fn drop(&mut self) {
85 unsafe {
86 libsodium_sys::sodium_mprotect_noaccess(self.0 .1);
87 }
88 }
89}
90
91impl std::ops::Deref for LockedArrayGuard<'_> {
92 type Target = [u8];
93
94 fn deref(&self) -> &Self::Target {
95 unsafe { std::slice::from_raw_parts(self.0 .1 as *const u8, self.0 .0) }
96 }
97}
98
99impl std::ops::DerefMut for LockedArrayGuard<'_> {
100 fn deref_mut(&mut self) -> &mut Self::Target {
101 unsafe {
102 std::slice::from_raw_parts_mut(self.0 .1 as *mut u8, self.0 .0)
103 }
104 }
105}
106
107pub struct SizedLockedArray<const N: usize>(LockedArray);
109
110impl<const N: usize> SizedLockedArray<N> {
111 pub fn new() -> Result<Self> {
113 Ok(Self(LockedArray::new(N)?))
114 }
115
116 pub fn lock(&mut self) -> SizedLockedArrayGuard<'_, N> {
118 SizedLockedArrayGuard::new(self)
119 }
120}
121
122pub struct SizedLockedArrayGuard<'g, const N: usize>(LockedArrayGuard<'g>);
124
125impl<'g, const N: usize> SizedLockedArrayGuard<'g, N> {
126 fn new(l: &'g mut SizedLockedArray<N>) -> Self {
127 Self(l.0.lock())
128 }
129}
130
131impl<const N: usize> std::ops::Deref for SizedLockedArrayGuard<'_, N> {
132 type Target = [u8; N];
133
134 fn deref(&self) -> &Self::Target {
135 unsafe {
136 &*(std::slice::from_raw_parts(self.0 .0 .1 as *const u8, N)[..N]
137 .as_ptr() as *const [u8; N])
138 }
139 }
140}
141
142impl<const N: usize> std::ops::DerefMut for SizedLockedArrayGuard<'_, N> {
143 fn deref_mut(&mut self) -> &mut Self::Target {
144 unsafe {
145 &mut *(std::slice::from_raw_parts_mut(self.0 .0 .1 as *mut u8, N)
146 [..N]
147 .as_mut_ptr() as *mut [u8; N])
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn use_locked_array_as_array() {
158 let mut locked = LockedArray::new(32).unwrap();
159 locked.lock().copy_from_slice(&[5; 32]);
160
161 fn call_me(input: &[u8]) -> u8 {
162 input[0]
163 }
164 assert_eq!(5, call_me(&*locked.lock()));
165
166 fn call_me_mut(input: &mut [u8]) {
167 input[0] = 8;
168 }
169 call_me_mut(&mut *locked.lock());
170
171 assert_eq!(8, call_me(&*locked.lock()));
172 }
173
174 #[test]
175 fn use_sized_locked_array_as_array() {
176 let mut locked = SizedLockedArray::<32>::new().unwrap();
177 locked.lock().copy_from_slice(&[5; 32]);
178
179 fn call_me(input: &[u8; 32]) -> u8 {
180 input[0]
181 }
182 assert_eq!(5, call_me(&*locked.lock()));
183
184 fn call_me_mut(input: &mut [u8; 32]) {
185 input[0] = 8;
186 }
187 call_me_mut(&mut *locked.lock());
188
189 assert_eq!(8, call_me(&*locked.lock()));
190 }
191
192 #[test]
195 fn clears_input_buffer_from_vec() {
196 let mut input = vec![1, 2, 3];
197 input.resize(3, 0);
198
199 let ptr = input.as_ptr();
200
201 let mut locked = LockedArray::from(input);
202
203 let vec = unsafe { Vec::from_raw_parts(ptr as *mut u8, 3, 3) };
204
205 assert_eq!(vec, vec![0, 0, 0]);
206 std::mem::forget(vec);
207
208 assert_eq!(&*locked.lock(), &[1, 2, 3]);
209 }
210
211 #[test]
214 fn clears_input_buffer_from_box() {
215 let input: Box<[u8]> = vec![1, 2, 3].into();
216
217 let ptr = input.as_ptr();
218
219 let mut locked = LockedArray::from(input);
220
221 let vec = unsafe { Box::from_raw(ptr as *mut [u8; 3]) };
222
223 assert_eq!(vec, Box::new([0, 0, 0]));
224 std::mem::forget(vec);
225
226 assert_eq!(&*locked.lock(), &[1, 2, 3]);
227 }
228}