use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct ProfileEntry {
pub function: String,
pub total_time_ns: u64,
pub call_count: u64,
pub self_time_ns: u64,
}
impl ProfileEntry {
pub fn avg_time_ns(&self) -> f64 {
if self.call_count == 0 {
0.0
} else {
self.total_time_ns as f64 / self.call_count as f64
}
}
}
#[derive(Debug, Default, Clone)]
struct FunctionStats {
total_ns: u64,
call_count: u64,
child_ns: u64,
}
#[derive(Debug)]
pub struct HierarchicalProfiler {
stats: Mutex<HashMap<String, FunctionStats>>,
}
impl HierarchicalProfiler {
pub fn new() -> Self {
Self {
stats: Mutex::new(HashMap::new()),
}
}
pub fn record(&self, name: &str, total_ns: u64, child_ns: u64) {
if let Ok(mut guard) = self.stats.lock() {
let entry = guard.entry(name.to_string()).or_default();
entry.total_ns += total_ns;
entry.call_count += 1;
entry.child_ns += child_ns;
}
}
fn snapshot(&self) -> HashMap<String, FunctionStats> {
self.stats.lock().map(|g| g.clone()).unwrap_or_default()
}
pub fn reset(&self) {
if let Ok(mut guard) = self.stats.lock() {
guard.clear();
}
}
}
impl Default for HierarchicalProfiler {
fn default() -> Self {
Self::new()
}
}
thread_local! {
static CALL_STACK: RefCell<Vec<(String, Instant, u64)>> = RefCell::new(Vec::new());
}
use std::cell::RefCell;
#[inline]
pub fn profile_fn<F, T>(profiler: Arc<HierarchicalProfiler>, name: &str, f: F) -> T
where
F: FnOnce() -> T,
{
CALL_STACK.with(|stack| {
stack
.borrow_mut()
.push((name.to_string(), Instant::now(), 0u64));
});
let result = f();
CALL_STACK.with(|stack| {
let mut borrow = stack.borrow_mut();
if let Some((frame_name, start, child_ns)) = borrow.pop() {
let total_ns = start.elapsed().as_nanos() as u64;
if let Some(parent) = borrow.last_mut() {
parent.2 += total_ns;
}
profiler.record(&frame_name, total_ns, child_ns);
}
});
result
}
pub fn report_profile(profiler: &HierarchicalProfiler) -> Vec<ProfileEntry> {
let snap = profiler.snapshot();
let mut entries: Vec<ProfileEntry> = snap
.into_iter()
.map(|(name, stats)| {
let self_ns = stats.total_ns.saturating_sub(stats.child_ns);
ProfileEntry {
function: name,
total_time_ns: stats.total_ns,
call_count: stats.call_count,
self_time_ns: self_ns,
}
})
.collect();
entries.sort_by(|a, b| b.total_time_ns.cmp(&a.total_time_ns));
entries
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
AddSub,
Mul,
Div,
Sqrt,
Exp,
Log,
Fma,
Custom(u32),
}
pub struct FlopsEstimator;
impl FlopsEstimator {
pub fn matmul(m: usize, k: usize, n: usize) -> f64 {
2.0 * m as f64 * k as f64 * n as f64
}
pub fn elementwise(n: usize, op: UnaryOp) -> f64 {
let flops_per_elem: f64 = match op {
UnaryOp::AddSub | UnaryOp::Mul | UnaryOp::Div | UnaryOp::Sqrt | UnaryOp::Exp
| UnaryOp::Log => 1.0,
UnaryOp::Fma => 2.0,
UnaryOp::Custom(c) => c as f64,
};
n as f64 * flops_per_elem
}
pub fn dot_product(n: usize) -> f64 {
2.0 * n as f64
}
pub fn gemv(m: usize, n: usize) -> f64 {
2.0 * m as f64 * n as f64
}
pub fn fft(n: usize) -> f64 {
if n == 0 {
return 0.0;
}
5.0 * n as f64 * (n as f64).log2()
}
pub fn batch_norm(n: usize) -> f64 {
5.0 * n as f64
}
pub fn gflops(flops: f64, elapsed_ns: u64) -> f64 {
if elapsed_ns == 0 {
return 0.0;
}
flops / (elapsed_ns as f64 * 1e-9) / 1e9
}
}
pub fn throughput_benchmark<F>(f: F, n_bytes: usize, n_iter: usize) -> f64
where
F: Fn(),
{
if n_iter == 0 {
return 0.0;
}
f();
let start = Instant::now();
for _ in 0..n_iter {
f();
}
let elapsed_s = start.elapsed().as_secs_f64();
if elapsed_s <= 0.0 {
return 0.0;
}
let total_bytes = n_bytes as f64 * n_iter as f64;
total_bytes / elapsed_s / 1e9
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn test_profile_fn_records() {
let p = Arc::new(HierarchicalProfiler::new());
let _sum = profile_fn(Arc::clone(&p), "my_fn", || {
std::thread::sleep(Duration::from_millis(1));
42u64
});
let entries = report_profile(&p);
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].function, "my_fn");
assert_eq!(entries[0].call_count, 1);
assert!(entries[0].total_time_ns > 0);
}
#[test]
fn test_profile_fn_call_count() {
let p = Arc::new(HierarchicalProfiler::new());
for _ in 0..5 {
profile_fn(Arc::clone(&p), "repeated", || ());
}
let entries = report_profile(&p);
assert_eq!(entries[0].call_count, 5);
}
#[test]
fn test_report_sorted_by_total() {
let p = Arc::new(HierarchicalProfiler::new());
profile_fn(Arc::clone(&p), "fast", || ());
profile_fn(Arc::clone(&p), "slow", || {
std::thread::sleep(Duration::from_millis(5));
});
let entries = report_profile(&p);
assert!(entries[0].total_time_ns >= entries[1].total_time_ns);
}
#[test]
fn test_nested_self_time() {
let p = Arc::new(HierarchicalProfiler::new());
profile_fn(Arc::clone(&p), "outer", || {
profile_fn(Arc::clone(&p), "inner", || {
std::thread::sleep(Duration::from_millis(5));
});
});
let entries = report_profile(&p);
let outer = entries.iter().find(|e| e.function == "outer");
let inner = entries.iter().find(|e| e.function == "inner");
assert!(outer.is_some() && inner.is_some());
let o = outer.expect("outer missing");
let i = inner.expect("inner missing");
assert!(o.total_time_ns >= i.total_time_ns);
}
#[test]
fn test_flops_matmul() {
let f = FlopsEstimator::matmul(4, 4, 4);
assert_eq!(f as u64, 128); }
#[test]
fn test_flops_elementwise_fma() {
let f = FlopsEstimator::elementwise(100, UnaryOp::Fma);
assert_eq!(f as u64, 200);
}
#[test]
fn test_flops_dot_product() {
let f = FlopsEstimator::dot_product(512);
assert_eq!(f as u64, 1024);
}
#[test]
fn test_flops_fft_zero() {
assert_eq!(FlopsEstimator::fft(0) as u64, 0);
}
#[test]
fn test_flops_fft_positive() {
let f = FlopsEstimator::fft(1024);
assert!(f > 0.0);
}
#[test]
fn test_throughput_benchmark_positive() {
let gbps = throughput_benchmark(|| { let _: Vec<u8> = vec![0u8; 4096]; }, 4096, 50);
assert!(gbps >= 0.0);
}
#[test]
fn test_throughput_benchmark_zero_iter() {
let gbps = throughput_benchmark(|| {}, 1024, 0);
assert_eq!(gbps, 0.0);
}
#[test]
fn test_profiler_reset() {
let p = Arc::new(HierarchicalProfiler::new());
profile_fn(Arc::clone(&p), "to_reset", || ());
assert!(!report_profile(&p).is_empty());
p.reset();
assert!(report_profile(&p).is_empty());
}
#[test]
fn test_profile_entry_avg() {
let p = Arc::new(HierarchicalProfiler::new());
for _ in 0..4 {
profile_fn(Arc::clone(&p), "avg_test", || {
std::thread::sleep(Duration::from_millis(1));
});
}
let entries = report_profile(&p);
let entry = &entries[0];
assert_eq!(entry.call_count, 4);
let avg = entry.avg_time_ns();
assert!(avg > 0.0);
assert!((avg * 4.0 - entry.total_time_ns as f64).abs() < 1.0);
}
#[test]
fn test_gflops_estimate() {
let g = FlopsEstimator::gflops(1e9, 1_000_000_000);
assert!((g - 1.0).abs() < 1e-6);
}
}