cubecl_runtime/tune/
local.rs1use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner};
2use crate::{
3 channel::ComputeChannel, client::ComputeClient, server::ComputeServer, tune::TuneCacheResult,
4};
5use core::{fmt::Display, hash::Hash};
6use hashbrown::HashMap;
7
8#[cfg(not(feature = "std"))]
9use alloc::string::ToString;
10
11pub struct LocalTuner<AK: AutotuneKey, ID> {
14 state: spin::RwLock<Option<HashMap<ID, Tuner<AK>>>>,
15 name: &'static str,
16}
17
18#[macro_export]
20macro_rules! local_tuner {
21 ($name:expr) => {
22 LocalTuner::new(concat!(module_path!(), "-", $name));
23 };
24 () => {
25 LocalTuner::new(module_path!());
26 };
27}
28
29pub use local_tuner;
30
31impl<AK: AutotuneKey + 'static, ID: Hash + PartialEq + Eq + Clone + Display> LocalTuner<AK, ID> {
32 pub const fn new(name: &'static str) -> Self {
34 Self {
35 state: spin::RwLock::new(None),
36 name,
37 }
38 }
39
40 pub fn clear(&self) {
42 let mut state = self.state.write();
43 *state = None;
44 }
45
46 #[cfg(feature = "autotune-checks")]
47 fn checks<In: Send + Clone + 'static, Out: AutotuneOutput>(
48 &self,
49 operations: &TunableSet<AK, In, Out>,
50 inputs: &In,
51 ) {
52 let mut checks_outputs = Vec::new();
53 for i in 0..operations.len() {
54 let op = operations.fastest(i);
55 let result = op.execute(inputs.clone());
56 checks_outputs.push(result);
57 }
58 super::check_autotune_outputs(checks_outputs);
59 }
60
61 pub fn execute<S, C, In: Send + Clone + 'static, Out: AutotuneOutput>(
63 &self,
64 id: &ID,
65 client: &ComputeClient<S, C>,
66 operations: &TunableSet<AK, In, Out>,
67 inputs: In,
68 ) -> Out
69 where
70 S: ComputeServer + 'static,
71 C: ComputeChannel<S> + 'static,
72 {
73 let key = operations.generate_key(&inputs);
74
75 if let Some(map) = self.state.read().as_ref() {
77 if let Some(tuner) = map.get(id) {
78 if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
79 #[cfg(feature = "autotune-checks")]
80 self.checks(operations, &inputs);
81
82 let op = operations.fastest(fastest_index);
83 let result = op
84 .execute(inputs)
85 .expect("Should run when selected by autotune.");
86
87 return result;
88 }
89 }
90 }
91
92 let (fastest, run_autotune) = {
95 let mut state = self.state.write();
96 let map = state.get_or_insert_with(Default::default);
97 let tuner = map.entry(id.clone()).or_insert_with(move || {
98 let name = self.name.replace("::", "-");
99 Tuner::new(&name, &id.to_string())
100 });
101
102 #[allow(unused_mut)]
103 let mut fastest = tuner.fastest(&key);
104
105 #[cfg(autotune_persistent_cache)]
107 if matches!(fastest, TuneCacheResult::Unchecked) {
108 let checksum = operations.compute_checksum();
109 tuner.validate_checksum(&key, &checksum);
110 fastest = tuner.fastest(&key);
111 }
112 let mut run_autotune = false;
113
114 if matches!(fastest, TuneCacheResult::Miss) && !tuner.autotuning.contains(&key) {
115 tuner.autotuning.insert(key.clone());
116 run_autotune = true;
117 }
118 (fastest, run_autotune)
119 };
120
121 match fastest {
122 TuneCacheResult::Hit { fastest_index } => {
123 #[cfg(feature = "autotune-checks")]
124 self.checks(operations, &inputs);
125
126 return operations
127 .fastest(fastest_index)
128 .execute(inputs)
129 .expect("Should run when selected by autotune.");
130 }
131 TuneCacheResult::Miss => {
132 if run_autotune {
133 let state = self.state.read();
147 let tuner = state
148 .as_ref()
149 .and_then(|s| s.get(id))
150 .expect("Should be initialized");
151 tuner.execute_autotune(key.clone(), &inputs, operations, client);
152 } else {
153 }
155 }
156 TuneCacheResult::Pending => {
157 }
159 TuneCacheResult::Unchecked => {
160 panic!("Should have checked the cache already.")
161 }
162 };
163
164 let fastest = {
165 let mut state = self.state.write();
166 let tuner = state
167 .as_mut()
168 .and_then(|s| s.get_mut(id))
169 .expect("Should be initialized");
170
171 tuner.handle_results();
173
174 match tuner.fastest(&key) {
176 TuneCacheResult::Hit { fastest_index } => {
177 fastest_index
179 }
180 TuneCacheResult::Pending => {
181 0
183 }
184 TuneCacheResult::Miss => {
185 if run_autotune {
186 panic!("Should have at least started autotuning");
187 } else {
188 0
193 }
194 }
195 TuneCacheResult::Unchecked => {
196 panic!("Should have checked the cache.")
197 }
198 }
199 };
200
201 #[cfg(feature = "autotune-checks")]
202 self.checks(operations, &inputs);
203
204 operations
205 .fastest(fastest)
206 .execute(inputs)
207 .expect("Should run when selected by autotune.")
208 }
209}