trustformers-core 0.1.1

Core traits and utilities for TrustformeRS
Documentation
//! Custom Benchmarking Framework
//!
//! This module provides a flexible framework for creating and running custom benchmarks
//! for TrustformeRS models and components.

#![allow(dead_code)] // Module under development with reserved features

mod builder;
mod registry;
mod reporter;
mod runner;

pub use builder::{BenchmarkBuilder, BenchmarkDSL, BenchmarkSpec, BenchmarkStage};
pub use registry::{BenchmarkCategory, BenchmarkMetadata, BenchmarkRegistry};
pub use reporter::{BenchmarkReport, ReportFormat, Reporter};
pub use runner::{BenchmarkRunner, BenchmarkRunnerBuilder, RunConfig, RunMode};

use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::time::Duration;

/// Trait for implementing custom benchmarks
pub trait CustomBenchmark: Send + Sync {
    /// Unique name for the benchmark
    fn name(&self) -> &str;

    /// Description of what the benchmark measures
    fn description(&self) -> &str;

    /// Tags for categorizing the benchmark
    fn tags(&self) -> Vec<String> {
        vec![]
    }

    /// Setup phase - run once before iterations
    fn setup(&mut self) -> Result<()> {
        Ok(())
    }

    /// Warmup phase - run before measurement
    fn warmup(&mut self, iterations: usize) -> Result<()> {
        for _ in 0..iterations {
            self.run_iteration()?;
        }
        Ok(())
    }

    /// Run a single iteration of the benchmark
    fn run_iteration(&mut self) -> Result<BenchmarkIteration>;

    /// Teardown phase - cleanup after benchmark
    fn teardown(&mut self) -> Result<()> {
        Ok(())
    }

    /// Validate results for correctness
    fn validate(&self, #[allow(unused_variables)] iteration: &BenchmarkIteration) -> Result<bool> {
        Ok(true)
    }

    /// Get configuration parameters
    fn config(&self) -> BenchmarkConfig {
        BenchmarkConfig::default()
    }
}

/// Result of a single benchmark iteration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkIteration {
    /// Duration of the iteration
    pub duration: Duration,
    /// Custom metrics collected
    pub metrics: BenchmarkMetrics,
    /// Optional validation result
    pub validation_passed: Option<bool>,
    /// Any additional data
    pub metadata: Option<serde_json::Value>,
}

/// Custom metrics that can be collected
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BenchmarkMetrics {
    /// Throughput (items per second)
    pub throughput: Option<f64>,
    /// Latency percentiles
    pub latency_percentiles: Option<LatencyPercentiles>,
    /// Memory usage in bytes
    pub memory_bytes: Option<usize>,
    /// GPU utilization percentage
    pub gpu_utilization: Option<f64>,
    /// Model-specific metrics
    pub model_metrics: Option<ModelMetrics>,
    /// Custom numeric metrics
    pub custom: std::collections::HashMap<String, f64>,
}

/// Latency percentiles
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyPercentiles {
    pub p50: f64,
    pub p90: f64,
    pub p95: f64,
    pub p99: f64,
    pub p999: f64,
    pub min: f64,
    pub max: f64,
}

/// Model-specific metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetrics {
    /// Tokens per second
    pub tokens_per_second: Option<f64>,
    /// FLOPS utilization
    pub flops_utilization: Option<f64>,
    /// Batch size used
    pub batch_size: Option<usize>,
    /// Sequence length
    pub sequence_length: Option<usize>,
}

/// Configuration for benchmark execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkConfig {
    /// Number of warmup iterations
    pub warmup_iterations: usize,
    /// Number of measurement iterations
    pub iterations: usize,
    /// Minimum benchmark duration
    pub min_duration: Duration,
    /// Maximum benchmark duration
    pub max_duration: Duration,
    /// Whether to validate results
    pub validate_results: bool,
    /// Number of threads to use
    pub num_threads: Option<usize>,
    /// Device to run on
    pub device: DeviceConfig,
}

impl Default for BenchmarkConfig {
    fn default() -> Self {
        Self {
            warmup_iterations: 10,
            iterations: 100,
            min_duration: Duration::from_secs(10),
            max_duration: Duration::from_secs(300),
            validate_results: true,
            num_threads: None,
            device: DeviceConfig::default(),
        }
    }
}

/// Device configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub enum DeviceConfig {
    /// Run on CPU
    #[default]
    Cpu,
    /// Run on specific GPU
    Gpu(usize),
    /// Run on all available GPUs
    AllGpus,
    /// Custom device string
    Custom(String),
}

/// Macro for easily creating custom benchmarks
#[macro_export]
macro_rules! create_benchmark {
    ($name:ident, $description:expr, $run_fn:expr) => {
        pub struct $name {
            config: BenchmarkConfig,
        }

        impl $name {
            pub fn new() -> Self {
                Self {
                    config: BenchmarkConfig::default(),
                }
            }

            pub fn with_config(config: BenchmarkConfig) -> Self {
                Self { config }
            }
        }

        impl CustomBenchmark for $name {
            fn name(&self) -> &str {
                stringify!($name)
            }

            fn description(&self) -> &str {
                $description
            }

            fn run_iteration(&mut self) -> Result<BenchmarkIteration> {
                $run_fn()
            }

            fn config(&self) -> BenchmarkConfig {
                self.config.clone()
            }
        }
    };
}

/// Example custom benchmark implementation
pub struct ExampleBenchmark {
    model_name: String,
    batch_size: usize,
    sequence_length: usize,
}

impl ExampleBenchmark {
    pub fn new(model_name: String, batch_size: usize, sequence_length: usize) -> Self {
        Self {
            model_name,
            batch_size,
            sequence_length,
        }
    }
}

impl CustomBenchmark for ExampleBenchmark {
    fn name(&self) -> &str {
        "example_benchmark"
    }

    fn description(&self) -> &str {
        "Example custom benchmark for demonstration"
    }

    fn tags(&self) -> Vec<String> {
        vec!["example".to_string(), "demo".to_string()]
    }

    fn setup(&mut self) -> Result<()> {
        println!("Setting up benchmark for model: {}", self.model_name);
        Ok(())
    }

    fn run_iteration(&mut self) -> Result<BenchmarkIteration> {
        use std::time::Instant;

        let start = Instant::now();

        // Simulate some work
        std::thread::sleep(Duration::from_millis(10));

        let duration = start.elapsed();

        let metrics = BenchmarkMetrics {
            throughput: Some(self.batch_size as f64 / duration.as_secs_f64()),
            model_metrics: Some(ModelMetrics {
                tokens_per_second: Some(
                    (self.batch_size * self.sequence_length) as f64 / duration.as_secs_f64(),
                ),
                flops_utilization: Some(0.85),
                batch_size: Some(self.batch_size),
                sequence_length: Some(self.sequence_length),
            }),
            ..Default::default()
        };

        Ok(BenchmarkIteration {
            duration,
            metrics,
            validation_passed: Some(true),
            metadata: None,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_example_benchmark() {
        let mut benchmark = ExampleBenchmark::new("test-model".to_string(), 32, 128);

        assert_eq!(benchmark.name(), "example_benchmark");
        assert!(benchmark.setup().is_ok());

        let iteration = benchmark.run_iteration().expect("operation failed in test");
        assert!(iteration.duration.as_millis() >= 10);
        assert!(iteration.metrics.throughput.is_some());
    }

    #[test]
    fn test_benchmark_config() {
        let config = BenchmarkConfig::default();
        assert_eq!(config.warmup_iterations, 10);
        assert_eq!(config.iterations, 100);
        assert!(config.validate_results);
    }
}