1#![allow(unsafe_code)]
5
6use std::ptr::NonNull;
7
8use rand::TryRngCore;
9use zeroize::Zeroize;
10
11use super::memcall::{os_alloc, os_free, os_lock, os_protect, os_unlock, page_size, Protection};
12use crate::error::Error;
13
14const CANARY_LEN: usize = 32;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub(super) enum State {
18 Mutable,
19 Frozen,
20 Dead,
21}
22
23pub struct SecureBuffer {
30 alloc_ptr: NonNull<u8>,
32 alloc_len: usize,
34 inner_ptr: NonNull<u8>,
36 inner_len: usize,
38 pre_canary: [u8; CANARY_LEN],
40 post_canary: [u8; CANARY_LEN],
41 page_size: usize,
42 pub(super) state: State,
43 mlocked: bool,
44}
45
46unsafe impl Send for SecureBuffer {}
48
49impl std::fmt::Debug for SecureBuffer {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("SecureBuffer")
52 .field("inner_len", &self.inner_len)
53 .field("state", &self.state)
54 .finish()
55 }
56}
57
58impl SecureBuffer {
59 pub fn new(size: usize) -> crate::error::Result<Self> {
61 let ps = page_size();
62 let inner_rounded = size.div_ceil(ps) * ps;
64 let alloc_len = ps + inner_rounded + ps;
65
66 let alloc_ptr = unsafe { os_alloc(alloc_len) }
67 .map_err(|e| Error::Memory(format!("SecureBuffer::new alloc: {e}")))?;
68
69 let inner_ptr = unsafe { NonNull::new_unchecked(alloc_ptr.as_ptr().add(ps)) };
74
75 let mut pre_canary = [0_u8; CANARY_LEN];
77 let mut post_canary = [0_u8; CANARY_LEN];
78 if rand::rngs::OsRng.try_fill_bytes(&mut pre_canary).is_err() {
79 pre_canary.fill(0xAB);
80 }
81 if rand::rngs::OsRng.try_fill_bytes(&mut post_canary).is_err() {
82 post_canary.fill(0xCD);
83 }
84
85 unsafe {
87 let pre_guard = alloc_ptr.as_ptr();
88 std::ptr::copy_nonoverlapping(pre_canary.as_ptr(), pre_guard, CANARY_LEN.min(ps));
89 let post_guard = alloc_ptr.as_ptr().add(ps + inner_rounded);
90 std::ptr::copy_nonoverlapping(post_canary.as_ptr(), post_guard, CANARY_LEN.min(ps));
91 }
92
93 let mlocked = unsafe { os_lock(inner_ptr.as_ptr(), inner_rounded) }.is_ok();
95
96 drop(unsafe { os_protect(alloc_ptr.as_ptr(), ps, Protection::NoAccess) });
98 drop(unsafe {
99 os_protect(
100 alloc_ptr.as_ptr().add(ps + inner_rounded),
101 ps,
102 Protection::NoAccess,
103 )
104 });
105
106 Ok(Self {
107 alloc_ptr,
108 alloc_len,
109 inner_ptr,
110 inner_len: size,
111 pre_canary,
112 post_canary,
113 page_size: ps,
114 state: State::Mutable,
115 mlocked,
116 })
117 }
118
119 pub fn size(&self) -> usize {
120 self.inner_len
121 }
122
123 pub fn is_alive(&self) -> bool {
124 self.state != State::Dead
125 }
126
127 pub fn is_mutable(&self) -> bool {
128 self.state == State::Mutable
129 }
130
131 pub fn bytes(&mut self) -> &mut [u8] {
133 assert!(
134 self.state == State::Mutable,
135 "SecureBuffer: bytes() called in non-mutable state"
136 );
137 unsafe { std::slice::from_raw_parts_mut(self.inner_ptr.as_ptr(), self.inner_len) }
142 }
143
144 pub fn as_slice(&self) -> &[u8] {
146 assert!(
147 self.state != State::Dead,
148 "SecureBuffer: as_slice() on dead buffer"
149 );
150 unsafe { std::slice::from_raw_parts(self.inner_ptr.as_ptr(), self.inner_len) }
155 }
156
157 pub fn freeze(&mut self) -> crate::error::Result<()> {
159 if self.state == State::Dead {
160 return Err(Error::Memory("SecureBuffer::freeze on dead buffer".into()));
161 }
162 let inner_rounded = self.alloc_len - 2 * self.page_size;
163 unsafe { os_protect(self.inner_ptr.as_ptr(), inner_rounded, Protection::ReadOnly) }
164 .map_err(|e| Error::Memory(format!("freeze: {e}")))?;
165 self.state = State::Frozen;
166 Ok(())
167 }
168
169 pub fn melt(&mut self) -> crate::error::Result<()> {
171 if self.state == State::Dead {
172 return Err(Error::Memory("SecureBuffer::melt on dead buffer".into()));
173 }
174 let inner_rounded = self.alloc_len - 2 * self.page_size;
175 unsafe {
176 os_protect(
177 self.inner_ptr.as_ptr(),
178 inner_rounded,
179 Protection::ReadWrite,
180 )
181 }
182 .map_err(|e| Error::Memory(format!("melt: {e}")))?;
183 self.state = State::Mutable;
184 Ok(())
185 }
186
187 pub fn destroy(&mut self) -> crate::error::Result<()> {
191 if self.state == State::Dead {
192 return Ok(());
193 }
194
195 let ps = self.page_size;
196 let inner_rounded = self.alloc_len - 2 * ps;
197
198 let pre_guard = self.alloc_ptr.as_ptr();
200 let post_guard = unsafe { self.alloc_ptr.as_ptr().add(ps + inner_rounded) };
201
202 drop(unsafe { os_protect(pre_guard, ps, Protection::ReadOnly) });
203 drop(unsafe { os_protect(post_guard, ps, Protection::ReadOnly) });
204
205 let pre_guard_slice = unsafe { std::slice::from_raw_parts(pre_guard, CANARY_LEN) };
207 let post_guard_slice = unsafe { std::slice::from_raw_parts(post_guard, CANARY_LEN) };
208
209 let pre_ok = pre_guard_slice
211 .iter()
212 .zip(self.pre_canary.iter())
213 .fold(0_u8, |acc, (a, b)| acc | (a ^ b))
214 == 0;
215 let post_ok = post_guard_slice
216 .iter()
217 .zip(self.post_canary.iter())
218 .fold(0_u8, |acc, (a, b)| acc | (a ^ b))
219 == 0;
220
221 drop(unsafe {
223 os_protect(
224 self.inner_ptr.as_ptr(),
225 inner_rounded,
226 Protection::ReadWrite,
227 )
228 });
229
230 unsafe {
232 let s = std::slice::from_raw_parts_mut(self.inner_ptr.as_ptr(), inner_rounded);
233 s.zeroize();
234 }
235
236 if self.mlocked {
238 drop(unsafe { os_unlock(self.inner_ptr.as_ptr(), inner_rounded) });
239 }
240
241 drop(unsafe { os_free(self.alloc_ptr.as_ptr(), self.alloc_len) });
247
248 self.state = State::Dead;
249
250 if !pre_ok || !post_ok {
251 return Err(Error::Memory(
252 "SecureBuffer: guard page canary corrupted — buffer overflow detected".into(),
253 ));
254 }
255 Ok(())
256 }
257
258 pub fn scramble(&mut self) -> crate::error::Result<()> {
260 if self.state != State::Mutable {
261 self.melt()?;
262 }
263 let buf = self.bytes();
264 rand::rngs::OsRng
265 .try_fill_bytes(buf)
266 .map_err(|e| Error::Memory(format!("scramble OsRng: {e}")))
267 }
268}
269
270#[allow(clippy::panic)]
271impl Drop for SecureBuffer {
272 fn drop(&mut self) {
273 if let Err(e) = self.destroy() {
274 tracing::error!(error = %e, "SecureBuffer canary corruption detected — possible buffer overflow");
277 #[cfg(debug_assertions)]
281 panic!("SecureBuffer canary corrupted: {e}");
282 }
283 }
284}
285
286#[cfg(test)]
287#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
288mod tests {
289 use super::*;
290
291 #[test]
294 fn canary_corruption_detected() {
295 let mut buf = SecureBuffer::new(64).unwrap();
296
297 let ps = page_size();
299 let inner_rounded = 64_usize.div_ceil(ps) * ps;
300 let post_guard = unsafe { buf.alloc_ptr.as_ptr().add(ps + inner_rounded) };
301
302 unsafe {
303 os_protect(post_guard, ps, Protection::ReadWrite).unwrap();
304 *post_guard = !*post_guard; }
306
307 let result = buf.destroy();
309 assert!(
310 result.is_err(),
311 "destroy should report canary failure but returned Ok"
312 );
313 let msg = result.unwrap_err().to_string();
314 assert!(
315 msg.contains("canary"),
316 "error should mention canary, got: {msg}"
317 );
318 }
319
320 #[test]
321 fn new_buffer_is_mutable() {
322 let buf = SecureBuffer::new(32).unwrap();
323 assert!(buf.is_mutable());
324 assert!(buf.is_alive());
325 }
326
327 #[test]
328 fn freeze_and_melt() {
329 let mut buf = SecureBuffer::new(32).unwrap();
330 buf.freeze().unwrap();
331 assert!(!buf.is_mutable());
332 buf.melt().unwrap();
333 assert!(buf.is_mutable());
334 }
335
336 #[test]
337 fn bytes_writes_and_reads_back() {
338 let mut buf = SecureBuffer::new(64).unwrap();
339 buf.bytes()[0] = 0xAA_u8;
340 buf.bytes()[63] = 0xBB_u8;
341 assert_eq!(buf.as_slice()[0], 0xAA_u8);
342 assert_eq!(buf.as_slice()[63], 0xBB_u8);
343 }
344
345 #[test]
346 fn scramble_produces_non_zero() {
347 let mut buf = SecureBuffer::new(64).unwrap();
348 buf.scramble().unwrap();
349 let all_zero = buf.as_slice().iter().all(|&b| b == 0_u8);
351 assert!(!all_zero, "scramble should produce non-zero bytes");
352 }
353
354 #[test]
355 fn destroy_returns_ok_on_clean_buffer() {
356 let mut buf = SecureBuffer::new(32).unwrap();
357 buf.destroy().unwrap();
358 assert!(!buf.is_alive());
359 }
360
361 #[test]
362 fn drop_without_explicit_destroy_does_not_panic() {
363 let mut buf = SecureBuffer::new(128).unwrap();
365 buf.bytes()[0] = 1_u8;
366 drop(buf);
367 }
368
369 #[test]
370 fn freeze_twice_is_idempotent() {
371 let mut buf = SecureBuffer::new(32).unwrap();
372 buf.freeze().unwrap();
373 buf.freeze().unwrap();
375 assert!(!buf.is_mutable());
376 }
377
378 #[test]
379 fn melt_twice_is_idempotent() {
380 let mut buf = SecureBuffer::new(32).unwrap();
381 buf.freeze().unwrap();
382 buf.melt().unwrap();
383 buf.melt().unwrap();
384 assert!(buf.is_mutable());
385 }
386
387 #[test]
388 fn frozen_buffer_is_readable() {
389 let mut buf = SecureBuffer::new(16).unwrap();
390 buf.bytes()[0] = 0x99;
391 buf.freeze().unwrap();
392 assert_eq!(buf.as_slice()[0], 0x99);
393 }
394
395 #[test]
396 fn scramble_on_frozen_buffer_melts_first() {
397 let mut buf = SecureBuffer::new(32).unwrap();
398 buf.freeze().unwrap();
399 buf.scramble().unwrap();
401 assert!(buf.is_mutable());
402 }
403
404 #[test]
405 fn destroy_twice_is_safe() {
406 let mut buf = SecureBuffer::new(32).unwrap();
407 buf.destroy().unwrap();
408 assert!(!buf.is_alive());
409 buf.destroy().unwrap();
411 assert!(!buf.is_alive());
412 }
413
414 #[test]
415 fn boundary_sizes() {
416 let ps = page_size();
417 for size in [
418 1_usize,
419 15,
420 16,
421 31,
422 32,
423 33,
424 63,
425 64,
426 ps - 1,
427 ps,
428 ps + 1,
429 ps * 2,
430 ] {
431 let mut buf = SecureBuffer::new(size).unwrap();
432 assert_eq!(buf.size(), size);
433 buf.bytes().fill(0xAB);
435 assert!(buf.as_slice().iter().all(|&b| b == 0xAB));
436 buf.destroy().unwrap();
437 }
438 }
439
440 #[test]
441 fn canary_pre_guard_corruption_detected() {
442 let ps = page_size();
444 let mut buf = SecureBuffer::new(64).unwrap();
445 let pre_guard = buf.alloc_ptr.as_ptr();
448 unsafe {
449 os_protect(pre_guard, ps, Protection::ReadWrite).unwrap();
450 *pre_guard = !*pre_guard;
451 }
452 let result = buf.destroy();
453 assert!(
454 result.is_err(),
455 "pre-guard canary corruption must be detected"
456 );
457 let msg = result.unwrap_err().to_string();
458 assert!(msg.contains("canary"), "error must mention canary: {msg}");
459 }
460
461 #[test]
462 fn drop_zeroes_inner_region() {
463 let mut buf = SecureBuffer::new(64).unwrap();
466 buf.bytes().fill(0xDE);
467 buf.destroy().unwrap();
468 assert!(!buf.is_alive());
469 }
470}