Skip to main content

cubecl_runtime/tune/
input_generator.rs

1use super::TuneInputs;
2
3/// Produces the benchmark inputs for a given key and reference inputs.
4///
5/// A tuner runs each candidate on the value `generate` returns, not on the real call
6/// inputs directly, so callers can synthesize test inputs (for example, allocate fresh
7/// output buffers) without mutating the real ones. The common case is just cloning the
8/// reference inputs; use [`CloneInputGenerator`] for that.
9///
10/// There's no `Fn`-based blanket impl for arbitrary [`TuneInputs`] families because
11/// `for<'a> Fn(&K, &I::At<'a>) -> I::At<'a>` runs into E0582 (`Fn`'s `Output` cannot
12/// depend on the higher-ranked lifetime). For the owned-input case the HRTB collapses
13/// and the `Fn` blanket below works; borrowed-input families implement this trait
14/// directly.
15#[diagnostic::on_unimplemented(
16    message = "`{Self}` is not a valid input generator",
17    label = "invalid input generator"
18)]
19pub trait InputGenerator<K, I: TuneInputs>: Send + Sync + 'static {
20    /// Generate a set of inputs for a given key and reference inputs.
21    fn generate<'a>(&self, key: &K, inputs: &I::At<'a>) -> I::At<'a>;
22}
23
24/// [`InputGenerator`] that clones the reference inputs verbatim.
25#[derive(Copy, Clone, Debug, Default)]
26pub struct CloneInputGenerator;
27
28impl<K, I: TuneInputs> InputGenerator<K, I> for CloneInputGenerator {
29    fn generate<'a>(&self, _key: &K, inputs: &I::At<'a>) -> I::At<'a> {
30        inputs.clone()
31    }
32}
33
34/// `Fn(&K, &A) -> A` acts as an [`InputGenerator`] when `A` is an owned type. For
35/// multi-input kernels, `A` is a tuple that the closure destructures internally.
36impl<K, Func, A> InputGenerator<K, A> for Func
37where
38    A: Clone + Send + Sync + 'static,
39    K: 'static,
40    Func: Send + Sync + 'static + Fn(&K, &A) -> A,
41{
42    #[inline]
43    fn generate<'a>(
44        &self,
45        key: &K,
46        inputs: &<A as TuneInputs>::At<'a>,
47    ) -> <A as TuneInputs>::At<'a> {
48        (self)(key, inputs)
49    }
50}