cubecl_common/
benchmark.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::vec;
4use alloc::vec::Vec;
5use core::fmt::Display;
6use serde::{Deserialize, Serialize};
7
8use super::stub::Duration;
9
10/// How a benchmark's execution times are measured.
11#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
12pub enum TimingMethod {
13    /// Time measurements come from full timing of execution + sync
14    /// calls.
15    #[default]
16    Full,
17    /// Time measurements come from hardware reported timestamps
18    /// coming from a sync call.
19    DeviceOnly,
20}
21
22impl Display for TimingMethod {
23    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
24        match self {
25            TimingMethod::Full => f.write_str("full"),
26            TimingMethod::DeviceOnly => f.write_str("device_only"),
27        }
28    }
29}
30
31/// Error that can occurred when collecting timestamps from a device.
32#[derive(Debug)]
33pub enum TimestampsError {
34    /// Collecting timestamps is disabled, make sure to enable it.
35    Disabled,
36    /// Collecting timestamps isn't available.
37    Unavailable,
38    /// An unknown error occurred while collecting timestamps.
39    Unknown(String),
40}
41
42/// Result when collecting timestamps.
43pub type TimestampsResult = Result<Duration, TimestampsError>;
44
45/// Results of a benchmark run.
46#[derive(new, Debug, Default, Clone, Serialize, Deserialize)]
47pub struct BenchmarkDurations {
48    /// How these durations were measured.
49    pub timing_method: TimingMethod,
50    /// All durations of the run, in the order they were benchmarked
51    pub durations: Vec<Duration>,
52}
53
54impl BenchmarkDurations {
55    /// Returns a tuple of durations: (min, max, median)
56    fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
57        let mut sorted = self.durations.clone();
58        sorted.sort();
59        let min = *sorted.first().unwrap();
60        let max = *sorted.last().unwrap();
61        let median = *sorted.get(sorted.len() / 2).unwrap();
62        (min, max, median)
63    }
64
65    /// Returns the median duration among all durations
66    pub(crate) fn mean_duration(&self) -> Duration {
67        self.durations.iter().sum::<Duration>() / self.durations.len() as u32
68    }
69
70    /// Returns the variance durations for the durations
71    pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
72        let var = self
73            .durations
74            .iter()
75            .map(|duration| {
76                let tmp = duration.as_secs_f64() - mean.as_secs_f64();
77                Duration::from_secs_f64(tmp * tmp)
78            })
79            .sum::<Duration>()
80            / self.durations.len() as u32;
81        var
82    }
83}
84
85impl Display for BenchmarkDurations {
86    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
87        let computed = BenchmarkComputations::new(self);
88        let BenchmarkComputations {
89            mean,
90            median,
91            variance,
92            min,
93            max,
94        } = computed;
95        let num_sample = self.durations.len();
96        let timing_method = self.timing_method;
97
98        f.write_str(
99            format!(
100                "
101―――――――― Result ―――――――――
102  Timing      {timing_method}
103  Samples     {num_sample}
104  Mean        {mean:.3?}
105  Variance    {variance:.3?}
106  Median      {median:.3?}
107  Min         {min:.3?}
108  Max         {max:.3?}
109―――――――――――――――――――――――――"
110            )
111            .as_str(),
112        )
113    }
114}
115
116/// Computed values from benchmark durations.
117#[derive(Debug, Default, Clone, Serialize, Deserialize)]
118pub struct BenchmarkComputations {
119    /// Mean of all the durations.
120    pub mean: Duration,
121    /// Median of all the durations.
122    pub median: Duration,
123    /// Variance of all the durations.
124    pub variance: Duration,
125    /// Minimum duration amongst all durations.
126    pub min: Duration,
127    /// Maximum duration amongst all durations.
128    pub max: Duration,
129}
130
131impl BenchmarkComputations {
132    /// Compute duration values and return a BenchmarkComputations struct
133    pub fn new(durations: &BenchmarkDurations) -> Self {
134        let mean = durations.mean_duration();
135        let (min, max, median) = durations.min_max_median_durations();
136        Self {
137            mean,
138            median,
139            min,
140            max,
141            variance: durations.variance_duration(mean),
142        }
143    }
144}
145
146/// Benchmark trait.
147pub trait Benchmark {
148    /// Benchmark arguments.
149    type Args: Clone;
150
151    /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
152    /// count as included in the duration.
153    ///
154    /// # Notes
155    ///
156    /// This should not include warmup, the benchmark will be run at least one time without
157    /// measuring the execution time.
158    fn prepare(&self) -> Self::Args;
159    /// Execute the benchmark and returns the time it took to complete.
160    fn execute(&self, args: Self::Args);
161    /// Number of samples per run required to have a statistical significance.
162    fn num_samples(&self) -> usize {
163        const DEFAULT: usize = 10;
164
165        #[cfg(feature = "std")]
166        {
167            std::env::var("BENCH_NUM_SAMPLES")
168                .map(|val| str::parse::<usize>(&val).unwrap_or(DEFAULT))
169                .unwrap_or(DEFAULT)
170        }
171
172        #[cfg(not(feature = "std"))]
173        {
174            DEFAULT
175        }
176    }
177    /// Name of the benchmark, should be short and it should match the name
178    /// defined in the crate Cargo.toml
179    fn name(&self) -> String;
180    /// The options passed to the benchmark.
181    fn options(&self) -> Option<String> {
182        None
183    }
184    /// Shapes dimensions
185    fn shapes(&self) -> Vec<Vec<usize>> {
186        vec![]
187    }
188    /// Wait for computation to complete.
189    fn sync(&self);
190
191    /// Wait for computation to complete and return hardware reported
192    /// computation duration.
193    fn sync_elapsed(&self) -> TimestampsResult {
194        Err(TimestampsError::Unavailable)
195    }
196
197    /// Run the benchmark a number of times.
198    #[allow(unused_variables)]
199    fn run(&self, timing_method: TimingMethod) -> BenchmarkDurations {
200        #[cfg(not(feature = "std"))]
201        panic!("Attempting to run benchmark in a no-std environment");
202        #[cfg(feature = "std")]
203        {
204            // Warmup
205            let args = self.prepare();
206
207            for _ in 0..self.num_samples() {
208                self.execute(args.clone());
209            }
210
211            match timing_method {
212                TimingMethod::Full => self.sync(),
213                TimingMethod::DeviceOnly => {
214                    let _ = self.sync_elapsed();
215                }
216            }
217            std::thread::sleep(Duration::from_secs(1));
218
219            let mut durations = Vec::with_capacity(self.num_samples());
220
221            for _ in 0..self.num_samples() {
222                match timing_method {
223                    TimingMethod::Full => durations.push(self.run_one_full(args.clone())),
224                    TimingMethod::DeviceOnly => {
225                        durations.push(self.run_one_device_only(args.clone()))
226                    }
227                }
228            }
229
230            BenchmarkDurations {
231                timing_method,
232                durations,
233            }
234        }
235    }
236    #[cfg(feature = "std")]
237    /// Collect one sample directly measuring the full execute + sync
238    /// step.
239    fn run_one_full(&self, args: Self::Args) -> Duration {
240        let start = std::time::Instant::now();
241        self.execute(args);
242        self.sync();
243        start.elapsed()
244    }
245    /// Collect one sample using timing measurements reported by the
246    /// device.
247    #[cfg(feature = "std")]
248    fn run_one_device_only(&self, args: Self::Args) -> Duration {
249        let start = std::time::Instant::now();
250
251        self.execute(args);
252
253        let result = self.sync_elapsed();
254
255        match result {
256            Ok(time) => time,
257            Err(err) => match err {
258                TimestampsError::Disabled => {
259                    panic!("Collecting timestamps is deactivated, make sure to enable it before running the benchmark");
260                }
261                TimestampsError::Unavailable => start.elapsed(),
262                TimestampsError::Unknown(err) => {
263                    panic!("An unknown error occurred while collecting the timestamps when benchmarking: {err}");
264                }
265            },
266        }
267    }
268}
269
270/// Result of a benchmark run, with metadata
271#[derive(Default, Clone)]
272pub struct BenchmarkResult {
273    /// Individual raw results of the run
274    pub raw: BenchmarkDurations,
275    /// Computed values for the run
276    pub computed: BenchmarkComputations,
277    /// Git commit hash of the commit in which the run occurred
278    pub git_hash: String,
279    /// Name of the benchmark
280    pub name: String,
281    /// Options passed to the benchmark
282    pub options: Option<String>,
283    /// Shape dimensions
284    pub shapes: Vec<Vec<usize>>,
285    /// Time just before the run
286    pub timestamp: u128,
287}
288
289impl Display for BenchmarkResult {
290    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
291        f.write_str(
292            format!(
293                "
294        Timestamp: {}
295        Git Hash: {}
296        Benchmarking - {}{}
297        ",
298                self.timestamp, self.git_hash, self.name, self.raw
299            )
300            .as_str(),
301        )
302    }
303}
304
305#[cfg(feature = "std")]
306/// Runs the given benchmark on the device and prints result and information.
307pub fn run_benchmark<BM>(benchmark: BM) -> BenchmarkResult
308where
309    BM: Benchmark,
310{
311    let timestamp = std::time::SystemTime::now()
312        .duration_since(std::time::UNIX_EPOCH)
313        .unwrap()
314        .as_millis();
315    let output = std::process::Command::new("git")
316        .args(["rev-parse", "HEAD"])
317        .output()
318        .unwrap();
319    let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
320    let durations = benchmark.run(TimingMethod::Full);
321
322    BenchmarkResult {
323        raw: durations.clone(),
324        computed: BenchmarkComputations::new(&durations),
325        git_hash,
326        name: benchmark.name(),
327        options: benchmark.options(),
328        shapes: benchmark.shapes(),
329        timestamp,
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use alloc::vec;
337
338    #[test]
339    fn test_min_max_median_durations_even_number_of_samples() {
340        let durations = BenchmarkDurations {
341            timing_method: TimingMethod::Full,
342            durations: vec![
343                Duration::new(10, 0),
344                Duration::new(20, 0),
345                Duration::new(30, 0),
346                Duration::new(40, 0),
347                Duration::new(50, 0),
348            ],
349        };
350        let (min, max, median) = durations.min_max_median_durations();
351        assert_eq!(min, Duration::from_secs(10));
352        assert_eq!(max, Duration::from_secs(50));
353        assert_eq!(median, Duration::from_secs(30));
354    }
355
356    #[test]
357    fn test_min_max_median_durations_odd_number_of_samples() {
358        let durations = BenchmarkDurations {
359            timing_method: TimingMethod::Full,
360            durations: vec![
361                Duration::new(18, 5),
362                Duration::new(20, 0),
363                Duration::new(30, 0),
364                Duration::new(40, 0),
365            ],
366        };
367        let (min, max, median) = durations.min_max_median_durations();
368        assert_eq!(min, Duration::from_nanos(18000000005_u64));
369        assert_eq!(max, Duration::from_secs(40));
370        assert_eq!(median, Duration::from_secs(30));
371    }
372
373    #[test]
374    fn test_mean_duration() {
375        let durations = BenchmarkDurations {
376            timing_method: TimingMethod::Full,
377            durations: vec![
378                Duration::new(10, 0),
379                Duration::new(20, 0),
380                Duration::new(30, 0),
381                Duration::new(40, 0),
382            ],
383        };
384        let mean = durations.mean_duration();
385        assert_eq!(mean, Duration::from_secs(25));
386    }
387
388    #[test]
389    fn test_variance_duration() {
390        let durations = BenchmarkDurations {
391            timing_method: TimingMethod::Full,
392            durations: vec![
393                Duration::new(10, 0),
394                Duration::new(20, 0),
395                Duration::new(30, 0),
396                Duration::new(40, 0),
397                Duration::new(50, 0),
398            ],
399        };
400        let mean = durations.mean_duration();
401        let variance = durations.variance_duration(mean);
402        assert_eq!(variance, Duration::from_secs(200));
403    }
404}