use core::{hash::Hash, marker::PhantomData};
use hashbrown::HashMap;
use libafl_bolts::{
Named,
rands::Rand,
tuples::{Handle, Handled, MatchName},
};
use serde::{Deserialize, Serialize};
use crate::{
Error, HasMetadata,
corpus::{Corpus, CorpusId, HasTestcase, Testcase},
random_corpus_id,
schedulers::{
AflScheduler, HasQueueCycles, RemovableScheduler, Scheduler, on_add_metadata_default,
on_evaluation_metadata_default, on_next_metadata_default,
powersched::{BaseSchedule, PowerSchedule, SchedulerMetadata},
testcase_score::{CorpusWeightTestcaseScore, TestcaseScore},
},
state::{HasCorpus, HasRand},
};
#[cfg_attr(
any(not(feature = "serdeany_autoreg"), miri),
expect(clippy::unsafe_derive_deserialize)
)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WeightedScheduleMetadata {
runs_in_current_cycle: usize,
alias_table: HashMap<CorpusId, CorpusId>,
alias_probability: HashMap<CorpusId, f64>,
}
impl Default for WeightedScheduleMetadata {
fn default() -> Self {
Self::new()
}
}
impl WeightedScheduleMetadata {
#[must_use]
pub fn new() -> Self {
Self {
runs_in_current_cycle: 0,
alias_table: HashMap::default(),
alias_probability: HashMap::default(),
}
}
#[must_use]
pub fn runs_in_current_cycle(&self) -> usize {
self.runs_in_current_cycle
}
pub fn set_runs_current_cycle(&mut self, cycles: usize) {
self.runs_in_current_cycle = cycles;
}
#[must_use]
pub fn alias_table(&self) -> &HashMap<CorpusId, CorpusId> {
&self.alias_table
}
pub fn set_alias_table(&mut self, table: HashMap<CorpusId, CorpusId>) {
self.alias_table = table;
}
#[must_use]
pub fn alias_probability(&self) -> &HashMap<CorpusId, f64> {
&self.alias_probability
}
pub fn set_alias_probability(&mut self, probability: HashMap<CorpusId, f64>) {
self.alias_probability = probability;
}
}
libafl_bolts::impl_serdeany!(WeightedScheduleMetadata);
#[derive(Debug, Clone)]
pub struct WeightedScheduler<C, F, O> {
table_invalidated: bool,
strat: Option<PowerSchedule>,
observer_handle: Handle<C>,
last_hash: usize,
queue_cycles: u64,
phantom: PhantomData<(F, O)>,
cycle_schedules: bool,
}
impl<C, F, O> WeightedScheduler<C, F, O>
where
C: Named,
{
#[must_use]
pub fn new<S>(state: &mut S, observer: &C) -> Self
where
S: HasMetadata,
{
Self::with_schedule(state, observer, None)
}
#[must_use]
pub fn with_schedule<S>(state: &mut S, observer: &C, strat: Option<PowerSchedule>) -> Self
where
S: HasMetadata,
{
let _ = state.metadata_or_insert_with(|| SchedulerMetadata::new(strat));
let _ = state.metadata_or_insert_with(WeightedScheduleMetadata::new);
Self {
strat,
observer_handle: observer.handle(),
last_hash: 0,
queue_cycles: 0,
table_invalidated: true,
cycle_schedules: false,
phantom: PhantomData,
}
}
#[must_use]
pub fn cycling_scheduler(mut self) -> Self {
self.cycle_schedules = true;
self
}
#[must_use]
pub fn strat(&self) -> &Option<PowerSchedule> {
&self.strat
}
#[expect(clippy::cast_precision_loss)]
pub fn create_alias_table<I, S>(&self, state: &mut S) -> Result<(), Error>
where
F: TestcaseScore<I, S>,
S: HasCorpus<I> + HasMetadata,
{
let n = state.corpus().count();
let mut alias_table: HashMap<CorpusId, CorpusId> = HashMap::default();
let mut alias_probability: HashMap<CorpusId, f64> = HashMap::default();
let mut weights: HashMap<CorpusId, f64> = HashMap::default();
let mut p_arr: HashMap<CorpusId, f64> = HashMap::default();
let mut s_arr: HashMap<usize, CorpusId> = HashMap::default();
let mut l_arr: HashMap<usize, CorpusId> = HashMap::default();
let mut sum: f64 = 0.0;
for i in state.corpus().ids() {
let mut testcase = state.corpus().get(i)?.borrow_mut();
let weight = F::compute(state, &mut *testcase)?;
weights.insert(i, weight);
sum += weight;
}
for (i, w) in &weights {
p_arr.insert(*i, w * (n as f64) / sum);
}
let mut n_s = 0;
let mut n_l = 0;
for s in state.corpus().ids().rev() {
if *p_arr.get(&s).unwrap() < 1.0 {
s_arr.insert(n_s, s);
n_s += 1;
} else {
l_arr.insert(n_l, s);
n_l += 1;
}
}
while n_s > 0 && n_l > 0 {
n_s -= 1;
n_l -= 1;
let a = *s_arr.get(&n_s).unwrap();
let g = *l_arr.get(&n_l).unwrap();
alias_probability.insert(a, *p_arr.get(&a).unwrap());
alias_table.insert(a, g);
*p_arr.get_mut(&g).unwrap() += p_arr.get(&a).unwrap() - 1.0;
if *p_arr.get(&g).unwrap() < 1.0 {
*s_arr.get_mut(&n_s).unwrap() = g;
n_s += 1;
} else {
*l_arr.get_mut(&n_l).unwrap() = g;
n_l += 1;
}
}
while n_l > 0 {
n_l -= 1;
alias_probability.insert(*l_arr.get(&n_l).unwrap(), 1.0);
}
while n_s > 0 {
n_s -= 1;
alias_probability.insert(*s_arr.get(&n_s).unwrap(), 1.0);
}
let wsmeta = state.metadata_mut::<WeightedScheduleMetadata>()?;
wsmeta.set_alias_probability(alias_probability);
wsmeta.set_alias_table(alias_table);
Ok(())
}
fn cycle_schedule(&mut self, metadata: &mut SchedulerMetadata) -> Result<(), Error> {
let mut ps = metadata.strat().ok_or(Error::illegal_argument(
"No strategy specified when initializing scheduler; cannot cycle!",
))?;
let new_base = match ps.base() {
BaseSchedule::EXPLORE => BaseSchedule::EXPLOIT,
BaseSchedule::COE => BaseSchedule::LIN,
BaseSchedule::LIN => BaseSchedule::QUAD,
BaseSchedule::FAST => BaseSchedule::COE,
BaseSchedule::QUAD => BaseSchedule::FAST,
BaseSchedule::EXPLOIT => BaseSchedule::EXPLORE,
};
ps.set_base(new_base);
metadata.set_strat(Some(ps));
self.table_invalidated = true;
Ok(())
}
}
impl<C, F, I, O, S> RemovableScheduler<I, S> for WeightedScheduler<C, F, O> {
fn on_remove(
&mut self,
_state: &mut S,
_id: CorpusId,
_prev: &Option<Testcase<I>>,
) -> Result<(), Error> {
self.table_invalidated = true;
Ok(())
}
fn on_replace(
&mut self,
_state: &mut S,
_id: CorpusId,
_prev: &Testcase<I>,
) -> Result<(), Error> {
self.table_invalidated = true;
Ok(())
}
}
impl<C, F, O> AflScheduler for WeightedScheduler<C, F, O> {
type ObserverRef = C;
fn last_hash(&self) -> usize {
self.last_hash
}
fn set_last_hash(&mut self, hash: usize) {
self.last_hash = hash;
}
fn observer_handle(&self) -> &Handle<C> {
&self.observer_handle
}
}
impl<C, F, O> HasQueueCycles for WeightedScheduler<C, F, O> {
fn queue_cycles(&self) -> u64 {
self.queue_cycles
}
}
impl<C, F, I, O, S> Scheduler<I, S> for WeightedScheduler<C, F, O>
where
C: AsRef<O> + Named,
F: TestcaseScore<I, S>,
O: Hash,
S: HasCorpus<I> + HasMetadata + HasRand + HasTestcase<I>,
{
fn on_add(&mut self, state: &mut S, id: CorpusId) -> Result<(), Error> {
on_add_metadata_default(self, state, id)?;
self.table_invalidated = true;
Ok(())
}
fn on_evaluation<OT>(&mut self, state: &mut S, _input: &I, observers: &OT) -> Result<(), Error>
where
OT: MatchName,
{
on_evaluation_metadata_default(self, state, observers)
}
fn next(&mut self, state: &mut S) -> Result<CorpusId, Error> {
if self.table_invalidated {
self.create_alias_table(state)?;
self.table_invalidated = false;
}
let corpus_counts = state.corpus().count();
if corpus_counts == 0 {
Err(Error::empty(
"No entries in corpus. This often implies the target is not properly instrumented.",
))
} else {
let s = random_corpus_id!(state.corpus(), state.rand_mut());
let probability = state.rand_mut().next_float();
let wsmeta = state.metadata_mut::<WeightedScheduleMetadata>()?;
let runs_in_current_cycle = wsmeta.runs_in_current_cycle();
if runs_in_current_cycle >= corpus_counts {
wsmeta.set_runs_current_cycle(0);
} else {
wsmeta.set_runs_current_cycle(runs_in_current_cycle + 1);
}
let idx = if probability < *wsmeta.alias_probability().get(&s).unwrap() {
s
} else {
*wsmeta.alias_table().get(&s).unwrap()
};
if runs_in_current_cycle >= corpus_counts {
self.queue_cycles += 1;
let psmeta = state.metadata_mut::<SchedulerMetadata>()?;
psmeta.set_queue_cycles(self.queue_cycles());
if self.cycle_schedules {
self.cycle_schedule(psmeta)?;
}
}
self.set_current_scheduled(state, Some(idx))?;
Ok(idx)
}
}
fn set_current_scheduled(
&mut self,
state: &mut S,
next_id: Option<CorpusId>,
) -> Result<(), Error> {
on_next_metadata_default(state)?;
*state.corpus_mut().current_mut() = next_id;
Ok(())
}
}
pub type StdWeightedScheduler<C, O> = WeightedScheduler<C, CorpusWeightTestcaseScore, O>;