cubecl_runtime/tune/
tune_cache.rs1#[cfg(std_io)]
2use super::AutotuneError;
3#[cfg(std_io)]
4use super::AutotuneOutcome;
5#[cfg(std_io)]
6use cubecl_common::cache::Cache;
7#[cfg(std_io)]
8use cubecl_common::cache::CacheError;
9#[cfg(std_io)]
10use serde::{Deserialize, Serialize};
11
12use super::AutotuneKey;
13use alloc::string::String;
14use hashbrown::HashMap;
15
16#[derive(Debug)]
18pub(crate) enum CacheEntry {
19 Done {
20 checksum: ChecksumState,
21 fastest_index: usize,
22 },
23 Pending,
24}
25
26#[derive(Debug)]
27#[allow(dead_code)] pub(crate) enum ChecksumState {
29 Match,
30 NoMatch,
31 ToBeVerified(String),
32}
33
34#[cfg(std_io)]
36#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
37pub(crate) struct PersistentCacheKey<K> {
38 key: K,
39 checksum: String,
40}
41
42#[cfg(std_io)]
44#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
45pub(crate) struct PersistentCacheValue {
46 fastest_index: usize,
47 results: Vec<Result<AutotuneOutcome, AutotuneError>>,
48}
49
50#[derive(Debug)]
52pub(crate) struct TuneCache<K> {
53 in_memory_cache: HashMap<K, CacheEntry>,
54 #[cfg(std_io)]
55 persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
56}
57
58#[derive(Debug)]
60pub enum TuneCacheResult {
61 Hit {
63 fastest_index: usize,
65 },
66 Unchecked,
68 Pending,
70 Miss,
72}
73
74impl<K: AutotuneKey> TuneCache<K> {
75 pub(crate) fn new(
76 #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
77 #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
78 ) -> Self {
79 #[cfg(std_io)]
80 {
81 let root = crate::config::GlobalConfig::get().autotune.cache.root();
82 let options = cubecl_common::cache::CacheOption::default();
83 let mut cache = TuneCache {
84 in_memory_cache: HashMap::new(),
85 persistent_cache: Cache::new(
86 format!("{device_id}/{name}"),
87 options.root(root).name("autotune"),
88 ),
89 };
90 cache.load();
91 cache
92 }
93
94 #[cfg(not(std_io))]
95 {
96 TuneCache {
97 in_memory_cache: HashMap::new(),
98 }
99 }
100 }
101
102 pub fn fastest(&self, key: &K) -> TuneCacheResult {
103 let result = self.in_memory_cache.get(key);
104
105 let Some(val) = result else {
106 return TuneCacheResult::Miss;
107 };
108
109 match val {
110 CacheEntry::Done {
111 checksum,
112 fastest_index,
113 } => {
114 if cfg!(std_io) {
115 match checksum {
116 ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
119 fastest_index: *fastest_index,
120 },
121 }
122 } else {
123 let _ = checksum;
125 TuneCacheResult::Hit {
126 fastest_index: *fastest_index,
127 }
128 }
129 }
130 CacheEntry::Pending => TuneCacheResult::Pending,
131 }
132 }
133
134 #[cfg(std_io)]
135 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
136 let result = self.in_memory_cache.get_mut(key);
137 let Some(val) = result else {
138 return;
139 };
140
141 if let CacheEntry::Done {
142 checksum: checksum_state,
143 ..
144 } = val
145 && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
146 {
147 if checksum_expected == checksum {
148 *checksum_state = ChecksumState::Match;
149 } else {
150 *checksum_state = ChecksumState::NoMatch;
151 }
152 }
153 }
154
155 #[allow(unused)]
156 pub(crate) fn mark_pending(&mut self, key: K) {
157 self.in_memory_cache.insert(key, CacheEntry::Pending);
158 }
159
160 pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
161 self.in_memory_cache.insert(
162 key,
163 CacheEntry::Done {
164 checksum: ChecksumState::Match,
165 fastest_index,
166 },
167 );
168 }
169}
170
171#[cfg(std_io)]
172impl<K: AutotuneKey> TuneCache<K> {
173 pub(crate) fn persistent_cache_insert(
174 &mut self,
175 key: K,
176 checksum: String,
177 fastest_index: usize,
178 results: Vec<Result<AutotuneOutcome, AutotuneError>>,
179 ) {
180 if let Err(err) = self.persistent_cache.insert(
181 PersistentCacheKey { key, checksum },
182 PersistentCacheValue {
183 fastest_index,
184 results,
185 },
186 ) {
187 match err {
188 CacheError::DuplicatedKey {
189 key,
190 value_previous,
191 value_updated,
192 } => {
193 log::warn!(
194 "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
195 );
196 }
197 CacheError::KeyOutOfSync { .. } => {
198 }
200 }
201 }
202 }
204
205 pub(crate) fn load(&mut self) {
207 log::info!("Load autotune cache ...");
208 let mut loaded = 0;
209 self.persistent_cache.for_each(|key, value| {
210 loaded += 1;
211 self.in_memory_cache.insert(
212 key.key.clone(),
213 CacheEntry::Done {
214 checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
215 fastest_index: value.fastest_index,
216 },
217 );
218 });
219 log::info!("Loaded {loaded} autotune cached entries");
220 }
221}