1use alloc::format;
2use alloc::string::String;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::fmt::Display;
6use serde::{Deserialize, Serialize};
7
8use super::stub::Duration;
9
10#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
12pub enum TimingMethod {
13 #[default]
16 Full,
17 DeviceOnly,
20}
21
22impl Display for TimingMethod {
23 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
24 match self {
25 TimingMethod::Full => f.write_str("full"),
26 TimingMethod::DeviceOnly => f.write_str("device_only"),
27 }
28 }
29}
30
31#[derive(Debug)]
33pub enum TimestampsError {
34 Disabled,
36 Unavailable,
38 Unknown(String),
40}
41
42pub type TimestampsResult = Result<Duration, TimestampsError>;
44
45#[derive(new, Debug, Default, Clone, Serialize, Deserialize)]
47pub struct BenchmarkDurations {
48 pub timing_method: TimingMethod,
50 pub durations: Vec<Duration>,
52}
53
54impl BenchmarkDurations {
55 fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
57 let mut sorted = self.durations.clone();
58 sorted.sort();
59 let min = *sorted.first().unwrap();
60 let max = *sorted.last().unwrap();
61 let median = *sorted.get(sorted.len() / 2).unwrap();
62 (min, max, median)
63 }
64
65 pub(crate) fn mean_duration(&self) -> Duration {
67 self.durations.iter().sum::<Duration>() / self.durations.len() as u32
68 }
69
70 pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
72 let var = self
73 .durations
74 .iter()
75 .map(|duration| {
76 let tmp = duration.as_secs_f64() - mean.as_secs_f64();
77 Duration::from_secs_f64(tmp * tmp)
78 })
79 .sum::<Duration>()
80 / self.durations.len() as u32;
81 var
82 }
83}
84
85impl Display for BenchmarkDurations {
86 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
87 let computed = BenchmarkComputations::new(self);
88 let BenchmarkComputations {
89 mean,
90 median,
91 variance,
92 min,
93 max,
94 } = computed;
95 let num_sample = self.durations.len();
96 let timing_method = self.timing_method;
97
98 f.write_str(
99 format!(
100 "
101―――――――― Result ―――――――――
102 Timing {timing_method}
103 Samples {num_sample}
104 Mean {mean:.3?}
105 Variance {variance:.3?}
106 Median {median:.3?}
107 Min {min:.3?}
108 Max {max:.3?}
109―――――――――――――――――――――――――"
110 )
111 .as_str(),
112 )
113 }
114}
115
116#[derive(Debug, Default, Clone, Serialize, Deserialize)]
118pub struct BenchmarkComputations {
119 pub mean: Duration,
121 pub median: Duration,
123 pub variance: Duration,
125 pub min: Duration,
127 pub max: Duration,
129}
130
131impl BenchmarkComputations {
132 pub fn new(durations: &BenchmarkDurations) -> Self {
134 let mean = durations.mean_duration();
135 let (min, max, median) = durations.min_max_median_durations();
136 Self {
137 mean,
138 median,
139 min,
140 max,
141 variance: durations.variance_duration(mean),
142 }
143 }
144}
145
146pub trait Benchmark {
148 type Args: Clone;
150
151 fn prepare(&self) -> Self::Args;
159 fn execute(&self, args: Self::Args);
161 fn num_samples(&self) -> usize {
163 const DEFAULT: usize = 10;
164
165 #[cfg(feature = "std")]
166 {
167 std::env::var("BENCH_NUM_SAMPLES")
168 .map(|val| str::parse::<usize>(&val).unwrap_or(DEFAULT))
169 .unwrap_or(DEFAULT)
170 }
171
172 #[cfg(not(feature = "std"))]
173 {
174 DEFAULT
175 }
176 }
177 fn name(&self) -> String;
180 fn options(&self) -> Option<String> {
182 None
183 }
184 fn shapes(&self) -> Vec<Vec<usize>> {
186 vec![]
187 }
188 fn sync(&self);
190
191 fn sync_elapsed(&self) -> TimestampsResult {
194 Err(TimestampsError::Unavailable)
195 }
196
197 #[allow(unused_variables)]
199 fn run(&self, timing_method: TimingMethod) -> BenchmarkDurations {
200 #[cfg(not(feature = "std"))]
201 panic!("Attempting to run benchmark in a no-std environment");
202 #[cfg(feature = "std")]
203 {
204 let args = self.prepare();
206
207 for _ in 0..self.num_samples() {
208 self.execute(args.clone());
209 }
210
211 match timing_method {
212 TimingMethod::Full => self.sync(),
213 TimingMethod::DeviceOnly => {
214 let _ = self.sync_elapsed();
215 }
216 }
217 std::thread::sleep(Duration::from_secs(1));
218
219 let mut durations = Vec::with_capacity(self.num_samples());
220
221 for _ in 0..self.num_samples() {
222 match timing_method {
223 TimingMethod::Full => durations.push(self.run_one_full(args.clone())),
224 TimingMethod::DeviceOnly => {
225 durations.push(self.run_one_device_only(args.clone()))
226 }
227 }
228 }
229
230 BenchmarkDurations {
231 timing_method,
232 durations,
233 }
234 }
235 }
236 #[cfg(feature = "std")]
237 fn run_one_full(&self, args: Self::Args) -> Duration {
240 let start = std::time::Instant::now();
241 self.execute(args);
242 self.sync();
243 start.elapsed()
244 }
245 #[cfg(feature = "std")]
248 fn run_one_device_only(&self, args: Self::Args) -> Duration {
249 let start = std::time::Instant::now();
250
251 self.execute(args);
252
253 let result = self.sync_elapsed();
254
255 match result {
256 Ok(time) => time,
257 Err(err) => match err {
258 TimestampsError::Disabled => {
259 panic!("Collecting timestamps is deactivated, make sure to enable it before running the benchmark");
260 }
261 TimestampsError::Unavailable => start.elapsed(),
262 TimestampsError::Unknown(err) => {
263 panic!("An unknown error occurred while collecting the timestamps when benchmarking: {err}");
264 }
265 },
266 }
267 }
268}
269
270#[derive(Default, Clone)]
272pub struct BenchmarkResult {
273 pub raw: BenchmarkDurations,
275 pub computed: BenchmarkComputations,
277 pub git_hash: String,
279 pub name: String,
281 pub options: Option<String>,
283 pub shapes: Vec<Vec<usize>>,
285 pub timestamp: u128,
287}
288
289impl Display for BenchmarkResult {
290 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
291 f.write_str(
292 format!(
293 "
294 Timestamp: {}
295 Git Hash: {}
296 Benchmarking - {}{}
297 ",
298 self.timestamp, self.git_hash, self.name, self.raw
299 )
300 .as_str(),
301 )
302 }
303}
304
305#[cfg(feature = "std")]
306pub fn run_benchmark<BM>(benchmark: BM) -> BenchmarkResult
308where
309 BM: Benchmark,
310{
311 let timestamp = std::time::SystemTime::now()
312 .duration_since(std::time::UNIX_EPOCH)
313 .unwrap()
314 .as_millis();
315 let output = std::process::Command::new("git")
316 .args(["rev-parse", "HEAD"])
317 .output()
318 .unwrap();
319 let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
320 let durations = benchmark.run(TimingMethod::Full);
321
322 BenchmarkResult {
323 raw: durations.clone(),
324 computed: BenchmarkComputations::new(&durations),
325 git_hash,
326 name: benchmark.name(),
327 options: benchmark.options(),
328 shapes: benchmark.shapes(),
329 timestamp,
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use alloc::vec;
337
338 #[test]
339 fn test_min_max_median_durations_even_number_of_samples() {
340 let durations = BenchmarkDurations {
341 timing_method: TimingMethod::Full,
342 durations: vec![
343 Duration::new(10, 0),
344 Duration::new(20, 0),
345 Duration::new(30, 0),
346 Duration::new(40, 0),
347 Duration::new(50, 0),
348 ],
349 };
350 let (min, max, median) = durations.min_max_median_durations();
351 assert_eq!(min, Duration::from_secs(10));
352 assert_eq!(max, Duration::from_secs(50));
353 assert_eq!(median, Duration::from_secs(30));
354 }
355
356 #[test]
357 fn test_min_max_median_durations_odd_number_of_samples() {
358 let durations = BenchmarkDurations {
359 timing_method: TimingMethod::Full,
360 durations: vec![
361 Duration::new(18, 5),
362 Duration::new(20, 0),
363 Duration::new(30, 0),
364 Duration::new(40, 0),
365 ],
366 };
367 let (min, max, median) = durations.min_max_median_durations();
368 assert_eq!(min, Duration::from_nanos(18000000005_u64));
369 assert_eq!(max, Duration::from_secs(40));
370 assert_eq!(median, Duration::from_secs(30));
371 }
372
373 #[test]
374 fn test_mean_duration() {
375 let durations = BenchmarkDurations {
376 timing_method: TimingMethod::Full,
377 durations: vec![
378 Duration::new(10, 0),
379 Duration::new(20, 0),
380 Duration::new(30, 0),
381 Duration::new(40, 0),
382 ],
383 };
384 let mean = durations.mean_duration();
385 assert_eq!(mean, Duration::from_secs(25));
386 }
387
388 #[test]
389 fn test_variance_duration() {
390 let durations = BenchmarkDurations {
391 timing_method: TimingMethod::Full,
392 durations: vec![
393 Duration::new(10, 0),
394 Duration::new(20, 0),
395 Duration::new(30, 0),
396 Duration::new(40, 0),
397 Duration::new(50, 0),
398 ],
399 };
400 let mean = durations.mean_duration();
401 let variance = durations.variance_duration(mean);
402 assert_eq!(variance, Duration::from_secs(200));
403 }
404}