seckey/
bytes.rs

1use core::fmt;
2use core::ptr::NonNull;
3use core::ops::{ Deref, DerefMut };
4
5#[cfg(feature = "use_os")]
6use core::cell::Cell;
7
8#[cfg(feature = "use_os")]
9use memsec::{ mprotect, Prot };
10
11
12#[cfg(feature = "use_os")]
13mod alloc {
14    use std::ptr::NonNull;
15
16    #[inline]
17    pub unsafe fn malloc_sized(size: usize) -> Option<NonNull<u8>> {
18        let memptr = memsec::malloc_sized(size)?;
19        Some(memptr.cast())
20    }
21
22    pub unsafe fn free(memptr: NonNull<u8>, _size: usize) {
23        memsec::free(memptr);
24    }
25}
26
27#[cfg(not(feature = "use_os"))]
28mod alloc {
29    use std::ptr::NonNull;
30    use std::alloc::Layout;
31
32    #[inline]
33    pub unsafe fn malloc_sized(size: usize) -> Option<NonNull<u8>> {
34        NonNull::new(std::alloc::alloc(Layout::from_size_align_unchecked(size, 1)))
35    }
36
37    #[inline]
38    pub unsafe fn free(memptr: NonNull<u8>, size: usize) {
39        std::alloc::dealloc(memptr.as_ptr(), Layout::from_size_align_unchecked(size, 1));
40    }
41}
42
43pub struct SecBytes {
44    ptr: NonNull<u8>,
45    len: usize,
46
47    #[cfg(feature = "use_os")]
48    count: Cell<usize>
49}
50
51// Safety: It is safe to make SecBytes sendable because `ptr` is only used
52//         by us and it doesn't have any thread specific behavior.
53unsafe impl Send for SecBytes {}
54
55impl SecBytes {
56    pub fn new(len: usize) -> SecBytes {
57        fn id(_: &mut [u8]) {}
58
59        SecBytes::with(len, id)
60    }
61
62    pub fn with<F>(len: usize, f: F) -> SecBytes
63        where F: FnOnce(&mut [u8])
64    {
65        let ptr = unsafe {
66            let memptr = alloc::malloc_sized(len).expect("seckey alloc failed");
67
68            {
69                let arr = std::slice::from_raw_parts_mut(memptr.as_ptr(), len);
70                f(arr);
71            }
72
73            #[cfg(feature = "use_os")]
74            mprotect(memptr, Prot::NoAccess);
75
76            memptr
77        };
78
79        SecBytes {
80            ptr, len,
81
82            #[cfg(feature = "use_os")]
83            count: Cell::new(0)
84        }
85    }
86
87    /// Borrow Read
88    ///
89    /// ```
90    /// use seckey::SecBytes;
91    ///
92    /// let secpass = SecBytes::with(8, |buf| buf.copy_from_slice(&[8u8; 8][..]));
93    /// assert_eq!([8u8; 8], *secpass.read());
94    /// ```
95    #[cfg_attr(not(feature = "use_os"), inline)]
96    pub fn read(&self) -> SecReadGuard<'_> {
97        #[cfg(feature = "use_os")] {
98            let count = self.count.get();
99            self.count.set(count + 1);
100            if count == 0 {
101                unsafe { mprotect(self.ptr, Prot::ReadOnly) };
102            }
103        }
104
105        SecReadGuard(self)
106    }
107
108    /// Borrow Write
109    ///
110    /// ```
111    /// # use seckey::SecBytes;
112    /// #
113    /// # let mut secpass = SecBytes::with(8, |buf| buf.copy_from_slice(&[8u8; 8][..]));
114    /// let mut wpass = secpass.write();
115    /// wpass[0] = 0;
116    /// assert_eq!([0, 8, 8, 8, 8, 8, 8, 8], *wpass);
117    /// ```
118    #[cfg_attr(not(feature = "use_os"), inline)]
119    pub fn write(&mut self) -> SecWriteGuard<'_> {
120        #[cfg(feature = "use_os")]
121        unsafe {
122            mprotect(self.ptr, Prot::ReadWrite)
123        };
124
125        SecWriteGuard(self)
126    }
127}
128
129impl fmt::Debug for SecBytes {
130    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131        f.debug_tuple("SecBytes")
132            .field(&format_args!("{:p}", self.ptr))
133            .finish()
134    }
135}
136
137impl fmt::Pointer for SecBytes {
138    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
139        write!(f, "{:p}", self.ptr)
140    }
141}
142
143impl Drop for SecBytes {
144    fn drop(&mut self) {
145        unsafe {
146            #[cfg(feature = "use_os")]
147            mprotect(self.ptr, Prot::ReadWrite);
148
149            alloc::free(self.ptr, self.len);
150        }
151    }
152}
153
154
155/// Read Guard
156pub struct SecReadGuard<'a>(&'a SecBytes);
157
158impl<'a> Deref for SecReadGuard<'a> {
159    type Target = [u8];
160
161    #[inline]
162    fn deref(&self) -> &[u8] {
163        unsafe {
164            std::slice::from_raw_parts(self.0.ptr.as_ptr(), self.0.len)
165        }
166    }
167}
168
169impl<'a> Drop for SecReadGuard<'a> {
170    fn drop(&mut self) {
171        #[cfg(feature = "use_os")]
172        unsafe {
173            let count = self.0.count.get();
174            self.0.count.set(count - 1);
175            if count <= 1 {
176                mprotect(self.0.ptr, Prot::NoAccess);
177            }
178        }
179    }
180}
181
182
183/// Write Guard
184pub struct SecWriteGuard<'a>(&'a mut SecBytes);
185
186impl<'a> Deref for SecWriteGuard<'a> {
187    type Target = [u8];
188
189    #[inline]
190    fn deref(&self) -> &[u8] {
191        unsafe {
192            std::slice::from_raw_parts(self.0.ptr.as_ptr(), self.0.len)
193        }
194    }
195}
196
197impl<'a> DerefMut for SecWriteGuard<'a> {
198    #[inline]
199    fn deref_mut(&mut self) -> &mut [u8] {
200        unsafe {
201            std::slice::from_raw_parts_mut(self.0.ptr.as_ptr(), self.0.len)
202        }
203    }
204}
205
206impl<'a> Drop for SecWriteGuard<'a> {
207    fn drop(&mut self) {
208        #[cfg(feature = "use_os")]
209        unsafe {
210            mprotect(self.0.ptr, Prot::NoAccess);
211        }
212    }
213}