Skip to main content

cubecl_runtime/tune/
operation.rs

1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::fmt::{Debug, Display};
6use core::hash::Hash;
7
8use alloc::format;
9
10use super::{
11    AutotuneError, input_generator::InputGenerator, key_generator::KeyGenerator,
12    tune_inputs::TuneInputs,
13};
14use super::{Tunable, TunePlan};
15
16/// A type-erased delegate for a tunable function.
17///
18/// The lifetime `'inp` is the lifetime of the input data, the function must be defined such that
19/// it can be called for any lifetime `inp` and produce a `Result<Out, AutotuneError>`.
20type TuneDelegate<I, Out> =
21    dyn for<'inp> Fn(<I as TuneInputs>::At<'inp>) -> Result<Out, AutotuneError> + Send + Sync;
22
23/// A named, type-erased tunable function stored in a [`TunableSet`]. Constructed via
24/// [`Tunable::new`](super::Tunable::new); callers don't name this type directly.
25#[derive(new)]
26pub struct TuneFn<I: TuneInputs, Out> {
27    pub(crate) name: String,
28    func: Box<TuneDelegate<I, Out>>,
29}
30
31impl<I: TuneInputs, Out: 'static> TuneFn<I, Out> {
32    /// Run the wrapped function on the given inputs.
33    pub fn execute<'a>(&self, inputs: <I as TuneInputs>::At<'a>) -> Result<Out, AutotuneError> {
34        (self.func)(inputs)
35    }
36}
37
38/// A set of candidate tunable functions for autotune, sharing a key generator and an
39/// input generator. See [`TuneInputs`] for the `F` parameter.
40pub struct TunableSet<K: AutotuneKey, F: TuneInputs, Output: 'static> {
41    tunables: Vec<Tunable<K, F, Output>>,
42    key_gen: Arc<dyn KeyGenerator<K, F> + Send + Sync>,
43    input_gen: Arc<dyn InputGenerator<K, F> + Send + Sync>,
44}
45
46impl<K: AutotuneKey, F: TuneInputs, Output: 'static> TunableSet<K, F, Output> {
47    /// The number of tunables in the set.
48    pub fn len(&self) -> usize {
49        self.tunables.len()
50    }
51
52    /// Whether this set contains no tunables.
53    pub fn is_empty(&self) -> bool {
54        self.tunables.is_empty()
55    }
56
57    /// Create a tunable set from a key generator and an input generator.
58    pub fn new(key_gen: impl KeyGenerator<K, F>, input_gen: impl InputGenerator<K, F>) -> Self {
59        Self {
60            tunables: Default::default(),
61            input_gen: Arc::new(input_gen),
62            key_gen: Arc::new(key_gen),
63        }
64    }
65
66    /// Shorthand for [`new`](Self::new) with a [`CloneInputGenerator`]: benchmarks run
67    /// on clones of the real call inputs.
68    pub fn new_cloning_inputs(key_gen: impl KeyGenerator<K, F>) -> Self {
69        Self::new(key_gen, super::CloneInputGenerator)
70    }
71
72    /// Register a tunable with this tunable set.
73    pub fn with(mut self, tunable: Tunable<K, F, Output>) -> Self {
74        self.tunables.push(tunable);
75        self
76    }
77
78    /// All candidate operations in this set, in registration order.
79    pub fn autotunables(&self) -> impl Iterator<Item = &TuneFn<F, Output>> {
80        self.tunables.iter().map(|tunable| &tunable.function)
81    }
82
83    /// Returns the [autotune plan](TunePlan) for the given set.
84    pub(crate) fn plan(&self, key: &K) -> TunePlan {
85        TunePlan::new(key, &self.tunables)
86    }
87
88    /// Returns the operation for the given index, matching the order returned by
89    /// `autotunables`. Tunables are tried in order, so index 0 should be a good default.
90    pub fn fastest(&self, fastest_index: usize) -> &TuneFn<F, Output> {
91        &self.tunables[fastest_index].function
92    }
93
94    /// Compute a checksum that invalidates outdated cached auto-tune results when the
95    /// set of tunable names changes.
96    pub fn compute_checksum(&self) -> String {
97        let mut checksum = String::new();
98        for tune in &self.tunables {
99            checksum += &tune.function.name;
100        }
101        format!("{:x}", md5::compute(checksum))
102    }
103
104    /// Generate a key from a set of inputs
105    pub fn generate_key<'a>(&self, inputs: &F::At<'a>) -> K {
106        self.key_gen.generate(inputs)
107    }
108
109    /// Generate a set of test inputs from a key and reference inputs.
110    pub fn generate_inputs<'a>(&self, key: &K, inputs: &F::At<'a>) -> F::At<'a> {
111        self.input_gen.generate(key, inputs)
112    }
113}
114
115#[cfg(std_io)]
116/// Trait alias with support for persistent caching
117pub trait AutotuneKey:
118    Clone
119    + Debug
120    + PartialEq
121    + Eq
122    + Hash
123    + Display
124    + serde::Serialize
125    + serde::de::DeserializeOwned
126    + Send
127    + Sync
128    + 'static
129{
130}
131#[cfg(not(std_io))]
132/// Trait alias
133pub trait AutotuneKey:
134    Clone + Debug + PartialEq + Eq + Hash + Display + Send + Sync + 'static
135{
136}
137
138impl AutotuneKey for String {}