burn_compute/tune/
tuner.rs

1use core::marker::PhantomData;
2#[cfg(target_family = "wasm")]
3use web_time::Duration;
4
5#[cfg(not(target_family = "wasm"))]
6use core::time::Duration;
7
8use alloc::boxed::Box;
9use alloc::string::ToString;
10use alloc::vec::Vec;
11use burn_common::benchmark::{Benchmark, BenchmarkComputations, BenchmarkDurations};
12
13use crate::channel::ComputeChannel;
14use crate::client::ComputeClient;
15use crate::server::ComputeServer;
16use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache};
17
18#[derive(Debug)]
19/// Executes autotune benchmarking and caching
20pub struct Tuner<S: ComputeServer, C> {
21    tune_cache: TuneCache<S::AutotuneKey>,
22    _channel: PhantomData<C>,
23}
24
25#[allow(clippy::new_without_default)]
26impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
27    /// Returns a tuner with cache initialized from persistent cache
28    pub fn new(device_id: &str) -> Self {
29        Self {
30            tune_cache: TuneCache::new(device_id),
31            _channel: PhantomData,
32        }
33    }
34
35    pub(crate) fn autotune_fastest(&self, key: &S::AutotuneKey) -> Option<usize> {
36        self.tune_cache.find_fastest(key)
37    }
38
39    pub(crate) fn execute_autotune(
40        &mut self,
41        autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
42        client: &ComputeClient<S, C>,
43    ) {
44        let operation = match self.tune_cache.try_cache(autotune_operation_set) {
45            super::TuneCacheResult::Hit(ops) => ops,
46            super::TuneCacheResult::Miss(set) => self.autotuning(set, client),
47        };
48
49        AutotuneOperation::execute(operation);
50    }
51
52    fn autotuning(
53        &mut self,
54        autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
55        client: &ComputeClient<S, C>,
56    ) -> Box<dyn AutotuneOperation> {
57        let key = autotune_operation_set.key();
58        let autotunables = autotune_operation_set.autotunables();
59        let mut names = Vec::with_capacity(autotunables.len());
60
61        let results: Vec<BenchmarkDurations> = autotunables
62            .into_iter()
63            .map(|op| {
64                names.push(op.name().to_string());
65                self.run_benchmark(op, client)
66            })
67            .collect();
68
69        // Finds the fastest operation, stores it and returns it
70        let fastest_index = self.find_fastest(results);
71        let fastest_name = names.get(fastest_index).unwrap();
72        log::info!("Fastest result {fastest_name}-{key}");
73
74        self.tune_cache.cache_insert(key.clone(), fastest_index);
75        #[cfg(feature = "autotune-persistent-cache")]
76        {
77            let checksum = autotune_operation_set.compute_checksum();
78            self.tune_cache
79                .persistent_cache_insert(key, checksum, fastest_index);
80            self.tune_cache.save();
81        }
82
83        match self.tune_cache.try_cache(autotune_operation_set) {
84            super::TuneCacheResult::Hit(ops) => ops,
85            super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"),
86        }
87    }
88
89    fn run_benchmark(
90        &mut self,
91        operation: Box<dyn AutotuneOperation>,
92        client: &ComputeClient<S, C>,
93    ) -> BenchmarkDurations {
94        TuneBenchmark::new(operation, client.clone()).run()
95    }
96
97    fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
98        let mut smallest_duration = Duration::MAX;
99        let mut fastest_tunable = None;
100
101        for (i, result) in results.into_iter().enumerate() {
102            let computed = BenchmarkComputations::new(&result);
103
104            if computed.median < smallest_duration {
105                smallest_duration = computed.median;
106                fastest_tunable = Some(i);
107            }
108        }
109
110        fastest_tunable.expect("At least one kernel needed. ")
111    }
112}