1use serde::{Deserialize, Serialize};
32use std::path::PathBuf;
33use std::sync::{Mutex, OnceLock};
34
35const TRACE_OUT_ENV: &str = "FERRUM_TRACE_OUT";
36
37static GLOBAL_TRACE: OnceLock<TraceWriter> = OnceLock::new();
46
47pub fn global_trace() -> &'static TraceWriter {
50 GLOBAL_TRACE.get_or_init(TraceWriter::from_env)
51}
52
53pub fn flush_global_trace() {
56 if let Some(w) = GLOBAL_TRACE.get() {
57 let _ = w.flush();
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TraceEvent {
64 pub name: String,
65 pub cat: String,
66 pub ph: char, pub ts: u64,
69 pub dur: u64,
71 pub pid: u32,
72 pub tid: u32,
73 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
75 pub args: serde_json::Map<String, serde_json::Value>,
76}
77
78impl TraceEvent {
79 pub fn complete(
83 name: impl Into<String>,
84 cat: impl Into<String>,
85 start_ts_us: u64,
86 dur_ms: f64,
87 tid: u32,
88 ) -> Self {
89 Self {
90 name: name.into(),
91 cat: cat.into(),
92 ph: 'X',
93 ts: start_ts_us,
94 dur: (dur_ms * 1000.0).round() as u64,
95 pid: 0,
96 tid,
97 args: serde_json::Map::new(),
98 }
99 }
100}
101
102pub struct TraceWriter {
110 inner: Mutex<TraceWriterInner>,
111}
112
113enum TraceWriterInner {
114 Disabled,
115 Buffering {
116 out_path: PathBuf,
117 events: Vec<TraceEvent>,
118 epoch: std::time::Instant,
119 },
120}
121
122impl TraceWriter {
123 pub fn from_env() -> Self {
126 Self::from_env_vars(std::env::vars())
127 }
128
129 pub fn from_env_vars<I, K, V>(vars: I) -> Self
130 where
131 I: IntoIterator<Item = (K, V)>,
132 K: Into<String>,
133 V: Into<String>,
134 {
135 let out_path = vars.into_iter().find_map(|(name, value)| {
136 (name.into() == TRACE_OUT_ENV)
137 .then(|| value.into())
138 .filter(|value: &String| !value.is_empty())
139 });
140 out_path
141 .map(|path| Self::enabled(PathBuf::from(path)))
142 .unwrap_or_else(Self::disabled)
143 }
144
145 pub fn enabled(out_path: PathBuf) -> Self {
146 Self {
147 inner: Mutex::new(TraceWriterInner::Buffering {
148 out_path,
149 events: Vec::with_capacity(1024),
150 epoch: std::time::Instant::now(),
151 }),
152 }
153 }
154
155 pub fn disabled() -> Self {
156 Self {
157 inner: Mutex::new(TraceWriterInner::Disabled),
158 }
159 }
160
161 pub fn is_enabled(&self) -> bool {
164 matches!(
165 *self.inner.lock().unwrap(),
166 TraceWriterInner::Buffering { .. }
167 )
168 }
169
170 pub fn push(&self, name: impl Into<String>, cat: impl Into<String>, dur_ms: f64, tid: u32) {
173 let mut inner = self.inner.lock().unwrap();
174 if let TraceWriterInner::Buffering { events, epoch, .. } = &mut *inner {
175 let now = std::time::Instant::now();
176 let ts_us = now.duration_since(*epoch).as_micros() as u64;
177 let start_us = ts_us.saturating_sub((dur_ms * 1000.0) as u64);
180 events.push(TraceEvent::complete(name, cat, start_us, dur_ms, tid));
181 }
182 }
183
184 pub fn push_with_args(
186 &self,
187 name: impl Into<String>,
188 cat: impl Into<String>,
189 dur_ms: f64,
190 tid: u32,
191 args: serde_json::Map<String, serde_json::Value>,
192 ) {
193 let mut inner = self.inner.lock().unwrap();
194 if let TraceWriterInner::Buffering { events, epoch, .. } = &mut *inner {
195 let now = std::time::Instant::now();
196 let ts_us = now.duration_since(*epoch).as_micros() as u64;
197 let start_us = ts_us.saturating_sub((dur_ms * 1000.0) as u64);
198 let mut e = TraceEvent::complete(name, cat, start_us, dur_ms, tid);
199 e.args = args;
200 events.push(e);
201 }
202 }
203
204 pub fn flush(&self) -> std::io::Result<()> {
208 let mut inner = self.inner.lock().unwrap();
209 if let TraceWriterInner::Buffering {
210 out_path, events, ..
211 } = &mut *inner
212 {
213 let json = serde_json::to_string(&events).expect("serialize trace");
214 std::fs::write(out_path, json)?;
215 events.clear();
216 }
217 Ok(())
218 }
219}
220
221impl Drop for TraceWriter {
222 fn drop(&mut self) {
223 let _ = self.flush();
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn complete_event_round_trip() {
233 let e = TraceEvent::complete("rms_norm", "norm", 1_000_000, 0.123, 1);
234 assert_eq!(e.ph, 'X');
235 assert_eq!(e.dur, 123); let j = serde_json::to_string(&e).unwrap();
237 let back: TraceEvent = serde_json::from_str(&j).unwrap();
238 assert_eq!(back.name, "rms_norm");
239 assert_eq!(back.dur, 123);
240 }
241
242 #[test]
243 fn disabled_writer_is_noop() {
244 let w = TraceWriter::disabled();
245 w.push("rms_norm", "norm", 1.0, 0);
246 assert!(!w.is_enabled());
247 w.flush().unwrap(); }
249
250 #[test]
251 fn trace_writer_parses_env_snapshot() {
252 let disabled = TraceWriter::from_env_vars([(TRACE_OUT_ENV, ""), ("OTHER", "1")]);
253 assert!(!disabled.is_enabled());
254
255 let enabled = TraceWriter::from_env_vars([(TRACE_OUT_ENV, "/tmp/ferrum-trace.json")]);
256 assert!(enabled.is_enabled());
257 }
258
259 #[test]
260 fn enabled_writer_flushes_to_file() {
261 let dir = tempdir();
262 let path = dir.join("trace.json");
263 let w = TraceWriter::enabled(path.clone());
264 w.push("rms_norm", "norm", 1.0, 1);
265 w.push("rope", "attn", 0.5, 1);
266 w.flush().unwrap();
267 let s = std::fs::read_to_string(&path).unwrap();
268 let events: Vec<TraceEvent> = serde_json::from_str(&s).unwrap();
269 assert_eq!(events.len(), 2);
270 assert_eq!(events[0].name, "rms_norm");
271 assert_eq!(events[1].cat, "attn");
272 let _ = std::fs::remove_dir_all(&dir);
273 }
274
275 fn tempdir() -> std::path::PathBuf {
276 let d = std::env::temp_dir().join(format!(
277 "ferrum-trace-test-{}",
278 std::time::SystemTime::now()
279 .duration_since(std::time::UNIX_EPOCH)
280 .unwrap()
281 .as_nanos()
282 ));
283 std::fs::create_dir_all(&d).unwrap();
284 d
285 }
286}