Skip to main content

cubecl_runtime/tune/
tune_cache.rs

1#[cfg(std_io)]
2use std::vec::Vec;
3
4#[cfg(std_io)]
5use cubecl_common::cache::Cache;
6#[cfg(std_io)]
7use cubecl_common::cache::CacheError;
8#[cfg(std_io)]
9use serde::{Deserialize, Serialize};
10
11use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
12use alloc::string::String;
13use hashbrown::HashMap;
14
15#[derive(Debug)]
16pub(crate) enum CacheEntry {
17    Done {
18        checksum: ChecksumState,
19        fastest_index: usize,
20    },
21    Pending,
22}
23
24#[derive(Debug)]
25#[allow(dead_code)] // Some variants are not created when the cache isn't saved.
26pub(crate) enum ChecksumState {
27    Match,
28    NoMatch,
29    ToBeVerified(String),
30}
31
32/// Persistent cache key
33#[cfg(std_io)]
34#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
35pub(crate) struct PersistentCacheKey<K> {
36    key: K,
37    checksum: String,
38}
39
40/// Persistent cache entry
41#[cfg(std_io)]
42#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
43pub(crate) struct PersistentCacheValue {
44    fastest_index: usize,
45    results: Vec<AutotuneResult>,
46}
47
48#[cfg_attr(std_io, derive(Serialize, Deserialize))]
49#[derive(Debug, Clone)]
50/// The result of an autotune job.
51pub struct AutotuneResult {
52    pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
53}
54
55impl AutotuneResult {
56    pub(crate) fn error(error: AutotuneError) -> Self {
57        Self {
58            outcome: Err(error),
59        }
60    }
61    pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
62        Self {
63            outcome: Ok(outcome),
64        }
65    }
66}
67
68impl Eq for AutotuneResult {}
69impl PartialEq for AutotuneResult {
70    fn eq(&self, other: &Self) -> bool {
71        match (&self.outcome, &other.outcome) {
72            (Ok(lhs), Ok(rhs)) => lhs == rhs,
73            (Ok(_), Err(_)) => false,
74            (Err(_), Ok(_)) => false,
75            // We don't have to check the error
76            (Err(_), Err(_)) => true,
77        }
78    }
79}
80
81/// Use to find and reuse the best kernel for some input
82#[derive(Debug)]
83pub(crate) struct TuneCache<K> {
84    in_memory_cache: HashMap<K, CacheEntry>,
85    #[cfg(std_io)]
86    persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
87}
88
89/// Result of the cache try
90#[derive(Debug)]
91pub enum TuneCacheResult {
92    /// An operation is found.
93    Hit {
94        /// The index of the fastest operation to execute.
95        fastest_index: usize,
96    },
97    /// The operation might be cached, but we don't know yet whether the checksum is valid.
98    Unchecked,
99    /// A tuning job is in flight for this key — the worker hasn't published a result yet.
100    /// The receiver wakes (with `Err(RecvError)`) when the worker commits the result. Native
101    /// callers `block_on` it and re-query; wasm callers drop it and fall back.
102    Pending,
103    /// No operation is found yet.
104    Miss,
105}
106
107impl<K: AutotuneKey> TuneCache<K> {
108    pub(crate) fn new(
109        #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
110        #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
111    ) -> Self {
112        #[cfg(std_io)]
113        {
114            use crate::config::RuntimeConfig;
115            use std::format;
116
117            let root = crate::config::CubeClRuntimeConfig::get()
118                .autotune
119                .cache
120                .root();
121            let options = cubecl_common::cache::CacheOption::default();
122            let mut cache = TuneCache {
123                in_memory_cache: HashMap::new(),
124                persistent_cache: Cache::new(
125                    format!("{device_id}/{name}"),
126                    options.root(root).name("autotune"),
127                ),
128            };
129            cache.load();
130            cache
131        }
132
133        #[cfg(not(std_io))]
134        {
135            TuneCache {
136                in_memory_cache: HashMap::new(),
137            }
138        }
139    }
140
141    pub fn fastest(&self, key: &K) -> TuneCacheResult {
142        let Some(val) = self.in_memory_cache.get(key) else {
143            return TuneCacheResult::Miss;
144        };
145
146        let CacheEntry::Done {
147            checksum,
148            fastest_index,
149        } = val
150        else {
151            // Pending: clone the receiver so the caller can subscribe to the in-flight tune.
152            let CacheEntry::Pending = val else {
153                unreachable!()
154            };
155            return TuneCacheResult::Pending;
156        };
157
158        if cfg!(std_io) {
159            match checksum {
160                ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet.
161                ChecksumState::NoMatch => TuneCacheResult::Miss,               // Can't use this.
162                ChecksumState::Match => TuneCacheResult::Hit {
163                    fastest_index: *fastest_index,
164                },
165            }
166        } else {
167            // Clippy;
168            let _ = checksum;
169            TuneCacheResult::Hit {
170                fastest_index: *fastest_index,
171            }
172        }
173    }
174
175    #[cfg(std_io)]
176    pub fn validate_checksum(&mut self, key: &K, checksum: &str) -> TuneCacheResult {
177        let Some(val) = self.in_memory_cache.get_mut(key) else {
178            return TuneCacheResult::Miss;
179        };
180
181        if let CacheEntry::Done {
182            checksum: checksum_state,
183            ..
184        } = val
185            && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
186        {
187            if checksum_expected == checksum {
188                *checksum_state = ChecksumState::Match;
189            } else {
190                *checksum_state = ChecksumState::NoMatch;
191            }
192        }
193
194        self.fastest(key)
195    }
196
197    /// Mark a key as being tuned. Used by [`Tuner::tune`] under the cache mutex so that
198    /// concurrent callers see [`TuneCacheResult::Pending`] and wait on the same job instead of
199    /// starting a second one. Returns `(Sender, Receiver)`:
200    pub(crate) fn mark_pending(&mut self, key: K) {
201        self.in_memory_cache.insert(key, CacheEntry::Pending);
202    }
203
204    pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
205        self.in_memory_cache.insert(
206            key,
207            CacheEntry::Done {
208                checksum: ChecksumState::Match,
209                fastest_index,
210            },
211        );
212    }
213}
214
215#[cfg(std_io)]
216impl<K: AutotuneKey> TuneCache<K> {
217    pub(crate) fn persistent_cache_insert(
218        &mut self,
219        key: K,
220        checksum: String,
221        fastest_index: usize,
222        results: Vec<AutotuneResult>,
223    ) {
224        if let Err(err) = self.persistent_cache.insert(
225            PersistentCacheKey { key, checksum },
226            PersistentCacheValue {
227                fastest_index,
228                results,
229            },
230        ) {
231            match err {
232                CacheError::DuplicatedKey {
233                    key,
234                    value_previous,
235                    value_updated,
236                } => {
237                    log::warn!(
238                        "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
239                    );
240                }
241                CacheError::KeyOutOfSync { .. } => {
242                    // This is OK.
243                }
244            }
245        }
246        // .expect();
247    }
248
249    /// Load the persistent cache data from disk
250    pub(crate) fn load(&mut self) {
251        log::info!("Load autotune cache ...");
252        let mut loaded = 0;
253        self.persistent_cache.for_each(|key, value| {
254            loaded += 1;
255            self.in_memory_cache.insert(
256                key.key.clone(),
257                CacheEntry::Done {
258                    checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
259                    fastest_index: value.fastest_index,
260                },
261            );
262        });
263        log::info!("Loaded {loaded} autotune cached entries");
264    }
265}