1use core::fmt;
2use core::ptr::NonNull;
3use core::ops::{ Deref, DerefMut };
4
5#[cfg(feature = "use_os")]
6use core::cell::Cell;
7
8#[cfg(feature = "use_os")]
9use memsec::{ mprotect, Prot };
10
11
12#[cfg(feature = "use_os")]
13mod alloc {
14 use std::ptr::NonNull;
15
16 #[inline]
17 pub unsafe fn malloc_sized(size: usize) -> Option<NonNull<u8>> {
18 let memptr = memsec::malloc_sized(size)?;
19 Some(memptr.cast())
20 }
21
22 pub unsafe fn free(memptr: NonNull<u8>, _size: usize) {
23 memsec::free(memptr);
24 }
25}
26
27#[cfg(not(feature = "use_os"))]
28mod alloc {
29 use std::ptr::NonNull;
30 use std::alloc::Layout;
31
32 #[inline]
33 pub unsafe fn malloc_sized(size: usize) -> Option<NonNull<u8>> {
34 NonNull::new(std::alloc::alloc(Layout::from_size_align_unchecked(size, 1)))
35 }
36
37 #[inline]
38 pub unsafe fn free(memptr: NonNull<u8>, size: usize) {
39 std::alloc::dealloc(memptr.as_ptr(), Layout::from_size_align_unchecked(size, 1));
40 }
41}
42
43pub struct SecBytes {
44 ptr: NonNull<u8>,
45 len: usize,
46
47 #[cfg(feature = "use_os")]
48 count: Cell<usize>
49}
50
51unsafe impl Send for SecBytes {}
54
55impl SecBytes {
56 pub fn new(len: usize) -> SecBytes {
57 fn id(_: &mut [u8]) {}
58
59 SecBytes::with(len, id)
60 }
61
62 pub fn with<F>(len: usize, f: F) -> SecBytes
63 where F: FnOnce(&mut [u8])
64 {
65 let ptr = unsafe {
66 let memptr = alloc::malloc_sized(len).expect("seckey alloc failed");
67
68 {
69 let arr = std::slice::from_raw_parts_mut(memptr.as_ptr(), len);
70 f(arr);
71 }
72
73 #[cfg(feature = "use_os")]
74 mprotect(memptr, Prot::NoAccess);
75
76 memptr
77 };
78
79 SecBytes {
80 ptr, len,
81
82 #[cfg(feature = "use_os")]
83 count: Cell::new(0)
84 }
85 }
86
87 #[cfg_attr(not(feature = "use_os"), inline)]
96 pub fn read(&self) -> SecReadGuard<'_> {
97 #[cfg(feature = "use_os")] {
98 let count = self.count.get();
99 self.count.set(count + 1);
100 if count == 0 {
101 unsafe { mprotect(self.ptr, Prot::ReadOnly) };
102 }
103 }
104
105 SecReadGuard(self)
106 }
107
108 #[cfg_attr(not(feature = "use_os"), inline)]
119 pub fn write(&mut self) -> SecWriteGuard<'_> {
120 #[cfg(feature = "use_os")]
121 unsafe {
122 mprotect(self.ptr, Prot::ReadWrite)
123 };
124
125 SecWriteGuard(self)
126 }
127}
128
129impl fmt::Debug for SecBytes {
130 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131 f.debug_tuple("SecBytes")
132 .field(&format_args!("{:p}", self.ptr))
133 .finish()
134 }
135}
136
137impl fmt::Pointer for SecBytes {
138 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
139 write!(f, "{:p}", self.ptr)
140 }
141}
142
143impl Drop for SecBytes {
144 fn drop(&mut self) {
145 unsafe {
146 #[cfg(feature = "use_os")]
147 mprotect(self.ptr, Prot::ReadWrite);
148
149 alloc::free(self.ptr, self.len);
150 }
151 }
152}
153
154
155pub struct SecReadGuard<'a>(&'a SecBytes);
157
158impl<'a> Deref for SecReadGuard<'a> {
159 type Target = [u8];
160
161 #[inline]
162 fn deref(&self) -> &[u8] {
163 unsafe {
164 std::slice::from_raw_parts(self.0.ptr.as_ptr(), self.0.len)
165 }
166 }
167}
168
169impl<'a> Drop for SecReadGuard<'a> {
170 fn drop(&mut self) {
171 #[cfg(feature = "use_os")]
172 unsafe {
173 let count = self.0.count.get();
174 self.0.count.set(count - 1);
175 if count <= 1 {
176 mprotect(self.0.ptr, Prot::NoAccess);
177 }
178 }
179 }
180}
181
182
183pub struct SecWriteGuard<'a>(&'a mut SecBytes);
185
186impl<'a> Deref for SecWriteGuard<'a> {
187 type Target = [u8];
188
189 #[inline]
190 fn deref(&self) -> &[u8] {
191 unsafe {
192 std::slice::from_raw_parts(self.0.ptr.as_ptr(), self.0.len)
193 }
194 }
195}
196
197impl<'a> DerefMut for SecWriteGuard<'a> {
198 #[inline]
199 fn deref_mut(&mut self) -> &mut [u8] {
200 unsafe {
201 std::slice::from_raw_parts_mut(self.0.ptr.as_ptr(), self.0.len)
202 }
203 }
204}
205
206impl<'a> Drop for SecWriteGuard<'a> {
207 fn drop(&mut self) {
208 #[cfg(feature = "use_os")]
209 unsafe {
210 mprotect(self.0.ptr, Prot::NoAccess);
211 }
212 }
213}