use alloc::sync::Arc;
use alloc::vec::Vec;
use crate::channel::ComputeChannel;
use crate::client::ComputeClient;
use crate::server::ComputeServer;
use cubecl_common::benchmark::ProfileDuration;
use super::{AutotuneError, Tunable};
#[derive(new)]
pub struct TuneBenchmark<S: ComputeServer, C, In: Clone + Send + 'static, Out: Send + 'static> {
operation: Arc<dyn Tunable<Inputs = In, Output = Out>>,
inputs: In,
client: ComputeClient<S, C>,
}
pub trait AutotuneOutput: Send + 'static {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, other: Self);
}
impl AutotuneOutput for () {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, _other: Self) {
}
}
impl<
S: ComputeServer + 'static,
C: ComputeChannel<S> + 'static,
In: Clone + Send + 'static,
Out: AutotuneOutput,
> TuneBenchmark<S, C, In, Out>
{
#[cfg(feature = "autotune-checks")]
pub(crate) fn output_for_checks(&self) -> Result<Out, AutotuneError> {
self.operation.clone().execute(self.inputs.clone())
}
pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
let operation = self.operation;
operation.execute(self.inputs.clone())?;
let num_samples = 10;
let durations = (0..num_samples)
.map(|_| {
self.client.profile(|| {
operation
.execute(self.inputs.clone())
.expect("Should not fail when previously tried during the warmup.");
})
})
.collect();
Ok(durations)
}
}