cubecl_runtime/tune/input_generator.rs
1use super::TuneInputs;
2
3/// Produces the benchmark inputs for a given key and reference inputs.
4///
5/// A tuner runs each candidate on the value `generate` returns, not on the real call
6/// inputs directly, so callers can synthesize test inputs (for example, allocate fresh
7/// output buffers) without mutating the real ones. The common case is just cloning the
8/// reference inputs; use [`CloneInputGenerator`] for that.
9///
10/// There's no `Fn`-based blanket impl for arbitrary [`TuneInputs`] families because
11/// `for<'a> Fn(&K, &I::At<'a>) -> I::At<'a>` runs into E0582 (`Fn`'s `Output` cannot
12/// depend on the higher-ranked lifetime). For the owned-input case the HRTB collapses
13/// and the `Fn` blanket below works; borrowed-input families implement this trait
14/// directly.
15#[diagnostic::on_unimplemented(
16 message = "`{Self}` is not a valid input generator",
17 label = "invalid input generator"
18)]
19pub trait InputGenerator<K, I: TuneInputs>: Send + Sync + 'static {
20 /// Generate a set of inputs for a given key and reference inputs.
21 fn generate<'a>(&self, key: &K, inputs: &I::At<'a>) -> I::At<'a>;
22}
23
24/// [`InputGenerator`] that clones the reference inputs verbatim.
25#[derive(Copy, Clone, Debug, Default)]
26pub struct CloneInputGenerator;
27
28impl<K, I: TuneInputs> InputGenerator<K, I> for CloneInputGenerator {
29 fn generate<'a>(&self, _key: &K, inputs: &I::At<'a>) -> I::At<'a> {
30 inputs.clone()
31 }
32}
33
34/// `Fn(&K, &A) -> A` acts as an [`InputGenerator`] when `A` is an owned type. For
35/// multi-input kernels, `A` is a tuple that the closure destructures internally.
36impl<K, Func, A> InputGenerator<K, A> for Func
37where
38 A: Clone + Send + Sync + 'static,
39 K: 'static,
40 Func: Send + Sync + 'static + Fn(&K, &A) -> A,
41{
42 #[inline]
43 fn generate<'a>(
44 &self,
45 key: &K,
46 inputs: &<A as TuneInputs>::At<'a>,
47 ) -> <A as TuneInputs>::At<'a> {
48 (self)(key, inputs)
49 }
50}