#[cfg(feature = "std")]
use alloc::{
string::{String, ToString},
vec::Vec,
};
use core::{mem::ManuallyDrop, ptr::addr_of};
#[cfg(target_vendor = "apple")]
use std::fs;
use std::{
cell::RefCell,
env,
io::{Read, Write},
marker::PhantomData,
rc::{Rc, Weak},
sync::{Arc, Condvar, Mutex},
thread::JoinHandle,
};
#[cfg(all(feature = "std", unix))]
use std::{
os::unix::{
io::{AsRawFd, RawFd},
net::{UnixListener, UnixStream},
},
thread,
};
use hashbrown::HashMap;
#[cfg(all(feature = "std", unix))]
use nix::poll::{poll, PollFd, PollFlags};
use serde::{Deserialize, Serialize};
#[cfg(all(unix, feature = "std"))]
use uds::{UnixListenerExt, UnixSocketAddr, UnixStreamExt};
use crate::{
bolts::{
shmem::{ShMem, ShMemDescription, ShMemId, ShMemProvider},
AsMutSlice, AsSlice,
},
Error,
};
#[cfg(all(unix, not(target_vendor = "apple")))]
const UNIX_SERVER_NAME: &str = "@libafl_unix_shmem_server";
#[cfg(target_vendor = "apple")]
const UNIX_SERVER_NAME: &str = "./libafl_unix_shmem_server";
const AFL_SHMEM_SERVICE_STARTED: &str = "AFL_SHMEM_SERVICE_STARTED";
#[derive(Debug)]
pub struct ServedShMemProvider<SP>
where
SP: ShMemProvider,
{
stream: UnixStream,
inner: SP,
id: i32,
service: ShMemService<SP>,
}
#[derive(Clone, Debug)]
pub struct ServedShMem<SH>
where
SH: ShMem,
{
inner: ManuallyDrop<SH>,
server_fd: i32,
}
impl<SH> ShMem for ServedShMem<SH>
where
SH: ShMem,
{
fn id(&self) -> ShMemId {
let client_id = self.inner.id();
ShMemId::from_string(&format!("{}:{client_id}", self.server_fd))
}
fn len(&self) -> usize {
self.inner.len()
}
}
impl<SH> AsSlice for ServedShMem<SH>
where
SH: ShMem,
{
type Entry = u8;
fn as_slice(&self) -> &[u8] {
self.inner.as_slice()
}
}
impl<SH> AsMutSlice for ServedShMem<SH>
where
SH: ShMem,
{
type Entry = u8;
fn as_mut_slice(&mut self) -> &mut [u8] {
self.inner.as_mut_slice()
}
}
impl<SP> ServedShMemProvider<SP>
where
SP: ShMemProvider,
{
#[allow(clippy::similar_names)] fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> {
let body = postcard::to_allocvec(&request)?;
let header = (body.len() as u32).to_be_bytes();
let mut message = header.to_vec();
message.extend(body);
self.stream
.write_all(&message)
.expect("Failed to send message");
let mut shm_slice = [0_u8; 20];
let mut fd_buf = [-1; 1];
self.stream
.recv_fds(&mut shm_slice, &mut fd_buf)
.expect("Did not receive a response");
let server_id = ShMemId::from_array(&shm_slice);
let server_fd: i32 = server_id.into();
Ok((server_fd, fd_buf[0]))
}
}
impl<SP> Default for ServedShMemProvider<SP>
where
SP: ShMemProvider,
{
fn default() -> Self {
Self::new().unwrap()
}
}
impl<SP> Clone for ServedShMemProvider<SP>
where
SP: ShMemProvider,
{
fn clone(&self) -> Self {
let mut cloned = Self::new().unwrap();
cloned.service = self.service.clone();
cloned
}
}
impl<SP> ShMemProvider for ServedShMemProvider<SP>
where
SP: ShMemProvider,
{
type ShMem = ServedShMem<SP::ShMem>;
fn new() -> Result<Self, Error> {
let service = ShMemService::<SP>::start();
let mut res = Self {
stream: UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?,
inner: SP::new()?,
id: -1,
service,
};
let (id, _) = res.send_receive(ServedShMemRequest::Hello())?;
res.id = id;
Ok(res)
}
fn new_shmem(&mut self, map_size: usize) -> Result<Self::ShMem, Error> {
let (server_fd, client_fd) = self.send_receive(ServedShMemRequest::NewMap(map_size))?;
Ok(ServedShMem {
inner: ManuallyDrop::new(
self.inner.shmem_from_id_and_size(
ShMemId::from_string(&format!("{client_fd}")),
map_size,
)?,
),
server_fd,
})
}
fn shmem_from_id_and_size(&mut self, id: ShMemId, size: usize) -> Result<Self::ShMem, Error> {
let parts = id.as_str().split(':').collect::<Vec<&str>>();
let server_id_str = parts.first().unwrap();
let (server_fd, client_fd) = self.send_receive(ServedShMemRequest::ExistingMap(
ShMemDescription::from_string_and_size(server_id_str, size),
))?;
Ok(ServedShMem {
inner: ManuallyDrop::new(
self.inner
.shmem_from_id_and_size(ShMemId::from_string(&format!("{client_fd}")), size)?,
),
server_fd,
})
}
fn pre_fork(&mut self) -> Result<(), Error> {
self.send_receive(ServedShMemRequest::PreFork())?;
Ok(())
}
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
if is_child {
if let ShMemService::Started { bg_thread, .. } = &self.service {
bg_thread.lock().unwrap().join_handle = None;
}
self.stream =
UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?;
let (id, _) = self.send_receive(ServedShMemRequest::PostForkChildHello(self.id))?;
self.id = id;
}
Ok(())
}
fn release_shmem(&mut self, map: &mut Self::ShMem) {
let (refcount, _) = self
.send_receive(ServedShMemRequest::Deregister(map.server_fd))
.expect("Could not communicate with ServedShMem server!");
if refcount == 1 {
unsafe {
ManuallyDrop::drop(&mut map.inner);
}
}
}
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub enum ServedShMemRequest {
NewMap(usize),
ExistingMap(ShMemDescription),
Deregister(i32),
Hello(),
PreFork(),
PostForkChildHello(i32),
Exit,
}
#[derive(Debug)]
struct SharedShMemClient<SH>
where
SH: ShMem,
{
stream: UnixStream,
maps: HashMap<i32, Vec<Rc<RefCell<SH>>>>,
}
impl<SH> SharedShMemClient<SH>
where
SH: ShMem,
{
fn new(stream: UnixStream) -> Self {
Self {
stream,
maps: HashMap::new(),
}
}
}
#[derive(Debug)]
enum ServedShMemResponse<SP>
where
SP: ShMemProvider,
{
Mapping(Rc<RefCell<SP::ShMem>>),
Id(i32),
RefCount(u32),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ShMemServiceStatus {
Starting,
Started,
Failed,
}
#[derive(Debug, Clone)]
pub enum ShMemService<SP>
where
SP: ShMemProvider,
{
Started {
bg_thread: Arc<Mutex<ShMemServiceThread>>,
phantom: PhantomData<SP>,
},
Failed {
err_msg: String,
phantom: PhantomData<SP>,
},
}
#[derive(Debug)]
pub struct ShMemServiceThread {
join_handle: Option<JoinHandle<Result<(), Error>>>,
}
impl Drop for ShMemServiceThread {
fn drop(&mut self) {
if self.join_handle.is_some() {
println!("Stopping ShMemService");
let Ok(mut stream) = UnixStream::connect_to_unix_addr(
&UnixSocketAddr::new(UNIX_SERVER_NAME).unwrap(),
) else { return };
let body = postcard::to_allocvec(&ServedShMemRequest::Exit).unwrap();
let header = (body.len() as u32).to_be_bytes();
let mut message = header.to_vec();
message.extend(body);
stream
.write_all(&message)
.expect("Failed to send bye-message to ShMemService");
self.join_handle
.take()
.unwrap()
.join()
.expect("Failed to join ShMemService thread!")
.expect("Error in ShMemService background thread!");
#[cfg(target_vendor = "apple")]
fs::remove_file(UNIX_SERVER_NAME).unwrap();
env::remove_var(AFL_SHMEM_SERVICE_STARTED);
}
}
}
impl<SP> ShMemService<SP>
where
SP: ShMemProvider,
{
#[must_use]
pub fn start() -> Self {
if env::var(AFL_SHMEM_SERVICE_STARTED).is_ok() {
return Self::Failed {
err_msg: "ShMemService already started".to_string(),
phantom: PhantomData,
};
}
#[allow(clippy::mutex_atomic)]
let syncpair = Arc::new((Mutex::new(ShMemServiceStatus::Starting), Condvar::new()));
let childsyncpair = Arc::clone(&syncpair);
let join_handle = thread::spawn(move || {
let mut worker = match ServedShMemServiceWorker::<SP>::new() {
Ok(worker) => worker,
Err(e) => {
let (lock, cvar) = &*childsyncpair;
*lock.lock().unwrap() = ShMemServiceStatus::Failed;
cvar.notify_one();
println!("Error creating ShMemService: {e:?}");
return Err(e);
}
};
if let Err(e) = worker.listen(UNIX_SERVER_NAME, &childsyncpair) {
println!("Error spawning ShMemService: {e:?}");
Err(e)
} else {
Ok(())
}
});
let (lock, cvar) = &*syncpair;
let mut success = lock.lock().unwrap();
while *success == ShMemServiceStatus::Starting {
success = cvar.wait(success).unwrap();
}
env::set_var(AFL_SHMEM_SERVICE_STARTED, "true");
let status = *success;
match status {
ShMemServiceStatus::Starting => panic!("Unreachable"),
ShMemServiceStatus::Started => {
println!("Started ShMem Service");
Self::Started {
bg_thread: Arc::new(Mutex::new(ShMemServiceThread {
join_handle: Some(join_handle),
})),
phantom: PhantomData,
}
}
ShMemServiceStatus::Failed => {
let err = join_handle.join();
let err = err.expect("Failed to join ShMemService thread!");
let err = err.expect_err("Expected service start to have failed, but it didn't?");
Self::Failed {
err_msg: format!("{err}"),
phantom: PhantomData,
}
}
}
}
}
#[allow(clippy::type_complexity)]
struct ServedShMemServiceWorker<SP>
where
SP: ShMemProvider,
{
provider: SP,
clients: HashMap<RawFd, SharedShMemClient<SP::ShMem>>,
forking_clients: HashMap<RawFd, HashMap<i32, Vec<Rc<RefCell<SP::ShMem>>>>>,
all_shmems: HashMap<i32, Weak<RefCell<SP::ShMem>>>,
}
impl<SP> ServedShMemServiceWorker<SP>
where
SP: ShMemProvider,
{
fn new() -> Result<Self, Error> {
Ok(Self {
provider: SP::new()?,
clients: HashMap::new(),
all_shmems: HashMap::new(),
forking_clients: HashMap::new(),
})
}
fn upgrade_shmem_with_id(&mut self, description_id: i32) -> Rc<RefCell<SP::ShMem>> {
self.all_shmems
.get_mut(&description_id)
.unwrap()
.clone()
.upgrade()
.unwrap()
}
fn handle_request(&mut self, client_id: RawFd) -> Result<ServedShMemResponse<SP>, Error> {
let request = self.read_request(client_id)?;
let response = match request {
ServedShMemRequest::Hello() => Ok(ServedShMemResponse::Id(client_id)),
ServedShMemRequest::PreFork() => {
self.forking_clients
.insert(client_id, self.clients[&client_id].maps.clone());
Ok(ServedShMemResponse::Id(client_id))
}
ServedShMemRequest::PostForkChildHello(other_id) => {
let client = self.clients.get_mut(&client_id).unwrap();
client.maps = self.forking_clients.remove(&other_id).unwrap();
Ok(ServedShMemResponse::Id(client_id))
}
ServedShMemRequest::NewMap(map_size) => {
let new_shmem = self.provider.new_shmem(map_size)?;
let description = new_shmem.description();
let new_rc = Rc::new(RefCell::new(new_shmem));
self.all_shmems
.insert(description.id.into(), Rc::downgrade(&new_rc));
Ok(ServedShMemResponse::Mapping(new_rc))
}
ServedShMemRequest::ExistingMap(description) => {
let client = self.clients.get_mut(&client_id).unwrap();
let description_id: i32 = description.id.into();
if client.maps.contains_key(&description_id) {
#[allow(clippy::option_if_let_else)]
Ok(ServedShMemResponse::Mapping(
if let Some(map) = client
.maps
.get_mut(&description_id)
.as_mut()
.unwrap()
.first()
.as_mut()
{
map.clone()
} else {
self.upgrade_shmem_with_id(description_id)
},
))
} else {
Ok(ServedShMemResponse::Mapping(
self.upgrade_shmem_with_id(description_id),
))
}
}
ServedShMemRequest::Deregister(map_id) => {
let client = self.clients.get_mut(&client_id).unwrap();
let maps = client.maps.entry(map_id).or_default();
if maps.is_empty() {
Ok(ServedShMemResponse::RefCount(0_u32))
} else {
Ok(ServedShMemResponse::RefCount(
Rc::strong_count(&maps.pop().unwrap()) as u32,
))
}
}
ServedShMemRequest::Exit => {
println!("ShMemService - Exiting");
return Err(Error::shutting_down());
}
};
response
}
fn read_request(&mut self, client_id: RawFd) -> Result<ServedShMemRequest, Error> {
let client = self.clients.get_mut(&client_id).unwrap();
let mut size_bytes = [0_u8; 4];
client.stream.read_exact(&mut size_bytes)?;
let size = u32::from_be_bytes(size_bytes);
let mut bytes = vec![];
bytes.resize(size as usize, 0_u8);
client
.stream
.read_exact(&mut bytes)
.expect("Failed to read message body");
let request: ServedShMemRequest = postcard::from_bytes(&bytes)?;
Ok(request)
}
fn handle_client(&mut self, client_id: RawFd) -> Result<(), Error> {
let response = self.handle_request(client_id)?;
match response {
ServedShMemResponse::Mapping(mapping) => {
let id = mapping.as_ref().borrow().id();
let server_fd: i32 = id.to_string().parse().unwrap();
let client = self.clients.get_mut(&client_id).unwrap();
client
.stream
.send_fds(id.to_string().as_bytes(), &[server_fd])?;
client.maps.entry(server_fd).or_default().push(mapping);
}
ServedShMemResponse::Id(id) => {
let client = self.clients.get_mut(&client_id).unwrap();
client.stream.send_fds(id.to_string().as_bytes(), &[])?;
}
ServedShMemResponse::RefCount(refcount) => {
let client = self.clients.get_mut(&client_id).unwrap();
client
.stream
.send_fds(refcount.to_string().as_bytes(), &[])?;
}
}
Ok(())
}
fn listen(
&mut self,
filename: &str,
syncpair: &Arc<(Mutex<ShMemServiceStatus>, Condvar)>,
) -> Result<(), Error> {
let listener = match UnixListener::bind_unix_addr(&UnixSocketAddr::new(filename)?) {
Ok(listener) => listener,
Err(err) => {
let (lock, cvar) = &**syncpair;
*lock.lock().unwrap() = ShMemServiceStatus::Failed;
cvar.notify_one();
return Err(Error::unknown(format!(
"The ShMem server appears to already be running. We are probably a client. Error: {err:?}")));
}
};
let mut poll_fds: Vec<PollFd> = vec![PollFd::new(
listener.as_raw_fd(),
PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND,
)];
let (lock, cvar) = &**syncpair;
*lock.lock().unwrap() = ShMemServiceStatus::Started;
cvar.notify_one();
loop {
match poll(&mut poll_fds, -1) {
Ok(num_fds) if num_fds > 0 => (),
Ok(_) => continue,
Err(e) => {
println!("Error polling for activity: {e:?}");
continue;
}
};
let copied_poll_fds: Vec<PollFd> = poll_fds.clone();
for poll_fd in copied_poll_fds {
let revents = poll_fd.revents().expect("revents should not be None");
let raw_polled_fd = unsafe { *((addr_of!(poll_fd)) as *const libc::pollfd) }.fd;
if revents.contains(PollFlags::POLLHUP) {
poll_fds.remove(poll_fds.iter().position(|item| *item == poll_fd).unwrap());
self.clients.remove(&raw_polled_fd);
} else if revents.contains(PollFlags::POLLIN) {
if self.clients.contains_key(&raw_polled_fd) {
match self.handle_client(raw_polled_fd) {
Ok(()) => (),
Err(e) => {
dbg!("Ignoring failed read from client", e, poll_fd);
continue;
}
};
} else {
let (stream, _addr) = match listener.accept_unix_addr() {
Ok(stream_val) => stream_val,
Err(e) => {
println!("Error accepting client: {e:?}");
continue;
}
};
println!("Recieved connection from {_addr:?}");
let pollfd = PollFd::new(
stream.as_raw_fd(),
PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND,
);
poll_fds.push(pollfd);
let client = SharedShMemClient::new(stream);
let client_id = client.stream.as_raw_fd();
self.clients.insert(client_id, client);
match self.handle_client(client_id) {
Ok(()) => (),
Err(Error::ShuttingDown) => {
println!("Shutting down");
return Ok(());
}
Err(e) => {
dbg!("Ignoring failed read from client", e);
}
};
}
} else {
}
}
}
}
}