use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::{Mutex, OnceLock};
const TRACE_OUT_ENV: &str = "FERRUM_TRACE_OUT";
static GLOBAL_TRACE: OnceLock<TraceWriter> = OnceLock::new();
pub fn global_trace() -> &'static TraceWriter {
GLOBAL_TRACE.get_or_init(TraceWriter::from_env)
}
pub fn flush_global_trace() {
if let Some(w) = GLOBAL_TRACE.get() {
let _ = w.flush();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceEvent {
pub name: String,
pub cat: String,
pub ph: char, pub ts: u64,
pub dur: u64,
pub pid: u32,
pub tid: u32,
#[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
pub args: serde_json::Map<String, serde_json::Value>,
}
impl TraceEvent {
pub fn complete(
name: impl Into<String>,
cat: impl Into<String>,
start_ts_us: u64,
dur_ms: f64,
tid: u32,
) -> Self {
Self {
name: name.into(),
cat: cat.into(),
ph: 'X',
ts: start_ts_us,
dur: (dur_ms * 1000.0).round() as u64,
pid: 0,
tid,
args: serde_json::Map::new(),
}
}
}
pub struct TraceWriter {
inner: Mutex<TraceWriterInner>,
}
enum TraceWriterInner {
Disabled,
Buffering {
out_path: PathBuf,
events: Vec<TraceEvent>,
epoch: std::time::Instant,
},
}
impl TraceWriter {
pub fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
pub fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
let out_path = vars.into_iter().find_map(|(name, value)| {
(name.into() == TRACE_OUT_ENV)
.then(|| value.into())
.filter(|value: &String| !value.is_empty())
});
out_path
.map(|path| Self::enabled(PathBuf::from(path)))
.unwrap_or_else(Self::disabled)
}
pub fn enabled(out_path: PathBuf) -> Self {
Self {
inner: Mutex::new(TraceWriterInner::Buffering {
out_path,
events: Vec::with_capacity(1024),
epoch: std::time::Instant::now(),
}),
}
}
pub fn disabled() -> Self {
Self {
inner: Mutex::new(TraceWriterInner::Disabled),
}
}
pub fn is_enabled(&self) -> bool {
matches!(
*self.inner.lock().unwrap(),
TraceWriterInner::Buffering { .. }
)
}
pub fn push(&self, name: impl Into<String>, cat: impl Into<String>, dur_ms: f64, tid: u32) {
let mut inner = self.inner.lock().unwrap();
if let TraceWriterInner::Buffering { events, epoch, .. } = &mut *inner {
let now = std::time::Instant::now();
let ts_us = now.duration_since(*epoch).as_micros() as u64;
let start_us = ts_us.saturating_sub((dur_ms * 1000.0) as u64);
events.push(TraceEvent::complete(name, cat, start_us, dur_ms, tid));
}
}
pub fn push_with_args(
&self,
name: impl Into<String>,
cat: impl Into<String>,
dur_ms: f64,
tid: u32,
args: serde_json::Map<String, serde_json::Value>,
) {
let mut inner = self.inner.lock().unwrap();
if let TraceWriterInner::Buffering { events, epoch, .. } = &mut *inner {
let now = std::time::Instant::now();
let ts_us = now.duration_since(*epoch).as_micros() as u64;
let start_us = ts_us.saturating_sub((dur_ms * 1000.0) as u64);
let mut e = TraceEvent::complete(name, cat, start_us, dur_ms, tid);
e.args = args;
events.push(e);
}
}
pub fn flush(&self) -> std::io::Result<()> {
let mut inner = self.inner.lock().unwrap();
if let TraceWriterInner::Buffering {
out_path, events, ..
} = &mut *inner
{
let json = serde_json::to_string(&events).expect("serialize trace");
std::fs::write(out_path, json)?;
events.clear();
}
Ok(())
}
}
impl Drop for TraceWriter {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn complete_event_round_trip() {
let e = TraceEvent::complete("rms_norm", "norm", 1_000_000, 0.123, 1);
assert_eq!(e.ph, 'X');
assert_eq!(e.dur, 123); let j = serde_json::to_string(&e).unwrap();
let back: TraceEvent = serde_json::from_str(&j).unwrap();
assert_eq!(back.name, "rms_norm");
assert_eq!(back.dur, 123);
}
#[test]
fn disabled_writer_is_noop() {
let w = TraceWriter::disabled();
w.push("rms_norm", "norm", 1.0, 0);
assert!(!w.is_enabled());
w.flush().unwrap(); }
#[test]
fn trace_writer_parses_env_snapshot() {
let disabled = TraceWriter::from_env_vars([(TRACE_OUT_ENV, ""), ("OTHER", "1")]);
assert!(!disabled.is_enabled());
let enabled = TraceWriter::from_env_vars([(TRACE_OUT_ENV, "/tmp/ferrum-trace.json")]);
assert!(enabled.is_enabled());
}
#[test]
fn enabled_writer_flushes_to_file() {
let dir = tempdir();
let path = dir.join("trace.json");
let w = TraceWriter::enabled(path.clone());
w.push("rms_norm", "norm", 1.0, 1);
w.push("rope", "attn", 0.5, 1);
w.flush().unwrap();
let s = std::fs::read_to_string(&path).unwrap();
let events: Vec<TraceEvent> = serde_json::from_str(&s).unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].name, "rms_norm");
assert_eq!(events[1].cat, "attn");
let _ = std::fs::remove_dir_all(&dir);
}
fn tempdir() -> std::path::PathBuf {
let d = std::env::temp_dir().join(format!(
"ferrum-trace-test-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
std::fs::create_dir_all(&d).unwrap();
d
}
}