use std::cell::RefCell;
use std::sync::Arc;
use std::time::Instant;
pub trait OpProfiler: Send + Sync {
fn record_op(&self, name: &str, category: &str, shapes: &[&[usize]], duration_us: u64);
}
thread_local! {
static CURRENT: RefCell<Option<Arc<dyn OpProfiler>>> = const { RefCell::new(None) };
}
pub fn set_current(profiler: Option<Arc<dyn OpProfiler>>) {
CURRENT.with(|c| *c.borrow_mut() = profiler);
}
pub fn current() -> Option<Arc<dyn OpProfiler>> {
CURRENT.with(|c| c.borrow().clone())
}
pub fn profile_op_scope<F, R>(
name: &str,
category: &str,
shapes: &[&[usize]],
f: F,
) -> R
where
F: FnOnce() -> R,
{
let p = current();
if let Some(profiler) = p {
let start = Instant::now();
let result = f();
let elapsed_us = start.elapsed().as_micros() as u64;
profiler.record_op(name, category, shapes, elapsed_us);
result
} else {
f()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[derive(Default)]
struct TestProfiler {
events: Mutex<Vec<(String, String, Vec<Vec<usize>>, u64)>>,
}
impl OpProfiler for TestProfiler {
fn record_op(&self, name: &str, category: &str, shapes: &[&[usize]], duration_us: u64) {
let shapes_owned: Vec<Vec<usize>> = shapes.iter().map(|s| s.to_vec()).collect();
self.events.lock().unwrap().push((
name.to_string(),
category.to_string(),
shapes_owned,
duration_us,
));
}
}
#[test]
fn test_no_profiler_active_by_default() {
std::thread::spawn(|| {
assert!(current().is_none());
})
.join()
.unwrap();
}
#[test]
fn test_profile_op_scope_no_profiler_runs_closure() {
std::thread::spawn(|| {
let result = profile_op_scope("test_op", "test", &[], || 42);
assert_eq!(result, 42);
assert!(current().is_none());
})
.join()
.unwrap();
}
#[test]
fn test_profile_op_scope_records_when_active() {
std::thread::spawn(|| {
let p = Arc::new(TestProfiler::default());
set_current(Some(p.clone() as Arc<dyn OpProfiler>));
let result = profile_op_scope("matmul", "linalg", &[&[2, 3], &[3, 4]], || {
std::thread::sleep(std::time::Duration::from_micros(1));
"ok"
});
assert_eq!(result, "ok");
set_current(None);
let events = p.events.lock().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].0, "matmul");
assert_eq!(events[0].1, "linalg");
assert_eq!(events[0].2, vec![vec![2, 3], vec![3, 4]]);
})
.join()
.unwrap();
}
#[test]
fn test_set_current_can_be_cleared() {
std::thread::spawn(|| {
let p = Arc::new(TestProfiler::default());
set_current(Some(p as Arc<dyn OpProfiler>));
assert!(current().is_some());
set_current(None);
assert!(current().is_none());
})
.join()
.unwrap();
}
#[test]
fn test_nested_profile_op_scope_records_inner_op() {
std::thread::spawn(|| {
let p = Arc::new(TestProfiler::default());
set_current(Some(p.clone() as Arc<dyn OpProfiler>));
profile_op_scope("outer", "test", &[&[2, 2]], || {
profile_op_scope("inner", "test", &[&[2, 2]], || {});
});
set_current(None);
let events = p.events.lock().unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].0, "inner");
assert_eq!(events[1].0, "outer");
})
.join()
.unwrap();
}
#[test]
fn test_thread_local_isolation() {
let h = std::thread::spawn(|| {
let p = Arc::new(TestProfiler::default());
set_current(Some(p as Arc<dyn OpProfiler>));
assert!(current().is_some());
std::thread::spawn(|| {
assert!(current().is_none());
})
.join()
.unwrap();
set_current(None);
});
h.join().unwrap();
}
}