sodoken/
locked_mem.rs

1use std::io::{ErrorKind, Result};
2
3/// Locked memory.
4pub struct LockedArray(usize, *mut libc::c_void);
5
6// the sodium_malloc-ed c_void is safe to Send
7unsafe impl Send for LockedArray {}
8
9// note - it might technically be ok it implement Sync,
10// but not to lock/unlock it across threads... so we
11// need to keep the lock function `&mut self` so the
12// Sync wouldn't help usage in any way.
13
14impl Drop for LockedArray {
15    fn drop(&mut self) {
16        unsafe {
17            libsodium_sys::sodium_free(self.1);
18        }
19    }
20}
21
22impl LockedArray {
23    /// Create a new locked memory buffer.
24    pub fn new(size: usize) -> Result<Self> {
25        crate::sodium_init();
26
27        let z = unsafe {
28            // sodium_malloc requires memory-aligned sizes,
29            // round up to the nearest 8 bytes.
30            let align_size = (size + 7) & !7;
31            let z = libsodium_sys::sodium_malloc(align_size);
32            if z.is_null() {
33                return Err(ErrorKind::OutOfMemory.into());
34            }
35            libsodium_sys::sodium_memzero(z, align_size);
36            libsodium_sys::sodium_mprotect_noaccess(z);
37            z
38        };
39
40        Ok(Self(size, z))
41    }
42
43    /// Get access to this memory.
44    pub fn lock(&mut self) -> LockedArrayGuard<'_> {
45        LockedArrayGuard::new(self)
46    }
47}
48
49impl From<Vec<u8>> for LockedArray {
50    fn from(s: Vec<u8>) -> Self {
51        let s = zeroize::Zeroizing::new(s);
52
53        let mut out = Self::new(s.len()).unwrap();
54        out.lock().copy_from_slice(s.as_slice());
55
56        out
57    }
58}
59
60impl From<Box<[u8]>> for LockedArray {
61    fn from(value: Box<[u8]>) -> Self {
62        let value = zeroize::Zeroizing::new(value);
63
64        let mut out = Self::new(value.len()).unwrap();
65        out.lock().copy_from_slice(value.as_ref());
66
67        out
68    }
69}
70
71/// Locked memory that is unlocked for access.
72pub struct LockedArrayGuard<'g>(&'g mut LockedArray);
73
74impl<'g> LockedArrayGuard<'g> {
75    fn new(l: &'g mut LockedArray) -> Self {
76        unsafe {
77            libsodium_sys::sodium_mprotect_readwrite(l.1);
78        }
79        Self(l)
80    }
81}
82
83impl Drop for LockedArrayGuard<'_> {
84    fn drop(&mut self) {
85        unsafe {
86            libsodium_sys::sodium_mprotect_noaccess(self.0 .1);
87        }
88    }
89}
90
91impl std::ops::Deref for LockedArrayGuard<'_> {
92    type Target = [u8];
93
94    fn deref(&self) -> &Self::Target {
95        unsafe { std::slice::from_raw_parts(self.0 .1 as *const u8, self.0 .0) }
96    }
97}
98
99impl std::ops::DerefMut for LockedArrayGuard<'_> {
100    fn deref_mut(&mut self) -> &mut Self::Target {
101        unsafe {
102            std::slice::from_raw_parts_mut(self.0 .1 as *mut u8, self.0 .0)
103        }
104    }
105}
106
107/// Locked memory with a size known at compile time.
108pub struct SizedLockedArray<const N: usize>(LockedArray);
109
110impl<const N: usize> SizedLockedArray<N> {
111    /// Create a new locked memory buffer.
112    pub fn new() -> Result<Self> {
113        Ok(Self(LockedArray::new(N)?))
114    }
115
116    /// Get access to this memory.
117    pub fn lock(&mut self) -> SizedLockedArrayGuard<'_, N> {
118        SizedLockedArrayGuard::new(self)
119    }
120}
121
122/// Locked memory that is unlocked for access.
123pub struct SizedLockedArrayGuard<'g, const N: usize>(LockedArrayGuard<'g>);
124
125impl<'g, const N: usize> SizedLockedArrayGuard<'g, N> {
126    fn new(l: &'g mut SizedLockedArray<N>) -> Self {
127        Self(l.0.lock())
128    }
129}
130
131impl<const N: usize> std::ops::Deref for SizedLockedArrayGuard<'_, N> {
132    type Target = [u8; N];
133
134    fn deref(&self) -> &Self::Target {
135        unsafe {
136            &*(std::slice::from_raw_parts(self.0 .0 .1 as *const u8, N)[..N]
137                .as_ptr() as *const [u8; N])
138        }
139    }
140}
141
142impl<const N: usize> std::ops::DerefMut for SizedLockedArrayGuard<'_, N> {
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        unsafe {
145            &mut *(std::slice::from_raw_parts_mut(self.0 .0 .1 as *mut u8, N)
146                [..N]
147                .as_mut_ptr() as *mut [u8; N])
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn use_locked_array_as_array() {
158        let mut locked = LockedArray::new(32).unwrap();
159        locked.lock().copy_from_slice(&[5; 32]);
160
161        fn call_me(input: &[u8]) -> u8 {
162            input[0]
163        }
164        assert_eq!(5, call_me(&*locked.lock()));
165
166        fn call_me_mut(input: &mut [u8]) {
167            input[0] = 8;
168        }
169        call_me_mut(&mut *locked.lock());
170
171        assert_eq!(8, call_me(&*locked.lock()));
172    }
173
174    #[test]
175    fn use_sized_locked_array_as_array() {
176        let mut locked = SizedLockedArray::<32>::new().unwrap();
177        locked.lock().copy_from_slice(&[5; 32]);
178
179        fn call_me(input: &[u8; 32]) -> u8 {
180            input[0]
181        }
182        assert_eq!(5, call_me(&*locked.lock()));
183
184        fn call_me_mut(input: &mut [u8; 32]) {
185            input[0] = 8;
186        }
187        call_me_mut(&mut *locked.lock());
188
189        assert_eq!(8, call_me(&*locked.lock()));
190    }
191
192    // This test relies on being able to read back memory that has been unallocated.
193    // If this breaks, it may be best to remove it.
194    #[test]
195    fn clears_input_buffer_from_vec() {
196        let mut input = vec![1, 2, 3];
197        input.resize(3, 0);
198
199        let ptr = input.as_ptr();
200
201        let mut locked = LockedArray::from(input);
202
203        let vec = unsafe { Vec::from_raw_parts(ptr as *mut u8, 3, 3) };
204
205        assert_eq!(vec, vec![0, 0, 0]);
206        std::mem::forget(vec);
207
208        assert_eq!(&*locked.lock(), &[1, 2, 3]);
209    }
210
211    // This test relies on being able to read back memory that has been unallocated.
212    // If this breaks, it may be best to remove it.
213    #[test]
214    fn clears_input_buffer_from_box() {
215        let input: Box<[u8]> = vec![1, 2, 3].into();
216
217        let ptr = input.as_ptr();
218
219        let mut locked = LockedArray::from(input);
220
221        let vec = unsafe { Box::from_raw(ptr as *mut [u8; 3]) };
222
223        assert_eq!(vec, Box::new([0, 0, 0]));
224        std::mem::forget(vec);
225
226        assert_eq!(&*locked.lock(), &[1, 2, 3]);
227    }
228}