use alloc::string::{String, ToString};
use core::{hash::Hasher, marker::PhantomData, mem::size_of, ptr, slice};
use std::{
env::temp_dir,
fs::{self, File},
io::{Read, Write},
path::PathBuf,
ptr::read_volatile,
};
use ahash::AHasher;
use serde::{de::DeserializeOwned, Serialize};
use crate::{
bolts::{
shmem::{ShMem, ShMemProvider},
AsSlice,
},
Error,
};
#[repr(C)]
struct StateShMemContent {
is_disk: bool,
buf_len: usize,
buf: [u8; 0],
}
impl StateShMemContent {
pub fn tmpfile(&self, shmem_size: usize) -> Result<Option<PathBuf>, Error> {
Ok(if self.is_disk {
let bytes = unsafe {
slice::from_raw_parts(self.buf.as_ptr(), self.buf_len_checked(shmem_size)?)
};
let filename = postcard::from_bytes::<String>(bytes)?;
Some(temp_dir().join(filename))
} else {
None
})
}
pub fn buf_len_checked(&self, shmem_size: usize) -> Result<usize, Error> {
let buf_len = unsafe { read_volatile(&self.buf_len) };
if size_of::<StateShMemContent>() + buf_len > shmem_size {
Err(Error::illegal_state(format!("Stored buf_len is larger than the shared map! Shared data corrupted? Expected {shmem_size} bytes max, but got {} (buf_len {buf_len})", size_of::<StateShMemContent>() + buf_len)))
} else {
Ok(buf_len)
}
}
}
#[derive(Debug, Clone)]
pub struct StateRestorer<SP>
where
SP: ShMemProvider,
{
shmem: SP::ShMem,
phantom: PhantomData<*const SP>,
}
impl<SP> StateRestorer<SP>
where
SP: ShMemProvider,
{
pub fn mapsize(&self) -> usize {
self.shmem.len()
}
pub fn write_to_env(&self, env_name: &str) -> Result<(), Error> {
self.shmem.write_to_env(env_name)
}
pub fn from_env(shmem_provider: &mut SP, env_name: &str) -> Result<Self, Error> {
Ok(Self {
shmem: shmem_provider.existing_from_env(env_name)?,
phantom: PhantomData,
})
}
pub fn new(shmem: SP::ShMem) -> Self {
let mut ret = Self {
shmem,
phantom: PhantomData,
};
ret.reset();
ret
}
pub fn save<S>(&mut self, state: &S) -> Result<(), Error>
where
S: Serialize,
{
if self.has_content() {
return Err(Error::illegal_state(
"Trying to save state to a non-empty state map".to_string(),
));
}
let serialized = postcard::to_allocvec(state)?;
if size_of::<StateShMemContent>() + serialized.len() > self.shmem.len() {
let mut hasher = AHasher::new_with_keys(0, 0);
hasher.write(&serialized[serialized.len().saturating_sub(4096)..]);
let filename = format!("{:016x}.libafl_state", hasher.finish());
let tmpfile = temp_dir().join(&filename);
File::create(tmpfile)?.write_all(&serialized)?;
let filename_buf = postcard::to_allocvec(&filename)?;
let len = filename_buf.len();
if len > self.shmem.len() {
return Err(Error::illegal_state(format!(
"The state restorer map is too small to fit anything, even the filename!
It needs to be at least {} bytes.
The tmpfile was written to {:?}.",
len,
temp_dir().join(&filename)
)));
}
let shmem_content = self.content_mut();
unsafe {
ptr::copy_nonoverlapping(
filename_buf.as_ptr() as *const u8,
shmem_content.buf.as_mut_ptr(),
len,
);
}
shmem_content.buf_len = len;
shmem_content.is_disk = true;
} else {
let len = serialized.len();
let shmem_content = self.content_mut();
unsafe {
ptr::copy_nonoverlapping(
serialized.as_ptr() as *const u8,
shmem_content.buf.as_mut_ptr(),
len,
);
}
shmem_content.buf_len = len;
shmem_content.is_disk = false;
};
Ok(())
}
pub fn reset(&mut self) {
let mapsize = self.mapsize();
let content_mut = self.content_mut();
if let Ok(Some(tmpfile)) = content_mut.tmpfile(mapsize) {
drop(fs::remove_file(tmpfile));
}
content_mut.is_disk = false;
content_mut.buf_len = 0;
}
fn content_mut(&mut self) -> &mut StateShMemContent {
let ptr = self.shmem.as_slice().as_ptr();
#[allow(clippy::cast_ptr_alignment)] unsafe {
&mut *(ptr as *mut StateShMemContent)
}
}
fn content(&self) -> &StateShMemContent {
#[allow(clippy::cast_ptr_alignment)] let ptr = self.shmem.as_slice().as_ptr() as *const StateShMemContent;
unsafe { &*(ptr) }
}
pub fn has_content(&self) -> bool {
self.content().buf_len > 0
}
pub fn restore<S>(&self) -> Result<Option<S>, Error>
where
S: DeserializeOwned,
{
if !self.has_content() {
return Ok(None);
}
let state_shmem_content = self.content();
let bytes = unsafe {
slice::from_raw_parts(
state_shmem_content.buf.as_ptr(),
state_shmem_content.buf_len_checked(self.mapsize())?,
)
};
let mut state = bytes;
let mut file_content;
if state_shmem_content.buf_len == 0 {
return Ok(None);
} else if state_shmem_content.is_disk {
let filename: String = postcard::from_bytes(bytes)?;
let tmpfile = temp_dir().join(&filename);
file_content = vec![];
File::open(tmpfile)?.read_to_end(&mut file_content)?;
if file_content.is_empty() {
return Err(Error::illegal_state(format!(
"Colud not restore state from file {}",
&filename
)));
}
state = &file_content;
}
let deserialized = postcard::from_bytes(state)?;
Ok(Some(deserialized))
}
}
#[cfg(test)]
mod tests {
use alloc::{
string::{String, ToString},
vec::Vec,
};
use serial_test::serial;
use crate::bolts::{
shmem::{ShMemProvider, StdShMemProvider},
staterestore::StateRestorer,
};
#[test]
#[serial]
fn test_state_restore() {
const TESTMAP_SIZE: usize = 1024;
let mut shmem_provider = StdShMemProvider::new().unwrap();
let shmem = shmem_provider.new_shmem(TESTMAP_SIZE).unwrap();
let mut state_restorer = StateRestorer::<StdShMemProvider>::new(shmem);
let state = "hello world".to_string();
state_restorer.save(&state).unwrap();
assert!(state_restorer.has_content());
let restored = state_restorer.restore::<String>().unwrap().unwrap();
println!("Restored {restored}");
assert_eq!(restored, "hello world");
assert!(!state_restorer.content().is_disk);
state_restorer.reset();
assert!(!state_restorer.has_content());
assert!(!state_restorer.content().is_disk);
assert!(state_restorer.restore::<String>().unwrap().is_none());
let too_large = vec![4u8; TESTMAP_SIZE + 1];
state_restorer.save(&too_large).unwrap();
assert!(state_restorer.has_content());
let large_restored = state_restorer.restore::<Vec<u8>>().unwrap().unwrap();
assert_eq!(large_restored, too_large);
assert_eq!(large_restored.len(), too_large.len());
assert_eq!(large_restored[TESTMAP_SIZE], 4u8);
assert!(state_restorer.content().is_disk);
assert_ne!(state_restorer.content().buf_len, 0);
let state_shmem_content = state_restorer.content();
let tmpfile = state_shmem_content
.tmpfile(state_restorer.mapsize())
.unwrap()
.unwrap();
assert!(tmpfile.exists());
state_restorer.reset();
assert!(!state_restorer.has_content());
assert!(!tmpfile.exists());
}
}