cubecl_runtime/tune/
tune_cache.rs

1#[cfg(std_io)]
2use super::AutotuneError;
3#[cfg(std_io)]
4use super::AutotuneOutcome;
5#[cfg(std_io)]
6use cubecl_common::cache::Cache;
7#[cfg(std_io)]
8use cubecl_common::cache::CacheError;
9#[cfg(std_io)]
10use serde::{Deserialize, Serialize};
11
12use super::AutotuneKey;
13use alloc::string::String;
14use hashbrown::HashMap;
15
16/// In-memory cache entry
17#[derive(Debug)]
18pub(crate) enum CacheEntry {
19    Done {
20        checksum: ChecksumState,
21        fastest_index: usize,
22    },
23    Pending,
24}
25
26#[derive(Debug)]
27#[allow(dead_code)] // Some variants are not created when the cache isn't saved.
28pub(crate) enum ChecksumState {
29    Match,
30    NoMatch,
31    ToBeVerified(String),
32}
33
34/// Persistent cache key
35#[cfg(std_io)]
36#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
37pub(crate) struct PersistentCacheKey<K> {
38    key: K,
39    checksum: String,
40}
41
42/// Persistent cache entry
43#[cfg(std_io)]
44#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
45pub(crate) struct PersistentCacheValue {
46    fastest_index: usize,
47    results: Vec<Result<AutotuneOutcome, AutotuneError>>,
48}
49
50/// Use to find and reuse the best kernel for some input
51#[derive(Debug)]
52pub(crate) struct TuneCache<K> {
53    in_memory_cache: HashMap<K, CacheEntry>,
54    #[cfg(std_io)]
55    persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
56}
57
58/// Result of the cache try
59#[derive(Debug)]
60pub enum TuneCacheResult {
61    /// An operation is found.
62    Hit {
63        /// The index of the fastest operation to execute.
64        fastest_index: usize,
65    },
66    /// The operation might be cached, but we don't know yet whether the checksum is valid.
67    Unchecked,
68    /// We don't know yet what is fastest, but are waiting for a result to come in.
69    Pending,
70    /// No operation is found yet.
71    Miss,
72}
73
74impl<K: AutotuneKey> TuneCache<K> {
75    pub(crate) fn new(
76        #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
77        #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
78    ) -> Self {
79        #[cfg(std_io)]
80        {
81            let root = crate::config::GlobalConfig::get().autotune.cache.root();
82            let options = cubecl_common::cache::CacheOption::default();
83            let mut cache = TuneCache {
84                in_memory_cache: HashMap::new(),
85                persistent_cache: Cache::new(
86                    format!("{device_id}/{name}"),
87                    options.root(root).name("autotune"),
88                ),
89            };
90            cache.load();
91            cache
92        }
93
94        #[cfg(not(std_io))]
95        {
96            TuneCache {
97                in_memory_cache: HashMap::new(),
98            }
99        }
100    }
101
102    pub fn fastest(&self, key: &K) -> TuneCacheResult {
103        let result = self.in_memory_cache.get(key);
104
105        let Some(val) = result else {
106            return TuneCacheResult::Miss;
107        };
108
109        match val {
110            CacheEntry::Done {
111                checksum,
112                fastest_index,
113            } => {
114                if cfg!(std_io) {
115                    match checksum {
116                        ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet.
117                        ChecksumState::NoMatch => TuneCacheResult::Miss, // Can't use this.
118                        ChecksumState::Match => TuneCacheResult::Hit {
119                            fastest_index: *fastest_index,
120                        },
121                    }
122                } else {
123                    // Clippy;
124                    let _ = checksum;
125                    TuneCacheResult::Hit {
126                        fastest_index: *fastest_index,
127                    }
128                }
129            }
130            CacheEntry::Pending => TuneCacheResult::Pending,
131        }
132    }
133
134    #[cfg(std_io)]
135    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
136        let result = self.in_memory_cache.get_mut(key);
137        let Some(val) = result else {
138            return;
139        };
140
141        if let CacheEntry::Done {
142            checksum: checksum_state,
143            ..
144        } = val
145            && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
146        {
147            if checksum_expected == checksum {
148                *checksum_state = ChecksumState::Match;
149            } else {
150                *checksum_state = ChecksumState::NoMatch;
151            }
152        }
153    }
154
155    #[allow(unused)]
156    pub(crate) fn mark_pending(&mut self, key: K) {
157        self.in_memory_cache.insert(key, CacheEntry::Pending);
158    }
159
160    pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
161        self.in_memory_cache.insert(
162            key,
163            CacheEntry::Done {
164                checksum: ChecksumState::Match,
165                fastest_index,
166            },
167        );
168    }
169}
170
171#[cfg(std_io)]
172impl<K: AutotuneKey> TuneCache<K> {
173    pub(crate) fn persistent_cache_insert(
174        &mut self,
175        key: K,
176        checksum: String,
177        fastest_index: usize,
178        results: Vec<Result<AutotuneOutcome, AutotuneError>>,
179    ) {
180        if let Err(err) = self.persistent_cache.insert(
181            PersistentCacheKey { key, checksum },
182            PersistentCacheValue {
183                fastest_index,
184                results,
185            },
186        ) {
187            match err {
188                CacheError::DuplicatedKey {
189                    key,
190                    value_previous,
191                    value_updated,
192                } => {
193                    log::warn!(
194                        "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
195                    );
196                }
197                CacheError::KeyOutOfSync { .. } => {
198                    // This is OK.
199                }
200            }
201        }
202        // .expect();
203    }
204
205    /// Load the persistent cache data from disk
206    pub(crate) fn load(&mut self) {
207        log::info!("Load autotune cache ...");
208        let mut loaded = 0;
209        self.persistent_cache.for_each(|key, value| {
210            loaded += 1;
211            self.in_memory_cache.insert(
212                key.key.clone(),
213                CacheEntry::Done {
214                    checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
215                    fastest_index: value.fastest_index,
216                },
217            );
218        });
219        log::info!("Loaded {loaded} autotune cached entries");
220    }
221}