cubecl_runtime/tune/
operation.rs

1use alloc::string::String;
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use core::fmt::{Debug, Display};
5use core::hash::Hash;
6
7use super::{
8    AutotuneError,
9    input_generator::{InputGenerator, IntoInputGenerator},
10    key_generator::{IntoKeyGenerator, KeyGenerator},
11};
12
13/// Default checksum for an operation set
14#[cfg(autotune_persistent_cache)]
15pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
16    autotunables: &[Arc<dyn Tunable<Inputs = In, Output = Out>>],
17) -> String {
18    let mut checksum = String::new();
19    autotunables.iter().for_each(|op| {
20        checksum += op.name();
21    });
22    format!("{:x}", md5::compute(checksum))
23}
24
25/// Groups operations of the same type for autotune
26#[derive(Clone)]
27pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
28    tunables: Vec<Arc<dyn Tunable<Inputs = Inputs, Output = Output>>>,
29    key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
30    input_gen: Arc<dyn InputGenerator<K, Inputs>>,
31    #[allow(clippy::type_complexity)]
32    checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
33}
34
35impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
36    TunableSet<K, Inputs, Output>
37{
38    /// The number of tunables in the set.
39    pub fn len(&self) -> usize {
40        self.tunables.len()
41    }
42    /// If the tunable set is empty.
43    pub fn is_empty(&self) -> bool {
44        self.tunables.len() == 0
45    }
46    /// Create a tunable set from a key generator and an input generator
47    pub fn new<KMarker, IMarker>(
48        key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
49        input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
50    ) -> Self {
51        Self {
52            tunables: Default::default(),
53            input_gen: Arc::new(input_gen.into_input_gen()),
54            key_gen: Arc::new(key_gen.into_key_gen()),
55            checksum_override: None,
56        }
57    }
58
59    /// Register a tunable with this tunable set
60    pub fn with_tunable<Marker>(
61        mut self,
62        tunable: impl IntoTunable<Inputs, Output, Marker>,
63    ) -> Self {
64        self.tunables.push(Arc::new(tunable.into_tunable()));
65        self
66    }
67
68    /// Override the checksum algorithm
69    pub fn with_custom_checksum(
70        mut self,
71        checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
72    ) -> Self {
73        self.checksum_override = Some(Arc::new(checksum));
74        self
75    }
76
77    /// All candidate operations for autotuning this operation type
78    /// Operations can run on toy tensors of relevant size
79    pub fn autotunables(&self) -> Vec<Arc<dyn Tunable<Inputs = Inputs, Output = Output>>> {
80        self.tunables.clone()
81    }
82
83    /// Returns the operation for the given index, matching the order
84    /// returned by autotunables. Operation obtained here runs on original tensors
85    /// Nb: The 0 index is used a "good default".
86    pub fn fastest(
87        &self,
88        fastest_index: usize,
89    ) -> Arc<dyn Tunable<Inputs = Inputs, Output = Output>> {
90        self.tunables[fastest_index].clone()
91    }
92
93    /// Compute a checksum that can invalidate outdated cached auto-tune results.
94    #[cfg(autotune_persistent_cache)]
95    pub fn compute_checksum(&self) -> String {
96        if let Some(checksum_override) = &self.checksum_override {
97            checksum_override(self)
98        } else {
99            compute_checksum(&self.tunables)
100        }
101    }
102
103    /// Generate a key from a set of inputs
104    pub fn generate_key(&self, inputs: &Inputs) -> K {
105        self.key_gen.generate(inputs)
106    }
107
108    /// Generate a set of test inputs from a key and reference inputs
109    pub fn generate_inputs(&self, key: &K, inputs: &Inputs) -> Inputs {
110        self.input_gen.generate(key, inputs)
111    }
112}
113
114/// A tunable entry in a tunable set
115pub trait Tunable: Send + Sync + 'static {
116    /// Inputs to the tunable function
117    type Inputs: Clone;
118    /// Output from the tunable function
119    type Output;
120
121    /// Run a tuneable function
122    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
123
124    /// The name of the tuneable function
125    fn name(&self) -> &str {
126        core::any::type_name::<Self>()
127    }
128}
129
130/// Something that can be turned into a [Tunable]
131///
132/// # Marker
133/// The marker generic is used to work around limitations in the trait resolver that causes
134/// conflicting implementation errors.
135pub trait IntoTunable<In, Out, Marker> {
136    /// The output tunable type
137    type Tunable: Tunable<Inputs = In, Output = Out>;
138
139    /// Convert to a tunable
140    fn into_tunable(self) -> Self::Tunable;
141}
142
143/// Dummy marker for [`IntoTunable`] on [`Tunable`]s
144#[doc(hidden)]
145pub struct IsIdentity;
146
147impl<T: Tunable> IntoTunable<T::Inputs, T::Output, IsIdentity> for T {
148    type Tunable = T;
149
150    fn into_tunable(self) -> Self::Tunable {
151        self
152    }
153}
154
155#[cfg(autotune_persistent_cache)]
156/// Trait alias with support for persistent caching
157pub trait AutotuneKey:
158    Clone
159    + Debug
160    + PartialEq
161    + Eq
162    + Hash
163    + Display
164    + serde::Serialize
165    + serde::de::DeserializeOwned
166    + Send
167    + Sync
168    + 'static
169{
170}
171#[cfg(not(autotune_persistent_cache))]
172/// Trait alias
173pub trait AutotuneKey:
174    Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
175{
176}
177
178impl AutotuneKey for String {}