use crate::mutator::Mutator;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum DriverMode {
Reproduce,
Run,
}
pub struct FuzzerDriver<T> {
thread_count: usize,
threads: RwLock<Vec<thread::JoinHandle<()>>>,
num_iterations: AtomicUsize,
num_failed_iterations: AtomicUsize,
exit: AtomicBool,
seed: u64,
global_context: Option<Arc<RwLock<T>>>,
mode: DriverMode,
start_iteration: u64,
end_iteration: u64,
thread_last_execution_time: Vec<AtomicUsize>,
thread_timeout: Duration,
}
impl<T: 'static + Send + Sync> Default for FuzzerDriver<T> {
fn default() -> Self {
FuzzerDriver::<T>::new(1)
}
}
impl<T: 'static + Send + Sync> FuzzerDriver<T> {
pub fn new(num_threads: usize) -> Self {
let mut last_execution_times = Vec::with_capacity(num_threads);
let since_the_epoch = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
for _i in 0..num_threads {
last_execution_times.push(AtomicUsize::new(since_the_epoch.as_secs() as usize));
}
FuzzerDriver {
thread_count: num_threads,
threads: RwLock::new(Vec::with_capacity(num_threads)),
num_iterations: Default::default(),
num_failed_iterations: Default::default(),
exit: Default::default(),
seed: rand::random(),
global_context: Default::default(),
mode: DriverMode::Run,
start_iteration: 0,
end_iteration: 0,
thread_last_execution_time: last_execution_times,
thread_timeout: Duration::from_secs(10u64),
}
}
pub fn thread_count(&self) -> usize {
self.thread_count
}
pub fn set_to_reproduce_mode(&mut self, start_iteration: u64, end_iteration: u64) {
self.mode = DriverMode::Reproduce;
self.start_iteration = start_iteration;
self.end_iteration = end_iteration;
self.num_iterations
.store(start_iteration as usize, Ordering::SeqCst);
}
pub fn num_iterations(&self) -> usize {
self.num_iterations.load(Ordering::SeqCst)
}
pub fn set_iterations(&self, iterations: usize) {
self.num_iterations.store(iterations, Ordering::SeqCst);
}
pub fn num_failed_iterations(&self) -> usize {
self.num_failed_iterations.load(Ordering::SeqCst)
}
pub fn set_global_context(&mut self, context: Arc<RwLock<T>>) {
self.global_context = Some(context);
}
pub fn global_context(&self) -> Option<Arc<RwLock<T>>> {
self.global_context.as_ref().map(|c| Arc::clone(c))
}
pub fn set_seed(&mut self, seed: u64) {
self.seed = seed;
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn signal_exit(&self) {
self.exit.store(true, Ordering::SeqCst);
}
pub fn join_threads(&self) {
let mut threads = self.threads.write().unwrap();
loop {
let handle = threads.pop();
match handle {
Some(handle) => {
let thread_name = handle.thread().name().map_or(
String::from("UNNAMED_THREAD"),
std::borrow::ToOwned::to_owned,
);
handle
.join()
.unwrap_or_else(|_| println!("thread {} failed to join", thread_name));
}
None => break,
}
}
}
pub(crate) fn should_exit(&self) -> bool {
if self.mode == DriverMode::Reproduce {
return self.num_iterations() == self.end_iteration as usize;
}
self.exit.load(Ordering::SeqCst)
}
pub fn mode(&self) -> DriverMode {
self.mode
}
pub(crate) fn set_thread_last_execution_time(&self, thread_index: usize) {
let since_the_epoch = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
self.thread_last_execution_time[thread_index]
.store(since_the_epoch.as_secs() as usize, Ordering::SeqCst);
}
pub fn check_for_stalled_threads(&self) -> bool {
let mut threads_have_stalled = false;
let since_the_epoch = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
for i in 0..self.thread_count {
let last_update = self.thread_last_execution_time[i].load(Ordering::SeqCst) as u64;
let last_update = Duration::from_secs(last_update);
if last_update > since_the_epoch {
continue;
}
if since_the_epoch - last_update > self.thread_timeout {
error!(
"{:?} has stalled!",
self.threads.read().unwrap()[i].thread().id()
);
threads_have_stalled = true;
}
}
threads_have_stalled
}
pub fn set_thread_timeout(&mut self, duration: Duration) {
self.thread_timeout = duration
}
}
pub fn start_fuzzer<F: 'static, C: 'static, T: 'static + Send + Sync>(
driver: Arc<FuzzerDriver<T>>,
callback: F,
) where
F: Fn(&mut Mutator<StdRng>, &mut C, Option<Arc<RwLock<T>>>) -> Result<(), ()>
+ std::marker::Send
+ std::marker::Sync
+ Copy,
C: Default,
{
let mut root_rng = StdRng::seed_from_u64(driver.seed());
let mut threads = driver.threads.write().unwrap();
for i in 0..threads.capacity() {
let thread_driver = driver.clone();
let thread_name = format!("Fuzzer thread {}", i);
let thread_seed: u64 = root_rng.gen();
let join_handle = thread::Builder::new()
.name(thread_name)
.spawn(move || {
let thread_rng = StdRng::seed_from_u64(0u64);
let mut mutator = Mutator::new(thread_rng);
let mut context = C::default();
loop {
thread_driver.set_thread_last_execution_time(i);
let new_seed = thread_seed.wrapping_add(thread_driver.num_iterations() as u64);
mutator.rng = StdRng::seed_from_u64(new_seed);
if thread_driver.should_exit() {
log::info!("{} exiting", thread::current().name().unwrap());
return;
}
mutator.random_flags();
if let Err(_) =
(callback)(&mut mutator, &mut context, thread_driver.global_context())
{
thread_driver
.num_failed_iterations
.fetch_add(1, Ordering::SeqCst);
}
thread_driver.num_iterations.fetch_add(1, Ordering::SeqCst);
}
})
.unwrap_or_else(|_| panic!("could not create new thread"));
threads.push(join_handle);
}
}