Skip to main content

runmat_runtime/builtins/common/
spec.rs

1use std::fmt;
2
3#[cfg(not(target_arch = "wasm32"))]
4use std::collections::HashMap;
5
6#[cfg(not(target_arch = "wasm32"))]
7use once_cell::sync::OnceCell;
8
9#[cfg(target_arch = "wasm32")]
10pub(crate) mod wasm_registry {
11    #![allow(dead_code)]
12    use super::{BuiltinFusionSpec, BuiltinGpuSpec};
13    use once_cell::sync::Lazy;
14    use std::collections::HashMap;
15    use std::sync::Mutex;
16
17    static GPU_SPECS: Lazy<Mutex<Vec<&'static BuiltinGpuSpec>>> =
18        Lazy::new(|| Mutex::new(Vec::new()));
19    static FUSION_SPECS: Lazy<Mutex<Vec<&'static BuiltinFusionSpec>>> =
20        Lazy::new(|| Mutex::new(Vec::new()));
21    static RESIDENCY_POLICIES: Lazy<Mutex<HashMap<String, super::ResidencyPolicy>>> =
22        Lazy::new(|| Mutex::new(HashMap::new()));
23
24    pub(crate) fn submit_gpu_spec(spec: &'static BuiltinGpuSpec) {
25        GPU_SPECS
26            .lock()
27            .expect("gpu spec registry poisoned")
28            .push(spec);
29        RESIDENCY_POLICIES
30            .lock()
31            .expect("gpu spec registry poisoned")
32            .insert(spec.name.to_ascii_lowercase(), spec.residency);
33    }
34
35    pub(crate) fn submit_fusion_spec(spec: &'static BuiltinFusionSpec) {
36        FUSION_SPECS
37            .lock()
38            .expect("fusion spec registry poisoned")
39            .push(spec);
40    }
41
42    pub(crate) fn gpu_specs() -> std::vec::IntoIter<&'static BuiltinGpuSpec> {
43        GPU_SPECS
44            .lock()
45            .expect("gpu spec registry poisoned")
46            .clone()
47            .into_iter()
48    }
49
50    pub(crate) fn residency_policy(name: &str) -> Option<super::ResidencyPolicy> {
51        RESIDENCY_POLICIES
52            .lock()
53            .expect("gpu spec registry poisoned")
54            .get(&name.to_ascii_lowercase())
55            .copied()
56    }
57
58    pub(crate) fn fusion_specs() -> std::vec::IntoIter<&'static BuiltinFusionSpec> {
59        FUSION_SPECS
60            .lock()
61            .expect("fusion spec registry poisoned")
62            .clone()
63            .into_iter()
64    }
65}
66
67/// Supported scalar precisions that GPU kernels may target.
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum ScalarType {
70    F32,
71    F64,
72    I32,
73    Bool,
74}
75
76/// High-level GPU operation kind for builtin categorisation.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum GpuOpKind {
79    Elementwise,
80    Reduction,
81    MatMul,
82    Transpose,
83    PlotRender,
84    Custom(&'static str),
85}
86
87/// Broadcast semantics supported by the builtin.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum BroadcastSemantics {
90    Matlab,
91    ScalarOnly,
92    None,
93}
94
95/// Hook names that providers may implement for specialised kernels.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum ProviderHook {
98    Unary {
99        name: &'static str,
100    },
101    Binary {
102        name: &'static str,
103        commutative: bool,
104    },
105    Reduction {
106        name: &'static str,
107    },
108    Custom(&'static str),
109}
110
111/// Strategy used when embedding constants in fused kernels.
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum ConstantStrategy {
114    InlineLiteral,
115    UniformBuffer,
116    WorkgroupMemory,
117}
118
119/// Residency policy for builtin outputs.
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum ResidencyPolicy {
122    InheritInputs,
123    NewHandle,
124    GatherImmediately,
125}
126
127/// How reductions should treat NaN values.
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub enum ReductionNaN {
130    Include,
131    Omit,
132}
133
134/// Shape requirements for fused kernels.
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum ShapeRequirements {
137    BroadcastCompatible,
138    Exact(&'static [usize]),
139    Any,
140}
141
142/// Context provided to fusion expression builders.
143pub struct FusionExprContext<'a> {
144    pub scalar_ty: ScalarType,
145    pub inputs: &'a [&'a str],
146    pub constants: &'a [&'a str],
147}
148
149/// Builder used to generate WGSL expressions.
150pub type FusionExprBuilder = fn(&FusionExprContext) -> Result<String, FusionError>;
151
152/// Description of a fusion kernel template.
153#[derive(Clone)]
154pub struct FusionKernelTemplate {
155    pub scalar_precisions: &'static [ScalarType],
156    pub wgsl_body: FusionExprBuilder,
157}
158
159/// Possible errors emitted by a fusion builder.
160#[derive(Debug)]
161pub enum FusionError {
162    MissingInput(usize),
163    UnsupportedPrecision(ScalarType),
164    Message(&'static str),
165}
166
167/// GPU metadata registered alongside builtin functions.
168#[derive(Debug, Clone, Copy)]
169pub struct BuiltinGpuSpec {
170    pub name: &'static str,
171    pub op_kind: GpuOpKind,
172    pub supported_precisions: &'static [ScalarType],
173    pub broadcast: BroadcastSemantics,
174    pub provider_hooks: &'static [ProviderHook],
175    pub constant_strategy: ConstantStrategy,
176    pub residency: ResidencyPolicy,
177    pub nan_mode: ReductionNaN,
178    /// If set, reductions with reduce_len greater than this should prefer a two-pass kernel.
179    pub two_pass_threshold: Option<usize>,
180    /// Optional workgroup size hint for generated kernels.
181    pub workgroup_size: Option<u32>,
182    /// Whether the provider hook (if used) supports device-side omitnan handling.
183    pub accepts_nan_mode: bool,
184    pub notes: &'static str,
185}
186
187impl fmt::Display for FusionError {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        match self {
190            FusionError::MissingInput(idx) => write!(f, "missing input {}", idx),
191            FusionError::UnsupportedPrecision(ty) => write!(f, "unsupported precision {:?}", ty),
192            FusionError::Message(msg) => write!(f, "{msg}"),
193        }
194    }
195}
196
197impl std::error::Error for FusionError {}
198
199/// Fusion metadata registered alongside builtin functions.
200#[derive(Clone)]
201pub struct BuiltinFusionSpec {
202    pub name: &'static str,
203    pub shape: ShapeRequirements,
204    pub constant_strategy: ConstantStrategy,
205    pub elementwise: Option<FusionKernelTemplate>,
206    pub reduction: Option<FusionKernelTemplate>,
207    pub emits_nan: bool,
208    pub notes: &'static str,
209}
210
211/// Inventory wrapper for GPU specs.
212pub struct GpuSpecInventory {
213    pub spec: &'static BuiltinGpuSpec,
214}
215
216/// Inventory wrapper for fusion specs.
217pub struct FusionSpecInventory {
218    pub spec: &'static BuiltinFusionSpec,
219}
220
221#[cfg(not(target_arch = "wasm32"))]
222inventory::collect!(GpuSpecInventory);
223#[cfg(not(target_arch = "wasm32"))]
224inventory::collect!(FusionSpecInventory);
225
226/// Iterate all registered GPU specs.
227#[cfg(not(target_arch = "wasm32"))]
228pub fn builtin_gpu_specs() -> impl Iterator<Item = &'static BuiltinGpuSpec> {
229    inventory::iter::<GpuSpecInventory>().map(|entry| entry.spec)
230}
231
232#[cfg(target_arch = "wasm32")]
233pub fn builtin_gpu_specs() -> std::vec::IntoIter<&'static BuiltinGpuSpec> {
234    wasm_registry::gpu_specs()
235}
236
237/// Iterate all registered fusion specs.
238#[cfg(not(target_arch = "wasm32"))]
239pub fn builtin_fusion_specs() -> impl Iterator<Item = &'static BuiltinFusionSpec> {
240    inventory::iter::<FusionSpecInventory>().map(|entry| entry.spec)
241}
242
243#[cfg(target_arch = "wasm32")]
244pub fn builtin_fusion_specs() -> std::vec::IntoIter<&'static BuiltinFusionSpec> {
245    wasm_registry::fusion_specs()
246}
247
248#[cfg(not(target_arch = "wasm32"))]
249static RESIDENCY_POLICY_MAP: OnceCell<HashMap<String, ResidencyPolicy>> = OnceCell::new();
250
251#[cfg(not(target_arch = "wasm32"))]
252fn build_residency_policy_map() -> HashMap<String, ResidencyPolicy> {
253    let mut map = HashMap::new();
254    for spec in builtin_gpu_specs() {
255        map.insert(spec.name.to_ascii_lowercase(), spec.residency);
256    }
257    map
258}
259
260/// Return the declared residency policy for a builtin's GPU implementation.
261///
262/// This is used at the runtime/auto-offload boundary to decide whether GPU-resident
263/// arguments must be gathered to host (`GatherImmediately`) or can remain device-resident.
264pub fn builtin_residency_policy(name: &str) -> Option<ResidencyPolicy> {
265    #[cfg(target_arch = "wasm32")]
266    {
267        return wasm_registry::residency_policy(name);
268    }
269
270    #[cfg(not(target_arch = "wasm32"))]
271    {
272        let map = RESIDENCY_POLICY_MAP.get_or_init(build_residency_policy_map);
273        map.get(&name.to_ascii_lowercase()).copied()
274    }
275}
276
277impl fmt::Debug for BuiltinFusionSpec {
278    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279        f.debug_struct("BuiltinFusionSpec")
280            .field("name", &self.name)
281            .field("shape", &self.shape)
282            .field("emits_nan", &self.emits_nan)
283            .finish()
284    }
285}