runmat_runtime/builtins/common/
spec.rs1use std::fmt;
2
3#[cfg(not(target_arch = "wasm32"))]
4use std::collections::HashMap;
5
6#[cfg(not(target_arch = "wasm32"))]
7use once_cell::sync::OnceCell;
8
9#[cfg(target_arch = "wasm32")]
10pub(crate) mod wasm_registry {
11 #![allow(dead_code)]
12 use super::{BuiltinFusionSpec, BuiltinGpuSpec};
13 use once_cell::sync::Lazy;
14 use std::collections::HashMap;
15 use std::sync::Mutex;
16
17 static GPU_SPECS: Lazy<Mutex<Vec<&'static BuiltinGpuSpec>>> =
18 Lazy::new(|| Mutex::new(Vec::new()));
19 static FUSION_SPECS: Lazy<Mutex<Vec<&'static BuiltinFusionSpec>>> =
20 Lazy::new(|| Mutex::new(Vec::new()));
21 static RESIDENCY_POLICIES: Lazy<Mutex<HashMap<String, super::ResidencyPolicy>>> =
22 Lazy::new(|| Mutex::new(HashMap::new()));
23
24 pub(crate) fn submit_gpu_spec(spec: &'static BuiltinGpuSpec) {
25 GPU_SPECS
26 .lock()
27 .expect("gpu spec registry poisoned")
28 .push(spec);
29 RESIDENCY_POLICIES
30 .lock()
31 .expect("gpu spec registry poisoned")
32 .insert(spec.name.to_ascii_lowercase(), spec.residency);
33 }
34
35 pub(crate) fn submit_fusion_spec(spec: &'static BuiltinFusionSpec) {
36 FUSION_SPECS
37 .lock()
38 .expect("fusion spec registry poisoned")
39 .push(spec);
40 }
41
42 pub(crate) fn gpu_specs() -> std::vec::IntoIter<&'static BuiltinGpuSpec> {
43 GPU_SPECS
44 .lock()
45 .expect("gpu spec registry poisoned")
46 .clone()
47 .into_iter()
48 }
49
50 pub(crate) fn residency_policy(name: &str) -> Option<super::ResidencyPolicy> {
51 RESIDENCY_POLICIES
52 .lock()
53 .expect("gpu spec registry poisoned")
54 .get(&name.to_ascii_lowercase())
55 .copied()
56 }
57
58 pub(crate) fn fusion_specs() -> std::vec::IntoIter<&'static BuiltinFusionSpec> {
59 FUSION_SPECS
60 .lock()
61 .expect("fusion spec registry poisoned")
62 .clone()
63 .into_iter()
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum ScalarType {
70 F32,
71 F64,
72 I32,
73 Bool,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum GpuOpKind {
79 Elementwise,
80 Reduction,
81 MatMul,
82 Transpose,
83 PlotRender,
84 Custom(&'static str),
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum BroadcastSemantics {
90 Matlab,
91 ScalarOnly,
92 None,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum ProviderHook {
98 Unary {
99 name: &'static str,
100 },
101 Binary {
102 name: &'static str,
103 commutative: bool,
104 },
105 Reduction {
106 name: &'static str,
107 },
108 Custom(&'static str),
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum ConstantStrategy {
114 InlineLiteral,
115 UniformBuffer,
116 WorkgroupMemory,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum ResidencyPolicy {
122 InheritInputs,
123 NewHandle,
124 GatherImmediately,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub enum ReductionNaN {
130 Include,
131 Omit,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum ShapeRequirements {
137 BroadcastCompatible,
138 Exact(&'static [usize]),
139 Any,
140}
141
142pub struct FusionExprContext<'a> {
144 pub scalar_ty: ScalarType,
145 pub inputs: &'a [&'a str],
146 pub constants: &'a [&'a str],
147}
148
149pub type FusionExprBuilder = fn(&FusionExprContext) -> Result<String, FusionError>;
151
152#[derive(Clone)]
154pub struct FusionKernelTemplate {
155 pub scalar_precisions: &'static [ScalarType],
156 pub wgsl_body: FusionExprBuilder,
157}
158
159#[derive(Debug)]
161pub enum FusionError {
162 MissingInput(usize),
163 UnsupportedPrecision(ScalarType),
164 Message(&'static str),
165}
166
167#[derive(Debug, Clone, Copy)]
169pub struct BuiltinGpuSpec {
170 pub name: &'static str,
171 pub op_kind: GpuOpKind,
172 pub supported_precisions: &'static [ScalarType],
173 pub broadcast: BroadcastSemantics,
174 pub provider_hooks: &'static [ProviderHook],
175 pub constant_strategy: ConstantStrategy,
176 pub residency: ResidencyPolicy,
177 pub nan_mode: ReductionNaN,
178 pub two_pass_threshold: Option<usize>,
180 pub workgroup_size: Option<u32>,
182 pub accepts_nan_mode: bool,
184 pub notes: &'static str,
185}
186
187impl fmt::Display for FusionError {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 match self {
190 FusionError::MissingInput(idx) => write!(f, "missing input {}", idx),
191 FusionError::UnsupportedPrecision(ty) => write!(f, "unsupported precision {:?}", ty),
192 FusionError::Message(msg) => write!(f, "{msg}"),
193 }
194 }
195}
196
197impl std::error::Error for FusionError {}
198
199#[derive(Clone)]
201pub struct BuiltinFusionSpec {
202 pub name: &'static str,
203 pub shape: ShapeRequirements,
204 pub constant_strategy: ConstantStrategy,
205 pub elementwise: Option<FusionKernelTemplate>,
206 pub reduction: Option<FusionKernelTemplate>,
207 pub emits_nan: bool,
208 pub notes: &'static str,
209}
210
211pub struct GpuSpecInventory {
213 pub spec: &'static BuiltinGpuSpec,
214}
215
216pub struct FusionSpecInventory {
218 pub spec: &'static BuiltinFusionSpec,
219}
220
221#[cfg(not(target_arch = "wasm32"))]
222inventory::collect!(GpuSpecInventory);
223#[cfg(not(target_arch = "wasm32"))]
224inventory::collect!(FusionSpecInventory);
225
226#[cfg(not(target_arch = "wasm32"))]
228pub fn builtin_gpu_specs() -> impl Iterator<Item = &'static BuiltinGpuSpec> {
229 inventory::iter::<GpuSpecInventory>().map(|entry| entry.spec)
230}
231
232#[cfg(target_arch = "wasm32")]
233pub fn builtin_gpu_specs() -> std::vec::IntoIter<&'static BuiltinGpuSpec> {
234 wasm_registry::gpu_specs()
235}
236
237#[cfg(not(target_arch = "wasm32"))]
239pub fn builtin_fusion_specs() -> impl Iterator<Item = &'static BuiltinFusionSpec> {
240 inventory::iter::<FusionSpecInventory>().map(|entry| entry.spec)
241}
242
243#[cfg(target_arch = "wasm32")]
244pub fn builtin_fusion_specs() -> std::vec::IntoIter<&'static BuiltinFusionSpec> {
245 wasm_registry::fusion_specs()
246}
247
248#[cfg(not(target_arch = "wasm32"))]
249static RESIDENCY_POLICY_MAP: OnceCell<HashMap<String, ResidencyPolicy>> = OnceCell::new();
250
251#[cfg(not(target_arch = "wasm32"))]
252fn build_residency_policy_map() -> HashMap<String, ResidencyPolicy> {
253 let mut map = HashMap::new();
254 for spec in builtin_gpu_specs() {
255 map.insert(spec.name.to_ascii_lowercase(), spec.residency);
256 }
257 map
258}
259
260pub fn builtin_residency_policy(name: &str) -> Option<ResidencyPolicy> {
265 #[cfg(target_arch = "wasm32")]
266 {
267 return wasm_registry::residency_policy(name);
268 }
269
270 #[cfg(not(target_arch = "wasm32"))]
271 {
272 let map = RESIDENCY_POLICY_MAP.get_or_init(build_residency_policy_map);
273 map.get(&name.to_ascii_lowercase()).copied()
274 }
275}
276
277impl fmt::Debug for BuiltinFusionSpec {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 f.debug_struct("BuiltinFusionSpec")
280 .field("name", &self.name)
281 .field("shape", &self.shape)
282 .field("emits_nan", &self.emits_nan)
283 .finish()
284 }
285}