cubecl_runtime/tune/
tune_cache.rs

1#[cfg(std_io)]
2use cubecl_common::cache::Cache;
3#[cfg(std_io)]
4use cubecl_common::cache::CacheError;
5#[cfg(std_io)]
6use serde::{Deserialize, Serialize};
7
8use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
9use alloc::string::String;
10use hashbrown::HashMap;
11
12/// In-memory cache entry
13#[derive(Debug)]
14pub(crate) enum CacheEntry {
15    Done {
16        checksum: ChecksumState,
17        fastest_index: usize,
18    },
19    Pending,
20}
21
22#[derive(Debug)]
23#[allow(dead_code)] // Some variants are not created when the cache isn't saved.
24pub(crate) enum ChecksumState {
25    Match,
26    NoMatch,
27    ToBeVerified(String),
28}
29
30/// Persistent cache key
31#[cfg(std_io)]
32#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
33pub(crate) struct PersistentCacheKey<K> {
34    key: K,
35    checksum: String,
36}
37
38/// Persistent cache entry
39#[cfg(std_io)]
40#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
41pub(crate) struct PersistentCacheValue {
42    fastest_index: usize,
43    results: Vec<AutotuneResult>,
44}
45
46#[cfg_attr(std_io, derive(Serialize, Deserialize))]
47#[derive(Debug, Clone)]
48/// The result of an autotune job.
49pub struct AutotuneResult {
50    pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
51}
52
53impl AutotuneResult {
54    pub(crate) fn error(error: AutotuneError) -> Self {
55        Self {
56            outcome: Err(error),
57        }
58    }
59    pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
60        Self {
61            outcome: Ok(outcome),
62        }
63    }
64}
65
66impl Eq for AutotuneResult {}
67impl PartialEq for AutotuneResult {
68    fn eq(&self, other: &Self) -> bool {
69        match (&self.outcome, &other.outcome) {
70            (Ok(lhs), Ok(rhs)) => lhs == rhs,
71            (Ok(_), Err(_)) => false,
72            (Err(_), Ok(_)) => false,
73            // We don't have to check the error
74            (Err(_), Err(_)) => true,
75        }
76    }
77}
78
79/// Use to find and reuse the best kernel for some input
80#[derive(Debug)]
81pub(crate) struct TuneCache<K> {
82    in_memory_cache: HashMap<K, CacheEntry>,
83    #[cfg(std_io)]
84    persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
85}
86
87/// Result of the cache try
88#[derive(Debug)]
89pub enum TuneCacheResult {
90    /// An operation is found.
91    Hit {
92        /// The index of the fastest operation to execute.
93        fastest_index: usize,
94    },
95    /// The operation might be cached, but we don't know yet whether the checksum is valid.
96    Unchecked,
97    /// We don't know yet what is fastest, but are waiting for a result to come in.
98    Pending,
99    /// No operation is found yet.
100    Miss,
101}
102
103impl<K: AutotuneKey> TuneCache<K> {
104    pub(crate) fn new(
105        #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
106        #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
107    ) -> Self {
108        #[cfg(std_io)]
109        {
110            let root = crate::config::GlobalConfig::get().autotune.cache.root();
111            let options = cubecl_common::cache::CacheOption::default();
112            let mut cache = TuneCache {
113                in_memory_cache: HashMap::new(),
114                persistent_cache: Cache::new(
115                    format!("{device_id}/{name}"),
116                    options.root(root).name("autotune"),
117                ),
118            };
119            cache.load();
120            cache
121        }
122
123        #[cfg(not(std_io))]
124        {
125            TuneCache {
126                in_memory_cache: HashMap::new(),
127            }
128        }
129    }
130
131    pub fn fastest(&self, key: &K) -> TuneCacheResult {
132        let result = self.in_memory_cache.get(key);
133
134        let Some(val) = result else {
135            return TuneCacheResult::Miss;
136        };
137
138        match val {
139            CacheEntry::Done {
140                checksum,
141                fastest_index,
142            } => {
143                if cfg!(std_io) {
144                    match checksum {
145                        ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet.
146                        ChecksumState::NoMatch => TuneCacheResult::Miss, // Can't use this.
147                        ChecksumState::Match => TuneCacheResult::Hit {
148                            fastest_index: *fastest_index,
149                        },
150                    }
151                } else {
152                    // Clippy;
153                    let _ = checksum;
154                    TuneCacheResult::Hit {
155                        fastest_index: *fastest_index,
156                    }
157                }
158            }
159            CacheEntry::Pending => TuneCacheResult::Pending,
160        }
161    }
162
163    #[cfg(std_io)]
164    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
165        let result = self.in_memory_cache.get_mut(key);
166        let Some(val) = result else {
167            return;
168        };
169
170        if let CacheEntry::Done {
171            checksum: checksum_state,
172            ..
173        } = val
174            && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
175        {
176            if checksum_expected == checksum {
177                *checksum_state = ChecksumState::Match;
178            } else {
179                *checksum_state = ChecksumState::NoMatch;
180            }
181        }
182    }
183
184    #[allow(unused)]
185    pub(crate) fn mark_pending(&mut self, key: K) {
186        self.in_memory_cache.insert(key, CacheEntry::Pending);
187    }
188
189    pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
190        self.in_memory_cache.insert(
191            key,
192            CacheEntry::Done {
193                checksum: ChecksumState::Match,
194                fastest_index,
195            },
196        );
197    }
198}
199
200#[cfg(std_io)]
201impl<K: AutotuneKey> TuneCache<K> {
202    pub(crate) fn persistent_cache_insert(
203        &mut self,
204        key: K,
205        checksum: String,
206        fastest_index: usize,
207        results: Vec<AutotuneResult>,
208    ) {
209        if let Err(err) = self.persistent_cache.insert(
210            PersistentCacheKey { key, checksum },
211            PersistentCacheValue {
212                fastest_index,
213                results,
214            },
215        ) {
216            match err {
217                CacheError::DuplicatedKey {
218                    key,
219                    value_previous,
220                    value_updated,
221                } => {
222                    log::warn!(
223                        "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
224                    );
225                }
226                CacheError::KeyOutOfSync { .. } => {
227                    // This is OK.
228                }
229            }
230        }
231        // .expect();
232    }
233
234    /// Load the persistent cache data from disk
235    pub(crate) fn load(&mut self) {
236        log::info!("Load autotune cache ...");
237        let mut loaded = 0;
238        self.persistent_cache.for_each(|key, value| {
239            loaded += 1;
240            self.in_memory_cache.insert(
241                key.key.clone(),
242                CacheEntry::Done {
243                    checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
244                    fastest_index: value.fastest_index,
245                },
246            );
247        });
248        log::info!("Loaded {loaded} autotune cached entries");
249    }
250}