Skip to main content

cubecl_common/
benchmark.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::fmt::Display;
6use core::time::Duration;
7
8pub use crate::profile::{Instant, TimingMethod};
9
10#[cfg(feature = "std")]
11pub use crate::profile::ProfileDuration;
12
13/// Results of a benchmark run.
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15#[derive(new, Debug, Clone)]
16pub struct BenchmarkDurations {
17    /// How these durations were measured.
18    pub timing_method: TimingMethod,
19    /// All durations of the run, in the order they were benchmarked
20    pub durations: Vec<Duration>,
21}
22
23impl BenchmarkDurations {
24    /// Construct from a list of durations.
25    pub fn from_durations(timing_method: TimingMethod, durations: Vec<Duration>) -> Self {
26        Self {
27            timing_method,
28            durations,
29        }
30    }
31
32    /// Returns a tuple of durations: (min, max, median)
33    fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
34        let mut sorted = self.durations.clone();
35        sorted.sort();
36        let min = *sorted.first().unwrap();
37        let max = *sorted.last().unwrap();
38        let median = *sorted.get(sorted.len() / 2).unwrap();
39        (min, max, median)
40    }
41
42    /// Returns the median duration among all durations
43    pub(crate) fn mean_duration(&self) -> Duration {
44        self.durations.iter().sum::<Duration>() / self.durations.len() as u32
45    }
46
47    /// Returns the variance durations for the durations
48    pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
49        self.durations
50            .iter()
51            .map(|duration| {
52                let tmp = duration.as_secs_f64() - mean.as_secs_f64();
53                Duration::from_secs_f64(tmp * tmp)
54            })
55            .sum::<Duration>()
56            / self.durations.len() as u32
57    }
58}
59
60impl Display for BenchmarkDurations {
61    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62        let computed = BenchmarkComputations::new(self);
63        let BenchmarkComputations {
64            mean,
65            median,
66            variance,
67            min,
68            max,
69        } = computed;
70        let num_sample = self.durations.len();
71        let timing_method = self.timing_method;
72
73        f.write_str(
74            format!(
75                "
76―――――――― Result ―――――――――
77  Timing      {timing_method}
78  Samples     {num_sample}
79  Mean        {mean:.3?}
80  Variance    {variance:.3?}
81  Median      {median:.3?}
82  Min         {min:.3?}
83  Max         {max:.3?}
84―――――――――――――――――――――――――"
85            )
86            .as_str(),
87        )
88    }
89}
90
91/// Computed values from benchmark durations.
92#[cfg_attr(
93    feature = "serde",
94    derive(serde::Serialize, serde::Deserialize, PartialEq, Eq)
95)]
96#[derive(Debug, Default, Clone)]
97pub struct BenchmarkComputations {
98    /// Mean of all the durations.
99    pub mean: Duration,
100    /// Median of all the durations.
101    pub median: Duration,
102    /// Variance of all the durations.
103    pub variance: Duration,
104    /// Minimum duration amongst all durations.
105    pub min: Duration,
106    /// Maximum duration amongst all durations.
107    pub max: Duration,
108}
109
110impl BenchmarkComputations {
111    /// Compute duration values and return a `BenchmarkComputations` struct
112    pub fn new(durations: &BenchmarkDurations) -> Self {
113        let mean = durations.mean_duration();
114        let (min, max, median) = durations.min_max_median_durations();
115        Self {
116            mean,
117            median,
118            min,
119            max,
120            variance: durations.variance_duration(mean),
121        }
122    }
123
124    /// Returns the score of the current benchmark.
125    pub fn score(&self) -> u64 {
126        // How much optimism we have regarding the benchmark.
127        //
128        // The higher the value, the more we prioritize the fastest run regardless of variation.
129        const ALPHA: f64 = 0.8;
130
131        let min_ns = self.min.as_nanos() as f64;
132        let median_ns = self.median.as_nanos() as f64;
133        let variance_ns = self.variance.as_nanos() as f64;
134        let mean_ns = self.mean.as_nanos() as f64;
135
136        // The base score is based on the fastest run and the median duration.
137        let base_score = (min_ns * ALPHA) + (median_ns * (1.0 - ALPHA));
138
139        // If the standard deviation is high relative to the mean,
140        // we inflate the score (making it less desirable).
141        let std_dev = num_traits::Float::sqrt(variance_ns);
142
143        // Lower is better
144        let coefficient_of_variation = 1.0
145            + (std_dev
146                / (
147                    // The `1.0` is only for numerical stability with small numbers.
148                    // Since we work with nanos, this is negligible.
149                    1.0 + mean_ns
150                ));
151
152        // Return score (Lower is better)
153        (base_score * coefficient_of_variation) as u64
154    }
155}
156
157/// Benchmark trait.
158pub trait Benchmark {
159    /// Benchmark input arguments.
160    type Input: Clone;
161    /// The benchmark output.
162    type Output;
163
164    /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
165    /// count as included in the duration.
166    ///
167    /// # Notes
168    ///
169    /// This should not include warmup, the benchmark will be run at least one time without
170    /// measuring the execution time.
171    fn prepare(&self) -> Self::Input;
172
173    /// Execute the benchmark and returns the logical output of the task executed.
174    ///
175    /// It is important to return the output since otherwise deadcode optimization might optimize
176    /// away code that should be benchmarked.
177    fn execute(&self, input: Self::Input) -> Result<Self::Output, String>;
178
179    /// Number of samples per run required to have a statistical significance.
180    fn num_samples(&self) -> usize {
181        const DEFAULT: usize = 15;
182        #[cfg(feature = "std")]
183        {
184            std::env::var("BENCH_NUM_SAMPLES")
185                .map(|val| str::parse::<usize>(&val).unwrap_or(DEFAULT))
186                .unwrap_or(DEFAULT)
187        }
188
189        #[cfg(not(feature = "std"))]
190        {
191            DEFAULT
192        }
193    }
194
195    /// Name of the benchmark, should be short and it should match the name
196    /// defined in the crate Cargo.toml
197    fn name(&self) -> String;
198
199    /// The options passed to the benchmark.
200    fn options(&self) -> Option<String> {
201        None
202    }
203
204    /// Shapes dimensions
205    fn shapes(&self) -> Vec<Vec<usize>> {
206        vec![]
207    }
208
209    /// Wait for computation to complete.
210    fn sync(&self);
211
212    /// Start measuring the computation duration.
213    #[cfg(feature = "std")]
214    fn profile(&self, args: Self::Input) -> Result<ProfileDuration, String> {
215        self.profile_full(args)
216    }
217
218    /// Start measuring the computation duration. Use the full duration irregardless of whether
219    /// device duration is available or not.
220    #[cfg(feature = "std")]
221    fn profile_full(&self, args: Self::Input) -> Result<ProfileDuration, String> {
222        self.sync();
223        let start_time = Instant::now();
224        let out = self.execute(args)?;
225        self.sync();
226        core::mem::drop(out);
227        Ok(ProfileDuration::new_system_time(start_time, Instant::now()))
228    }
229
230    /// Run the benchmark a number of times.
231    #[allow(unused_variables)]
232    fn run(&self, timing_method: TimingMethod) -> Result<BenchmarkDurations, String> {
233        #[cfg(not(feature = "std"))]
234        panic!("Attempting to run benchmark in a no-std environment");
235
236        #[cfg(feature = "std")]
237        {
238            let execute = |args: &Self::Input| {
239                let profile: Result<ProfileDuration, String> = match timing_method {
240                    TimingMethod::System => self.profile_full(args.clone()),
241                    TimingMethod::Device => self.profile(args.clone()),
242                };
243                let profile = match profile {
244                    Ok(val) => val,
245                    Err(err) => return Err(err),
246                };
247                Ok(crate::future::block_on(profile.resolve()))
248            };
249            let args = self.prepare();
250
251            // Triggers JIT-compilation and perform a Warmup
252            //
253            // We are using 5 iterations, where the first one probably triggers the JIT-compilation
254            // and it is then followed by 4 warmup executions.
255            for _ in 0..5 {
256                let _duration: Result<crate::profile::ProfileTicks, _> = execute(&args);
257            }
258
259            // Real execution.
260            let mut durations = Vec::with_capacity(self.num_samples());
261            for _ in 0..self.num_samples() {
262                match execute(&args) {
263                    Ok(val) => durations.push(val.duration()),
264                    Err(err) => {
265                        return Err(err);
266                    }
267                }
268            }
269
270            Ok(BenchmarkDurations {
271                timing_method,
272                durations,
273            })
274        }
275    }
276}
277
278/// Result of a benchmark run, with metadata
279#[derive(Clone)]
280pub struct BenchmarkResult {
281    /// Individual raw results of the run
282    pub raw: BenchmarkDurations,
283    /// Computed values for the run
284    pub computed: BenchmarkComputations,
285    /// Git commit hash of the commit in which the run occurred
286    pub git_hash: String,
287    /// Name of the benchmark
288    pub name: String,
289    /// Options passed to the benchmark
290    pub options: Option<String>,
291    /// Shape dimensions
292    pub shapes: Vec<Vec<usize>>,
293    /// Time just before the run
294    pub timestamp: u128,
295}
296
297impl Display for BenchmarkResult {
298    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
299        f.write_str(
300            format!(
301                "
302        Timestamp: {}
303        Git Hash: {}
304        Benchmarking - {}{}
305        ",
306                self.timestamp, self.git_hash, self.name, self.raw
307            )
308            .as_str(),
309        )
310    }
311}
312
313#[cfg(feature = "std")]
314/// Runs the given benchmark on the device and prints result and information.
315pub fn run_benchmark<BM>(benchmark: BM) -> Result<BenchmarkResult, String>
316where
317    BM: Benchmark,
318{
319    use std::string::ToString;
320
321    let timestamp = std::time::SystemTime::now()
322        .duration_since(std::time::UNIX_EPOCH)
323        .unwrap()
324        .as_millis();
325    let output = std::process::Command::new("git")
326        .args(["rev-parse", "HEAD"])
327        .output()
328        .unwrap();
329    let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
330    let durations = benchmark.run(TimingMethod::System)?;
331
332    Ok(BenchmarkResult {
333        raw: durations.clone(),
334        computed: BenchmarkComputations::new(&durations),
335        git_hash,
336        name: benchmark.name(),
337        options: benchmark.options(),
338        shapes: benchmark.shapes(),
339        timestamp,
340    })
341}
342
343#[cfg(test)]
344#[cfg(feature = "std")]
345mod tests {
346    use super::*;
347    use alloc::vec;
348
349    #[test_log::test]
350    fn test_min_max_median_durations_even_number_of_samples() {
351        let durations = BenchmarkDurations {
352            timing_method: TimingMethod::System,
353            durations: vec![
354                Duration::new(10, 0),
355                Duration::new(20, 0),
356                Duration::new(30, 0),
357                Duration::new(40, 0),
358                Duration::new(50, 0),
359            ],
360        };
361        let (min, max, median) = durations.min_max_median_durations();
362        assert_eq!(min, Duration::from_secs(10));
363        assert_eq!(max, Duration::from_secs(50));
364        assert_eq!(median, Duration::from_secs(30));
365    }
366
367    #[test_log::test]
368    fn test_min_max_median_durations_odd_number_of_samples() {
369        let durations = BenchmarkDurations {
370            timing_method: TimingMethod::System,
371            durations: vec![
372                Duration::new(18, 5),
373                Duration::new(20, 0),
374                Duration::new(30, 0),
375                Duration::new(40, 0),
376            ],
377        };
378        let (min, max, median) = durations.min_max_median_durations();
379        assert_eq!(min, Duration::from_nanos(18000000005_u64));
380        assert_eq!(max, Duration::from_secs(40));
381        assert_eq!(median, Duration::from_secs(30));
382    }
383
384    #[test_log::test]
385    fn test_mean_duration() {
386        let durations = BenchmarkDurations {
387            timing_method: TimingMethod::System,
388            durations: vec![
389                Duration::new(10, 0),
390                Duration::new(20, 0),
391                Duration::new(30, 0),
392                Duration::new(40, 0),
393            ],
394        };
395        let mean = durations.mean_duration();
396        assert_eq!(mean, Duration::from_secs(25));
397    }
398
399    #[test_log::test]
400    fn test_variance_duration() {
401        let durations = BenchmarkDurations {
402            timing_method: TimingMethod::System,
403            durations: vec![
404                Duration::new(10, 0),
405                Duration::new(20, 0),
406                Duration::new(30, 0),
407                Duration::new(40, 0),
408                Duration::new(50, 0),
409            ],
410        };
411        let mean = durations.mean_duration();
412        let variance = durations.variance_duration(mean);
413        assert_eq!(variance, Duration::from_secs(200));
414    }
415}