cubecl_runtime/tune/
operation.rs

1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::fmt::{Debug, Display};
6use core::hash::Hash;
7
8use super::{
9    AutotuneError,
10    input_generator::{InputGenerator, IntoInputGenerator},
11    key_generator::{IntoKeyGenerator, KeyGenerator},
12};
13use super::{Tunable, TunePlan};
14
15/// Default checksum for an operation set
16#[cfg(std_io)]
17pub fn compute_checksum<In: Clone + Send + 'static, Out: 'static>(
18    autotunables: impl Iterator<Item = Arc<dyn TuneFn<Inputs = In, Output = Out>>>,
19) -> String {
20    let mut checksum = String::new();
21    autotunables.for_each(|op| {
22        checksum += op.name();
23    });
24    format!("{:x}", md5::compute(checksum))
25}
26
27/// Groups operations of the same type for autotune
28pub struct TunableSet<K: AutotuneKey, Inputs: Send + 'static, Output: 'static> {
29    tunables: Vec<Tunable<K, Inputs, Output>>,
30    key_gen: Arc<dyn KeyGenerator<K, Inputs>>,
31    input_gen: Arc<dyn InputGenerator<K, Inputs>>,
32    #[allow(clippy::type_complexity)]
33    checksum_override: Option<Arc<dyn Fn(&Self) -> String + Send + Sync>>,
34}
35
36unsafe impl<K: AutotuneKey, In: Send, Out> Send for TunableSet<K, In, Out> {}
37unsafe impl<K: AutotuneKey, In: Send, Out> Sync for TunableSet<K, In, Out> {}
38
39impl<K: AutotuneKey, Inputs: Clone + Send + 'static, Output: 'static>
40    TunableSet<K, Inputs, Output>
41{
42    /// The number of tunables in the set.
43    pub fn len(&self) -> usize {
44        self.tunables.len()
45    }
46    /// If the tunable set is empty.
47    pub fn is_empty(&self) -> bool {
48        self.tunables.len() == 0
49    }
50    /// Create a tunable set from a key generator and an input generator
51    pub fn new<KMarker, IMarker>(
52        key_gen: impl IntoKeyGenerator<K, Inputs, KMarker>,
53        input_gen: impl IntoInputGenerator<K, Inputs, IMarker>,
54    ) -> Self {
55        Self {
56            tunables: Default::default(),
57            input_gen: Arc::new(input_gen.into_input_gen()),
58            key_gen: Arc::new(key_gen.into_key_gen()),
59            checksum_override: None,
60        }
61    }
62
63    /// Register a tunable with this tunable set
64    pub fn with(mut self, tunable: Tunable<K, Inputs, Output>) -> Self {
65        self.tunables.push(tunable);
66        self
67    }
68
69    /// Override the checksum algorithm
70    pub fn with_custom_checksum(
71        mut self,
72        checksum: impl Fn(&Self) -> String + Send + Sync + 'static,
73    ) -> Self {
74        self.checksum_override = Some(Arc::new(checksum));
75        self
76    }
77
78    /// All candidate operations for autotuning this operation type
79    /// Operations can run on toy tensors of relevant size
80    pub fn autotunables(&self) -> Vec<Arc<dyn TuneFn<Inputs = Inputs, Output = Output>>> {
81        self.tunables
82            .iter()
83            .map(|tunable| tunable.function.clone())
84            .collect()
85    }
86
87    /// Returns the [autotune plan](TunePlan) for the given set.
88    pub(crate) fn plan(&self, key: &K) -> TunePlan {
89        TunePlan::new(key, &self.tunables)
90    }
91
92    /// Returns the operation for the given index, matching the order
93    /// returned by autotunables. Operation obtained here runs on original tensors
94    /// Nb: The 0 index is used a "good default".
95    pub fn fastest(
96        &self,
97        fastest_index: usize,
98    ) -> Arc<dyn TuneFn<Inputs = Inputs, Output = Output>> {
99        self.tunables[fastest_index].function.clone()
100    }
101
102    /// Compute a checksum that can invalidate outdated cached auto-tune results.
103    #[cfg(std_io)]
104    pub fn compute_checksum(&self) -> String {
105        if let Some(checksum_override) = &self.checksum_override {
106            checksum_override(self)
107        } else {
108            compute_checksum(self.tunables.iter().map(|tune| tune.function.clone()))
109        }
110    }
111
112    /// Generate a key from a set of inputs
113    pub fn generate_key(&self, inputs: &Inputs) -> K {
114        self.key_gen.generate(inputs)
115    }
116
117    /// Generate a set of test inputs from a key and reference inputs
118    pub fn inputs_generator(&self, key: &K, inputs: &Inputs) -> Box<dyn FnOnce() -> Inputs> {
119        let generate = self.input_gen.clone();
120        let key = key.clone();
121        let inputs = inputs.clone();
122
123        Box::new(move || generate.generate(&key, &inputs))
124    }
125}
126
127/// A tunable entry in a tunable set
128pub trait TuneFn: Send + Sync + 'static {
129    /// Inputs to the tunable function
130    type Inputs: Clone;
131    /// Output from the tunable function
132    type Output;
133
134    /// Run a tuneable function
135    fn execute(&self, inputs: Self::Inputs) -> Result<Self::Output, AutotuneError>;
136
137    /// The name of the tuneable function
138    fn name(&self) -> &str {
139        core::any::type_name::<Self>()
140    }
141}
142
143/// Something that can be turned into a [Tunable]
144///
145/// # Marker
146/// The marker generic is used to work around limitations in the trait resolver that causes
147/// conflicting implementation errors.
148pub trait IntoTuneFn<In, Out, Marker> {
149    /// The output tunable type
150    type Tunable: TuneFn<Inputs = In, Output = Out>;
151
152    /// Convert to a tunable
153    fn into_tunable(self) -> Self::Tunable;
154}
155
156/// Dummy marker for [`IntoTunable`] on [`Tunable`]s
157#[doc(hidden)]
158pub struct IsIdentity;
159
160impl<T: TuneFn> IntoTuneFn<T::Inputs, T::Output, IsIdentity> for T {
161    type Tunable = T;
162
163    fn into_tunable(self) -> Self::Tunable {
164        self
165    }
166}
167
168#[cfg(std_io)]
169/// Trait alias with support for persistent caching
170pub trait AutotuneKey:
171    Clone
172    + Debug
173    + PartialEq
174    + Eq
175    + Hash
176    + Display
177    + serde::Serialize
178    + serde::de::DeserializeOwned
179    + Send
180    + Sync
181    + 'static
182{
183}
184#[cfg(not(std_io))]
185/// Trait alias
186pub trait AutotuneKey:
187    Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
188{
189}
190
191impl AutotuneKey for String {}