use std::cell::RefCell;
use std::collections::HashMap;
use std::io::Write;
use std::path::PathBuf;
use std::sync::{Arc, LazyLock, Mutex};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
static RUN_ID: LazyLock<String> =
LazyLock::new(|| format!("{}_{}", std::process::id(), timestamp_ms()));
#[derive(Debug, Clone)]
pub struct FunctionRecord {
pub name: String,
pub calls: u64,
pub total_ms: f64,
pub self_ms: f64,
}
struct StackEntry {
name: &'static str,
start: Instant,
children_ms: f64,
}
struct RawRecord {
name: &'static str,
elapsed_ms: f64,
children_ms: f64,
}
struct AutoFlushRecords {
records: Vec<RawRecord>,
}
impl AutoFlushRecords {
const fn new() -> Self {
Self {
records: Vec::new(),
}
}
}
impl Drop for AutoFlushRecords {
fn drop(&mut self) {
if cfg!(test) {
return;
}
let registered: Vec<&str> = REGISTERED
.try_with(|reg| reg.borrow().clone())
.unwrap_or_default();
if self.records.is_empty() && registered.is_empty() {
return;
}
let aggregated = aggregate(&self.records, ®istered);
if aggregated.is_empty() {
return;
}
let Some(dir) = runs_dir() else { return };
let path = dir.join(format!("{}.json", timestamp_ms()));
let _ = write_json(&aggregated, &path);
}
}
thread_local! {
static STACK: RefCell<Vec<StackEntry>> = const { RefCell::new(Vec::new()) };
static RECORDS: RefCell<AutoFlushRecords> = const { RefCell::new(AutoFlushRecords::new()) };
static REGISTERED: RefCell<Vec<&'static str>> = const { RefCell::new(Vec::new()) };
}
#[must_use = "dropping the guard immediately records ~0ms; bind it with `let _guard = ...`"]
pub struct Guard {
_private: (),
}
impl Drop for Guard {
fn drop(&mut self) {
STACK.with(|stack| {
let entry = stack.borrow_mut().pop();
let Some(entry) = entry else {
eprintln!("piano-runtime: guard dropped without matching stack entry (bug)");
return;
};
let elapsed_ms = entry.start.elapsed().as_secs_f64() * 1000.0;
let children_ms = entry.children_ms;
if let Some(parent) = stack.borrow_mut().last_mut() {
parent.children_ms += elapsed_ms;
}
RECORDS.with(|records| {
records.borrow_mut().records.push(RawRecord {
name: entry.name,
elapsed_ms,
children_ms,
});
});
});
}
}
pub fn enter(name: &'static str) -> Guard {
STACK.with(|stack| {
stack.borrow_mut().push(StackEntry {
name,
start: Instant::now(),
children_ms: 0.0,
});
});
Guard { _private: () }
}
pub fn register(name: &'static str) {
REGISTERED.with(|reg| {
let mut reg = reg.borrow_mut();
if !reg.contains(&name) {
reg.push(name);
}
});
}
fn aggregate(raw: &[RawRecord], registered: &[&str]) -> Vec<FunctionRecord> {
let mut map: HashMap<&str, (u64, f64, f64)> = HashMap::new();
for name in registered {
map.entry(name).or_insert((0, 0.0, 0.0));
}
for rec in raw {
let entry = map.entry(rec.name).or_insert((0, 0.0, 0.0));
entry.0 += 1;
entry.1 += rec.elapsed_ms;
entry.2 += (rec.elapsed_ms - rec.children_ms).max(0.0);
}
let mut result: Vec<FunctionRecord> = map
.into_iter()
.map(|(name, (calls, total_ms, self_ms))| FunctionRecord {
name: name.to_owned(),
calls,
total_ms,
self_ms,
})
.collect();
result.sort_by(|a, b| {
b.self_ms
.partial_cmp(&a.self_ms)
.unwrap_or(std::cmp::Ordering::Equal)
});
result
}
pub fn collect() -> Vec<FunctionRecord> {
RECORDS
.with(|records| REGISTERED.with(|reg| aggregate(&records.borrow().records, ®.borrow())))
}
pub fn reset() {
STACK.with(|stack| stack.borrow_mut().clear());
RECORDS.with(|records| records.borrow_mut().records.clear());
REGISTERED.with(|reg| reg.borrow_mut().clear());
}
fn timestamp_ms() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
}
fn runs_dir() -> Option<PathBuf> {
if let Ok(dir) = std::env::var("PIANO_RUNS_DIR") {
return Some(PathBuf::from(dir));
}
dirs_fallback().map(|home| home.join(".piano").join("runs"))
}
fn dirs_fallback() -> Option<PathBuf> {
std::env::var_os("HOME").map(PathBuf::from)
}
fn write_json(records: &[FunctionRecord], path: &std::path::Path) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut f = std::fs::File::create(path)?;
let ts = timestamp_ms();
let run_id = &*RUN_ID;
write!(
f,
"{{\"run_id\":\"{run_id}\",\"timestamp_ms\":{ts},\"functions\":["
)?;
for (i, rec) in records.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
let name = rec.name.replace('\\', "\\\\").replace('"', "\\\"");
write!(
f,
"{{\"name\":\"{name}\",\"calls\":{},\"total_ms\":{:.3},\"self_ms\":{:.3}}}",
rec.calls, rec.total_ms, rec.self_ms
)?;
}
writeln!(f, "]}}")?;
Ok(())
}
pub fn flush() {
let records = collect();
if records.is_empty() {
return;
}
let Some(dir) = runs_dir() else {
return;
};
let path = dir.join(format!("{}.json", timestamp_ms()));
let _ = write_json(&records, &path);
reset();
}
pub fn init() {}
#[must_use = "dropping SpanContext without calling finalize() loses child attribution"]
pub struct SpanContext {
parent_name: &'static str,
children_ms: Arc<Mutex<f64>>,
}
impl SpanContext {
pub fn finalize(self) {
let children = *self.children_ms.lock().unwrap_or_else(|e| e.into_inner());
STACK.with(|stack| {
if let Some(top) = stack.borrow_mut().last_mut() {
top.children_ms += children;
}
});
}
}
#[must_use = "dropping AdoptGuard immediately records ~0ms; bind it with `let _guard = ...`"]
pub struct AdoptGuard {
ctx_children_ms: Arc<Mutex<f64>>,
}
impl Drop for AdoptGuard {
fn drop(&mut self) {
STACK.with(|stack| {
let entry = stack.borrow_mut().pop();
let Some(entry) = entry else { return };
let elapsed_ms = entry.start.elapsed().as_secs_f64() * 1000.0;
let mut children = self
.ctx_children_ms
.lock()
.unwrap_or_else(|e| e.into_inner());
*children += elapsed_ms;
});
}
}
pub fn fork() -> Option<SpanContext> {
STACK.with(|stack| {
let stack = stack.borrow();
let top = stack.last()?;
Some(SpanContext {
parent_name: top.name,
children_ms: Arc::new(Mutex::new(0.0)),
})
})
}
pub fn adopt(ctx: &SpanContext) -> AdoptGuard {
STACK.with(|stack| {
stack.borrow_mut().push(StackEntry {
name: ctx.parent_name,
start: Instant::now(),
children_ms: 0.0,
});
});
AdoptGuard {
ctx_children_ms: Arc::clone(&ctx.children_ms),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
fn burn_cpu(iterations: u64) {
let mut buf = [0x42u8; 4096];
for i in 0..iterations {
for b in &mut buf {
*b = b.wrapping_add(i as u8).wrapping_mul(31);
}
}
std::hint::black_box(&buf);
}
#[test]
fn burn_cpu_takes_measurable_time() {
let start = std::time::Instant::now();
burn_cpu(50_000);
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() >= 1,
"burn_cpu(50_000) should take at least 1ms, took {:?}",
elapsed
);
}
#[test]
fn flush_writes_valid_json_to_env_dir() {
reset();
{
let _g = enter("flush_test");
thread::sleep(Duration::from_millis(5));
}
let tmp = std::env::temp_dir().join(format!("piano_test_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
unsafe { std::env::set_var("PIANO_RUNS_DIR", &tmp) };
flush();
unsafe { std::env::remove_var("PIANO_RUNS_DIR") };
let files: Vec<_> = std::fs::read_dir(&tmp)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "json"))
.collect();
assert!(!files.is_empty(), "expected at least one JSON file");
let content = std::fs::read_to_string(files[0].path()).unwrap();
assert!(
content.contains("\"flush_test\""),
"should contain function name"
);
assert!(
content.contains("\"timestamp_ms\""),
"should contain timestamp_ms"
);
assert!(content.contains("\"self_ms\""), "should contain self_ms");
let _ = std::fs::remove_dir_all(&tmp);
}
#[test]
fn write_json_produces_valid_format() {
let records = vec![
FunctionRecord {
name: "walk".into(),
calls: 3,
total_ms: 12.5,
self_ms: 8.3,
},
FunctionRecord {
name: "resolve".into(),
calls: 1,
total_ms: 4.2,
self_ms: 4.2,
},
];
let tmp = std::env::temp_dir().join(format!("piano_json_{}.json", std::process::id()));
write_json(&records, &tmp).unwrap();
let content = std::fs::read_to_string(&tmp).unwrap();
assert!(
content.starts_with("{\"run_id\":\""),
"should start with run_id"
);
assert!(
content.contains("\"timestamp_ms\":"),
"should contain timestamp_ms"
);
assert!(
content.contains("\"functions\":["),
"should have functions array"
);
assert!(content.contains("\"walk\""), "should contain walk");
assert!(content.contains("\"resolve\""), "should contain resolve");
assert!(content.contains("\"calls\":3"), "should have calls count");
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn init_can_be_called_multiple_times() {
init();
init();
init();
}
#[test]
fn single_function_timing() {
reset();
{
let _g = enter("work");
thread::sleep(Duration::from_millis(10));
}
let records = collect();
assert_eq!(records.len(), 1);
assert_eq!(records[0].name, "work");
assert_eq!(records[0].calls, 1);
assert!(
records[0].total_ms >= 5.0,
"total_ms={}",
records[0].total_ms
);
assert!(records[0].self_ms >= 5.0, "self_ms={}", records[0].self_ms);
}
#[test]
fn nested_function_self_time() {
reset();
{
let _outer = enter("outer");
thread::sleep(Duration::from_millis(10));
{
let _inner = enter("inner");
thread::sleep(Duration::from_millis(10));
}
}
let records = collect();
let outer = records
.iter()
.find(|r| r.name == "outer")
.expect("outer not found");
let inner = records
.iter()
.find(|r| r.name == "inner")
.expect("inner not found");
assert!(outer.total_ms >= 15.0, "outer.total_ms={}", outer.total_ms);
assert!(outer.self_ms >= 5.0, "outer.self_ms={}", outer.self_ms);
assert!(
outer.self_ms < outer.total_ms,
"self should be less than total"
);
let diff = (inner.self_ms - inner.total_ms).abs();
assert!(
diff < 2.0,
"inner self_ms={} total_ms={}",
inner.self_ms,
inner.total_ms
);
}
#[test]
fn call_count_tracking() {
reset();
for _ in 0..5 {
let _g = enter("repeated");
}
let records = collect();
assert_eq!(records.len(), 1);
assert_eq!(records[0].name, "repeated");
assert_eq!(records[0].calls, 5);
}
#[test]
fn reset_clears_state() {
reset();
{
let _g = enter("something");
thread::sleep(Duration::from_millis(1));
}
reset();
let records = collect();
assert!(
records.is_empty(),
"expected empty after reset, got {} records",
records.len()
);
}
#[test]
fn collect_sorts_by_self_time_descending() {
reset();
{
let _g = enter("fast");
thread::sleep(Duration::from_millis(1));
}
{
let _g = enter("slow");
thread::sleep(Duration::from_millis(15));
}
let records = collect();
assert_eq!(records.len(), 2);
assert_eq!(
records[0].name, "slow",
"expected slow first, got {:?}",
records[0].name
);
assert_eq!(
records[1].name, "fast",
"expected fast second, got {:?}",
records[1].name
);
}
#[test]
fn registered_but_uncalled_functions_appear_with_zero_calls() {
reset();
register("never_called");
{
let _g = enter("called_once");
thread::sleep(Duration::from_millis(1));
}
let records = collect();
assert_eq!(records.len(), 2, "should have both functions");
let never = records
.iter()
.find(|r| r.name == "never_called")
.expect("never_called");
assert_eq!(never.calls, 0);
assert!((never.total_ms).abs() < f64::EPSILON);
assert!((never.self_ms).abs() < f64::EPSILON);
let called = records
.iter()
.find(|r| r.name == "called_once")
.expect("called_once");
assert_eq!(called.calls, 1);
}
#[test]
fn json_output_contains_run_id() {
reset();
{
let _g = enter("rid_test");
thread::sleep(Duration::from_millis(1));
}
let tmp = std::env::temp_dir().join(format!("piano_rid_{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
unsafe { std::env::set_var("PIANO_RUNS_DIR", &tmp) };
flush();
unsafe { std::env::remove_var("PIANO_RUNS_DIR") };
let files: Vec<_> = std::fs::read_dir(&tmp)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "json"))
.collect();
assert!(!files.is_empty());
let content = std::fs::read_to_string(files[0].path()).unwrap();
assert!(
content.contains("\"run_id\":\""),
"should contain run_id field: {content}"
);
let _ = std::fs::remove_dir_all(&tmp);
}
#[test]
fn conservation_sequential_calls() {
reset();
{
let _main = enter("main_seq");
burn_cpu(10_000);
{
let _a = enter("a");
burn_cpu(30_000);
}
{
let _b = enter("b");
burn_cpu(20_000);
}
}
let records = collect();
let main_r = records.iter().find(|r| r.name == "main_seq").unwrap();
let a_r = records.iter().find(|r| r.name == "a").unwrap();
let b_r = records.iter().find(|r| r.name == "b").unwrap();
let sum_self = main_r.self_ms + a_r.self_ms + b_r.self_ms;
let root_total = main_r.total_ms;
let error_pct = ((sum_self - root_total) / root_total).abs() * 100.0;
assert!(
error_pct < 5.0,
"conservation violated: sum_self={sum_self:.3}ms, root_total={root_total:.3}ms, error={error_pct:.1}%"
);
}
#[test]
fn conservation_nested_calls() {
reset();
{
let _main = enter("main_nest");
burn_cpu(5_000);
{
let _a = enter("a_nest");
burn_cpu(5_000);
{
let _b = enter("b_nest");
burn_cpu(30_000);
}
}
}
let records = collect();
let main_r = records.iter().find(|r| r.name == "main_nest").unwrap();
let a_r = records.iter().find(|r| r.name == "a_nest").unwrap();
let b_r = records.iter().find(|r| r.name == "b_nest").unwrap();
let sum_self = main_r.self_ms + a_r.self_ms + b_r.self_ms;
let root_total = main_r.total_ms;
let error_pct = ((sum_self - root_total) / root_total).abs() * 100.0;
assert!(
error_pct < 5.0,
"conservation violated: sum_self={sum_self:.3}ms, root_total={root_total:.3}ms, error={error_pct:.1}%"
);
let b_diff = (b_r.self_ms - b_r.total_ms).abs();
assert!(
b_diff < 0.1,
"leaf self_ms should equal total_ms: self={:.3}, total={:.3}",
b_r.self_ms,
b_r.total_ms
);
}
#[test]
fn conservation_mixed_topology() {
reset();
{
let _main = enter("main_mix");
burn_cpu(5_000);
{
let _a = enter("a_mix");
burn_cpu(10_000);
{
let _b = enter("b_mix");
burn_cpu(20_000);
}
}
{
let _c = enter("c_mix");
burn_cpu(15_000);
}
}
let records = collect();
let main_r = records.iter().find(|r| r.name == "main_mix").unwrap();
let sum_self: f64 = records.iter().map(|r| r.self_ms).sum();
let root_total = main_r.total_ms;
let error_pct = ((sum_self - root_total) / root_total).abs() * 100.0;
assert!(
error_pct < 5.0,
"conservation violated: sum_self={sum_self:.3}ms, root_total={root_total:.3}ms, error={error_pct:.1}%"
);
}
#[test]
fn conservation_repeated_calls() {
reset();
{
let _main = enter("main_rep");
burn_cpu(5_000);
for _ in 0..10 {
let _a = enter("a_rep");
burn_cpu(5_000);
}
}
let records = collect();
let main_r = records.iter().find(|r| r.name == "main_rep").unwrap();
let a_r = records.iter().find(|r| r.name == "a_rep").unwrap();
assert_eq!(a_r.calls, 10);
let sum_self = main_r.self_ms + a_r.self_ms;
let root_total = main_r.total_ms;
let error_pct = ((sum_self - root_total) / root_total).abs() * 100.0;
assert!(
error_pct < 5.0,
"conservation violated: sum_self={sum_self:.3}ms, root_total={root_total:.3}ms, error={error_pct:.1}%"
);
}
#[test]
fn negative_self_time_clamped_to_zero() {
let raw = vec![RawRecord {
name: "drifted",
elapsed_ms: 10.0,
children_ms: 10.001,
}];
let result = aggregate(&raw, &[]);
assert_eq!(result.len(), 1);
assert_eq!(
result[0].self_ms, 0.0,
"negative self-time should be clamped to zero"
);
}
#[test]
fn guard_overhead_under_1us() {
reset();
let iterations = 1_000_000u64;
let start = std::time::Instant::now();
for _ in 0..iterations {
let _g = enter("overhead");
}
let elapsed = start.elapsed();
let per_call_ns = elapsed.as_nanos() as f64 / iterations as f64;
eprintln!("guard overhead: {per_call_ns:.1}ns per call ({iterations} iterations)");
assert!(
per_call_ns < 1000.0,
"per-call overhead {per_call_ns:.1}ns exceeds 1us limit"
);
reset();
}
#[test]
fn deep_nesting_100_levels() {
reset();
let names: Vec<&'static str> = (0..100)
.map(|i| -> &'static str { Box::leak(format!("level_{i}").into_boxed_str()) })
.collect();
let mut guards = Vec::with_capacity(100);
for name in &names {
guards.push(enter(name));
burn_cpu(1_000);
}
while let Some(g) = guards.pop() {
drop(g);
}
let records = collect();
assert_eq!(records.len(), 100, "expected 100 functions");
let root = records.iter().find(|r| r.name == "level_0").unwrap();
let sum_self: f64 = records.iter().map(|r| r.self_ms).sum();
let error_pct = ((sum_self - root.total_ms) / root.total_ms).abs() * 100.0;
assert!(
error_pct < 5.0,
"conservation violated at 100 levels: sum_self={sum_self:.3}ms, root_total={:.3}ms, error={error_pct:.1}%",
root.total_ms
);
for rec in &records {
assert!(
rec.self_ms >= 0.0,
"{} has negative self_ms: {}",
rec.name,
rec.self_ms
);
}
let innermost = records.iter().find(|r| r.name == "level_99").unwrap();
let diff = (innermost.self_ms - innermost.total_ms).abs();
assert!(
diff < 0.5,
"innermost level should have self ≈ total: self={:.3}, total={:.3}",
innermost.self_ms,
innermost.total_ms
);
reset();
}
#[test]
fn fork_returns_none_with_empty_stack() {
reset();
assert!(fork().is_none(), "fork should return None with empty stack");
}
#[test]
fn fork_adopt_propagates_child_time_to_parent() {
reset();
{
let _parent = enter("parent_fn");
burn_cpu(5_000);
let ctx = fork().expect("should have parent on stack");
{
let _adopt = adopt(&ctx);
{
let _child = enter("child_fn");
burn_cpu(20_000);
}
}
ctx.finalize();
}
let records = collect();
let parent = records.iter().find(|r| r.name == "parent_fn").unwrap();
let child = records.iter().find(|r| r.name == "child_fn").unwrap();
assert!(
parent.total_ms > child.total_ms,
"parent total ({:.1}ms) should exceed child total ({:.1}ms)",
parent.total_ms,
child.total_ms
);
let sum_self: f64 = records.iter().map(|r| r.self_ms).sum();
let error_pct = ((sum_self - parent.total_ms) / parent.total_ms).abs() * 100.0;
assert!(
error_pct < 10.0,
"conservation: sum_self={sum_self:.1}ms, root_total={:.1}ms, error={error_pct:.1}%",
parent.total_ms
);
}
#[test]
fn adopt_without_child_work_adds_minimal_overhead() {
reset();
{
let _parent = enter("overhead_parent");
let ctx = fork().unwrap();
{
let _adopt = adopt(&ctx);
}
ctx.finalize();
}
let records = collect();
let parent = records
.iter()
.find(|r| r.name == "overhead_parent")
.unwrap();
assert!(parent.calls == 1);
assert!(parent.total_ms >= 0.0);
}
#[test]
fn multiple_children_accumulate_in_parent() {
reset();
{
let _parent = enter("multi_parent");
burn_cpu(5_000);
let ctx = fork().unwrap();
for _ in 0..3 {
let _adopt = adopt(&ctx);
{
let _child = enter("worker");
burn_cpu(10_000);
}
}
ctx.finalize();
}
let records = collect();
let parent = records.iter().find(|r| r.name == "multi_parent").unwrap();
let worker = records.iter().find(|r| r.name == "worker").unwrap();
assert_eq!(worker.calls, 3, "should have 3 worker calls");
assert!(
parent.total_ms > worker.total_ms,
"parent total ({:.1}ms) should exceed single worker total ({:.1}ms)",
parent.total_ms,
worker.total_ms
);
}
#[test]
fn cross_thread_fork_adopt_propagates() {
reset();
{
let _parent = enter("baseline");
burn_cpu(5_000);
thread::sleep(Duration::from_millis(50));
}
let baseline_records = collect();
let baseline = baseline_records
.iter()
.find(|r| r.name == "baseline")
.unwrap();
let baseline_self = baseline.self_ms;
reset();
{
let _parent = enter("parent_fn");
burn_cpu(5_000);
let ctx = fork().expect("should have parent on stack");
thread::scope(|s| {
s.spawn(|| {
let _adopt = adopt(&ctx);
{
let _child = enter("thread_child");
thread::sleep(Duration::from_millis(50));
}
});
});
ctx.finalize();
}
let records = collect();
let parent = records.iter().find(|r| r.name == "parent_fn").unwrap();
assert!(
parent.self_ms < baseline_self,
"parent self ({:.1}ms) should be less than baseline self ({:.1}ms) \
due to cross-thread attribution",
parent.self_ms,
baseline_self
);
}
}