torsh-profiler 0.1.2

Performance profiling and monitoring for ToRSh
Documentation
//! TensorBoard export functionality

use crate::{ProfileEvent, TorshResult};
use std::fs::File;
use std::io::{BufWriter, Write};
use torsh_core::TorshError;

/// Export events to TensorBoard log format
/// Creates scalar logs for timing and performance metrics
pub fn export_tensorboard_scalars(events: &[ProfileEvent], path: &str) -> TorshResult<()> {
    let file = File::create(path).map_err(|e| {
        TorshError::InvalidArgument(format!("Failed to create TensorBoard log file {path}: {e}"))
    })?;

    let mut writer = BufWriter::new(file);

    // Write header comment
    writeln!(
        writer,
        "# TensorBoard scalars log generated by torsh-profiler"
    )
    .map_err(|e| TorshError::InvalidArgument(format!("Failed to write header: {e}")))?;

    // Group events by category for better organization
    let mut categories = std::collections::HashMap::new();
    for event in events {
        categories
            .entry(&event.category)
            .or_insert_with(Vec::new)
            .push(event);
    }

    let mut step = 0;
    for (category, category_events) in categories {
        // Write timing information
        for event in category_events {
            let duration_ms = event.duration_us as f64 / 1000.0;

            // Write duration
            writeln!(
                writer,
                "scalar,tag={}/{}/duration_ms,step={},value={}",
                category, event.name, step, duration_ms
            )
            .map_err(|e| {
                TorshError::InvalidArgument(format!("Failed to write duration scalar: {e}"))
            })?;

            // Write FLOPS if available
            if let Some(flops) = event.flops {
                let gflops = flops as f64 / 1_000_000_000.0;
                writeln!(
                    writer,
                    "scalar,tag={}/{}/gflops,step={},value={}",
                    category, event.name, step, gflops
                )
                .map_err(|e| {
                    TorshError::InvalidArgument(format!("Failed to write FLOPS scalar: {e}"))
                })?;

                // Write FLOPS per second
                if duration_ms > 0.0 {
                    let gflops_per_sec = gflops / (duration_ms / 1000.0);
                    writeln!(
                        writer,
                        "scalar,tag={}/{}/gflops_per_sec,step={},value={}",
                        category, event.name, step, gflops_per_sec
                    )
                    .map_err(|e| {
                        TorshError::InvalidArgument(format!(
                            "Failed to write FLOPS/sec scalar: {e}"
                        ))
                    })?;
                }
            }

            // Write bandwidth if available
            if let Some(bytes) = event.bytes_transferred {
                let gb = bytes as f64 / 1_073_741_824.0;
                writeln!(
                    writer,
                    "scalar,tag={}/{}/data_gb,step={},value={}",
                    category, event.name, step, gb
                )
                .map_err(|e| {
                    TorshError::InvalidArgument(format!("Failed to write data size scalar: {e}"))
                })?;

                // Write bandwidth in GB/s
                if duration_ms > 0.0 {
                    let bandwidth_gbps = gb / (duration_ms / 1000.0);
                    writeln!(
                        writer,
                        "scalar,tag={}/{}/bandwidth_gbps,step={},value={}",
                        category, event.name, step, bandwidth_gbps
                    )
                    .map_err(|e| {
                        TorshError::InvalidArgument(format!(
                            "Failed to write bandwidth scalar: {e}"
                        ))
                    })?;
                }
            }

            // Write operation count if available
            if let Some(ops) = event.operation_count {
                writeln!(
                    writer,
                    "scalar,tag={}/{}/operations,step={},value={}",
                    category, event.name, step, ops
                )
                .map_err(|e| {
                    TorshError::InvalidArgument(format!("Failed to write operations scalar: {e}"))
                })?;
            }

            step += 1;
        }
    }

    writer.flush().map_err(|e| {
        TorshError::InvalidArgument(format!("Failed to flush TensorBoard writer: {e}"))
    })?;

    Ok(())
}

