use std::{
alloc::{self, Layout},
ptr::NonNull,
slice,
};
const SIZE: usize = super::sys::PAM_MAX_RESP_SIZE as usize;
pub struct PamBuffer(NonNull<[u8; SIZE]>);
const LAYOUT: Layout = match Layout::from_size_align(SIZE, 1) {
Ok(layout) => layout,
Err(_) => unreachable!(),
};
impl PamBuffer {
pub fn leak(self) -> NonNull<u8> {
let result = self.0;
std::mem::forget(self);
result.cast()
}
#[cfg(test)]
pub fn new(mut src: impl AsMut<[u8]>) -> Self {
let mut buffer = PamBuffer::default();
let src = src.as_mut();
buffer[..src.len()].copy_from_slice(src);
wipe_memory(src);
buffer
}
}
impl Default for PamBuffer {
fn default() -> Self {
let res = unsafe { libc::calloc(1, SIZE) };
if let Some(nn) = NonNull::new(res) {
PamBuffer(nn.cast())
} else {
alloc::handle_alloc_error(LAYOUT)
}
}
}
impl std::ops::Deref for PamBuffer {
type Target = [u8];
fn deref(&self) -> &[u8] {
unsafe { slice::from_raw_parts(self.0.as_ptr().cast(), SIZE - 1) }
}
}
impl std::ops::DerefMut for PamBuffer {
fn deref_mut(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.0.as_ptr().cast(), SIZE - 1) }
}
}
impl Drop for PamBuffer {
fn drop(&mut self) {
wipe_memory(unsafe { self.0.as_mut() });
unsafe { libc::free(self.0.as_ptr().cast()) }
}
}
fn wipe_memory(memory: &mut [u8]) {
use std::sync::atomic;
let nonsense: u8 = 0x55;
for c in memory {
unsafe { std::ptr::write_volatile(c, nonsense) };
}
atomic::fence(atomic::Ordering::SeqCst);
atomic::compiler_fence(atomic::Ordering::SeqCst);
}
#[allow(clippy::undocumented_unsafe_blocks)]
#[cfg(test)]
mod test {
use super::PamBuffer;
#[test]
fn miri_test_leaky_cstring() {
let test = |text: &str| unsafe {
let buf = PamBuffer::new(text.to_string().as_bytes_mut());
assert_eq!(&buf[..text.len()], text.as_bytes());
let nn = buf.leak();
let result = crate::cutils::string_from_ptr(nn.as_ptr().cast());
libc::free(nn.as_ptr().cast());
result
};
assert_eq!(test(""), "");
assert_eq!(test("hello"), "hello");
}
#[test]
fn miri_test_wipe() {
let mut memory: [u8; 3] = [1, 2, 3];
let fix = PamBuffer::new(&mut memory);
assert_eq!(memory, [0x55, 0x55, 0x55]);
assert_eq!(fix[0..=2], [1, 2, 3]);
assert!(fix[3..].iter().all(|&x| x == 0));
std::mem::drop(fix);
}
}