1use crate::types::*;
5use zeroize::Zeroize;
6
7use core::{
8 cell::Cell,
9 fmt::{self, Debug},
10 mem,
11 ptr::NonNull,
12 slice,
13};
14
15use libsodium_sys::{
16 sodium_allocarray, sodium_free, sodium_init, sodium_mlock, sodium_mprotect_noaccess, sodium_mprotect_readonly,
17 sodium_mprotect_readwrite,
18};
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21enum Prot {
22 NoAccess,
23 ReadOnly,
24 ReadWrite,
25}
26
27type RefCount = u8;
28
29#[derive(Eq)]
31pub(crate) struct Boxed<T: Bytes> {
32 ptr: NonNull<T>,
34 len: usize,
36 prot: Cell<Prot>,
38 refs: Cell<RefCount>,
40}
41
42impl<T: Bytes> Boxed<T> {
43 pub(crate) fn new<F>(len: usize, init: F) -> Self
44 where
45 F: FnOnce(&mut Self),
46 {
47 let mut boxed = Self::new_unlocked(len);
48 unsafe { lock_memory(boxed.ptr.as_mut(), len) };
49
50 assert!(
51 boxed.ptr != core::ptr::NonNull::dangling(),
52 "Make sure pointer isn't dangling"
53 );
54 assert!(boxed.len == len);
55
56 init(&mut boxed);
57
58 boxed.lock();
59
60 boxed
61 }
62
63 #[allow(dead_code)]
64 pub(crate) fn try_new<R, E, F>(len: usize, init: F) -> Result<Self, E>
65 where
66 F: FnOnce(&mut Self) -> Result<R, E>,
67 {
68 let mut boxed = Self::new_unlocked(len);
69
70 assert!(
71 boxed.ptr != core::ptr::NonNull::dangling(),
72 "Make sure pointer isn't dangling"
73 );
74 assert!(boxed.len == len);
75
76 let res = init(&mut boxed);
77
78 boxed.lock();
79
80 res.map(|_| boxed)
81 }
82
83 pub(crate) fn len(&self) -> usize {
84 self.len
85 }
86
87 pub(crate) fn is_empty(&self) -> bool {
88 self.len == 0
89 }
90
91 pub(crate) fn size(&self) -> usize {
92 self.len * T::size()
93 }
94
95 pub(crate) fn unlock(&self) -> &Self {
96 self.retain(Prot::ReadOnly);
97 self
98 }
99
100 pub(crate) fn unlock_mut(&mut self) -> &mut Self {
101 self.retain(Prot::ReadWrite);
102 self
103 }
104
105 pub(crate) fn lock(&self) {
106 self.release()
107 }
108
109 #[allow(dead_code)]
110 pub(crate) fn as_ref(&self) -> &T {
111 assert!(!self.is_empty(), "Attempted to dereference a zero-length pointer");
112
113 assert!(self.prot.get() != Prot::NoAccess, "May not call Boxed while locked");
114
115 unsafe { self.ptr.as_ref() }
116 }
117
118 pub(crate) fn as_mut(&mut self) -> &mut T {
119 assert!(!self.is_empty(), "Attempted to dereference a zero-length pointer");
120
121 assert!(
122 self.prot.get() == Prot::ReadWrite,
123 "May not call Boxed unless mutably unlocked"
124 );
125
126 unsafe { self.ptr.as_mut() }
127 }
128
129 pub(crate) fn as_slice(&self) -> &[T] {
130 assert!(self.prot.get() != Prot::NoAccess, "May not call Boxed while locked");
131
132 unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
133 }
134
135 pub(crate) fn as_mut_slice(&mut self) -> &mut [T] {
136 assert!(
137 self.prot.get() == Prot::ReadWrite,
138 "May not call Boxed unless mutably unlocked"
139 );
140
141 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
142 }
143
144 fn new_unlocked(len: usize) -> Self {
145 if unsafe { sodium_init() == -1 } {
146 panic!("Failed to initialize libsodium")
147 }
148
149 let ptr = NonNull::new(unsafe { sodium_allocarray(len, mem::size_of::<T>()) as *mut _ })
150 .expect("Failed to allocate memory");
151
152 Self {
153 ptr,
154 len,
155 prot: Cell::new(Prot::ReadWrite),
156 refs: Cell::new(1),
157 }
158 }
159
160 fn retain(&self, prot: Prot) {
161 let refs = self.refs.get();
162
163 if refs == 0 {
164 assert!(prot != Prot::NoAccess, "Must retain readably or writably");
165
166 self.prot.set(prot);
167 mprotect(self.ptr.as_ptr(), prot);
168 } else {
169 assert!(
170 Prot::NoAccess != self.prot.get(),
171 "Out-of-order retain/release detected"
172 );
173 assert!(
174 Prot::ReadWrite != self.prot.get(),
175 "Cannot unlock mutably more than once"
176 );
177 assert!(Prot::ReadOnly == prot, "Cannot unlock mutably while unlocked immutably");
178 }
179
180 match refs.checked_add(1) {
181 Some(v) => self.refs.set(v),
182 None if self.is_locked() => panic!("Out-of-order retain/release detected"),
183 None => panic!("Retained too many times"),
184 };
185 }
186
187 fn release(&self) {
188 assert!(self.refs.get() != 0, "Releases exceeded retains");
189
190 assert!(
191 self.prot.get() != Prot::NoAccess,
192 "Releasing memory that's already locked"
193 );
194
195 let refs = self.refs.get().wrapping_sub(1);
196
197 self.refs.set(refs);
198
199 if refs == 0 {
200 mprotect(self.ptr.as_ptr(), Prot::NoAccess);
201 self.prot.set(Prot::NoAccess);
202 }
203 }
204
205 fn is_locked(&self) -> bool {
206 self.prot.get() == Prot::NoAccess
207 }
208
209 #[cfg(test)]
210 #[allow(dead_code)]
211 pub fn get_ptr_address(&self) -> usize {
213 self.ptr.as_ptr() as *const _ as usize
214 }
215}
216
217impl<T: Bytes + Randomized> Boxed<T> {
218 #[allow(dead_code)]
219 pub(crate) fn random(len: usize) -> Self {
220 Self::new(len, |b| b.as_mut_slice().randomize())
221 }
222}
223
224impl<T: Bytes + Zeroed> Boxed<T> {
225 #[allow(dead_code)]
226 pub(crate) fn zero(len: usize) -> Self {
227 Self::new(len, |b| b.as_mut_slice().zero())
228 }
229}
230
231impl<T: Bytes> Zeroize for Boxed<T> {
234 fn zeroize(&mut self) {
235 self.unlock_mut();
236 self.as_mut_slice().zero();
237 self.lock();
238 self.refs.set(0);
239 self.prot.set(Prot::NoAccess);
240 self.len = 0;
241 }
242}
243
244impl<T: Bytes> Drop for Boxed<T> {
245 fn drop(&mut self) {
246 extern crate std;
247
248 use std::thread;
249
250 if !thread::panicking() {
251 assert!(self.refs.get() == 0, "Retains exceeded releases");
252
253 assert!(self.prot.get() == Prot::NoAccess, "Dropped secret was still accessible");
254 }
255
256 unsafe { free(self.ptr.as_mut()) }
257 }
258}
259
260impl<T: Bytes> Debug for Boxed<T> {
261 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
262 write!(fmt, "{{ size: {}, hidden }}", self.size())
263 }
264}
265
266impl<T: Bytes> Clone for Boxed<T> {
267 fn clone(&self) -> Self {
268 Self::new(self.len, |b| {
269 b.as_mut_slice().copy_from_slice(self.unlock().as_slice());
270 self.lock();
271 })
272 }
273}
274
275impl<T: Bytes + ConstEq> PartialEq for Boxed<T> {
276 fn eq(&self, other: &Self) -> bool {
277 if self.len != other.len {
278 return false;
279 }
280
281 let lhs = self.unlock().as_slice();
282 let rhs = other.unlock().as_slice();
283
284 let ret = lhs.const_eq(rhs);
285
286 self.lock();
287 other.lock();
288
289 ret
290 }
291}
292
293impl<T: Bytes + Zeroed> From<&mut T> for Boxed<T> {
294 fn from(data: &mut T) -> Self {
295 Self::new(1, |b| unsafe { data.copy_and_zero(b.as_mut()) })
296 }
297}
298
299impl<T: Bytes + Zeroed> From<&mut [T]> for Boxed<T> {
300 fn from(data: &mut [T]) -> Self {
301 Self::new(data.len(), |b| unsafe { data.copy_and_zero(b.as_mut_slice()) })
302 }
303}
304
305unsafe impl<T: Bytes + Send> Send for Boxed<T> {}
306unsafe impl<T: Bytes + Sync> Sync for Boxed<T> {}
307
308fn mprotect<T>(ptr: *mut T, prot: Prot) {
309 if !match prot {
310 Prot::NoAccess => unsafe { sodium_mprotect_noaccess(ptr as *mut _) == 0 },
311 Prot::ReadOnly => unsafe { sodium_mprotect_readonly(ptr as *mut _) == 0 },
312 Prot::ReadWrite => unsafe { sodium_mprotect_readwrite(ptr as *mut _) == 0 },
313 } {
314 panic!("Error setting memory protection to {:?}", prot);
315 }
316}
317
318pub(crate) unsafe fn free<T>(ptr: *mut T) {
319 sodium_free(ptr as *mut _)
320}
321
322pub(crate) unsafe fn lock_memory<T>(ptr: *mut T, len: usize) {
323 sodium_mlock(ptr as *mut _, len);
324}
325
326#[cfg(test)]
327mod test {
328 extern crate alloc;
329
330 use alloc::vec;
331
332 use super::*;
333 use libsodium_sys::randombytes_buf;
334
335 #[test]
336 fn boxed_zeroize() {
337 let mut boxed = Boxed::<u8>::random(4);
338 let ptr = unsafe { core::slice::from_raw_parts(boxed.ptr.as_ptr(), 4) };
339 boxed.unlock();
340 assert_ne!(ptr, [0u8; 4]);
341 boxed.lock();
342
343 boxed.zeroize();
344
345 boxed.unlock();
346 assert_eq!(ptr, [0u8; 4]);
347 boxed.lock();
348 }
349
350 #[test]
351 fn test_init_with_garbage() {
352 let boxed = Boxed::<u8>::new(4, |_| {});
353 let unboxed = boxed.unlock().as_slice();
354
355 let garbage = unsafe {
356 let garb_ptr = sodium_allocarray(1, mem::size_of::<u8>()) as *mut u8;
357 let garb_byte = *garb_ptr;
358
359 free(garb_ptr);
360
361 vec![garb_byte; unboxed.len()]
362 };
363
364 assert_ne!(garbage, vec![0; garbage.len()]);
365 assert_eq!(unboxed, &garbage[..]);
366
367 boxed.lock();
368 }
369
370 #[test]
371 fn test_custom_init() {
372 let boxed = Boxed::<u8>::new(1, |secret| {
373 secret.as_mut_slice().copy_from_slice(b"\x04");
374 });
375
376 assert_eq!(boxed.unlock().as_slice(), [0x04]);
377 boxed.lock();
378 }
379
380 #[test]
381 fn test_init_with_zero() {
382 let boxed = Boxed::<u8>::zero(6);
383
384 assert_eq!(boxed.unlock().as_slice(), [0, 0, 0, 0, 0, 0]);
385
386 boxed.lock();
387 }
388
389 #[test]
390 fn test_init_with_values() {
391 let mut value = [8u64];
392 let boxed = Boxed::from(&mut value[..]);
393
394 assert_eq!(value, [0]);
395 assert_eq!(boxed.unlock().as_slice(), [8]);
396
397 boxed.lock();
398 }
399
400 #[allow(clippy::redundant_clone)]
401 #[test]
402 fn test_eq() {
403 let boxed_a = Boxed::<u8>::random(1);
404 let boxed_b = boxed_a.clone();
405
406 assert_eq!(boxed_a, boxed_b);
407 assert_eq!(boxed_b, boxed_a);
408
409 let boxed_a = Boxed::<u8>::random(16);
410 let boxed_b = Boxed::<u8>::random(16);
411
412 assert_ne!(boxed_a, boxed_b);
413 assert_ne!(boxed_b, boxed_a);
414
415 let boxed_b = Boxed::<u8>::random(12);
416
417 assert_ne!(boxed_a, boxed_b);
418 assert_ne!(boxed_b, boxed_a);
419 }
420
421 #[test]
422 fn test_refs() {
423 let mut boxed = Boxed::<u8>::zero(8);
424
425 assert_eq!(0, boxed.refs.get());
426
427 let _ = boxed.unlock();
428 let _ = boxed.unlock();
429
430 assert_eq!(2, boxed.refs.get());
431
432 boxed.lock();
433 boxed.lock();
434
435 assert_eq!(0, boxed.refs.get());
436
437 let _ = boxed.unlock_mut();
438
439 assert_eq!(1, boxed.refs.get());
440
441 boxed.lock();
442
443 assert_eq!(0, boxed.refs.get());
444 }
445
446 #[test]
447 fn test_ref_overflow() {
448 let boxed = Boxed::<u8>::zero(8);
449
450 for _ in 0..u8::max_value() {
451 let _ = boxed.unlock();
452 }
453
454 for _ in 0..u8::max_value() {
455 boxed.lock()
456 }
457 }
458
459 #[test]
460 fn test_random_borrow_amounts() {
461 let boxed = Boxed::<u8>::zero(1);
462 let mut counter = 0u8;
463
464 unsafe {
465 randombytes_buf(
466 counter.as_mut_bytes().as_mut_ptr() as *mut _,
467 counter.as_mut_bytes().len(),
468 );
469 }
470
471 for _ in 0..counter {
472 let _ = boxed.unlock();
473 }
474
475 for _ in 0..counter {
476 boxed.lock()
477 }
478 }
479
480 #[test]
481 fn test_threading() {
482 extern crate std;
483
484 use std::{sync::mpsc, thread};
485
486 let (tx, rx) = mpsc::channel();
487
488 let ch = thread::spawn(move || {
489 let boxed = Boxed::<u64>::random(1);
490 let val = boxed.unlock().as_slice().to_vec();
491
492 tx.send((boxed, val)).expect("failed to send via channel");
493 });
494
495 let (boxed, val) = rx.recv().expect("failed to read from channel");
496
497 assert_eq!(Prot::ReadOnly, boxed.prot.get());
498 assert_eq!(val, boxed.as_slice());
499
500 ch.join().expect("child thread terminated.");
501 boxed.lock();
502 }
503
504 #[test]
505 #[should_panic(expected = "Retained too many times")]
506 fn test_overflow_refs() {
507 let boxed = Boxed::<[u8; 4]>::zero(4);
508
509 for _ in 0..=u8::max_value() {
510 let _ = boxed.unlock();
511 }
512
513 for _ in 0..boxed.refs.get() {
514 boxed.lock()
515 }
516 }
517
518 #[test]
519 #[should_panic(expected = "Out-of-order retain/release detected")]
520 fn test_out_of_order() {
521 let boxed = Boxed::<u8>::zero(3);
522
523 boxed.refs.set(boxed.refs.get().wrapping_sub(1));
524 boxed.prot.set(Prot::NoAccess);
525
526 boxed.retain(Prot::ReadOnly);
527 }
528
529 #[test]
530 #[should_panic(expected = "Attempted to dereference a zero-length pointer")]
531 fn test_zero_length() {
532 let boxed = Boxed::<u8>::zero(0);
533
534 let _ = boxed.as_ref();
535 }
536
537 #[test]
538 #[should_panic(expected = "Cannot unlock mutably more than once")]
539 fn test_multiple_writers() {
540 let mut boxed = Boxed::<u64>::zero(1);
541
542 let _ = boxed.unlock_mut();
543 let _ = boxed.unlock_mut();
544 }
545
546 #[test]
547 #[should_panic(expected = "Releases exceeded retains")]
548 fn test_release_vs_retain() {
549 Boxed::<u64>::zero(2).lock();
550 }
551}