cubecl_common/
benchmark.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::fmt::Display;
7
8use core::pin::Pin;
9use core::time::Duration;
10
11/// Result from profiling between two measurements. This can either be a duration or a future that resolves to a duration.
12pub enum ProfileDuration {
13    /// Client profile contains a full duration.
14    Full(Duration),
15    /// Client profile measures the device duration, and requires to be resolved.
16    DeviceDuration(Pin<Box<dyn Future<Output = Duration> + Send + 'static>>),
17}
18
19impl ProfileDuration {
20    /// Create a new `ProfileDuration` straight from a duration.
21    pub fn from_duration(duration: Duration) -> Self {
22        ProfileDuration::Full(duration)
23    }
24
25    /// Create a new `ProfileDuration` from a future that resolves to a duration.
26    pub fn from_future(future: impl Future<Output = Duration> + Send + 'static) -> Self {
27        ProfileDuration::DeviceDuration(Box::pin(future))
28    }
29
30    /// The method used to measure the execution time.
31    pub fn timing_method(&self) -> TimingMethod {
32        match self {
33            ProfileDuration::Full(_) => TimingMethod::Full,
34            ProfileDuration::DeviceDuration(_) => TimingMethod::DeviceOnly,
35        }
36    }
37
38    /// Resolve the actual duration of the profile, possibly by waiting for the future to complete.
39    pub async fn resolve(self) -> Duration {
40        match self {
41            ProfileDuration::Full(duration) => duration,
42            ProfileDuration::DeviceDuration(future) => future.await,
43        }
44    }
45}
46
47/// How a benchmark's execution times are measured.
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, Eq, PartialEq)]
50pub enum TimingMethod {
51    /// Time measurements come from full timing of execution + sync
52    /// calls.
53    Full,
54    /// Time measurements come from hardware reported timestamps
55    /// coming from a sync call.
56    DeviceOnly,
57}
58
59impl Display for TimingMethod {
60    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61        match self {
62            TimingMethod::Full => f.write_str("full"),
63            TimingMethod::DeviceOnly => f.write_str("device_only"),
64        }
65    }
66}
67
68/// Results of a benchmark run.
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70#[derive(new, Debug, Clone)]
71pub struct BenchmarkDurations {
72    /// How these durations were measured.
73    pub timing_method: TimingMethod,
74    /// All durations of the run, in the order they were benchmarked
75    pub durations: Vec<Duration>,
76}
77
78impl BenchmarkDurations {
79    /// Construct from a list of durations.
80    pub fn from_durations(timing_method: TimingMethod, durations: Vec<Duration>) -> Self {
81        Self {
82            timing_method,
83            durations,
84        }
85    }
86
87    /// Construct from a list of profiles.
88    pub async fn from_profiles(profiles: Vec<ProfileDuration>) -> Self {
89        let mut durations = Vec::new();
90        let mut types = Vec::new();
91
92        for profile in profiles {
93            types.push(profile.timing_method());
94            durations.push(profile.resolve().await);
95        }
96
97        let timing_method = *types.first().expect("need at least 1 profile");
98        if types.iter().any(|&t| t != timing_method) {
99            panic!("all profiles must use the same timing method");
100        }
101
102        Self {
103            timing_method,
104            durations,
105        }
106    }
107
108    /// Returns a tuple of durations: (min, max, median)
109    fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
110        let mut sorted = self.durations.clone();
111        sorted.sort();
112        let min = *sorted.first().unwrap();
113        let max = *sorted.last().unwrap();
114        let median = *sorted.get(sorted.len() / 2).unwrap();
115        (min, max, median)
116    }
117
118    /// Returns the median duration among all durations
119    pub(crate) fn mean_duration(&self) -> Duration {
120        self.durations.iter().sum::<Duration>() / self.durations.len() as u32
121    }
122
123    /// Returns the variance durations for the durations
124    pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
125        self.durations
126            .iter()
127            .map(|duration| {
128                let tmp = duration.as_secs_f64() - mean.as_secs_f64();
129                Duration::from_secs_f64(tmp * tmp)
130            })
131            .sum::<Duration>()
132            / self.durations.len() as u32
133    }
134}
135
136impl Display for BenchmarkDurations {
137    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
138        let computed = BenchmarkComputations::new(self);
139        let BenchmarkComputations {
140            mean,
141            median,
142            variance,
143            min,
144            max,
145        } = computed;
146        let num_sample = self.durations.len();
147        let timing_method = self.timing_method;
148
149        f.write_str(
150            format!(
151                "
152―――――――― Result ―――――――――
153  Timing      {timing_method}
154  Samples     {num_sample}
155  Mean        {mean:.3?}
156  Variance    {variance:.3?}
157  Median      {median:.3?}
158  Min         {min:.3?}
159  Max         {max:.3?}
160―――――――――――――――――――――――――"
161            )
162            .as_str(),
163        )
164    }
165}
166
167/// Computed values from benchmark durations.
168#[cfg_attr(
169    feature = "serde",
170    derive(serde::Serialize, serde::Deserialize, PartialEq, Eq)
171)]
172#[derive(Debug, Default, Clone)]
173pub struct BenchmarkComputations {
174    /// Mean of all the durations.
175    pub mean: Duration,
176    /// Median of all the durations.
177    pub median: Duration,
178    /// Variance of all the durations.
179    pub variance: Duration,
180    /// Minimum duration amongst all durations.
181    pub min: Duration,
182    /// Maximum duration amongst all durations.
183    pub max: Duration,
184}
185
186impl BenchmarkComputations {
187    /// Compute duration values and return a BenchmarkComputations struct
188    pub fn new(durations: &BenchmarkDurations) -> Self {
189        let mean = durations.mean_duration();
190        let (min, max, median) = durations.min_max_median_durations();
191        Self {
192            mean,
193            median,
194            min,
195            max,
196            variance: durations.variance_duration(mean),
197        }
198    }
199}
200
201/// Benchmark trait.
202pub trait Benchmark {
203    /// Benchmark arguments.
204    type Args: Clone;
205
206    /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
207    /// count as included in the duration.
208    ///
209    /// # Notes
210    ///
211    /// This should not include warmup, the benchmark will be run at least one time without
212    /// measuring the execution time.
213    fn prepare(&self) -> Self::Args;
214    /// Execute the benchmark and returns the time it took to complete.
215    fn execute(&self, args: Self::Args);
216    /// Number of samples per run required to have a statistical significance.
217    fn num_samples(&self) -> usize {
218        const DEFAULT: usize = 10;
219
220        #[cfg(feature = "std")]
221        {
222            std::env::var("BENCH_NUM_SAMPLES")
223                .map(|val| str::parse::<usize>(&val).unwrap_or(DEFAULT))
224                .unwrap_or(DEFAULT)
225        }
226
227        #[cfg(not(feature = "std"))]
228        {
229            DEFAULT
230        }
231    }
232    /// Name of the benchmark, should be short and it should match the name
233    /// defined in the crate Cargo.toml
234    fn name(&self) -> String;
235    /// The options passed to the benchmark.
236    fn options(&self) -> Option<String> {
237        None
238    }
239    /// Shapes dimensions
240    fn shapes(&self) -> Vec<Vec<usize>> {
241        vec![]
242    }
243
244    /// Wait for computation to complete.
245    fn sync(&self);
246
247    /// Start measuring the computation duration.
248    #[cfg(feature = "std")]
249    fn profile(&self, args: Self::Args) -> ProfileDuration {
250        self.profile_full(args)
251    }
252
253    /// Start measuring the computation duration. Use the full duration irregardless of whether
254    /// device duration is available or not.
255    #[cfg(feature = "std")]
256    fn profile_full(&self, args: Self::Args) -> ProfileDuration {
257        self.sync();
258        let start_time = std::time::Instant::now();
259        self.execute(args);
260        self.sync();
261        ProfileDuration::from_duration(start_time.elapsed())
262    }
263
264    /// Run the benchmark a number of times.
265    #[allow(unused_variables)]
266    fn run(&self, timing_method: TimingMethod) -> BenchmarkDurations {
267        #[cfg(not(feature = "std"))]
268        panic!("Attempting to run benchmark in a no-std environment");
269
270        #[cfg(feature = "std")]
271        {
272            // Warmup
273            let args = self.prepare();
274            for _ in 0..self.num_samples() {
275                self.execute(args.clone());
276            }
277            std::thread::sleep(Duration::from_secs(1));
278
279            let mut durations = Vec::with_capacity(self.num_samples());
280
281            for _ in 0..self.num_samples() {
282                let profile = match timing_method {
283                    TimingMethod::Full => self.profile_full(args.clone()),
284                    TimingMethod::DeviceOnly => self.profile(args.clone()),
285                };
286                let duration = crate::future::block_on(profile.resolve());
287                durations.push(duration);
288            }
289
290            BenchmarkDurations {
291                timing_method,
292                durations,
293            }
294        }
295    }
296}
297
298/// Result of a benchmark run, with metadata
299#[derive(Clone)]
300pub struct BenchmarkResult {
301    /// Individual raw results of the run
302    pub raw: BenchmarkDurations,
303    /// Computed values for the run
304    pub computed: BenchmarkComputations,
305    /// Git commit hash of the commit in which the run occurred
306    pub git_hash: String,
307    /// Name of the benchmark
308    pub name: String,
309    /// Options passed to the benchmark
310    pub options: Option<String>,
311    /// Shape dimensions
312    pub shapes: Vec<Vec<usize>>,
313    /// Time just before the run
314    pub timestamp: u128,
315}
316
317impl Display for BenchmarkResult {
318    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319        f.write_str(
320            format!(
321                "
322        Timestamp: {}
323        Git Hash: {}
324        Benchmarking - {}{}
325        ",
326                self.timestamp, self.git_hash, self.name, self.raw
327            )
328            .as_str(),
329        )
330    }
331}
332
333#[cfg(feature = "std")]
334/// Runs the given benchmark on the device and prints result and information.
335pub fn run_benchmark<BM>(benchmark: BM) -> BenchmarkResult
336where
337    BM: Benchmark,
338{
339    let timestamp = std::time::SystemTime::now()
340        .duration_since(std::time::UNIX_EPOCH)
341        .unwrap()
342        .as_millis();
343    let output = std::process::Command::new("git")
344        .args(["rev-parse", "HEAD"])
345        .output()
346        .unwrap();
347    let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
348    let durations = benchmark.run(TimingMethod::Full);
349
350    BenchmarkResult {
351        raw: durations.clone(),
352        computed: BenchmarkComputations::new(&durations),
353        git_hash,
354        name: benchmark.name(),
355        options: benchmark.options(),
356        shapes: benchmark.shapes(),
357        timestamp,
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use alloc::vec;
365
366    #[test]
367    fn test_min_max_median_durations_even_number_of_samples() {
368        let durations = BenchmarkDurations {
369            timing_method: TimingMethod::Full,
370            durations: vec![
371                Duration::new(10, 0),
372                Duration::new(20, 0),
373                Duration::new(30, 0),
374                Duration::new(40, 0),
375                Duration::new(50, 0),
376            ],
377        };
378        let (min, max, median) = durations.min_max_median_durations();
379        assert_eq!(min, Duration::from_secs(10));
380        assert_eq!(max, Duration::from_secs(50));
381        assert_eq!(median, Duration::from_secs(30));
382    }
383
384    #[test]
385    fn test_min_max_median_durations_odd_number_of_samples() {
386        let durations = BenchmarkDurations {
387            timing_method: TimingMethod::Full,
388            durations: vec![
389                Duration::new(18, 5),
390                Duration::new(20, 0),
391                Duration::new(30, 0),
392                Duration::new(40, 0),
393            ],
394        };
395        let (min, max, median) = durations.min_max_median_durations();
396        assert_eq!(min, Duration::from_nanos(18000000005_u64));
397        assert_eq!(max, Duration::from_secs(40));
398        assert_eq!(median, Duration::from_secs(30));
399    }
400
401    #[test]
402    fn test_mean_duration() {
403        let durations = BenchmarkDurations {
404            timing_method: TimingMethod::Full,
405            durations: vec![
406                Duration::new(10, 0),
407                Duration::new(20, 0),
408                Duration::new(30, 0),
409                Duration::new(40, 0),
410            ],
411        };
412        let mean = durations.mean_duration();
413        assert_eq!(mean, Duration::from_secs(25));
414    }
415
416    #[test]
417    fn test_variance_duration() {
418        let durations = BenchmarkDurations {
419            timing_method: TimingMethod::Full,
420            durations: vec![
421                Duration::new(10, 0),
422                Duration::new(20, 0),
423                Duration::new(30, 0),
424                Duration::new(40, 0),
425                Duration::new(50, 0),
426            ],
427        };
428        let mean = durations.mean_duration();
429        let variance = durations.variance_duration(mean);
430        assert_eq!(variance, Duration::from_secs(200));
431    }
432}