use crate::{ProfileEvent, TorshResult};
use std::fs::File;
use std::io::{BufWriter, Write};
use torsh_core::TorshError;
pub fn export_tensorboard_scalars(events: &[ProfileEvent], path: &str) -> TorshResult<()> {
let file = File::create(path).map_err(|e| {
TorshError::InvalidArgument(format!("Failed to create TensorBoard log file {path}: {e}"))
})?;
let mut writer = BufWriter::new(file);
writeln!(
writer,
"# TensorBoard scalars log generated by torsh-profiler"
)
.map_err(|e| TorshError::InvalidArgument(format!("Failed to write header: {e}")))?;
let mut categories = std::collections::HashMap::new();
for event in events {
categories
.entry(&event.category)
.or_insert_with(Vec::new)
.push(event);
}
let mut step = 0;
for (category, category_events) in categories {
for event in category_events {
let duration_ms = event.duration_us as f64 / 1000.0;
writeln!(
writer,
"scalar,tag={}/{}/duration_ms,step={},value={}",
category, event.name, step, duration_ms
)
.map_err(|e| {
TorshError::InvalidArgument(format!("Failed to write duration scalar: {e}"))
})?;
if let Some(flops) = event.flops {
let gflops = flops as f64 / 1_000_000_000.0;
writeln!(
writer,
"scalar,tag={}/{}/gflops,step={},value={}",
category, event.name, step, gflops
)
.map_err(|e| {
TorshError::InvalidArgument(format!("Failed to write FLOPS scalar: {e}"))
})?;
if duration_ms > 0.0 {
let gflops_per_sec = gflops / (duration_ms / 1000.0);
writeln!(
writer,
"scalar,tag={}/{}/gflops_per_sec,step={},value={}",
category, event.name, step, gflops_per_sec
)
.map_err(|e| {
TorshError::InvalidArgument(format!(
"Failed to write FLOPS/sec scalar: {e}"
))
})?;
}
}
if let Some(bytes) = event.bytes_transferred {
let gb = bytes as f64 / 1_073_741_824.0;
writeln!(
writer,
"scalar,tag={}/{}/data_gb,step={},value={}",
category, event.name, step, gb
)
.map_err(|e| {
TorshError::InvalidArgument(format!("Failed to write data size scalar: {e}"))
})?;
if duration_ms > 0.0 {
let bandwidth_gbps = gb / (duration_ms / 1000.0);
writeln!(
writer,
"scalar,tag={}/{}/bandwidth_gbps,step={},value={}",
category, event.name, step, bandwidth_gbps
)
.map_err(|e| {
TorshError::InvalidArgument(format!(
"Failed to write bandwidth scalar: {e}"
))
})?;
}
}
if let Some(ops) = event.operation_count {
writeln!(
writer,
"scalar,tag={}/{}/operations,step={},value={}",
category, event.name, step, ops
)
.map_err(|e| {
TorshError::InvalidArgument(format!("Failed to write operations scalar: {e}"))
})?;
}
step += 1;
}
}
writer.flush().map_err(|e| {
TorshError::InvalidArgument(format!("Failed to flush TensorBoard writer: {e}"))
})?;
Ok(())
}
pub fn export_tensorboard_histograms(events: &[ProfileEvent], path: &str) -> TorshResult<()> {
let file = File::create(path).map_err(|e| {
TorshError::InvalidArgument(format!(
"Failed to create TensorBoard histogram file {path}: {e}"
))
})?;
let mut writer = BufWriter::new(file);
writeln!(
writer,
"# TensorBoard histograms log generated by torsh-profiler"
)
.map_err(|e| TorshError::InvalidArgument(format!("Failed to write header: {e}")))?;
let mut grouped_events = std::collections::HashMap::new();
for event in events {
let key = format!("{}/{}", event.category, event.name);
grouped_events
.entry(key)
.or_insert_with(Vec::new)
.push(event);
}
for (name, events_group) in grouped_events {
let durations: Vec<f64> = events_group
.iter()
.map(|e| e.duration_us as f64 / 1000.0)
.collect();
if !durations.is_empty() {
let min_val = durations.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = durations.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
if min_val < max_val {
let num_bins = 10;
let bin_width = (max_val - min_val) / num_bins as f64;
let mut bins = vec![0u32; num_bins];
for &duration in &durations {
let bin_index = ((duration - min_val) / bin_width).floor() as usize;
let bin_index = bin_index.min(num_bins - 1);
bins[bin_index] += 1;
}
writeln!(writer, "histogram,tag={name}/duration_ms_hist").map_err(|e| {
TorshError::InvalidArgument(format!("Failed to write histogram header: {e}"))
})?;
for (i, &count) in bins.iter().enumerate() {
let bin_start = min_val + i as f64 * bin_width;
let bin_end = bin_start + bin_width;
writeln!(
writer,
"bin,start={bin_start:.3},end={bin_end:.3},count={count}"
)
.map_err(|e| {
TorshError::InvalidArgument(format!("Failed to write histogram bin: {e}"))
})?;
}
}
}
}
writer.flush().map_err(|e| {
TorshError::InvalidArgument(format!("Failed to flush TensorBoard histogram writer: {e}"))
})?;
Ok(())
}
pub fn export_tensorboard_profile(events: &[ProfileEvent], base_path: &str) -> TorshResult<()> {
let scalars_path = format!("{base_path}_scalars.log");
export_tensorboard_scalars(events, &scalars_path)?;
let histograms_path = format!("{base_path}_histograms.log");
export_tensorboard_histograms(events, &histograms_path)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_events() -> Vec<ProfileEvent> {
vec![
ProfileEvent {
name: "matrix_mul".to_string(),
category: "compute".to_string(),
start_us: 1000,
duration_us: 2000,
thread_id: 1,
operation_count: Some(1000000),
flops: Some(2000000000),
bytes_transferred: Some(4096),
stack_trace: None,
},
ProfileEvent {
name: "conv2d".to_string(),
category: "neural_net".to_string(),
start_us: 3000,
duration_us: 1500,
thread_id: 2,
operation_count: Some(500000),
flops: Some(1000000000),
bytes_transferred: Some(2048),
stack_trace: None,
},
]
}
#[test]
fn test_tensorboard_scalars_export() {
let events = create_test_events();
let scalars_path = std::env::temp_dir().join("test_tensorboard_scalars.log");
let path = scalars_path.display().to_string();
export_tensorboard_scalars(&events, &path).unwrap();
assert!(std::fs::metadata(&path).is_ok());
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_tensorboard_histograms_export() {
let events = create_test_events();
let histograms_path = std::env::temp_dir().join("test_tensorboard_histograms.log");
let path = histograms_path.display().to_string();
export_tensorboard_histograms(&events, &path).unwrap();
assert!(std::fs::metadata(&path).is_ok());
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_tensorboard_profile_export() {
let events = create_test_events();
let profile_path = std::env::temp_dir().join("test_tensorboard_profile");
let base_path = profile_path.display().to_string();
export_tensorboard_profile(&events, &base_path).unwrap();
assert!(std::fs::metadata(format!("{base_path}_scalars.log")).is_ok());
assert!(std::fs::metadata(format!("{base_path}_histograms.log")).is_ok());
let _ = std::fs::remove_file(format!("{base_path}_scalars.log"));
let _ = std::fs::remove_file(format!("{base_path}_histograms.log"));
}
}