/// Export events to TensorBoard histogram format
/// Creates histograms for timing distribution analysis
pub fn export_tensorboard_histograms(events: &[ProfileEvent], path: &str) -> TorshResult<()> {
    let file = File::create(path).map_err(|e| {
        TorshError::InvalidArgument(format!(
            "Failed to create TensorBoard histogram file {path}: {e}"
        ))
    })?;

    let mut writer = BufWriter::new(file);

    // Write header comment
    writeln!(
        writer,
        "# TensorBoard histograms log generated by torsh-profiler"
    )
    .map_err(|e| TorshError::InvalidArgument(format!("Failed to write header: {e}")))?;

    // Group events by category and name
    let mut grouped_events = std::collections::HashMap::new();
    for event in events {
        let key = format!("{}/{}", event.category, event.name);
        grouped_events
            .entry(key)
            .or_insert_with(Vec::new)
            .push(event);
    }

    for (name, events_group) in grouped_events {
        // Collect timing data
        let durations: Vec<f64> = events_group
            .iter()
            .map(|e| e.duration_us as f64 / 1000.0)
            .collect();

        if !durations.is_empty() {
            // Create histogram bins
            let min_val = durations.iter().fold(f64::INFINITY, |a, &b| a.min(b));
            let max_val = durations.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

            if min_val < max_val {
                let num_bins = 10;
                let bin_width = (max_val - min_val) / num_bins as f64;
                let mut bins = vec![0u32; num_bins];

                for &duration in &durations {
                    let bin_index = ((duration - min_val) / bin_width).floor() as usize;
                    let bin_index = bin_index.min(num_bins - 1);
                    bins[bin_index] += 1;
                }

                // Write histogram data
                writeln!(writer, "histogram,tag={name}/duration_ms_hist").map_err(|e| {
                    TorshError::InvalidArgument(format!("Failed to write histogram header: {e}"))
                })?;

                for (i, &count) in bins.iter().enumerate() {
                    let bin_start = min_val + i as f64 * bin_width;
                    let bin_end = bin_start + bin_width;
                    writeln!(
                        writer,
                        "bin,start={bin_start:.3},end={bin_end:.3},count={count}"
                    )
                    .map_err(|e| {
                        TorshError::InvalidArgument(format!("Failed to write histogram bin: {e}"))
                    })?;
                }
            }
        }
    }

    writer.flush().map_err(|e| {
        TorshError::InvalidArgument(format!("Failed to flush TensorBoard histogram writer: {e}"))
    })?;

    Ok(())
}

/// Export comprehensive TensorBoard profiling data
pub fn export_tensorboard_profile(events: &[ProfileEvent], base_path: &str) -> TorshResult<()> {
    // Export scalars
    let scalars_path = format!("{base_path}_scalars.log");
    export_tensorboard_scalars(events, &scalars_path)?;

    // Export histograms
    let histograms_path = format!("{base_path}_histograms.log");
    export_tensorboard_histograms(events, &histograms_path)?;

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    fn create_test_events() -> Vec<ProfileEvent> {
        vec![
            ProfileEvent {
                name: "matrix_mul".to_string(),
                category: "compute".to_string(),
                start_us: 1000,
                duration_us: 2000,
                thread_id: 1,
                operation_count: Some(1000000),
                flops: Some(2000000000),
                bytes_transferred: Some(4096),
                stack_trace: None,
            },
            ProfileEvent {
                name: "conv2d".to_string(),
                category: "neural_net".to_string(),
                start_us: 3000,
                duration_us: 1500,
                thread_id: 2,
                operation_count: Some(500000),
                flops: Some(1000000000),
                bytes_transferred: Some(2048),
                stack_trace: None,
            },
        ]
    }

    #[test]
    fn test_tensorboard_scalars_export() {
        let events = create_test_events();
        let scalars_path = std::env::temp_dir().join("test_tensorboard_scalars.log");
        let path = scalars_path.display().to_string();

        export_tensorboard_scalars(&events, &path).unwrap();

        // Verify file was created
        assert!(std::fs::metadata(&path).is_ok());

        // Clean up
        let _ = std::fs::remove_file(&path);
    }

    #[test]
    fn test_tensorboard_histograms_export() {
        let events = create_test_events();
        let histograms_path = std::env::temp_dir().join("test_tensorboard_histograms.log");
        let path = histograms_path.display().to_string();

        export_tensorboard_histograms(&events, &path).unwrap();

        // Verify file was created
        assert!(std::fs::metadata(&path).is_ok());

        // Clean up
        let _ = std::fs::remove_file(&path);
    }

    #[test]
    fn test_tensorboard_profile_export() {
        let events = create_test_events();
        let profile_path = std::env::temp_dir().join("test_tensorboard_profile");
        let base_path = profile_path.display().to_string();

        export_tensorboard_profile(&events, &base_path).unwrap();

        // Verify files were created
        assert!(std::fs::metadata(format!("{base_path}_scalars.log")).is_ok());
        assert!(std::fs::metadata(format!("{base_path}_histograms.log")).is_ok());

        // Clean up
        let _ = std::fs::remove_file(format!("{base_path}_scalars.log"));
        let _ = std::fs::remove_file(format!("{base_path}_histograms.log"));
    }
}