Skip to main content

entrenar/cli/commands/
bench.rs

1//! Bench command implementation
2
3use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{BenchArgs, OutputFormat};
6use std::time::Instant;
7
8pub fn run_bench(args: BenchArgs, level: LogLevel) -> Result<(), String> {
9    log(level, LogLevel::Normal, &format!("Running benchmark: {}", args.input.display()));
10
11    // Parse batch sizes
12    let batch_sizes: Vec<usize> = args
13        .batch_sizes
14        .split(',')
15        .map(|s| s.trim().parse::<usize>())
16        .collect::<Result<Vec<_>, _>>()
17        .map_err(|e| format!("Invalid batch sizes: {e}"))?;
18
19    log(level, LogLevel::Normal, &format!("  Warmup: {} iterations", args.warmup));
20    log(level, LogLevel::Normal, &format!("  Iterations: {}", args.iterations));
21    log(level, LogLevel::Normal, &format!("  Batch sizes: {batch_sizes:?}"));
22
23    // Run benchmarks for each batch size
24    for batch_size in &batch_sizes {
25        log(level, LogLevel::Normal, &format!("\nBatch size: {batch_size}"));
26
27        // Warmup
28        for _ in 0..args.warmup {
29            // Simulate inference with small sleep
30            std::thread::sleep(std::time::Duration::from_micros(100));
31        }
32
33        // Measure latency
34        let mut latencies: Vec<f64> = Vec::with_capacity(args.iterations);
35        for _ in 0..args.iterations {
36            let start = Instant::now();
37            // Simulate inference - in real impl would run model forward pass
38            std::thread::sleep(std::time::Duration::from_micros(50 + *batch_size as u64 * 10));
39            let elapsed = start.elapsed().as_secs_f64() * 1000.0; // ms
40            latencies.push(elapsed);
41        }
42
43        // Sort for percentile calculation
44        latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
45
46        let p50 = latencies[latencies.len() * 50 / 100];
47        let p95 = latencies[latencies.len() * 95 / 100];
48        let p99 = latencies[latencies.len() * 99 / 100];
49        let mean = latencies.iter().sum::<f64>() / latencies.len().max(1) as f64;
50        let throughput = 1000.0 / mean * *batch_size as f64;
51
52        if args.format == OutputFormat::Json {
53            let result = serde_json::json!({
54                "batch_size": batch_size,
55                "iterations": args.iterations,
56                "latency_ms": {
57                    "p50": p50,
58                    "p95": p95,
59                    "p99": p99,
60                    "mean": mean
61                },
62                "throughput_samples_per_sec": throughput
63            });
64            if let Ok(json_str) = serde_json::to_string_pretty(&result) {
65                println!("{json_str}");
66            }
67        } else {
68            log(level, LogLevel::Normal, &format!("  p50: {p50:.2}ms"));
69            log(level, LogLevel::Normal, &format!("  p95: {p95:.2}ms"));
70            log(level, LogLevel::Normal, &format!("  p99: {p99:.2}ms"));
71            log(level, LogLevel::Normal, &format!("  mean: {mean:.2}ms"));
72            log(level, LogLevel::Normal, &format!("  throughput: {throughput:.1} samples/sec"));
73        }
74    }
75
76    Ok(())
77}