hardware_enclave/memory/
locked_buffer.rs1#![allow(unsafe_code)]
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex, OnceLock, Weak};
8
9use zeroize::{Zeroize, Zeroizing};
10
11use super::secure_buffer::SecureBuffer;
12use crate::error::Result;
13
14type Registry = Mutex<HashMap<usize, Weak<Mutex<SecureBuffer>>>>;
16
17fn registry() -> &'static Registry {
18 static REGISTRY: OnceLock<Registry> = OnceLock::new();
19 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
20}
21
22fn register(id: usize, weak: Weak<Mutex<SecureBuffer>>) {
23 if let Ok(mut r) = registry().lock() {
24 r.insert(id, weak);
25 }
26}
27
28fn unregister(id: usize) {
29 if let Ok(mut r) = registry().lock() {
30 r.remove(&id);
31 }
32}
33
34pub fn zeroize_all_registered_at_shutdown() {
44 if let Ok(r) = registry().lock() {
45 for weak in r.values() {
46 if let Some(arc) = weak.upgrade() {
47 debug_assert!(
51 Arc::strong_count(&arc) <= 2,
52 "zeroize_all_registered_at_shutdown called while LockedBuffer still in use"
53 );
54 if let Ok(mut buf) = arc.lock() {
55 drop(buf.melt());
56 if buf.is_alive() {
57 buf.bytes().zeroize();
58 }
59 }
60 }
61 }
62 }
63}
64
65pub struct LockedBuffer {
67 inner: Arc<Mutex<SecureBuffer>>,
68 id: usize,
69}
70
71impl std::fmt::Debug for LockedBuffer {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("LockedBuffer")
74 .field("id", &self.id)
75 .finish()
76 }
77}
78
79impl LockedBuffer {
80 fn from_buffer(buf: SecureBuffer) -> Result<Self> {
81 static ID: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1);
82 let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
83 let arc = Arc::new(Mutex::new(buf));
84 register(id, Arc::downgrade(&arc));
85 Ok(Self { inner: arc, id })
86 }
87
88 pub fn new(size: usize) -> Result<Self> {
90 Self::from_buffer(SecureBuffer::new(size)?)
91 }
92
93 pub fn random(size: usize) -> Result<Self> {
95 let mut buf = SecureBuffer::new(size)?;
96 buf.scramble()?;
97 Self::from_buffer(buf)
98 }
99
100 pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
102 let src = bytes.as_ref();
103 let mut buf = SecureBuffer::new(src.len())?;
104 buf.bytes().copy_from_slice(src);
105 Self::from_buffer(buf)
106 }
107
108 pub fn freeze(&self) -> Result<()> {
109 self.inner
110 .lock()
111 .unwrap_or_else(|e| e.into_inner())
112 .freeze()
113 }
114
115 pub fn melt(&self) -> Result<()> {
116 self.inner.lock().unwrap_or_else(|e| e.into_inner()).melt()
117 }
118
119 pub fn scramble(&self) -> Result<()> {
120 self.inner
121 .lock()
122 .unwrap_or_else(|e| e.into_inner())
123 .scramble()
124 }
125
126 pub fn wipe(&self) {
127 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
128 drop(guard.melt());
129 if guard.is_alive() {
130 guard.bytes().zeroize();
131 }
132 }
133
134 pub fn bytes_zeroizing(&self) -> Zeroizing<Vec<u8>> {
136 let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
137 Zeroizing::new(guard.as_slice().to_vec())
138 }
139
140 pub fn size(&self) -> usize {
141 self.inner.lock().unwrap_or_else(|e| e.into_inner()).size()
142 }
143}
144
145impl Drop for LockedBuffer {
146 fn drop(&mut self) {
147 unregister(self.id);
148 self.wipe();
149 }
150}
151
152#[cfg(test)]
153#[allow(clippy::unwrap_used, clippy::panic)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn zeroize_at_shutdown_does_not_panic_on_dead_weaks() {
159 let buf = LockedBuffer::new(32).unwrap();
162 {
164 let mut g = buf.inner.lock().unwrap_or_else(|e| e.into_inner());
165 g.bytes().fill(0xAB_u8);
166 }
167 drop(buf);
169 zeroize_all_registered_at_shutdown();
171 }
172
173 #[test]
174 fn zeroize_at_shutdown_zeroes_live_buffer() {
175 let buf = LockedBuffer::new(32).unwrap();
176 {
177 let mut g = buf.inner.lock().unwrap_or_else(|e| e.into_inner());
178 g.bytes().fill(0xAB_u8);
179 }
180 zeroize_all_registered_at_shutdown();
182 let g = buf.inner.lock().unwrap_or_else(|e| e.into_inner());
184 assert!(
185 g.as_slice().iter().all(|&b| b == 0),
186 "buffer must be zeroed after shutdown call"
187 );
188 }
189
190 #[test]
191 fn new_buffer_is_zeroed() {
192 let buf = LockedBuffer::new(32).unwrap();
193 let bytes = buf.bytes_zeroizing();
194 assert!(bytes.iter().all(|&b| b == 0), "new buffer must be zeroed");
195 }
196
197 #[test]
198 fn new_buffer_has_correct_size() {
199 for &size in &[1_usize, 16, 32, 64, 128] {
200 let buf = LockedBuffer::new(size).unwrap();
201 assert_eq!(buf.size(), size);
202 }
203 }
204
205 #[test]
206 fn random_buffer_is_nonzero() {
207 let buf = LockedBuffer::random(32).unwrap();
208 let bytes = buf.bytes_zeroizing();
209 assert!(
211 bytes.iter().any(|&b| b != 0),
212 "random buffer must not be all zeros"
213 );
214 }
215
216 #[test]
217 fn from_bytes_copies_data() {
218 let input = b"hello locked world";
219 let buf = LockedBuffer::from_bytes(input.as_ref()).unwrap();
220 assert_eq!(buf.size(), input.len());
221 let bytes = buf.bytes_zeroizing();
222 assert_eq!(bytes.as_slice(), input.as_ref());
223 }
224
225 #[test]
226 fn wipe_zeroes_contents() {
227 let buf = LockedBuffer::from_bytes(b"secret to wipe").unwrap();
228 buf.wipe();
229 let bytes = buf.bytes_zeroizing();
230 assert!(
231 bytes.iter().all(|&b| b == 0),
232 "after wipe(), buffer must be zeroed"
233 );
234 }
235
236 #[test]
237 fn bytes_zeroizing_returns_independent_copy() {
238 let buf = LockedBuffer::from_bytes(b"copy test").unwrap();
239 let copy = buf.bytes_zeroizing();
240 buf.wipe();
242 assert_eq!(copy.as_slice(), b"copy test");
244 }
245
246 #[test]
247 fn freeze_and_melt_through_handle() {
248 let buf = LockedBuffer::new(16).unwrap();
249 buf.freeze().unwrap();
250 buf.melt().unwrap();
251 let bytes = buf.bytes_zeroizing();
253 assert_eq!(bytes.len(), 16);
254 }
255
256 #[test]
257 fn scramble_produces_nonzero() {
258 let buf = LockedBuffer::new(64).unwrap();
259 buf.scramble().unwrap();
260 let bytes = buf.bytes_zeroizing();
261 assert!(
262 bytes.iter().any(|&b| b != 0),
263 "scramble must fill with non-zero bytes"
264 );
265 }
266
267 #[test]
268 fn zeroize_all_at_shutdown_zeroes_all_registered() {
269 let b1 = LockedBuffer::from_bytes(b"secret1").unwrap();
271 let b2 = LockedBuffer::from_bytes(b"secret2").unwrap();
272 zeroize_all_registered_at_shutdown();
274 let bytes1 = b1.bytes_zeroizing();
275 let bytes2 = b2.bytes_zeroizing();
276 assert!(
277 bytes1.iter().all(|&b| b == 0),
278 "b1 must be zeroed after shutdown call"
279 );
280 assert!(
281 bytes2.iter().all(|&b| b == 0),
282 "b2 must be zeroed after shutdown call"
283 );
284 }
285
286 #[test]
287 fn concurrent_access_from_multiple_threads() {
288 use std::sync::Arc;
289 use std::thread;
290 let buf = Arc::new(LockedBuffer::new(64).unwrap());
291 let handles: Vec<_> = (0..8)
292 .map(|i| {
293 let b = Arc::clone(&buf);
294 thread::spawn(move || {
295 b.freeze().ok();
297 let _ = b.size();
298 b.melt().ok();
299 let _ = i;
300 })
301 })
302 .collect();
303 for h in handles {
304 h.join().expect("thread panicked");
305 }
306 }
307}