use std::collections::BTreeMap;
use std::fmt::Write as _;
use std::fs;
use std::path::{Path, PathBuf};
use vyre_foundation::ir::Program;
const CANDIDATES: &[u32] = &[32, 64, 128, 256, 512, 1024];
const AUTOTUNER_ENV: &str = "VYRE_AUTOTUNER";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Mode {
On,
OffUseDefault,
}
impl Mode {
#[must_use]
pub fn from_env() -> Self {
match std::env::var(AUTOTUNER_ENV).ok().as_deref() {
Some("on") => Mode::On,
_ => Mode::OffUseDefault,
}
}
}
pub trait BackendTimer {
type Error;
fn measure_candidate_ns(
&mut self,
program: &Program,
workgroup_size: [u32; 3],
) -> Result<u64, Self::Error>;
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct TunerCache {
pub entries: BTreeMap<String, [u32; 3]>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StaticProgramShape {
pub workgroup_size: [u32; 3],
pub workgroup_count: Option<[u32; 3]>,
pub output_bytes: u64,
}
impl StaticProgramShape {
#[must_use]
pub fn new(program: &Program, workgroup_count: Option<[u32; 3]>, output_bytes: u64) -> Self {
Self {
workgroup_size: program.workgroup_size(),
workgroup_count,
output_bytes,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TunerProgramKey(String);
impl TunerProgramKey {
#[must_use]
pub fn from_program(program: &Program, shape: StaticProgramShape) -> Self {
let mut hasher = blake3::Hasher::new();
hasher.update(b"vyre-driver-workgroup-tuner-v1\0program\0");
hasher.update(&program.fingerprint());
hasher.update(b"\0workgroup-size\0");
for axis in shape.workgroup_size {
hasher.update(&axis.to_le_bytes());
}
hasher.update(b"\0workgroup-count\0");
match shape.workgroup_count {
Some(count) => {
hasher.update(&[1]);
for axis in count {
hasher.update(&axis.to_le_bytes());
}
}
None => {
hasher.update(&[0]);
}
}
hasher.update(b"\0output-bytes\0");
hasher.update(&shape.output_bytes.to_le_bytes());
let digest = hasher.finalize();
let mut key = String::with_capacity(67);
key.push_str("v1-");
push_hex(digest.as_bytes(), &mut key);
Self(key)
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
fn push_hex(bytes: &[u8], out: &mut String) {
const HEX: &[u8; 16] = b"0123456789abcdef";
for &byte in bytes {
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0f) as usize] as char);
}
}
impl AsRef<str> for TunerProgramKey {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl TunerCache {
#[must_use]
pub fn get(&self, program_fp: &str) -> Option<[u32; 3]> {
self.entries.get(program_fp).copied()
}
#[must_use]
pub fn get_key(&self, key: &TunerProgramKey) -> Option<[u32; 3]> {
self.get(key.as_str())
}
pub fn set(&mut self, program_fp: impl Into<String>, size: [u32; 3]) {
self.entries.insert(program_fp.into(), size);
}
pub fn set_key(&mut self, key: TunerProgramKey, size: [u32; 3]) {
self.entries.insert(key.0, size);
}
pub fn load(path: &Path) -> Result<Self, String> {
let Ok(contents) = fs::read_to_string(path) else {
return Ok(Self::default());
};
let parsed: toml::Value = toml::from_str(&contents).map_err(|error| {
format!(
"Fix: tuner cache `{}` is not valid TOML: {error}",
path.display()
)
})?;
let mut entries = BTreeMap::new();
if let Some(table) = parsed.as_table() {
for (key, value) in table {
if let Some(array) = value.as_array() {
if array.len() == 3 {
let mut triple = [0u32; 3];
for (index, value) in array.iter().enumerate() {
if let Some(number) = value.as_integer() {
if let Ok(converted) = u32::try_from(number) {
triple[index] = converted;
}
}
}
entries.insert(key.clone(), triple);
}
}
}
}
Ok(Self { entries })
}
pub fn save(&self, path: &Path) -> Result<(), String> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|error| {
format!(
"Fix: could not create tuner cache directory {}: {error}",
parent.display()
)
})?;
}
let mut out = String::with_capacity(self.entries.len().saturating_mul(96));
for (key, size) in &self.entries {
let _ = writeln!(out, "\"{}\" = [{}, {}, {}]", key, size[0], size[1], size[2]);
}
fs::write(path, &out).map_err(|error| {
format!(
"Fix: could not write tuner cache {}: {error}",
path.display()
)
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TuningMeasurement {
pub workgroup_size: [u32; 3],
pub elapsed_ns: u64,
}
pub struct Tuner {
mode: Mode,
cache: TunerCache,
cache_path: PathBuf,
}
impl Tuner {
#[must_use]
pub fn new(adapter_fp: &str, mode: Mode) -> Self {
let cache_path = Self::cache_path_for_adapter(adapter_fp);
let cache = TunerCache::load(&cache_path).unwrap_or_default();
Self {
mode,
cache,
cache_path,
}
}
#[must_use]
pub fn cache_path_for_adapter(adapter_fp: &str) -> PathBuf {
let mut home = dirs_cache_root();
home.push("vyre");
home.push("tuner");
home.push(format!("{adapter_fp}.toml"));
home
}
#[must_use]
pub fn candidates_for(&self, max_invocations: u32) -> Vec<u32> {
CANDIDATES
.iter()
.copied()
.filter(|candidate| *candidate <= max_invocations)
.collect()
}
#[must_use]
pub const fn default_workgroup_size() -> [u32; 3] {
crate::pipeline::DEFAULT_1D_WORKGROUP_SIZE
}
#[must_use]
pub const fn mode(&self) -> Mode {
self.mode
}
#[must_use]
pub fn resolve(&self, program_fp: &str) -> [u32; 3] {
self.cache
.get(program_fp)
.unwrap_or_else(Self::default_workgroup_size)
}
#[must_use]
pub fn resolve_key(&self, key: &TunerProgramKey) -> [u32; 3] {
self.resolve(key.as_str())
}
pub fn record_decision(&mut self, program_fp: impl Into<String>, size: [u32; 3]) {
self.cache.set(program_fp, size);
}
pub fn record_key_decision(&mut self, key: TunerProgramKey, size: [u32; 3]) {
self.cache.set_key(key, size);
}
pub fn best_of<T: BackendTimer>(
&self,
program: &Program,
candidates: impl IntoIterator<Item = [u32; 3]>,
timer: &mut T,
) -> Result<Option<TuningMeasurement>, T::Error> {
let mut best = None;
for workgroup_size in candidates {
let elapsed_ns = timer.measure_candidate_ns(program, workgroup_size)?;
let measurement = TuningMeasurement {
workgroup_size,
elapsed_ns,
};
if best
.map(|current: TuningMeasurement| elapsed_ns < current.elapsed_ns)
.unwrap_or(true)
{
best = Some(measurement);
}
}
Ok(best)
}
pub fn persist(&self) -> Result<(), String> {
self.cache.save(&self.cache_path)
}
}
#[derive(Debug, Clone)]
pub struct TunerFeedback {
pub per_opcode_counts: Vec<(u32, u32)>,
pub wall_time_us: u64,
pub idle_us: u64,
pub observed_workgroup_size_x: u32,
pub observed_throughput_per_us: f64,
}
#[derive(Debug, Clone)]
pub struct DefaultPolicy {
pub adapter_max_workgroup_size_x: u32,
pub minimum_workgroup_size_x: u32,
pub saturation_threshold_per_us: f64,
pub idle_shrink_us: u64,
}
impl Default for DefaultPolicy {
fn default() -> Self {
Self {
adapter_max_workgroup_size_x: 1024,
minimum_workgroup_size_x: 32,
saturation_threshold_per_us: 1.0,
idle_shrink_us: 100_000,
}
}
}
impl DefaultPolicy {
#[must_use]
pub fn suggest_resize(&self, feedback: &TunerFeedback) -> Option<u32> {
let current = feedback.observed_workgroup_size_x.max(1);
if feedback.idle_us > self.idle_shrink_us {
let shrunk = current / 2;
if shrunk >= self.minimum_workgroup_size_x && shrunk != current {
return Some(shrunk);
}
return None;
}
if feedback.observed_throughput_per_us < self.saturation_threshold_per_us {
let grown = current.saturating_mul(2);
if grown <= self.adapter_max_workgroup_size_x && grown != current {
return Some(grown);
}
}
None
}
}
fn dirs_cache_root() -> PathBuf {
if let Some(xdg) = std::env::var_os("XDG_CACHE_HOME") {
PathBuf::from(xdg)
} else if let Some(home) = std::env::var_os("HOME") {
let mut path = PathBuf::from(home);
path.push(".cache");
path
} else {
PathBuf::from(".")
}
}