Skip to main content

cubecl_runtime/tune/
operation.rs

1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::fmt::{Debug, Display};
6use core::hash::Hash;
7
8#[cfg(std_io)]
9use alloc::format;
10
11use super::{
12    AutotuneError,
13    input_generator::{InputGenerator, IntoInputGenerator},
14    key_generator::{IntoKeyGenerator, KeyGenerator},
15};
16use super::{Tunable, TunePlan};
17
18/// Default checksum for an operation set
19#[cfg(std_io)]
20pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
21    autotunables: impl Iterator<Item = Arc<dyn TuneFn<Inputs = In, Output = Out>>>,
22) -> String {
23    let mut checksum = String::new();
24    autotunables.for_each(|op| {
25        checksum += op.name();
26    });
27    format!("{:x}", md5::compute(checksum))
28}
29
30/// Groups operations of the same type for autotune
31pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
32    tunables: Vec<Tunable<K, Inputs, Output>>,
33    key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
34    input_gen: Arc<dyn InputGenerator<K, Inputs>>,
35    #[allow(clippy::type_complexity)]
36    checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
37}
38
39unsafe impl<K: AutotuneKey, In: Send, Out> Send for TunableSet<K, In, Out> {}
40unsafe impl<K: AutotuneKey, In: Send, Out> Sync for TunableSet<K, In, Out> {}
41
42impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
43    TunableSet<K, Inputs, Output>
44{
45    /// The number of tunables in the set.
46    pub fn len(&self) -> usize {
47        self.tunables.len()
48    }
49    /// If the tunable set is empty.
50    pub fn is_empty(&self) -> bool {
51        self.tunables.len() == 0
52    }
53    /// Create a tunable set from a key generator and an input generator
54    pub fn new<KMarker, IMarker>(
55        key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
56        input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
57    ) -> Self {
58        Self {
59            tunables: Default::default(),
60            input_gen: Arc::new(input_gen.into_input_gen()),
61            key_gen: Arc::new(key_gen.into_key_gen()),
62            checksum_override: None,
63        }
64    }
65
66    /// Register a tunable with this tunable set
67    pub fn with(mut self, tunable: Tunable<K, Inputs, Output>) -> Self {
68        self.tunables.push(tunable);
69        self
70    }
71
72    /// Override the checksum algorithm
73    pub fn with_custom_checksum(
74        mut self,
75        checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
76    ) -> Self {
77        self.checksum_override = Some(Arc::new(checksum));
78        self
79    }
80
81    /// All candidate operations for autotuning this operation type
82    /// Operations can run on toy tensors of relevant size
83    ///
84    /// TODO: Add name.
85    pub fn autotunables(&self) -> Vec<Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>> {
86        self.tunables
87            .iter()
88            .map(|tunable| tunable.function.clone())
89            .collect()
90    }
91
92    /// Returns the [autotune plan](TunePlan) for the given set.
93    pub(crate) fn plan(&self, key: &K) -> TunePlan {
94        TunePlan::new(key, &self.tunables)
95    }
96
97    /// Returns the operation for the given index, matching the order
98    /// returned by autotunables. Operation obtained here runs on original tensors
99    /// Nb: The 0 index is used a "good default".
100    pub fn fastest(
101        &self,
102        fastest_index: usize,
103    ) -> Arc<dyn TuneFn<Inputs = Inputs, Output = Output>> {
104        self.tunables[fastest_index].function.clone()
105    }
106
107    /// Compute a checksum that can invalidate outdated cached auto-tune results.
108    #[cfg(std_io)]
109    pub fn compute_checksum(&self) -> String {
110        if let Some(checksum_override) = &self.checksum_override {
111            checksum_override(self)
112        } else {
113            compute_checksum(self.tunables.iter().map(|tune| tune.function.clone()))
114        }
115    }
116
117    /// Generate a key from a set of inputs
118    pub fn generate_key(&self, inputs: &Inputs) -> K {
119        self.key_gen.generate(inputs)
120    }
121
122    /// Generate a set of test inputs from a key and reference inputs
123    pub fn inputs_generator(&self, key: &K, inputs: &Inputs) -> Box<dyn FnOnce() -> Inputs> {
124        let generate = self.input_gen.clone();
125        let key = key.clone();
126        let inputs = inputs.clone();
127
128        Box::new(move || generate.generate(&key, &inputs))
129    }
130}
131
132/// A tunable entry in a tunable set
133pub trait TuneFn: Send + Sync + 'static {
134    /// Inputs to the tunable function
135    type Inputs: Clone;
136    /// Output from the tunable function
137    type Output;
138
139    /// Run a tuneable function
140    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
141
142    /// The name of the tuneable function
143    fn name(&self) -> &str;
144}
145
146/// Something that can be turned into a [Tunable]
147///
148/// # Marker
149/// The marker generic is used to work around limitations in the trait resolver that causes
150/// conflicting implementation errors.
151pub trait IntoTuneFn<In, Out, Marker> {
152    /// The output tunable type
153    type Tunable: TuneFn<Inputs = In, Output = Out>;
154
155    /// Convert to a tunable
156    fn into_tunable(self, name: String) -> Self::Tunable;
157}
158
159/// Dummy marker for [`IntoTunable`] on [`Tunable`]s
160#[doc(hidden)]
161pub struct IsIdentity;
162
163impl<T: TuneFn> IntoTuneFn<T::Inputs, T::Output, IsIdentity> for T {
164    type Tunable = T;
165
166    fn into_tunable(self, _name: String) -> Self::Tunable {
167        self
168    }
169}
170
171#[cfg(std_io)]
172/// Trait alias with support for persistent caching
173pub trait AutotuneKey:
174    Clone
175    + Debug
176    + PartialEq
177    + Eq
178    + Hash
179    + Display
180    + serde::Serialize
181    + serde::de::DeserializeOwned
182    + Send
183    + Sync
184    + 'static
185{
186}
187#[cfg(not(std_io))]
188/// Trait alias
189pub trait AutotuneKey:
190    Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
191{
192}
193
194impl AutotuneKey for String {}