cubecl_runtime/tune/
operation.rs

1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::vec::Vec;
4use core::fmt::{Debug, Display};
5use core::hash::Hash;
6
7use super::AutotuneError;
8
9/// Default checksum for an operation set
10#[cfg(autotune_persistent_cache)]
11pub fn compute_checksum<Out: Send + 'static>(
12    autotunables: &[Box<dyn AutotuneOperation<Out>>],
13) -> String {
14    let mut checksum = String::new();
15    autotunables.iter().for_each(|op| {
16        checksum += op.name();
17    });
18    format!("{:x}", md5::compute(checksum))
19}
20
21/// Groups operations of the same type for autotune
22pub trait AutotuneOperationSet<K: Send + 'static, Output: Send + 'static = ()>: Send {
23    /// The key used in the tune cache
24    fn key(&self) -> K;
25
26    /// All candidate operations for autotuning this operation type
27    /// Operations can run on toy tensors of relevant size
28    fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation<Output>>>;
29
30    /// Returns the operation for the given index, matching the order
31    /// returned by autotunables. Operation obtained here runs on original tensors
32    /// Nb: The 0 index is used a "good default".
33    fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation<Output>>;
34
35    /// Compute a checksum that can invalidate outdated cached auto-tune results.
36    #[cfg(autotune_persistent_cache)]
37    fn compute_checksum(&self) -> String {
38        compute_checksum(&self.autotunables())
39    }
40
41    /// Enable or disable certain indices from being benchmarked based on the key
42    #[allow(unused)]
43    fn should_run(&self, key: &K, index: usize) -> bool {
44        true
45    }
46}
47
48/// Contains operation to run and inputs on which to run it
49pub trait AutotuneOperation<Output: Send + 'static = ()>: Send + core::fmt::Debug {
50    /// Runs the operation
51    fn execute(self: Box<Self>) -> Result<Output, AutotuneError>;
52
53    /// The name of the operation.
54    fn name(&self) -> &str {
55        core::any::type_name::<Self>()
56    }
57
58    /// Clones the operation and inputs
59    fn clone(&self) -> Box<dyn AutotuneOperation<Output>>;
60}
61
62#[cfg(autotune_persistent_cache)]
63/// Trait alias with support for persistent caching
64pub trait AutotuneKey:
65    Clone
66    + Debug
67    + PartialEq
68    + Eq
69    + Hash
70    + Display
71    + serde::Serialize
72    + serde::de::DeserializeOwned
73    + Send
74    + Sync
75    + 'static
76{
77}
78#[cfg(not(autotune_persistent_cache))]
79/// Trait alias
80pub trait AutotuneKey:
81    Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
82{
83}
84
85impl AutotuneKey for String {}