entrenar/cli/commands/
bench.rs1use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{BenchArgs, OutputFormat};
6use std::time::Instant;
7
8pub fn run_bench(args: BenchArgs, level: LogLevel) -> Result<(), String> {
9 log(level, LogLevel::Normal, &format!("Running benchmark: {}", args.input.display()));
10
11 let batch_sizes: Vec<usize> = args
13 .batch_sizes
14 .split(',')
15 .map(|s| s.trim().parse::<usize>())
16 .collect::<Result<Vec<_>, _>>()
17 .map_err(|e| format!("Invalid batch sizes: {e}"))?;
18
19 log(level, LogLevel::Normal, &format!(" Warmup: {} iterations", args.warmup));
20 log(level, LogLevel::Normal, &format!(" Iterations: {}", args.iterations));
21 log(level, LogLevel::Normal, &format!(" Batch sizes: {batch_sizes:?}"));
22
23 for batch_size in &batch_sizes {
25 log(level, LogLevel::Normal, &format!("\nBatch size: {batch_size}"));
26
27 for _ in 0..args.warmup {
29 std::thread::sleep(std::time::Duration::from_micros(100));
31 }
32
33 let mut latencies: Vec<f64> = Vec::with_capacity(args.iterations);
35 for _ in 0..args.iterations {
36 let start = Instant::now();
37 std::thread::sleep(std::time::Duration::from_micros(50 + *batch_size as u64 * 10));
39 let elapsed = start.elapsed().as_secs_f64() * 1000.0; latencies.push(elapsed);
41 }
42
43 latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
45
46 let p50 = latencies[latencies.len() * 50 / 100];
47 let p95 = latencies[latencies.len() * 95 / 100];
48 let p99 = latencies[latencies.len() * 99 / 100];
49 let mean = latencies.iter().sum::<f64>() / latencies.len().max(1) as f64;
50 let throughput = 1000.0 / mean * *batch_size as f64;
51
52 if args.format == OutputFormat::Json {
53 let result = serde_json::json!({
54 "batch_size": batch_size,
55 "iterations": args.iterations,
56 "latency_ms": {
57 "p50": p50,
58 "p95": p95,
59 "p99": p99,
60 "mean": mean
61 },
62 "throughput_samples_per_sec": throughput
63 });
64 if let Ok(json_str) = serde_json::to_string_pretty(&result) {
65 println!("{json_str}");
66 }
67 } else {
68 log(level, LogLevel::Normal, &format!(" p50: {p50:.2}ms"));
69 log(level, LogLevel::Normal, &format!(" p95: {p95:.2}ms"));
70 log(level, LogLevel::Normal, &format!(" p99: {p99:.2}ms"));
71 log(level, LogLevel::Normal, &format!(" mean: {mean:.2}ms"));
72 log(level, LogLevel::Normal, &format!(" throughput: {throughput:.1} samples/sec"));
73 }
74 }
75
76 Ok(())
77}