burn_compute/tune/
tuner.rs1use core::marker::PhantomData;
2#[cfg(target_family = "wasm")]
3use web_time::Duration;
4
5#[cfg(not(target_family = "wasm"))]
6use core::time::Duration;
7
8use alloc::boxed::Box;
9use alloc::string::ToString;
10use alloc::vec::Vec;
11use burn_common::benchmark::{Benchmark, BenchmarkComputations, BenchmarkDurations};
12
13use crate::channel::ComputeChannel;
14use crate::client::ComputeClient;
15use crate::server::ComputeServer;
16use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache};
17
18#[derive(Debug)]
19pub struct Tuner<S: ComputeServer, C> {
21 tune_cache: TuneCache<S::AutotuneKey>,
22 _channel: PhantomData<C>,
23}
24
25#[allow(clippy::new_without_default)]
26impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
27 pub fn new(device_id: &str) -> Self {
29 Self {
30 tune_cache: TuneCache::new(device_id),
31 _channel: PhantomData,
32 }
33 }
34
35 pub(crate) fn autotune_fastest(&self, key: &S::AutotuneKey) -> Option<usize> {
36 self.tune_cache.find_fastest(key)
37 }
38
39 pub(crate) fn execute_autotune(
40 &mut self,
41 autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
42 client: &ComputeClient<S, C>,
43 ) {
44 let operation = match self.tune_cache.try_cache(autotune_operation_set) {
45 super::TuneCacheResult::Hit(ops) => ops,
46 super::TuneCacheResult::Miss(set) => self.autotuning(set, client),
47 };
48
49 AutotuneOperation::execute(operation);
50 }
51
52 fn autotuning(
53 &mut self,
54 autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
55 client: &ComputeClient<S, C>,
56 ) -> Box<dyn AutotuneOperation> {
57 let key = autotune_operation_set.key();
58 let autotunables = autotune_operation_set.autotunables();
59 let mut names = Vec::with_capacity(autotunables.len());
60
61 let results: Vec<BenchmarkDurations> = autotunables
62 .into_iter()
63 .map(|op| {
64 names.push(op.name().to_string());
65 self.run_benchmark(op, client)
66 })
67 .collect();
68
69 let fastest_index = self.find_fastest(results);
71 let fastest_name = names.get(fastest_index).unwrap();
72 log::info!("Fastest result {fastest_name}-{key}");
73
74 self.tune_cache.cache_insert(key.clone(), fastest_index);
75 #[cfg(feature = "autotune-persistent-cache")]
76 {
77 let checksum = autotune_operation_set.compute_checksum();
78 self.tune_cache
79 .persistent_cache_insert(key, checksum, fastest_index);
80 self.tune_cache.save();
81 }
82
83 match self.tune_cache.try_cache(autotune_operation_set) {
84 super::TuneCacheResult::Hit(ops) => ops,
85 super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"),
86 }
87 }
88
89 fn run_benchmark(
90 &mut self,
91 operation: Box<dyn AutotuneOperation>,
92 client: &ComputeClient<S, C>,
93 ) -> BenchmarkDurations {
94 TuneBenchmark::new(operation, client.clone()).run()
95 }
96
97 fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
98 let mut smallest_duration = Duration::MAX;
99 let mut fastest_tunable = None;
100
101 for (i, result) in results.into_iter().enumerate() {
102 let computed = BenchmarkComputations::new(&result);
103
104 if computed.median < smallest_duration {
105 smallest_duration = computed.median;
106 fastest_tunable = Some(i);
107 }
108 }
109
110 fastest_tunable.expect("At least one kernel needed. ")
111 }
112}