use alloc::format;
use alloc::sync::Arc;
use alloc::vec::Vec;
use cubecl_common::profile::{ProfileDuration, TimingMethod};
use crate::channel::ComputeChannel;
use crate::client::ComputeClient;
use crate::server::ComputeServer;
use super::{AutotuneError, TuneFn};
#[derive(new)]
pub struct TuneBenchmark<S: ComputeServer, C, In: Clone + Send + 'static, Out: Send + 'static> {
operation: Arc<dyn TuneFn<Inputs = In, Output = Out>>,
inputs: In,
client: ComputeClient<S, C>,
}
pub trait AutotuneOutput: Send + 'static {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, other: Self);
}
impl AutotuneOutput for () {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, _other: Self) {
}
}
impl<
S: ComputeServer + 'static,
C: ComputeChannel<S> + 'static,
In: Clone + Send + 'static,
Out: AutotuneOutput,
> TuneBenchmark<S, C, In, Out>
{
pub fn profile(self) -> Result<Vec<ProfileDuration>, AutotuneError> {
match self.client.properties().timing_method {
TimingMethod::System => self.warmup_full_error_handling(),
TimingMethod::Device => self.warmup_minimal_error_handling(),
}?;
let operation = self.operation;
let num_samples = 10;
let durations: Vec<_> = (0..num_samples)
.filter_map(|_| {
let result: Result<ProfileDuration, crate::server::ProfileError> =
self.client.profile(
|| {
operation
.execute(self.inputs.clone())
.expect("Should not fail when previously tried during the warmup.")
},
operation.name(),
);
match result {
Ok(val) => Some(val),
Err(err) => {
log::warn!("Error while autotuning {err:?}");
None
}
}
})
.collect();
if durations.is_empty() {
Err(AutotuneError::InvalidSamples)
} else {
Ok(durations)
}
}
fn warmup_full_error_handling(&self) -> Result<(), AutotuneError> {
let mut error = None;
let result = self.client.profile(
|| {
if let Err(err) = self.operation.execute(self.inputs.clone()) {
error = Some(err);
}
},
self.operation.name(),
);
if let Err(err) = result {
return Err(AutotuneError::Unknown(format!("{err:?}")));
};
if let Some(err) = error {
return Err(err);
};
Ok(())
}
fn warmup_minimal_error_handling(&self) -> Result<(), AutotuneError> {
self.operation.execute(self.inputs.clone())?;
Ok(())
}
}