burn_compute/tune/
operation.rs

1use alloc::boxed::Box;
2use alloc::string::String;
3use alloc::vec::Vec;
4use core::fmt::{Debug, Display};
5use core::hash::Hash;
6
7/// Default checksum for an operation set
8#[cfg(feature = "autotune-persistent-cache")]
9pub fn compute_checksum(autotunables: &[Box<dyn AutotuneOperation>]) -> String {
10    let mut checksum = String::new();
11    autotunables.iter().for_each(|op| {
12        checksum += op.name();
13    });
14    format!("{:x}", md5::compute(checksum))
15}
16
17/// Groups operations of the same type for autotune
18pub trait AutotuneOperationSet<K>: Send {
19    /// The key used in the tune cache
20    fn key(&self) -> K;
21
22    /// All candidate operations for autotuning this operation type
23    /// Operations can run on toy tensors of relevant size
24    fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>>;
25
26    /// Returns the operation for the given index, matching the order
27    /// returned by autotunables. Operation obtained here runs on original tensors
28    fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation>;
29
30    /// Compute a checksum that can invalidate outdated cached auto-tune results.
31    #[cfg(feature = "autotune-persistent-cache")]
32    fn compute_checksum(&self) -> String {
33        compute_checksum(&self.autotunables())
34    }
35}
36
37/// Contains operation to run and inputs on which to run it
38pub trait AutotuneOperation {
39    /// Runs the operation
40    fn execute(self: Box<Self>);
41
42    /// The name of the operation.
43    fn name(&self) -> &str {
44        core::any::type_name::<Self>()
45    }
46
47    /// Clones the operation and inputs
48    fn clone(&self) -> Box<dyn AutotuneOperation>;
49}
50
51#[cfg(feature = "autotune-persistent-cache")]
52/// Trait alias with support for persistent caching
53pub trait AutotuneKey:
54    Clone
55    + Debug
56    + PartialEq
57    + Eq
58    + Hash
59    + Display
60    + serde::Serialize
61    + serde::de::DeserializeOwned
62    + Send
63    + Sync
64{
65}
66#[cfg(not(feature = "autotune-persistent-cache"))]
67/// Trait alias
68pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}
69impl AutotuneKey for String {}