use core::{fmt::Display, hash::Hash};
use hashbrown::HashMap;
use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
use super::{AutotuneKey, AutotuneOperationSet, Tuner};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::ToString};
pub struct LocalTuner<AK: AutotuneKey, ID> {
state: spin::RwLock<Option<HashMap<ID, Tuner<AK>>>>,
name: &'static str,
}
#[macro_export]
macro_rules! local_tuner {
($name:expr) => {
LocalTuner::new(concat!(module_path!(), "-", $name));
};
() => {
LocalTuner::new(module_path!());
};
}
pub use local_tuner;
impl<AK: AutotuneKey, ID: Hash + PartialEq + Eq + Clone + Display> LocalTuner<AK, ID> {
pub const fn new(name: &'static str) -> Self {
Self {
state: spin::RwLock::new(None),
name,
}
}
pub fn clear(&self) {
let mut state = self.state.write();
*state = None;
}
pub fn execute<S, C, Out>(
&self,
id: &ID,
client: &ComputeClient<S, C>,
autotune_operation_set: Box<dyn AutotuneOperationSet<AK, Out>>,
) -> Out
where
S: ComputeServer,
C: ComputeChannel<S>,
{
if let Some(state) = self.state.read().as_ref() {
if let Some(tuner) = state.get(id) {
let key = autotune_operation_set.key();
if let Some(index) = tuner.autotune_fastest(&key) {
let op = autotune_operation_set.fastest(index);
return op.execute();
}
}
}
let mut state = self.state.write();
let map = state.get_or_insert_with(Default::default);
let tuner = if let Some(tuner) = map.get_mut(id) {
tuner
} else {
let name = self.name.replace("::", "-");
let tuner = Tuner::new(&name, &id.to_string());
map.insert(id.clone(), tuner);
map.get_mut(id).unwrap()
};
tuner.execute_autotune(autotune_operation_set, client)
}
pub fn autotune_result(&self, id: &ID, key: &AK) -> Option<usize> {
if let Some(state) = self.state.read().as_ref() {
if let Some(tuner) = state.get(id) {
return tuner.autotune_fastest(key);
}
}
None
}
}