use crate::diagnostics::{Diagnostic, OpLocation};
use rustc_hash::FxHashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdapterCaps {
pub backend: &'static str,
pub supports_subgroup_ops: bool,
pub supports_indirect_dispatch: bool,
pub supports_specialization_constants: bool,
pub max_workgroup_size: [u32; 3],
pub max_invocations_per_workgroup: u32,
pub max_shared_memory_bytes: u32,
pub max_storage_buffer_binding_size: u64,
pub subgroup_size: u32,
pub compute_units: u32,
pub regs_per_thread_max: u32,
pub l1_cache_bytes: u32,
pub l2_cache_bytes: u32,
pub mem_bw_gbps: u32,
pub ideal_unroll_depth: u32,
pub ideal_vector_pack_bits: u32,
pub ideal_workgroup_tile: [u32; 3],
pub shared_memory_bank_count: u32,
pub shared_memory_bank_width_bytes: u32,
}
impl Default for AdapterCaps {
fn default() -> Self {
Self {
backend: "unknown",
supports_subgroup_ops: false,
supports_indirect_dispatch: false,
supports_specialization_constants: false,
max_workgroup_size: [256, 256, 64],
max_invocations_per_workgroup: 256,
max_shared_memory_bytes: 16 * 1024,
max_storage_buffer_binding_size: 128 * 1024 * 1024,
subgroup_size: 0,
compute_units: 0,
regs_per_thread_max: 0,
l1_cache_bytes: 0,
l2_cache_bytes: 0,
mem_bw_gbps: 0,
ideal_unroll_depth: 0,
ideal_vector_pack_bits: 0,
ideal_workgroup_tile: [0, 0, 0],
shared_memory_bank_count: 0,
shared_memory_bank_width_bytes: 0,
}
}
}
impl AdapterCaps {
#[must_use]
pub const fn conservative() -> Self {
Self {
backend: "conservative",
supports_subgroup_ops: false,
supports_indirect_dispatch: false,
supports_specialization_constants: false,
max_workgroup_size: [256, 1, 1],
max_invocations_per_workgroup: 256,
max_shared_memory_bytes: 16 * 1024,
max_storage_buffer_binding_size: 128 * 1024 * 1024,
subgroup_size: 0,
compute_units: 0,
regs_per_thread_max: 0,
l1_cache_bytes: 0,
l2_cache_bytes: 0,
mem_bw_gbps: 0,
ideal_unroll_depth: 0,
ideal_vector_pack_bits: 0,
ideal_workgroup_tile: [0, 0, 0],
shared_memory_bank_count: 0,
shared_memory_bank_width_bytes: 0,
}
}
#[must_use]
pub const fn high_end() -> Self {
Self {
backend: "high-end-dispatch",
supports_subgroup_ops: true,
supports_indirect_dispatch: true,
supports_specialization_constants: true,
max_workgroup_size: [1024, 1024, 64],
max_invocations_per_workgroup: 1024,
max_shared_memory_bytes: 128 * 1024,
max_storage_buffer_binding_size: 2 * 1024 * 1024 * 1024,
subgroup_size: 32,
compute_units: 128,
regs_per_thread_max: 255,
l1_cache_bytes: 128 * 1024,
l2_cache_bytes: 64 * 1024 * 1024,
mem_bw_gbps: 1700,
ideal_unroll_depth: 8,
ideal_vector_pack_bits: 128,
ideal_workgroup_tile: [16, 16, 1],
shared_memory_bank_count: 32,
shared_memory_bank_width_bytes: 4,
}
}
}
#[derive(Default)]
pub struct AnalysisCache {
entries: FxHashMap<&'static str, Box<dyn std::any::Any + Send + Sync>>,
}
impl AnalysisCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert<T: std::any::Any + Send + Sync>(&mut self, key: &'static str, value: T) {
self.entries.insert(key, Box::new(value));
}
#[must_use]
pub fn get<T: std::any::Any>(&self, key: &'static str) -> Option<&T> {
self.entries.get(key).and_then(|v| v.downcast_ref::<T>())
}
pub fn clear(&mut self) {
self.entries.clear();
}
}
impl std::fmt::Debug for AnalysisCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnalysisCache")
.field("entries", &self.entries.len())
.finish()
}
}
pub struct PassCtx<'a> {
pub program: &'a mut crate::ir_inner::model::program::Program,
pub adapter_caps: &'a AdapterCaps,
pub analyses: &'a mut AnalysisCache,
pub fact_substrate: &'a mut crate::optimizer::fact_substrate::FactSubstrate,
pub diagnostics: &'a mut Vec<Diagnostic>,
}
#[must_use]
pub fn scheduling_error_to_diagnostic(err: &crate::optimizer::PassSchedulingError) -> Diagnostic {
use crate::optimizer::PassSchedulingError as E;
match err {
E::UnknownRequire { pass, missing } => Diagnostic::error(format!(
"OPTSCHED001: pass `{pass}` requires unknown pass `{missing}`. Fix: register `{missing}` or drop the requirement."
))
.with_location(OpLocation::op(pass.to_string())),
E::Cycle { pass_ids, fix } => Diagnostic::error(format!(
"OPTSCHED002: cycle among passes {pass_ids:?}. Fix: {fix}"
)),
E::DuplicateId { id } => Diagnostic::error(format!(
"OPTSCHED003: duplicate pass id `{id}`. Fix: assign every pass a unique stable id."
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_caps_conservative() {
let caps = AdapterCaps::default();
assert_eq!(caps.backend, "unknown");
assert!(!caps.supports_subgroup_ops);
assert!(!caps.supports_indirect_dispatch);
assert_eq!(caps.subgroup_size, 0);
}
#[test]
fn conservative_profile() {
let caps = AdapterCaps::conservative();
assert_eq!(caps.backend, "conservative");
assert_eq!(caps.max_workgroup_size, [256, 1, 1]);
}
#[test]
fn high_end_profile() {
let caps = AdapterCaps::high_end();
assert_eq!(caps.backend, "high-end-dispatch");
assert!(caps.supports_subgroup_ops);
assert!(caps.supports_indirect_dispatch);
assert!(caps.supports_specialization_constants);
assert_eq!(caps.subgroup_size, 32);
assert_eq!(caps.max_invocations_per_workgroup, 1024);
}
#[test]
fn analysis_cache_insert_and_get() {
let mut cache = AnalysisCache::new();
cache.insert("node_count", 42u32);
assert_eq!(cache.get::<u32>("node_count"), Some(&42));
}
#[test]
fn analysis_cache_get_missing() {
let cache = AnalysisCache::new();
assert_eq!(cache.get::<u32>("nonexistent"), None);
}
#[test]
fn analysis_cache_type_mismatch() {
let mut cache = AnalysisCache::new();
cache.insert("node_count", 42u32);
assert_eq!(cache.get::<String>("node_count"), None);
}
#[test]
fn analysis_cache_clear() {
let mut cache = AnalysisCache::new();
cache.insert("a", 1u32);
cache.clear();
assert_eq!(cache.get::<u32>("a"), None);
}
#[test]
fn analysis_cache_debug() {
let cache = AnalysisCache::new();
let debug = format!("{cache:?}");
assert!(debug.contains("AnalysisCache"));
}
#[test]
fn scheduling_error_unknown_require() {
let err = crate::optimizer::PassSchedulingError::UnknownRequire {
pass: "fusion",
missing: "dead_buffer_elim",
};
let diag = scheduling_error_to_diagnostic(&err);
assert!(diag.message.contains("OPTSCHED001"));
assert!(diag.message.contains("fusion"));
}
#[test]
fn scheduling_error_cycle() {
let err = crate::optimizer::PassSchedulingError::Cycle {
pass_ids: vec!["a", "b"],
fix: "break the cycle",
};
let diag = scheduling_error_to_diagnostic(&err);
assert!(diag.message.contains("OPTSCHED002"));
}
#[test]
fn scheduling_error_duplicate_id() {
let err = crate::optimizer::PassSchedulingError::DuplicateId { id: "dup_pass" };
let diag = scheduling_error_to_diagnostic(&err);
assert!(diag.message.contains("OPTSCHED003"));
assert!(diag.message.contains("dup_pass"));
}
}