oxibonsai_eval/throughput.rs
1//! Throughput benchmarking for LLM inference.
2//!
3//! [`ThroughputBenchmark`] accumulates timing information from repeated
4//! generation runs and produces a [`ThroughputResult`] with token-per-second
5//! statistics, latency breakdowns, and percentile metrics.
6
7use std::time::{Duration, Instant};
8
9use serde::Serialize;
10
11// ──────────────────────────────────────────────────────────────────────────────
12// Timing helpers
13// ──────────────────────────────────────────────────────────────────────────────
14
15/// Time the execution of `f`, returning both the result and the elapsed duration.
16pub fn time_fn<F, R>(f: F) -> (R, Duration)
17where
18 F: FnOnce() -> R,
19{
20 let start = Instant::now();
21 let result = f();
22 let elapsed = start.elapsed();
23 (result, elapsed)
24}
25
26/// Compute the p-th percentile of `values` (0.0 ≤ p ≤ 100.0).
27///
28/// `values` is sorted in place. Uses linear interpolation between adjacent
29/// elements when the index is not an integer. Returns `0.0` for an empty slice.
30pub fn percentile(mut values: Vec<f32>, p: f32) -> f32 {
31 if values.is_empty() {
32 return 0.0;
33 }
34 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
35 let p_clamped = p.clamp(0.0, 100.0);
36 let index = p_clamped / 100.0 * (values.len() - 1) as f32;
37 let lo = index.floor() as usize;
38 let hi = (lo + 1).min(values.len() - 1);
39 let frac = index - lo as f32;
40 values[lo] * (1.0 - frac) + values[hi] * frac
41}
42
43// ──────────────────────────────────────────────────────────────────────────────
44// ThroughputResult
45// ──────────────────────────────────────────────────────────────────────────────
46
47/// Statistics from a completed throughput benchmark.
48#[derive(Debug, Serialize)]
49pub struct ThroughputResult {
50 /// Mean tokens per second across all measurement runs.
51 pub tokens_per_second: f32,
52 /// Mean prefill latency in milliseconds.
53 pub prefill_ms: f32,
54 /// Mean per-token decode latency in milliseconds.
55 pub decode_ms_per_token: f32,
56 /// Total number of tokens generated across all runs.
57 pub total_tokens: usize,
58 /// Number of measurement runs performed.
59 pub runs: usize,
60 /// Minimum tokens per second observed.
61 pub min_tps: f32,
62 /// Maximum tokens per second observed.
63 pub max_tps: f32,
64 /// 50th-percentile tokens per second.
65 pub p50_tps: f32,
66 /// 95th-percentile tokens per second.
67 pub p95_tps: f32,
68}
69
70impl ThroughputResult {
71 /// One-line human-readable summary of throughput statistics.
72 pub fn summary(&self) -> String {
73 format!(
74 "Throughput: {:.1} t/s (p50: {:.1}, p95: {:.1})",
75 self.tokens_per_second, self.p50_tps, self.p95_tps
76 )
77 }
78
79 /// Return `true` if the mean throughput meets or exceeds `target_tps`.
80 pub fn meets_target(&self, target_tps: f32) -> bool {
81 self.tokens_per_second >= target_tps
82 }
83
84 /// Human-readable latency breakdown string.
85 pub fn latency_breakdown(&self) -> String {
86 format!(
87 "Prefill: {:.2} ms | Decode: {:.3} ms/token",
88 self.prefill_ms, self.decode_ms_per_token
89 )
90 }
91}
92
93// ──────────────────────────────────────────────────────────────────────────────
94// ThroughputBenchmark
95// ──────────────────────────────────────────────────────────────────────────────
96
97/// Builder for throughput benchmark runs.
98///
99/// Collects timing data from caller-supplied closures rather than running the
100/// model directly, keeping this crate decoupled from the inference engine.
101pub struct ThroughputBenchmark {
102 /// Number of warm-up runs (results discarded).
103 pub warmup_runs: usize,
104 /// Number of measurement runs (results aggregated).
105 pub measurement_runs: usize,
106 /// The prompt used for benchmarking.
107 pub prompt: String,
108 /// Maximum tokens to generate per run.
109 pub max_tokens: usize,
110}
111
112impl ThroughputBenchmark {
113 /// Create a benchmark with 3 warm-up runs and 10 measurement runs.
114 pub fn new(prompt: &str, max_tokens: usize) -> Self {
115 Self {
116 warmup_runs: 3,
117 measurement_runs: 10,
118 prompt: prompt.to_string(),
119 max_tokens,
120 }
121 }
122
123 /// Override the number of warm-up runs.
124 pub fn with_warmup(mut self, warmup: usize) -> Self {
125 self.warmup_runs = warmup;
126 self
127 }
128
129 /// Override the number of measurement runs.
130 pub fn with_runs(mut self, runs: usize) -> Self {
131 self.measurement_runs = runs;
132 self
133 }
134
135 /// Run the benchmark using caller-supplied timing data.
136 ///
137 /// `run_timings` is a slice of `(prefill_ms, decode_ms, tokens_generated)` tuples,
138 /// one per measurement run (warm-up timings should already be excluded by the caller).
139 ///
140 /// This method computes aggregate statistics from the provided data without
141 /// calling the inference engine itself, allowing flexible integration.
142 pub fn from_timings(&self, run_timings: &[(f32, f32, usize)]) -> ThroughputResult {
143 if run_timings.is_empty() {
144 return ThroughputResult {
145 tokens_per_second: 0.0,
146 prefill_ms: 0.0,
147 decode_ms_per_token: 0.0,
148 total_tokens: 0,
149 runs: 0,
150 min_tps: 0.0,
151 max_tps: 0.0,
152 p50_tps: 0.0,
153 p95_tps: 0.0,
154 };
155 }
156
157 let n = run_timings.len() as f32;
158 let mut tps_values: Vec<f32> = Vec::with_capacity(run_timings.len());
159 let mut total_prefill_ms = 0.0f32;
160 let mut total_decode_ms = 0.0f32;
161 let mut total_tokens = 0usize;
162
163 for &(prefill_ms, decode_ms, tokens) in run_timings {
164 total_prefill_ms += prefill_ms;
165 total_decode_ms += decode_ms;
166 total_tokens += tokens;
167
168 let total_ms = prefill_ms + decode_ms;
169 let tps = if total_ms > 0.0 {
170 tokens as f32 / (total_ms / 1000.0)
171 } else {
172 0.0
173 };
174 tps_values.push(tps);
175 }
176
177 let mean_tps = tps_values.iter().copied().sum::<f32>() / n;
178 let min_tps = tps_values.iter().cloned().fold(f32::INFINITY, f32::min);
179 let max_tps = tps_values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
180 let p50_tps = percentile(tps_values.clone(), 50.0);
181 let p95_tps = percentile(tps_values, 95.0);
182
183 let mean_prefill_ms = total_prefill_ms / n;
184 let mean_decode_ms_per_token = if total_tokens > 0 {
185 total_decode_ms / total_tokens as f32
186 } else {
187 0.0
188 };
189
190 ThroughputResult {
191 tokens_per_second: mean_tps,
192 prefill_ms: mean_prefill_ms,
193 decode_ms_per_token: mean_decode_ms_per_token,
194 total_tokens,
195 runs: run_timings.len(),
196 min_tps,
197 max_tps,
198 p50_tps,
199 p95_tps,
200 }
201 }
202}