1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::fmt::Display;
7
8use core::pin::Pin;
9use core::time::Duration;
10
11pub enum ProfileDuration {
13 Full(Duration),
15 DeviceDuration(Pin<Box<dyn Future<Output = Duration> + Send + 'static>>),
17}
18
19impl ProfileDuration {
20 pub fn from_duration(duration: Duration) -> Self {
22 ProfileDuration::Full(duration)
23 }
24
25 pub fn from_future(future: impl Future<Output = Duration> + Send + 'static) -> Self {
27 ProfileDuration::DeviceDuration(Box::pin(future))
28 }
29
30 pub fn timing_method(&self) -> TimingMethod {
32 match self {
33 ProfileDuration::Full(_) => TimingMethod::Full,
34 ProfileDuration::DeviceDuration(_) => TimingMethod::DeviceOnly,
35 }
36 }
37
38 pub async fn resolve(self) -> Duration {
40 match self {
41 ProfileDuration::Full(duration) => duration,
42 ProfileDuration::DeviceDuration(future) => future.await,
43 }
44 }
45}
46
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, Eq, PartialEq)]
50pub enum TimingMethod {
51 Full,
54 DeviceOnly,
57}
58
59impl Display for TimingMethod {
60 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61 match self {
62 TimingMethod::Full => f.write_str("full"),
63 TimingMethod::DeviceOnly => f.write_str("device_only"),
64 }
65 }
66}
67
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70#[derive(new, Debug, Clone)]
71pub struct BenchmarkDurations {
72 pub timing_method: TimingMethod,
74 pub durations: Vec<Duration>,
76}
77
78impl BenchmarkDurations {
79 pub fn from_durations(timing_method: TimingMethod, durations: Vec<Duration>) -> Self {
81 Self {
82 timing_method,
83 durations,
84 }
85 }
86
87 pub async fn from_profiles(profiles: Vec<ProfileDuration>) -> Self {
89 let mut durations = Vec::new();
90 let mut types = Vec::new();
91
92 for profile in profiles {
93 types.push(profile.timing_method());
94 durations.push(profile.resolve().await);
95 }
96
97 let timing_method = *types.first().expect("need at least 1 profile");
98 if types.iter().any(|&t| t != timing_method) {
99 panic!("all profiles must use the same timing method");
100 }
101
102 Self {
103 timing_method,
104 durations,
105 }
106 }
107
108 fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
110 let mut sorted = self.durations.clone();
111 sorted.sort();
112 let min = *sorted.first().unwrap();
113 let max = *sorted.last().unwrap();
114 let median = *sorted.get(sorted.len() / 2).unwrap();
115 (min, max, median)
116 }
117
118 pub(crate) fn mean_duration(&self) -> Duration {
120 self.durations.iter().sum::<Duration>() / self.durations.len() as u32
121 }
122
123 pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
125 self.durations
126 .iter()
127 .map(|duration| {
128 let tmp = duration.as_secs_f64() - mean.as_secs_f64();
129 Duration::from_secs_f64(tmp * tmp)
130 })
131 .sum::<Duration>()
132 / self.durations.len() as u32
133 }
134}
135
136impl Display for BenchmarkDurations {
137 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
138 let computed = BenchmarkComputations::new(self);
139 let BenchmarkComputations {
140 mean,
141 median,
142 variance,
143 min,
144 max,
145 } = computed;
146 let num_sample = self.durations.len();
147 let timing_method = self.timing_method;
148
149 f.write_str(
150 format!(
151 "
152―――――――― Result ―――――――――
153 Timing {timing_method}
154 Samples {num_sample}
155 Mean {mean:.3?}
156 Variance {variance:.3?}
157 Median {median:.3?}
158 Min {min:.3?}
159 Max {max:.3?}
160―――――――――――――――――――――――――"
161 )
162 .as_str(),
163 )
164 }
165}
166
167#[cfg_attr(
169 feature = "serde",
170 derive(serde::Serialize, serde::Deserialize, PartialEq, Eq)
171)]
172#[derive(Debug, Default, Clone)]
173pub struct BenchmarkComputations {
174 pub mean: Duration,
176 pub median: Duration,
178 pub variance: Duration,
180 pub min: Duration,
182 pub max: Duration,
184}
185
186impl BenchmarkComputations {
187 pub fn new(durations: &BenchmarkDurations) -> Self {
189 let mean = durations.mean_duration();
190 let (min, max, median) = durations.min_max_median_durations();
191 Self {
192 mean,
193 median,
194 min,
195 max,
196 variance: durations.variance_duration(mean),
197 }
198 }
199}
200
201pub trait Benchmark {
203 type Args: Clone;
205
206 fn prepare(&self) -> Self::Args;
214 fn execute(&self, args: Self::Args);
216 fn num_samples(&self) -> usize {
218 const DEFAULT: usize = 10;
219
220 #[cfg(feature = "std")]
221 {
222 std::env::var("BENCH_NUM_SAMPLES")
223 .map(|val| str::parse::<usize>(&val).unwrap_or(DEFAULT))
224 .unwrap_or(DEFAULT)
225 }
226
227 #[cfg(not(feature = "std"))]
228 {
229 DEFAULT
230 }
231 }
232 fn name(&self) -> String;
235 fn options(&self) -> Option<String> {
237 None
238 }
239 fn shapes(&self) -> Vec<Vec<usize>> {
241 vec![]
242 }
243
244 fn sync(&self);
246
247 #[cfg(feature = "std")]
249 fn profile(&self, args: Self::Args) -> ProfileDuration {
250 self.profile_full(args)
251 }
252
253 #[cfg(feature = "std")]
256 fn profile_full(&self, args: Self::Args) -> ProfileDuration {
257 self.sync();
258 let start_time = std::time::Instant::now();
259 self.execute(args);
260 self.sync();
261 ProfileDuration::from_duration(start_time.elapsed())
262 }
263
264 #[allow(unused_variables)]
266 fn run(&self, timing_method: TimingMethod) -> BenchmarkDurations {
267 #[cfg(not(feature = "std"))]
268 panic!("Attempting to run benchmark in a no-std environment");
269
270 #[cfg(feature = "std")]
271 {
272 let args = self.prepare();
274 for _ in 0..self.num_samples() {
275 self.execute(args.clone());
276 }
277 std::thread::sleep(Duration::from_secs(1));
278
279 let mut durations = Vec::with_capacity(self.num_samples());
280
281 for _ in 0..self.num_samples() {
282 let profile = match timing_method {
283 TimingMethod::Full => self.profile_full(args.clone()),
284 TimingMethod::DeviceOnly => self.profile(args.clone()),
285 };
286 let duration = crate::future::block_on(profile.resolve());
287 durations.push(duration);
288 }
289
290 BenchmarkDurations {
291 timing_method,
292 durations,
293 }
294 }
295 }
296}
297
298#[derive(Clone)]
300pub struct BenchmarkResult {
301 pub raw: BenchmarkDurations,
303 pub computed: BenchmarkComputations,
305 pub git_hash: String,
307 pub name: String,
309 pub options: Option<String>,
311 pub shapes: Vec<Vec<usize>>,
313 pub timestamp: u128,
315}
316
317impl Display for BenchmarkResult {
318 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319 f.write_str(
320 format!(
321 "
322 Timestamp: {}
323 Git Hash: {}
324 Benchmarking - {}{}
325 ",
326 self.timestamp, self.git_hash, self.name, self.raw
327 )
328 .as_str(),
329 )
330 }
331}
332
333#[cfg(feature = "std")]
334pub fn run_benchmark<BM>(benchmark: BM) -> BenchmarkResult
336where
337 BM: Benchmark,
338{
339 let timestamp = std::time::SystemTime::now()
340 .duration_since(std::time::UNIX_EPOCH)
341 .unwrap()
342 .as_millis();
343 let output = std::process::Command::new("git")
344 .args(["rev-parse", "HEAD"])
345 .output()
346 .unwrap();
347 let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
348 let durations = benchmark.run(TimingMethod::Full);
349
350 BenchmarkResult {
351 raw: durations.clone(),
352 computed: BenchmarkComputations::new(&durations),
353 git_hash,
354 name: benchmark.name(),
355 options: benchmark.options(),
356 shapes: benchmark.shapes(),
357 timestamp,
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use alloc::vec;
365
366 #[test]
367 fn test_min_max_median_durations_even_number_of_samples() {
368 let durations = BenchmarkDurations {
369 timing_method: TimingMethod::Full,
370 durations: vec![
371 Duration::new(10, 0),
372 Duration::new(20, 0),
373 Duration::new(30, 0),
374 Duration::new(40, 0),
375 Duration::new(50, 0),
376 ],
377 };
378 let (min, max, median) = durations.min_max_median_durations();
379 assert_eq!(min, Duration::from_secs(10));
380 assert_eq!(max, Duration::from_secs(50));
381 assert_eq!(median, Duration::from_secs(30));
382 }
383
384 #[test]
385 fn test_min_max_median_durations_odd_number_of_samples() {
386 let durations = BenchmarkDurations {
387 timing_method: TimingMethod::Full,
388 durations: vec![
389 Duration::new(18, 5),
390 Duration::new(20, 0),
391 Duration::new(30, 0),
392 Duration::new(40, 0),
393 ],
394 };
395 let (min, max, median) = durations.min_max_median_durations();
396 assert_eq!(min, Duration::from_nanos(18000000005_u64));
397 assert_eq!(max, Duration::from_secs(40));
398 assert_eq!(median, Duration::from_secs(30));
399 }
400
401 #[test]
402 fn test_mean_duration() {
403 let durations = BenchmarkDurations {
404 timing_method: TimingMethod::Full,
405 durations: vec![
406 Duration::new(10, 0),
407 Duration::new(20, 0),
408 Duration::new(30, 0),
409 Duration::new(40, 0),
410 ],
411 };
412 let mean = durations.mean_duration();
413 assert_eq!(mean, Duration::from_secs(25));
414 }
415
416 #[test]
417 fn test_variance_duration() {
418 let durations = BenchmarkDurations {
419 timing_method: TimingMethod::Full,
420 durations: vec![
421 Duration::new(10, 0),
422 Duration::new(20, 0),
423 Duration::new(30, 0),
424 Duration::new(40, 0),
425 Duration::new(50, 0),
426 ],
427 };
428 let mean = durations.mean_duration();
429 let variance = durations.variance_duration(mean);
430 assert_eq!(variance, Duration::from_secs(200));
431 }
432}