use crate::Program;
use gecko_profile::{Frame, ProfileBuilder, StringIndex, ThreadBuilder};
use hashbrown::HashMap;
use indicatif::{ProgressBar, ProgressStyle};
#[derive(Debug, thiserror::Error)]
pub enum ProfilerError {
#[error("Failed to read ELF file {}", .0)]
Io(#[from] std::io::Error),
#[error("Failed to parse ELF file {}", .0)]
Elf(#[from] eyre::Error),
#[error("Failed to serialize samples {}", .0)]
Serde(#[from] serde_json::Error),
}
pub struct Profiler {
sample_rate: u64,
start_lookup: HashMap<u64, usize>,
function_ranges: Vec<(u64, u64, Frame)>,
function_stack: Vec<Frame>,
function_stack_indices: Vec<usize>,
function_stack_ranges: Vec<(u64, u64)>,
current_function_range: (u64, u64),
main_idx: Option<StringIndex>,
builder: ThreadBuilder,
samples: Vec<Sample>,
}
struct Sample {
stack: Vec<Frame>,
}
impl Profiler {
pub(super) fn from_program(program: &Program, sample_rate: u64) -> Self {
let mut start_lookup = HashMap::new();
let mut function_ranges = Vec::new();
let mut builder = ThreadBuilder::new(1, 0, std::time::Instant::now(), false, false);
let mut main_idx = None;
for (demangled_name, start_address, size) in &program.function_symbols {
let end_address = start_address + size - 4;
let string_idx = builder.intern_string(demangled_name);
if main_idx.is_none() && demangled_name == "main" {
main_idx = Some(string_idx);
}
let start_idx = function_ranges.len();
function_ranges.push((*start_address, end_address, Frame::Label(string_idx)));
start_lookup.insert(*start_address, start_idx);
}
Self {
builder,
main_idx,
sample_rate,
samples: Vec::new(),
start_lookup,
function_ranges,
function_stack: Vec::new(),
function_stack_indices: Vec::new(),
function_stack_ranges: Vec::new(),
current_function_range: (0, 0),
}
}
pub(super) fn record(&mut self, clk: u64, pc: u64) {
if pc > self.current_function_range.0 && pc <= self.current_function_range.1 {
if clk.is_multiple_of(self.sample_rate) {
self.samples.push(Sample { stack: self.function_stack.clone() });
}
return;
}
if let Some(f) = self.start_lookup.get(&pc) {
if !self.function_stack_indices.contains(f) {
self.function_stack_indices.push(*f);
let (start, end, name) = self.function_ranges.get(*f).unwrap();
self.current_function_range = (*start, *end);
self.function_stack_ranges.push((*start, *end));
self.function_stack.push(name.clone());
}
} else {
let mut unwind_point = 0;
let mut unwind_found = false;
for (c, (s, e)) in self.function_stack_ranges.iter().enumerate() {
if pc > *s && pc <= *e {
unwind_point = c;
unwind_found = true;
break;
}
}
if unwind_found {
self.function_stack.truncate(unwind_point + 1);
self.function_stack_ranges.truncate(unwind_point + 1);
self.function_stack_indices.truncate(unwind_point + 1);
}
}
if clk.is_multiple_of(self.sample_rate) {
self.samples.push(Sample { stack: self.function_stack.clone() });
}
}
pub(super) fn write(mut self, writer: impl std::io::Write) -> Result<(), ProfilerError> {
self.check_samples();
let start_time = std::time::Instant::now();
let mut profile_builder = ProfileBuilder::new(
start_time,
std::time::SystemTime::now(),
"SP1 ZKVM",
0,
std::time::Duration::from_micros(1),
);
let pb = ProgressBar::new(self.samples.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(
#[allow(clippy::literal_string_with_formatting_args)]
"{msg} \n {spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
)
.unwrap()
.progress_chars("#>-"),
);
pb.set_message("Creating profile");
let mut last_known_time = std::time::Instant::now();
for sample in self.samples.drain(..) {
pb.inc(1);
self.builder.add_sample(
last_known_time,
sample.stack.into_iter(),
std::time::Duration::from_micros(self.sample_rate),
);
last_known_time += std::time::Duration::from_micros(self.sample_rate);
}
profile_builder.add_thread(self.builder);
pb.finish();
eprintln!("Writing profile, this can take awhile");
serde_json::to_writer(writer, &profile_builder.to_serializable())?;
eprintln!("Profile written successfully");
Ok(())
}
fn check_samples(&self) {
let Some(main_idx) = self.main_idx else {
eprintln!(
"Warning: The `main` function is not present in the Elf file, this is likely caused by using the wrong Elf file"
);
return;
};
let main_count = self
.samples
.iter()
.filter(|s| {
s.stack
.iter()
.any(|f| if let Frame::Label(idx) = f { *idx == main_idx } else { false })
})
.count();
#[allow(clippy::cast_precision_loss)]
let main_ratio = main_count as f64 / self.samples.len() as f64;
if main_ratio < 0.9 {
eprintln!(
"Warning: This trace appears to be invalid. The `main` function is present in only {:.2}% of the samples, this is likely caused by the using the wrong Elf file",
main_ratio * 100.0
);
}
}
}