use crate::{config::ProfileResult, error::*};
use std::path::PathBuf;
use std::time::Instant;
#[cfg(unix)]
mod unix {
use super::*;
use pprof::ProfilerGuard;
pub struct ProfileSession {
guard: ProfilerGuard<'static>,
start_time: Instant,
output_path: PathBuf,
}
impl ProfileSession {
pub fn start(frequency: i32, output_path: PathBuf) -> Result<Self> {
let guard = pprof::ProfilerGuardBuilder::default()
.frequency(frequency)
.blocklist(&["libc", "libgcc", "pthread", "vdso"])
.build()
.map_err(|e| Error::StartFailed(e.to_string()))?;
Ok(Self {
guard,
start_time: Instant::now(),
output_path,
})
}
pub fn stop(self) -> Result<ProfileResult> {
use std::fs::File;
use std::io::BufWriter;
let duration = self.start_time.elapsed();
let report = self
.guard
.report()
.build()
.map_err(|e| Error::ReportFailed(e.to_string()))?;
let sample_count = report.data.len();
let file = File::create(&self.output_path)?;
let mut writer = BufWriter::new(file);
report
.flamegraph(&mut writer)
.map_err(|e| Error::FlamegraphFailed(e.to_string()))?;
Ok(ProfileResult {
flamegraph_path: self.output_path,
sample_count,
duration_ms: duration.as_millis() as u64,
})
}
}
}
#[cfg(windows)]
mod windows_impl {
use super::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use windows::Win32::Foundation::{CloseHandle, HANDLE};
use windows::Win32::System::Diagnostics::Debug::{
AddrModeFlat, CONTEXT, CONTEXT_FLAGS, GetThreadContext, STACKFRAME64, SYMBOL_INFO,
StackWalk64, SymCleanup, SymFromAddr, SymInitialize,
};
use windows::Win32::System::Diagnostics::ToolHelp::{
CreateToolhelp32Snapshot, TH32CS_SNAPTHREAD, THREADENTRY32, Thread32First, Thread32Next,
};
use windows::Win32::System::Threading::{
GetCurrentProcess, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
SuspendThread, THREAD_GET_CONTEXT, THREAD_QUERY_INFORMATION, THREAD_SUSPEND_RESUME,
};
const IMAGE_FILE_MACHINE_AMD64: u32 = 0x8664;
const CONTEXT_ALL: CONTEXT_FLAGS = CONTEXT_FLAGS(0x10001f);
type SampleData = HashMap<Vec<String>, usize>;
pub struct ProfileSession {
sampler_handle: Option<JoinHandle<SampleData>>,
stop_signal: Arc<AtomicBool>,
start_time: Instant,
output_path: PathBuf,
}
impl ProfileSession {
pub fn start(frequency: i32, output_path: PathBuf) -> Result<Self> {
let stop_signal = Arc::new(AtomicBool::new(false));
let stop_signal_clone = stop_signal.clone();
let sample_interval = Duration::from_micros(1_000_000 / frequency as u64);
let profiler_thread_id = unsafe { GetCurrentThreadId() };
let process_id = unsafe { GetCurrentProcessId() };
let sampler_handle = thread::Builder::new()
.name("cpu-profiler".to_string())
.spawn(move || {
sampler_thread_main(
process_id,
profiler_thread_id,
sample_interval,
stop_signal_clone,
)
})
.map_err(|e| {
Error::StartFailed(format!("Failed to spawn sampler thread: {}", e))
})?;
Ok(Self {
sampler_handle: Some(sampler_handle),
stop_signal,
start_time: Instant::now(),
output_path,
})
}
pub fn stop(mut self) -> Result<ProfileResult> {
use inferno::flamegraph::{self, Options};
use std::fs::File;
use std::io::{BufWriter, Write};
let duration = self.start_time.elapsed();
self.stop_signal.store(true, Ordering::SeqCst);
let samples = self
.sampler_handle
.take()
.ok_or_else(|| Error::ReportFailed("Sampler thread already stopped".to_string()))?
.join()
.map_err(|_| Error::ReportFailed("Sampler thread panicked".to_string()))?;
let sample_count = samples.values().sum::<usize>();
let mut folded_stacks: Vec<String> = samples
.into_iter()
.map(|(frames, count)| {
let reversed: Vec<_> = frames.into_iter().rev().collect();
format!("{} {}", reversed.join(";"), count)
})
.collect();
folded_stacks.sort();
let file = File::create(&self.output_path)?;
let mut writer = BufWriter::new(file);
let mut opts = Options::default();
opts.title = "CPU Profile".to_string();
let folded_data = folded_stacks.join("\n");
flamegraph::from_reader(&mut opts, folded_data.as_bytes(), &mut writer)
.map_err(|e| Error::FlamegraphFailed(e.to_string()))?;
writer.flush()?;
Ok(ProfileResult {
flamegraph_path: self.output_path,
sample_count,
duration_ms: duration.as_millis() as u64,
})
}
}
fn sampler_thread_main(
process_id: u32,
profiler_thread_id: u32,
sample_interval: Duration,
stop_signal: Arc<AtomicBool>,
) -> SampleData {
let mut samples: SampleData = HashMap::new();
let process_handle = unsafe { GetCurrentProcess() };
let sym_initialized = unsafe { SymInitialize(process_handle, None, true).is_ok() };
if !sym_initialized {
log::warn!("Failed to initialize symbol handler");
}
while !stop_signal.load(Ordering::SeqCst) {
if let Ok(thread_ids) = enumerate_threads(process_id, profiler_thread_id) {
for thread_id in thread_ids {
if let Ok(frames) = sample_thread(process_handle, thread_id, sym_initialized)
&& !frames.is_empty()
{
*samples.entry(frames).or_insert(0) += 1;
}
}
}
thread::sleep(sample_interval);
}
if sym_initialized {
unsafe {
let _ = SymCleanup(process_handle);
}
}
samples
}
fn enumerate_threads(process_id: u32, exclude_thread_id: u32) -> Result<Vec<u32>> {
let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) }
.map_err(|e| Error::StartFailed(format!("CreateToolhelp32Snapshot failed: {}", e)))?;
let mut thread_ids = Vec::new();
let mut entry = THREADENTRY32 {
dwSize: std::mem::size_of::<THREADENTRY32>() as u32,
..Default::default()
};
unsafe {
if Thread32First(snapshot, &mut entry).is_ok() {
loop {
if entry.th32OwnerProcessID == process_id
&& entry.th32ThreadID != exclude_thread_id
{
thread_ids.push(entry.th32ThreadID);
}
entry.dwSize = std::mem::size_of::<THREADENTRY32>() as u32;
if Thread32Next(snapshot, &mut entry).is_err() {
break;
}
}
}
let _ = CloseHandle(snapshot);
}
Ok(thread_ids)
}
fn sample_thread(
process_handle: HANDLE,
thread_id: u32,
symbolize: bool,
) -> Result<Vec<String>> {
let thread_handle = unsafe {
OpenThread(
THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_QUERY_INFORMATION,
false,
thread_id,
)
}
.map_err(|e| Error::StartFailed(format!("OpenThread failed: {}", e)))?;
let mut frames = Vec::new();
unsafe {
if SuspendThread(thread_handle) == u32::MAX {
let _ = CloseHandle(thread_handle);
return Ok(frames);
}
#[repr(align(16))]
struct AlignedContext {
context: CONTEXT,
}
let mut aligned_ctx = AlignedContext {
context: std::mem::zeroed(),
};
aligned_ctx.context.ContextFlags = CONTEXT_ALL;
if GetThreadContext(thread_handle, &mut aligned_ctx.context).is_ok() {
frames = walk_stack(
process_handle,
thread_handle,
&mut aligned_ctx.context,
symbolize,
);
}
let _ = ResumeThread(thread_handle);
let _ = CloseHandle(thread_handle);
}
Ok(frames)
}
fn walk_stack(
process_handle: HANDLE,
thread_handle: HANDLE,
context: &mut CONTEXT,
symbolize: bool,
) -> Vec<String> {
let mut frames = Vec::new();
const MAX_FRAMES: usize = 128;
unsafe {
let mut stack_frame: STACKFRAME64 = std::mem::zeroed();
stack_frame.AddrPC.Offset = context.Rip;
stack_frame.AddrPC.Mode = AddrModeFlat;
stack_frame.AddrStack.Offset = context.Rsp;
stack_frame.AddrStack.Mode = AddrModeFlat;
stack_frame.AddrFrame.Offset = context.Rbp;
stack_frame.AddrFrame.Mode = AddrModeFlat;
for _ in 0..MAX_FRAMES {
let result = StackWalk64(
IMAGE_FILE_MACHINE_AMD64,
process_handle,
thread_handle,
&mut stack_frame,
std::ptr::from_mut(context).cast(),
None,
None, None, None,
);
if !result.as_bool() || stack_frame.AddrPC.Offset == 0 {
break;
}
let frame_name = if symbolize {
get_symbol_name(process_handle, stack_frame.AddrPC.Offset)
.unwrap_or_else(|| format!("0x{:x}", stack_frame.AddrPC.Offset))
} else {
format!("0x{:x}", stack_frame.AddrPC.Offset)
};
frames.push(frame_name);
}
}
frames
}
fn get_symbol_name(process_handle: HANDLE, address: u64) -> Option<String> {
unsafe {
const MAX_NAME_LEN: usize = 256;
#[repr(C)]
struct SymbolBuffer {
info: SYMBOL_INFO,
name_buf: [u8; MAX_NAME_LEN],
}
let mut symbol_buf: SymbolBuffer = std::mem::zeroed();
symbol_buf.info.SizeOfStruct = std::mem::size_of::<SYMBOL_INFO>() as u32;
symbol_buf.info.MaxNameLen = MAX_NAME_LEN as u32;
let mut displacement: u64 = 0;
if SymFromAddr(
process_handle,
address,
Some(&mut displacement),
&mut symbol_buf.info,
)
.is_ok()
{
let name_len = symbol_buf.info.NameLen as usize;
if name_len > 0 && name_len <= MAX_NAME_LEN {
let name_ptr = symbol_buf.info.Name.as_ptr() as *const u8;
let name_slice = std::slice::from_raw_parts(name_ptr, name_len);
return String::from_utf8_lossy(name_slice).to_string().into();
}
}
None
}
}
}
#[cfg(not(any(unix, windows)))]
mod unsupported {
use super::*;
pub struct ProfileSession {
_private: (),
}
impl ProfileSession {
pub fn start(_frequency: i32, _output_path: PathBuf) -> Result<Self> {
Err(Error::UnsupportedPlatform)
}
pub fn stop(self) -> Result<ProfileResult> {
Err(Error::UnsupportedPlatform)
}
}
}
#[cfg(unix)]
pub use unix::ProfileSession;
#[cfg(windows)]
pub use windows_impl::ProfileSession;
#[cfg(not(any(unix, windows)))]
pub use unsupported::ProfileSession;
pub fn generate_output_path(base_dir: &std::path::Path, prefix: &str) -> PathBuf {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
base_dir.join(format!("{}_{}.svg", prefix, timestamp))
}