Skip to main content

oxibonsai_eval/
throughput.rs

1//! Throughput benchmarking for LLM inference.
2//!
3//! [`ThroughputBenchmark`] accumulates timing information from repeated
4//! generation runs and produces a [`ThroughputResult`] with token-per-second
5//! statistics, latency breakdowns, and percentile metrics.
6
7use std::time::{Duration, Instant};
8
9use serde::Serialize;
10
11// ──────────────────────────────────────────────────────────────────────────────
12// Timing helpers
13// ──────────────────────────────────────────────────────────────────────────────
14
15/// Time the execution of `f`, returning both the result and the elapsed duration.
16pub fn time_fn<F, R>(f: F) -> (R, Duration)
17where
18    F: FnOnce() -> R,
19{
20    let start = Instant::now();
21    let result = f();
22    let elapsed = start.elapsed();
23    (result, elapsed)
24}
25
26/// Compute the p-th percentile of `values` (0.0 ≤ p ≤ 100.0).
27///
28/// `values` is sorted in place. Uses linear interpolation between adjacent
29/// elements when the index is not an integer. Returns `0.0` for an empty slice.
30pub fn percentile(mut values: Vec<f32>, p: f32) -> f32 {
31    if values.is_empty() {
32        return 0.0;
33    }
34    values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
35    let p_clamped = p.clamp(0.0, 100.0);
36    let index = p_clamped / 100.0 * (values.len() - 1) as f32;
37    let lo = index.floor() as usize;
38    let hi = (lo + 1).min(values.len() - 1);
39    let frac = index - lo as f32;
40    values[lo] * (1.0 - frac) + values[hi] * frac
41}
42
43// ──────────────────────────────────────────────────────────────────────────────
44// ThroughputResult
45// ──────────────────────────────────────────────────────────────────────────────
46
47/// Statistics from a completed throughput benchmark.
48#[derive(Debug, Serialize)]
49pub struct ThroughputResult {
50    /// Mean tokens per second across all measurement runs.
51    pub tokens_per_second: f32,
52    /// Mean prefill latency in milliseconds.
53    pub prefill_ms: f32,
54    /// Mean per-token decode latency in milliseconds.
55    pub decode_ms_per_token: f32,
56    /// Total number of tokens generated across all runs.
57    pub total_tokens: usize,
58    /// Number of measurement runs performed.
59    pub runs: usize,
60    /// Minimum tokens per second observed.
61    pub min_tps: f32,
62    /// Maximum tokens per second observed.
63    pub max_tps: f32,
64    /// 50th-percentile tokens per second.
65    pub p50_tps: f32,
66    /// 95th-percentile tokens per second.
67    pub p95_tps: f32,
68}
69
70impl ThroughputResult {
71    /// One-line human-readable summary of throughput statistics.
72    pub fn summary(&self) -> String {
73        format!(
74            "Throughput: {:.1} t/s (p50: {:.1}, p95: {:.1})",
75            self.tokens_per_second, self.p50_tps, self.p95_tps
76        )
77    }
78
79    /// Return `true` if the mean throughput meets or exceeds `target_tps`.
80    pub fn meets_target(&self, target_tps: f32) -> bool {
81        self.tokens_per_second >= target_tps
82    }
83
84    /// Human-readable latency breakdown string.
85    pub fn latency_breakdown(&self) -> String {
86        format!(
87            "Prefill: {:.2} ms | Decode: {:.3} ms/token",
88            self.prefill_ms, self.decode_ms_per_token
89        )
90    }
91}
92
93// ──────────────────────────────────────────────────────────────────────────────
94// ThroughputBenchmark
95// ──────────────────────────────────────────────────────────────────────────────
96
97/// Builder for throughput benchmark runs.
98///
99/// Collects timing data from caller-supplied closures rather than running the
100/// model directly, keeping this crate decoupled from the inference engine.
101pub struct ThroughputBenchmark {
102    /// Number of warm-up runs (results discarded).
103    pub warmup_runs: usize,
104    /// Number of measurement runs (results aggregated).
105    pub measurement_runs: usize,
106    /// The prompt used for benchmarking.
107    pub prompt: String,
108    /// Maximum tokens to generate per run.
109    pub max_tokens: usize,
110}
111
112impl ThroughputBenchmark {
113    /// Create a benchmark with 3 warm-up runs and 10 measurement runs.
114    pub fn new(prompt: &str, max_tokens: usize) -> Self {
115        Self {
116            warmup_runs: 3,
117            measurement_runs: 10,
118            prompt: prompt.to_string(),
119            max_tokens,
120        }
121    }
122
123    /// Override the number of warm-up runs.
124    pub fn with_warmup(mut self, warmup: usize) -> Self {
125        self.warmup_runs = warmup;
126        self
127    }
128
129    /// Override the number of measurement runs.
130    pub fn with_runs(mut self, runs: usize) -> Self {
131        self.measurement_runs = runs;
132        self
133    }
134
135    /// Run the benchmark using caller-supplied timing data.
136    ///
137    /// `run_timings` is a slice of `(prefill_ms, decode_ms, tokens_generated)` tuples,
138    /// one per measurement run (warm-up timings should already be excluded by the caller).
139    ///
140    /// This method computes aggregate statistics from the provided data without
141    /// calling the inference engine itself, allowing flexible integration.
142    pub fn from_timings(&self, run_timings: &[(f32, f32, usize)]) -> ThroughputResult {
143        if run_timings.is_empty() {
144            return ThroughputResult {
145                tokens_per_second: 0.0,
146                prefill_ms: 0.0,
147                decode_ms_per_token: 0.0,
148                total_tokens: 0,
149                runs: 0,
150                min_tps: 0.0,
151                max_tps: 0.0,
152                p50_tps: 0.0,
153                p95_tps: 0.0,
154            };
155        }
156
157        let n = run_timings.len() as f32;
158        let mut tps_values: Vec<f32> = Vec::with_capacity(run_timings.len());
159        let mut total_prefill_ms = 0.0f32;
160        let mut total_decode_ms = 0.0f32;
161        let mut total_tokens = 0usize;
162
163        for &(prefill_ms, decode_ms, tokens) in run_timings {
164            total_prefill_ms += prefill_ms;
165            total_decode_ms += decode_ms;
166            total_tokens += tokens;
167
168            let total_ms = prefill_ms + decode_ms;
169            let tps = if total_ms > 0.0 {
170                tokens as f32 / (total_ms / 1000.0)
171            } else {
172                0.0
173            };
174            tps_values.push(tps);
175        }
176
177        let mean_tps = tps_values.iter().copied().sum::<f32>() / n;
178        let min_tps = tps_values.iter().cloned().fold(f32::INFINITY, f32::min);
179        let max_tps = tps_values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
180        let p50_tps = percentile(tps_values.clone(), 50.0);
181        let p95_tps = percentile(tps_values, 95.0);
182
183        let mean_prefill_ms = total_prefill_ms / n;
184        let mean_decode_ms_per_token = if total_tokens > 0 {
185            total_decode_ms / total_tokens as f32
186        } else {
187            0.0
188        };
189
190        ThroughputResult {
191            tokens_per_second: mean_tps,
192            prefill_ms: mean_prefill_ms,
193            decode_ms_per_token: mean_decode_ms_per_token,
194            total_tokens,
195            runs: run_timings.len(),
196            min_tps,
197            max_tps,
198            p50_tps,
199            p95_tps,
200        }
201    }
202}