use std::fmt;
#[cfg(not(target_arch = "wasm32"))]
use std::collections::HashMap;
#[cfg(not(target_arch = "wasm32"))]
use once_cell::sync::OnceCell;
#[cfg(target_arch = "wasm32")]
pub(crate) mod wasm_registry {
#![allow(dead_code)]
use super::{BuiltinFusionSpec, BuiltinGpuSpec};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Mutex;
static GPU_SPECS: Lazy<Mutex<Vec<&'static BuiltinGpuSpec>>> =
Lazy::new(|| Mutex::new(Vec::new()));
static FUSION_SPECS: Lazy<Mutex<Vec<&'static BuiltinFusionSpec>>> =
Lazy::new(|| Mutex::new(Vec::new()));
static RESIDENCY_POLICIES: Lazy<Mutex<HashMap<String, super::ResidencyPolicy>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
pub(crate) fn submit_gpu_spec(spec: &'static BuiltinGpuSpec) {
GPU_SPECS
.lock()
.expect("gpu spec registry poisoned")
.push(spec);
RESIDENCY_POLICIES
.lock()
.expect("gpu spec registry poisoned")
.insert(spec.name.to_ascii_lowercase(), spec.residency);
}
pub(crate) fn submit_fusion_spec(spec: &'static BuiltinFusionSpec) {
FUSION_SPECS
.lock()
.expect("fusion spec registry poisoned")
.push(spec);
}
pub(crate) fn gpu_specs() -> std::vec::IntoIter<&'static BuiltinGpuSpec> {
GPU_SPECS
.lock()
.expect("gpu spec registry poisoned")
.clone()
.into_iter()
}
pub(crate) fn residency_policy(name: &str) -> Option<super::ResidencyPolicy> {
RESIDENCY_POLICIES
.lock()
.expect("gpu spec registry poisoned")
.get(&name.to_ascii_lowercase())
.copied()
}
pub(crate) fn fusion_specs() -> std::vec::IntoIter<&'static BuiltinFusionSpec> {
FUSION_SPECS
.lock()
.expect("fusion spec registry poisoned")
.clone()
.into_iter()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScalarType {
F32,
F64,
I32,
Bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuOpKind {
Elementwise,
Reduction,
MatMul,
Transpose,
PlotRender,
Custom(&'static str),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BroadcastSemantics {
Matlab,
ScalarOnly,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProviderHook {
Unary {
name: &'static str,
},
Binary {
name: &'static str,
commutative: bool,
},
Reduction {
name: &'static str,
},
Custom(&'static str),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConstantStrategy {
InlineLiteral,
UniformBuffer,
WorkgroupMemory,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResidencyPolicy {
InheritInputs,
NewHandle,
GatherImmediately,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionNaN {
Include,
Omit,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShapeRequirements {
BroadcastCompatible,
Exact(&'static [usize]),
Any,
}
pub struct FusionExprContext<'a> {
pub scalar_ty: ScalarType,
pub inputs: &'a [&'a str],
pub constants: &'a [&'a str],
}
pub type FusionExprBuilder = fn(&FusionExprContext) -> Result<String, FusionError>;
#[derive(Clone)]
pub struct FusionKernelTemplate {
pub scalar_precisions: &'static [ScalarType],
pub wgsl_body: FusionExprBuilder,
}
#[derive(Debug)]
pub enum FusionError {
MissingInput(usize),
UnsupportedPrecision(ScalarType),
Message(&'static str),
}
#[derive(Debug, Clone, Copy)]
pub struct BuiltinGpuSpec {
pub name: &'static str,
pub op_kind: GpuOpKind,
pub supported_precisions: &'static [ScalarType],
pub broadcast: BroadcastSemantics,
pub provider_hooks: &'static [ProviderHook],
pub constant_strategy: ConstantStrategy,
pub residency: ResidencyPolicy,
pub nan_mode: ReductionNaN,
pub two_pass_threshold: Option<usize>,
pub workgroup_size: Option<u32>,
pub accepts_nan_mode: bool,
pub notes: &'static str,
}
impl fmt::Display for FusionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FusionError::MissingInput(idx) => write!(f, "missing input {}", idx),
FusionError::UnsupportedPrecision(ty) => write!(f, "unsupported precision {:?}", ty),
FusionError::Message(msg) => write!(f, "{msg}"),
}
}
}
impl std::error::Error for FusionError {}
#[derive(Clone)]
pub struct BuiltinFusionSpec {
pub name: &'static str,
pub shape: ShapeRequirements,
pub constant_strategy: ConstantStrategy,
pub elementwise: Option<FusionKernelTemplate>,
pub reduction: Option<FusionKernelTemplate>,
pub emits_nan: bool,
pub notes: &'static str,
}
pub struct GpuSpecInventory {
pub spec: &'static BuiltinGpuSpec,
}
pub struct FusionSpecInventory {
pub spec: &'static BuiltinFusionSpec,
}
#[cfg(not(target_arch = "wasm32"))]
inventory::collect!(GpuSpecInventory);
#[cfg(not(target_arch = "wasm32"))]
inventory::collect!(FusionSpecInventory);
#[cfg(not(target_arch = "wasm32"))]
pub fn builtin_gpu_specs() -> impl Iterator<Item = &'static BuiltinGpuSpec> {
inventory::iter::<GpuSpecInventory>().map(|entry| entry.spec)
}
#[cfg(target_arch = "wasm32")]
pub fn builtin_gpu_specs() -> std::vec::IntoIter<&'static BuiltinGpuSpec> {
wasm_registry::gpu_specs()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn builtin_fusion_specs() -> impl Iterator<Item = &'static BuiltinFusionSpec> {
inventory::iter::<FusionSpecInventory>().map(|entry| entry.spec)
}
#[cfg(target_arch = "wasm32")]
pub fn builtin_fusion_specs() -> std::vec::IntoIter<&'static BuiltinFusionSpec> {
wasm_registry::fusion_specs()
}
#[cfg(not(target_arch = "wasm32"))]
static RESIDENCY_POLICY_MAP: OnceCell<HashMap<String, ResidencyPolicy>> = OnceCell::new();
#[cfg(not(target_arch = "wasm32"))]
fn build_residency_policy_map() -> HashMap<String, ResidencyPolicy> {
let mut map = HashMap::new();
for spec in builtin_gpu_specs() {
map.insert(spec.name.to_ascii_lowercase(), spec.residency);
}
map
}
pub fn builtin_residency_policy(name: &str) -> Option<ResidencyPolicy> {
#[cfg(target_arch = "wasm32")]
{
return wasm_registry::residency_policy(name);
}
#[cfg(not(target_arch = "wasm32"))]
{
let map = RESIDENCY_POLICY_MAP.get_or_init(build_residency_policy_map);
map.get(&name.to_ascii_lowercase()).copied()
}
}
impl fmt::Debug for BuiltinFusionSpec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BuiltinFusionSpec")
.field("name", &self.name)
.field("shape", &self.shape)
.field("emits_nan", &self.emits_nan)
.finish()
}
}