1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
use cubecl_common::benchmark::Benchmark;
use cubecl_common::sync_type::SyncType;

use crate::channel::ComputeChannel;
use crate::client::ComputeClient;
use crate::server::ComputeServer;

use super::AutotuneOperation;
use alloc::boxed::Box;
use alloc::string::{String, ToString};

/// A benchmark that runs on server handles
#[derive(new)]
pub struct TuneBenchmark<S: ComputeServer, C, Out = ()> {
    operation: Box<dyn AutotuneOperation<Out>>,
    client: ComputeClient<S, C>,
}

impl<Out> Clone for Box<dyn AutotuneOperation<Out>> {
    fn clone(&self) -> Self {
        self.as_ref().clone()
    }
}

impl<S: ComputeServer, C: ComputeChannel<S>, Out> Benchmark for TuneBenchmark<S, C, Out> {
    type Args = Box<dyn AutotuneOperation<Out>>;

    fn prepare(&self) -> Self::Args {
        self.operation.clone()
    }

    fn num_samples(&self) -> usize {
        10
    }

    fn execute(&self, operation: Self::Args) {
        AutotuneOperation::execute(operation);
    }

    fn name(&self) -> String {
        "autotune".to_string()
    }

    fn sync(&self) {
        // For benchmarks - we need to wait for all tasks to complete before returning.
        self.client.sync(SyncType::Wait);
    }
}