#![forbid(unsafe_code)]
#![warn(missing_docs)]
#![allow(
// Auto-generated op wrappers replay derive attributes by design.
clippy::duplicated_attributes,
// GPU buffer layout types (bind-group slot tuples) are inherently complex.
clippy::type_complexity,
// Shader-side math and wire-format POD structs do intentional integer
// casts; the conform gate verifies byte-identity with the CPU reference.
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
// Explicit clones on Copy improve readability in serial layers where
// semantic ownership matters more than cycle count.
clippy::clone_on_copy,
// Three-branch comparisons are natural in range-check oracles.
clippy::comparison_chain,
// Vyre uses explicit invariant violations (expect/unwrap) with `Fix:`
// prose — not graceful degradation — per the engineering standard.
clippy::expect_used,
// Generic collections take external hashers by design.
clippy::implicit_hasher,
// SHA/hash compressors use the canonical single-letter state vars
// (a,b,c,d,e,f,g,h per FIPS 180-4).
clippy::many_single_char_names,
// Error prose is centralized in the `Error` enum; per-fn `# Errors`
// sections duplicate that contract.
clippy::missing_errors_doc,
// Panics document invariant violations with `Fix:` prose inline.
clippy::missing_panics_doc,
// Template-generated ops don't always merit `#[must_use]`.
clippy::must_use_candidate,
clippy::needless_pass_by_value,
clippy::needless_range_loop,
clippy::needless_raw_string_hashes,
clippy::module_name_repetitions,
clippy::module_inception,
clippy::similar_names,
clippy::should_implement_trait,
clippy::match_same_arms,
clippy::format_push_string,
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::trivially_copy_pass_by_ref,
clippy::unnecessary_wraps,
clippy::unnested_or_patterns,
clippy::unreadable_literal,
clippy::doc_markdown
)]
#![cfg_attr(not(test), deny(clippy::todo, clippy::unimplemented))]
pub use vyre_foundation::ir;
pub mod lower {
pub use vyre_lower::lower::lower;
pub use vyre_lower::*;
}
pub use vyre_foundation::optimizer;
pub use vyre_foundation::cpu_op;
pub use vyre_foundation::cpu_references;
pub use vyre_foundation::memory_model;
pub use vyre_foundation::MemoryOrdering;
pub use vyre_driver::routing;
pub use vyre_foundation::execution_plan;
pub use vyre_driver::error;
pub use vyre_driver::diagnostics;
pub use vyre_driver::backend;
pub use vyre_foundation::match_result;
pub use vyre_driver::pipeline;
pub use vyre_driver::{
BackendError, BackendRegistration, CompiledPipeline, DispatchConfig, Error, Executable, Memory,
MemoryRef, OutputBuffers, TypedDispatchExt, VyreBackend,
};
pub use vyre_driver::persistent::PersistentThreadMode;
pub use vyre_driver::speculate::SpeculationMode;
pub use ir::{validate, InterpCtx, NodeId, NodeStorage, OpId, Program, Value};
pub use vyre_foundation::match_result::Match;
pub use vyre_foundation::ByteRange;
#[must_use]
pub fn optimize(program: Program) -> Program {
let key = program.fingerprint();
if let Some(cached) = optimize_cache::get(&key) {
return cached;
}
let optimized = vyre_foundation::optimizer::pre_lowering::optimize(program);
optimize_cache::put(key, &optimized);
optimized
}
#[must_use]
pub fn optimize_for_device(program: Program, profile: &vyre_driver::DeviceProfile) -> Program {
let key = device_optimize_key(&program, profile);
if let Some(cached) = optimize_cache::get_device(&key) {
return cached;
}
let tuned =
vyre_foundation::optimizer::passes::autotune::Autotune::transform_for_adapter(
program,
&profile.adapter_caps(),
)
.program;
let optimized = optimize(tuned);
optimize_cache::put_device(key, &optimized);
optimized
}
#[must_use]
pub fn optimize_for_backend(program: Program, backend: &dyn vyre_driver::VyreBackend) -> Program {
let profile = backend.device_profile();
optimize_for_device(program, &profile)
}
fn device_optimize_key(program: &Program, profile: &vyre_driver::DeviceProfile) -> [u8; 32] {
let mut hasher = blake3::Hasher::new();
hasher.update(b"vyre-core-optimize-device-v1\0");
hasher.update(&program.fingerprint());
hasher.update(profile.backend.as_bytes());
hasher.update(&[u8::from(profile.supports_subgroup_ops)]);
hasher.update(&[u8::from(profile.supports_indirect_dispatch)]);
hasher.update(&[u8::from(profile.supports_f16)]);
hasher.update(&[u8::from(profile.supports_bf16)]);
hasher.update(&[u8::from(profile.supports_tensor_cores)]);
hasher.update(&profile.max_workgroup_size[0].to_le_bytes());
hasher.update(&profile.max_workgroup_size[1].to_le_bytes());
hasher.update(&profile.max_workgroup_size[2].to_le_bytes());
hasher.update(&profile.max_invocations_per_workgroup.to_le_bytes());
hasher.update(&profile.max_shared_memory_bytes.to_le_bytes());
hasher.update(&profile.subgroup_size.to_le_bytes());
hasher.update(&profile.compute_units.to_le_bytes());
hasher.update(&profile.ideal_unroll_depth.to_le_bytes());
hasher.update(&profile.ideal_vector_pack_bits.to_le_bytes());
*hasher.finalize().as_bytes()
}
pub const OPTIMIZE_CACHE_CAPACITY: usize = 256;
mod optimize_cache {
use super::Program;
use super::OPTIMIZE_CACHE_CAPACITY;
use std::collections::{HashMap, VecDeque};
use std::sync::Mutex;
struct Cache {
entries: HashMap<[u8; 32], Program>,
fifo: VecDeque<[u8; 32]>,
device_entries: HashMap<[u8; 32], Program>,
device_fifo: VecDeque<[u8; 32]>,
}
impl Cache {
fn new() -> Self {
Self {
entries: HashMap::with_capacity(OPTIMIZE_CACHE_CAPACITY),
fifo: VecDeque::with_capacity(OPTIMIZE_CACHE_CAPACITY),
device_entries: HashMap::with_capacity(OPTIMIZE_CACHE_CAPACITY),
device_fifo: VecDeque::with_capacity(OPTIMIZE_CACHE_CAPACITY),
}
}
}
fn cache() -> &'static Mutex<Cache> {
use std::sync::OnceLock;
static CACHE: OnceLock<Mutex<Cache>> = OnceLock::new();
CACHE.get_or_init(|| Mutex::new(Cache::new()))
}
pub(super) fn get(key: &[u8; 32]) -> Option<Program> {
let cache = cache().lock().ok()?;
cache.entries.get(key).cloned()
}
pub(super) fn put(key: [u8; 32], program: &Program) {
let Ok(mut cache) = cache().lock() else {
return;
};
if cache.entries.contains_key(&key) {
return;
}
if cache.entries.len() >= OPTIMIZE_CACHE_CAPACITY {
if let Some(evicted) = cache.fifo.pop_front() {
cache.entries.remove(&evicted);
}
}
cache.fifo.push_back(key);
cache.entries.insert(key, program.clone());
}
pub(super) fn get_device(key: &[u8; 32]) -> Option<Program> {
let cache = cache().lock().ok()?;
cache.device_entries.get(key).cloned()
}
pub(super) fn put_device(key: [u8; 32], program: &Program) {
let Ok(mut cache) = cache().lock() else {
return;
};
if cache.device_entries.contains_key(&key) {
return;
}
if cache.device_entries.len() >= OPTIMIZE_CACHE_CAPACITY {
if let Some(evicted) = cache.device_fifo.pop_front() {
cache.device_entries.remove(&evicted);
}
}
cache.device_fifo.push_back(key);
cache.device_entries.insert(key, program.clone());
}
#[cfg(test)]
pub(super) fn clear() {
if let Ok(mut cache) = cache().lock() {
cache.entries.clear();
cache.fifo.clear();
cache.device_entries.clear();
cache.device_fifo.clear();
}
}
#[cfg(test)]
pub(super) fn len() -> usize {
cache().lock().map(|c| c.entries.len()).unwrap_or(0)
}
#[cfg(test)]
pub(super) fn len_device() -> usize {
cache()
.lock()
.map(|c| c.device_entries.len())
.unwrap_or(0)
}
}
#[cfg(test)]
mod optimize_tests {
use super::*;
use std::sync::{Mutex, MutexGuard, OnceLock};
use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node};
fn serial() -> MutexGuard<'static, ()> {
static M: OnceLock<Mutex<()>> = OnceLock::new();
M.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|e| e.into_inner())
}
fn sample_program() -> Program {
Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(42))],
)
}
#[test]
fn optimize_is_cached_by_fingerprint() {
let _g = serial();
optimize_cache::clear();
let p1 = sample_program();
let p2 = sample_program();
let _ = optimize(p1);
let before = optimize_cache::len();
let _ = optimize(p2);
let after = optimize_cache::len();
assert_eq!(
before, after,
"second optimize on identical fingerprint must hit the cache"
);
assert_eq!(before, 1, "cache must contain exactly one entry");
}
#[test]
fn optimize_returns_equivalent_program_on_cache_hit() {
let _g = serial();
optimize_cache::clear();
let p = sample_program();
let first = optimize(p.clone());
let second = optimize(p);
assert_eq!(
first.fingerprint(),
second.fingerprint(),
"cache hit must return a Program with identical fingerprint"
);
}
#[test]
fn optimize_cache_evicts_at_capacity() {
let _g = serial();
optimize_cache::clear();
for i in 0..(OPTIMIZE_CACHE_CAPACITY + 1) {
let prog = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(i as u32))],
);
let _ = optimize(prog);
}
assert_eq!(
optimize_cache::len(),
OPTIMIZE_CACHE_CAPACITY,
"cache must cap at OPTIMIZE_CACHE_CAPACITY entries"
);
}
#[test]
fn optimize_cache_deduplicates_entries_by_fingerprint() {
let _g = serial();
optimize_cache::clear();
let p1 = sample_program();
let p2 = sample_program();
let _ = optimize(p1);
let before = optimize_cache::len();
let _ = optimize(p2);
let after = optimize_cache::len();
assert_eq!(
before,
after,
"optimize must reuse cached result on identical fingerprints"
);
}
#[test]
fn optimize_for_device_uses_device_specific_cache() {
let _g = serial();
optimize_cache::clear();
let mut profile = vyre_driver::DeviceProfile::conservative("test");
profile.max_workgroup_size = [256, 1, 1];
profile.max_invocations_per_workgroup = 256;
let p1 = sample_program();
let p2 = sample_program();
let first = optimize_for_device(p1, &profile);
let second = optimize_for_device(p2, &profile);
assert_eq!(first.fingerprint(), second.fingerprint());
assert_eq!(
optimize_cache::len_device(),
1,
"same program+device profile must hit the device optimize cache"
);
assert_eq!(
optimize_cache::len(),
1,
"device optimization should still reuse the canonical optimize cache after tuning"
);
}
}