use std::io::Write;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::ProfilerReport;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TracyZone {
pub name: String,
pub timestamp_ns: u64,
pub duration_ns: u64,
pub thread_id: u32,
}
impl TracyZone {
pub fn end_timestamp_ns(&self) -> u64 {
self.timestamp_ns.saturating_add(self.duration_ns)
}
}
#[derive(Debug, Default)]
pub struct TracyTrace {
zones: Vec<TracyZone>,
messages: Vec<(String, u64)>,
plots: Vec<(String, f64, u64)>,
}
impl TracyTrace {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.zones.is_empty() && self.messages.is_empty() && self.plots.is_empty()
}
pub fn total_records(&self) -> usize {
self.zones.len() + self.messages.len() + self.plots.len()
}
pub fn add_zone(&mut self, zone: TracyZone) {
self.zones.push(zone);
}
pub fn add_message(&mut self, msg: &str, timestamp_ns: u64) {
self.messages.push((msg.to_string(), timestamp_ns));
}
pub fn add_plot(&mut self, name: &str, value: f64, timestamp_ns: u64) {
self.plots.push((name.to_string(), value, timestamp_ns));
}
pub fn zones(&self) -> &[TracyZone] {
&self.zones
}
pub fn messages(&self) -> &[(String, u64)] {
&self.messages
}
pub fn plots(&self) -> &[(String, f64, u64)] {
&self.plots
}
pub fn export_to_file(&self, path: &std::path::Path) -> Result<()> {
let mut file = std::fs::File::create(path)?;
writeln!(file, "# TracyTrace export — generated by trustformers-debug")?;
for zone in &self.zones {
writeln!(
file,
"ZoneBegin,{},{},0,{}",
zone.name, zone.name, zone.timestamp_ns
)?;
writeln!(file, "ZoneEnd,{}", zone.end_timestamp_ns())?;
}
for (msg, ts) in &self.messages {
let safe_msg = msg.replace(',', "\\,");
writeln!(file, "Message,{},{}", safe_msg, ts)?;
}
for (name, value, ts) in &self.plots {
writeln!(file, "Plot,{},{},{}", name, value, ts)?;
}
tracing::debug!("Tracy trace written to {}", path.display());
Ok(())
}
}
pub struct TracyExporter;
impl TracyExporter {
pub fn export_profiler_report(
report: &ProfilerReport,
path: &std::path::Path,
) -> Result<()> {
let mut trace = TracyTrace::new();
let mut cursor_ns: u64 = 0;
for (layer_name, duration) in &report.slowest_layers {
let dur_ns = duration.as_nanos() as u64;
trace.add_zone(TracyZone {
name: layer_name.clone(),
timestamp_ns: cursor_ns,
duration_ns: dur_ns,
thread_id: 0,
});
cursor_ns += dur_ns;
}
for (idx, rec) in report.recommendations.iter().enumerate() {
trace.add_message(rec.as_str(), idx as u64 * 1_000);
}
for (name, stats) in &report.statistics {
let avg_us = stats.avg_duration.as_micros() as f64;
trace.add_plot(name, avg_us, cursor_ns);
}
trace.export_to_file(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::time::Duration;
#[test]
fn test_empty_trace() {
let trace = TracyTrace::new();
assert!(trace.is_empty());
assert_eq!(trace.total_records(), 0);
}
#[test]
fn test_add_zone() {
let mut trace = TracyTrace::new();
trace.add_zone(TracyZone {
name: "attention".to_string(),
timestamp_ns: 1000,
duration_ns: 500,
thread_id: 0,
});
assert_eq!(trace.zones().len(), 1);
assert_eq!(trace.zones()[0].end_timestamp_ns(), 1500);
}
#[test]
fn test_add_message_and_plot() {
let mut trace = TracyTrace::new();
trace.add_message("hello", 42);
trace.add_plot("loss", 0.5, 100);
assert_eq!(trace.messages().len(), 1);
assert_eq!(trace.plots().len(), 1);
assert_eq!(trace.total_records(), 2);
}
#[test]
fn test_export_to_file() {
let mut path = std::env::temp_dir();
path.push("tracy_test_trace.csv");
let mut trace = TracyTrace::new();
trace.add_zone(TracyZone {
name: "ffn".to_string(),
timestamp_ns: 0,
duration_ns: 2_000_000,
thread_id: 1,
});
trace.add_message("epoch start", 0);
trace.add_plot("loss", 0.42, 2_000_000);
trace.export_to_file(&path).unwrap();
assert!(path.exists());
let content = std::fs::read_to_string(&path).unwrap();
assert!(content.contains("ZoneBegin,ffn"));
assert!(content.contains("ZoneEnd,2000000"));
assert!(content.contains("Message,epoch start,0"));
assert!(content.contains("Plot,loss,0.42,2000000"));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_exporter_from_profiler_report() {
use crate::profiler::MemoryEfficiencyAnalysis;
let mut path = std::env::temp_dir();
path.push("tracy_profiler_report.csv");
let report = ProfilerReport {
total_events: 3,
total_runtime: Duration::from_millis(50),
statistics: HashMap::new(),
bottlenecks: vec![],
slowest_layers: vec![
("attn".to_string(), Duration::from_millis(20)),
("ffn".to_string(), Duration::from_millis(30)),
],
memory_efficiency: MemoryEfficiencyAnalysis::default(),
recommendations: vec!["Use flash attention".to_string()],
};
TracyExporter::export_profiler_report(&report, &path).unwrap();
assert!(path.exists());
let content = std::fs::read_to_string(&path).unwrap();
assert!(content.contains("ZoneBegin,attn"));
assert!(content.contains("ZoneBegin,ffn"));
assert!(content.contains("Message,Use flash attention"));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_comma_escaping_in_message() {
let mut path = std::env::temp_dir();
path.push("tracy_comma_test.csv");
let mut trace = TracyTrace::new();
trace.add_message("loss, accuracy: 0.9", 0);
trace.export_to_file(&path).unwrap();
let content = std::fs::read_to_string(&path).unwrap();
assert!(content.contains("loss\\, accuracy: 0.9"));
std::fs::remove_file(&path).ok();
}
}