use proc_macro2::TokenStream;
use quote::quote;
use quote::ToTokens;
use crate::types::trident_flow_executor::TridentFlowExecutorImpl;
impl ToTokens for TridentFlowExecutorImpl {
fn to_tokens(&self, tokens: &mut TokenStream) {
let expanded = self.generate_flow_executor_impl();
tokens.extend(expanded);
}
}
impl TridentFlowExecutorImpl {
fn generate_flow_executor_impl(&self) -> TokenStream {
let type_name = &self.type_name;
let impl_items = &self.impl_block;
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let flow_executor_impl = self.generate_flow_executor_trait_impl();
quote! {
impl #impl_generics #type_name #ty_generics #where_clause {
#(#impl_items)*
}
#flow_executor_impl
}
}
fn generate_flow_executor_trait_impl(&self) -> TokenStream {
let type_name = &self.type_name;
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let execute_flows_method = self.generate_execute_flows_method();
let coverage_method = self.generate_coverage_method();
quote! {
impl #impl_generics FlowExecutor for #type_name #ty_generics #where_clause {
fn new() -> Self {
Self::new()
}
fn execute_flows(&mut self, flow_calls_per_iteration: u64) -> std::result::Result<(), FuzzingError> {
#execute_flows_method
Ok(())
}
fn trident_mut(&mut self) -> &mut Trident {
&mut self.trident
}
fn reset_fuzz_accounts(&mut self) {
let _ = std::mem::take(&mut self.fuzz_accounts);
}
fn handle_llvm_coverage(&mut self, current_iteration: u64) {
#coverage_method
}
}
}
}
fn generate_execute_flows_method(&self) -> TokenStream {
let init_call = self.generate_init_call();
let flow_execution_logic = self.generate_flow_execution_logic();
let end_call = self.generate_end_call();
quote! {
#init_call
#flow_execution_logic
#end_call
}
}
fn generate_init_call(&self) -> TokenStream {
if let Some(init_method) = &self.init_method {
quote! {
self.#init_method();
}
} else {
quote! {}
}
}
fn generate_end_call(&self) -> TokenStream {
if let Some(end_method) = &self.end_method {
quote! {
self.#end_method();
}
} else {
quote! {}
}
}
fn generate_flow_execution_logic(&self) -> TokenStream {
let active_methods: Vec<_> = self
.flow_methods
.iter()
.filter(|method| !method.constraints.ignore)
.collect();
if active_methods.is_empty() {
quote! {
}
} else {
let flow_selection_logic = self.generate_flow_selection_logic(&active_methods);
quote! {
#flow_selection_logic
}
}
}
fn generate_flow_selection_logic(
&self,
active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
) -> TokenStream {
let has_weights = active_methods
.iter()
.any(|method| method.constraints.weight.is_some());
if has_weights {
self.generate_weighted_flow_selection(active_methods)
} else {
self.generate_uniform_flow_selection(active_methods)
}
}
fn generate_uniform_flow_selection(
&self,
active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
) -> TokenStream {
let flow_match_arms = active_methods.iter().enumerate().map(|(index, method)| {
let method_ident = &method.ident;
quote! {
#index => self.#method_ident(),
}
});
let num_flows = active_methods.len();
quote! {
for _ in 0..flow_calls_per_iteration {
let flow_index = self.trident.random_from_range(0..#num_flows);
match flow_index {
#(#flow_match_arms)*
_ => unreachable!("Invalid flow index"),
}
}
}
}
fn generate_weighted_flow_selection(
&self,
active_methods: &[&crate::types::trident_flow_executor::FlowMethod],
) -> TokenStream {
let weighted_methods: Vec<_> = active_methods
.iter()
.filter(|method| method.constraints.weight.unwrap_or(0) > 0)
.collect();
if weighted_methods.is_empty() {
return quote! {
};
}
let total_weight: u32 = weighted_methods
.iter()
.map(|method| method.constraints.weight.unwrap())
.sum();
let mut cumulative_weight = 0u32;
let weight_ranges: Vec<_> = weighted_methods
.iter()
.map(|method| {
let weight = method.constraints.weight.unwrap();
let _start = cumulative_weight;
cumulative_weight += weight;
let end = cumulative_weight;
let method_ident = &method.ident;
quote! {
if random_weight < #end {
self.#method_ident();
continue;
}
}
})
.collect();
quote! {
for _ in 0..flow_calls_per_iteration {
let random_weight = self.trident.random_from_range(0..#total_weight);
#(#weight_ranges)*
}
}
}
fn generate_coverage_method(&self) -> TokenStream {
let rustflags = std::env::var("RUSTFLAGS").unwrap_or_default();
let coverage_enabled = rustflags.contains("-C instrument-coverage");
if coverage_enabled {
quote! {
unsafe {
let filename = format!("target/fuzz-cov-run-{}.profraw", current_iteration);
if let Ok(filename_cstr) = std::ffi::CString::new(filename) {
trident_fuzz::fuzzing::__llvm_profile_set_filename(filename_cstr.as_ptr());
let _ = trident_fuzz::fuzzing::__llvm_profile_write_file();
trident_fuzz::fuzzing::__llvm_profile_reset_counters();
if let Ok(final_filename_cstr) = std::ffi::CString::new("target/fuzz-cov-run-final.profraw") {
trident_fuzz::fuzzing::__llvm_profile_set_filename(final_filename_cstr.as_ptr());
}
}
}
}
} else {
quote! {
}
}
}
}