Skip to main content

hardware_enclave/memory/
locked_buffer.rs

1// Copyright 2026 Jay Gowdy
2// SPDX-License-Identifier: MIT
3
4#![allow(unsafe_code)]
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex, OnceLock, Weak};
8
9use zeroize::{Zeroize, Zeroizing};
10
11use super::secure_buffer::SecureBuffer;
12use crate::error::Result;
13
14// Global registry for centralized shutdown cleanup.
15type Registry = Mutex<HashMap<usize, Weak<Mutex<SecureBuffer>>>>;
16
17fn registry() -> &'static Registry {
18    static REGISTRY: OnceLock<Registry> = OnceLock::new();
19    REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
20}
21
22fn register(id: usize, weak: Weak<Mutex<SecureBuffer>>) {
23    if let Ok(mut r) = registry().lock() {
24        r.insert(id, weak);
25    }
26}
27
28fn unregister(id: usize) {
29    if let Ok(mut r) = registry().lock() {
30        r.remove(&id);
31    }
32}
33
34/// Zeroize the contents of all registered LockedBuffers.
35///
36/// **Call only at process shutdown.** Any LockedBuffer user still holding
37/// a reference after this call will read zeroed data. The buffers are not
38/// destroyed — they remain live with zeroed content until normal Drop runs.
39///
40/// # Panics (debug only)
41/// In debug builds, panics if any LockedBuffer has a strong reference count > 2
42/// at the time of the call (i.e. a caller outside the registry still holds a clone).
43pub fn zeroize_all_registered_at_shutdown() {
44    if let Ok(r) = registry().lock() {
45        for weak in r.values() {
46            if let Some(arc) = weak.upgrade() {
47                // In debug mode, assert this is the only strong reference
48                // (registry holds one weak ref; the upgrade here is the second strong ref,
49                // so count == 2 means no external holders).
50                debug_assert!(
51                    Arc::strong_count(&arc) <= 2,
52                    "zeroize_all_registered_at_shutdown called while LockedBuffer still in use"
53                );
54                if let Ok(mut buf) = arc.lock() {
55                    drop(buf.melt());
56                    if buf.is_alive() {
57                        buf.bytes().zeroize();
58                    }
59                }
60            }
61        }
62    }
63}
64
65/// Arc-wrapped, Mutex-guarded SecureBuffer for sharing across threads.
66pub struct LockedBuffer {
67    inner: Arc<Mutex<SecureBuffer>>,
68    id: usize,
69}
70
71impl std::fmt::Debug for LockedBuffer {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("LockedBuffer")
74            .field("id", &self.id)
75            .finish()
76    }
77}
78
79impl LockedBuffer {
80    fn from_buffer(buf: SecureBuffer) -> Result<Self> {
81        static ID: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1);
82        let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
83        let arc = Arc::new(Mutex::new(buf));
84        register(id, Arc::downgrade(&arc));
85        Ok(Self { inner: arc, id })
86    }
87
88    /// Allocate a new zeroed buffer.
89    pub fn new(size: usize) -> Result<Self> {
90        Self::from_buffer(SecureBuffer::new(size)?)
91    }
92
93    /// Allocate and fill with OsRng random bytes.
94    pub fn random(size: usize) -> Result<Self> {
95        let mut buf = SecureBuffer::new(size)?;
96        buf.scramble()?;
97        Self::from_buffer(buf)
98    }
99
100    /// Create from an existing byte slice (copies into locked memory).
101    pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
102        let src = bytes.as_ref();
103        let mut buf = SecureBuffer::new(src.len())?;
104        buf.bytes().copy_from_slice(src);
105        Self::from_buffer(buf)
106    }
107
108    pub fn freeze(&self) -> Result<()> {
109        self.inner
110            .lock()
111            .unwrap_or_else(|e| e.into_inner())
112            .freeze()
113    }
114
115    pub fn melt(&self) -> Result<()> {
116        self.inner.lock().unwrap_or_else(|e| e.into_inner()).melt()
117    }
118
119    pub fn scramble(&self) -> Result<()> {
120        self.inner
121            .lock()
122            .unwrap_or_else(|e| e.into_inner())
123            .scramble()
124    }
125
126    pub fn wipe(&self) {
127        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
128        drop(guard.melt());
129        if guard.is_alive() {
130            guard.bytes().zeroize();
131        }
132    }
133
134    /// Copy contents to a Zeroizing heap allocation.
135    pub fn bytes_zeroizing(&self) -> Zeroizing<Vec<u8>> {
136        let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
137        Zeroizing::new(guard.as_slice().to_vec())
138    }
139
140    pub fn size(&self) -> usize {
141        self.inner.lock().unwrap_or_else(|e| e.into_inner()).size()
142    }
143}
144
145impl Drop for LockedBuffer {
146    fn drop(&mut self) {
147        unregister(self.id);
148        self.wipe();
149    }
150}
151
152#[cfg(test)]
153#[allow(clippy::unwrap_used, clippy::panic)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn zeroize_at_shutdown_does_not_panic_on_dead_weaks() {
159        // Create a buffer and drop it so the registry Weak becomes dead.
160        // zeroize_all_registered_at_shutdown() must be a no-op for dead Weaks.
161        let buf = LockedBuffer::new(32).unwrap();
162        // Write a known pattern via the inner lock.
163        {
164            let mut g = buf.inner.lock().unwrap_or_else(|e| e.into_inner());
165            g.bytes().fill(0xAB_u8);
166        }
167        // Drop the user-facing handle so the registry's Weak is the only reference.
168        drop(buf);
169        // After drop, the Weak in the registry is dead — this must not panic.
170        zeroize_all_registered_at_shutdown();
171    }
172
173    #[test]
174    fn zeroize_at_shutdown_zeroes_live_buffer() {
175        let buf = LockedBuffer::new(32).unwrap();
176        {
177            let mut g = buf.inner.lock().unwrap_or_else(|e| e.into_inner());
178            g.bytes().fill(0xAB_u8);
179        }
180        // Still holding buf — strong count is 1 (user) + upgrade in the function = 2.
181        zeroize_all_registered_at_shutdown();
182        // After the call the buffer contents must be zeroed.
183        let g = buf.inner.lock().unwrap_or_else(|e| e.into_inner());
184        assert!(
185            g.as_slice().iter().all(|&b| b == 0),
186            "buffer must be zeroed after shutdown call"
187        );
188    }
189
190    #[test]
191    fn new_buffer_is_zeroed() {
192        let buf = LockedBuffer::new(32).unwrap();
193        let bytes = buf.bytes_zeroizing();
194        assert!(bytes.iter().all(|&b| b == 0), "new buffer must be zeroed");
195    }
196
197    #[test]
198    fn new_buffer_has_correct_size() {
199        for &size in &[1_usize, 16, 32, 64, 128] {
200            let buf = LockedBuffer::new(size).unwrap();
201            assert_eq!(buf.size(), size);
202        }
203    }
204
205    #[test]
206    fn random_buffer_is_nonzero() {
207        let buf = LockedBuffer::random(32).unwrap();
208        let bytes = buf.bytes_zeroizing();
209        // Statistically certain with a real CSPRNG.
210        assert!(
211            bytes.iter().any(|&b| b != 0),
212            "random buffer must not be all zeros"
213        );
214    }
215
216    #[test]
217    fn from_bytes_copies_data() {
218        let input = b"hello locked world";
219        let buf = LockedBuffer::from_bytes(input.as_ref()).unwrap();
220        assert_eq!(buf.size(), input.len());
221        let bytes = buf.bytes_zeroizing();
222        assert_eq!(bytes.as_slice(), input.as_ref());
223    }
224
225    #[test]
226    fn wipe_zeroes_contents() {
227        let buf = LockedBuffer::from_bytes(b"secret to wipe").unwrap();
228        buf.wipe();
229        let bytes = buf.bytes_zeroizing();
230        assert!(
231            bytes.iter().all(|&b| b == 0),
232            "after wipe(), buffer must be zeroed"
233        );
234    }
235
236    #[test]
237    fn bytes_zeroizing_returns_independent_copy() {
238        let buf = LockedBuffer::from_bytes(b"copy test").unwrap();
239        let copy = buf.bytes_zeroizing();
240        // Wipe the original.
241        buf.wipe();
242        // The copy should still hold the original data.
243        assert_eq!(copy.as_slice(), b"copy test");
244    }
245
246    #[test]
247    fn freeze_and_melt_through_handle() {
248        let buf = LockedBuffer::new(16).unwrap();
249        buf.freeze().unwrap();
250        buf.melt().unwrap();
251        // After melt, should still be usable.
252        let bytes = buf.bytes_zeroizing();
253        assert_eq!(bytes.len(), 16);
254    }
255
256    #[test]
257    fn scramble_produces_nonzero() {
258        let buf = LockedBuffer::new(64).unwrap();
259        buf.scramble().unwrap();
260        let bytes = buf.bytes_zeroizing();
261        assert!(
262            bytes.iter().any(|&b| b != 0),
263            "scramble must fill with non-zero bytes"
264        );
265    }
266
267    #[test]
268    fn zeroize_all_at_shutdown_zeroes_all_registered() {
269        // Register multiple buffers, call zeroize_all, verify all are zeroed.
270        let b1 = LockedBuffer::from_bytes(b"secret1").unwrap();
271        let b2 = LockedBuffer::from_bytes(b"secret2").unwrap();
272        // Still holding strong refs to b1 and b2 while calling zeroize_all.
273        zeroize_all_registered_at_shutdown();
274        let bytes1 = b1.bytes_zeroizing();
275        let bytes2 = b2.bytes_zeroizing();
276        assert!(
277            bytes1.iter().all(|&b| b == 0),
278            "b1 must be zeroed after shutdown call"
279        );
280        assert!(
281            bytes2.iter().all(|&b| b == 0),
282            "b2 must be zeroed after shutdown call"
283        );
284    }
285
286    #[test]
287    fn concurrent_access_from_multiple_threads() {
288        use std::sync::Arc;
289        use std::thread;
290        let buf = Arc::new(LockedBuffer::new(64).unwrap());
291        let handles: Vec<_> = (0..8)
292            .map(|i| {
293                let b = Arc::clone(&buf);
294                thread::spawn(move || {
295                    // Each thread freezes, reads size, melts.
296                    b.freeze().ok();
297                    let _ = b.size();
298                    b.melt().ok();
299                    let _ = i;
300                })
301            })
302            .collect();
303        for h in handles {
304            h.join().expect("thread panicked");
305        }
306    }
307}