runmat_runtime/builtins/common/
spec.rs1use std::fmt;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum ScalarType {
6 F32,
7 F64,
8 I32,
9 Bool,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum GpuOpKind {
15 Elementwise,
16 Reduction,
17 MatMul,
18 Transpose,
19 Custom(&'static str),
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BroadcastSemantics {
25 Matlab,
26 ScalarOnly,
27 None,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ProviderHook {
33 Unary {
34 name: &'static str,
35 },
36 Binary {
37 name: &'static str,
38 commutative: bool,
39 },
40 Reduction {
41 name: &'static str,
42 },
43 Custom(&'static str),
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ConstantStrategy {
49 InlineLiteral,
50 UniformBuffer,
51 WorkgroupMemory,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum ResidencyPolicy {
57 InheritInputs,
58 NewHandle,
59 GatherImmediately,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum ReductionNaN {
65 Include,
66 Omit,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ShapeRequirements {
72 BroadcastCompatible,
73 Exact(&'static [usize]),
74 Any,
75}
76
77pub struct FusionExprContext<'a> {
79 pub scalar_ty: ScalarType,
80 pub inputs: &'a [&'a str],
81 pub constants: &'a [&'a str],
82}
83
84pub type FusionExprBuilder = fn(&FusionExprContext) -> Result<String, FusionError>;
86
87#[derive(Clone)]
89pub struct FusionKernelTemplate {
90 pub scalar_precisions: &'static [ScalarType],
91 pub wgsl_body: FusionExprBuilder,
92}
93
94#[derive(Debug)]
96pub enum FusionError {
97 MissingInput(usize),
98 UnsupportedPrecision(ScalarType),
99 Message(&'static str),
100}
101
102#[derive(Debug, Clone, Copy)]
104pub struct BuiltinGpuSpec {
105 pub name: &'static str,
106 pub op_kind: GpuOpKind,
107 pub supported_precisions: &'static [ScalarType],
108 pub broadcast: BroadcastSemantics,
109 pub provider_hooks: &'static [ProviderHook],
110 pub constant_strategy: ConstantStrategy,
111 pub residency: ResidencyPolicy,
112 pub nan_mode: ReductionNaN,
113 pub two_pass_threshold: Option<usize>,
115 pub workgroup_size: Option<u32>,
117 pub accepts_nan_mode: bool,
119 pub notes: &'static str,
120}
121
122impl fmt::Display for FusionError {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 match self {
125 FusionError::MissingInput(idx) => write!(f, "missing input {}", idx),
126 FusionError::UnsupportedPrecision(ty) => write!(f, "unsupported precision {:?}", ty),
127 FusionError::Message(msg) => write!(f, "{msg}"),
128 }
129 }
130}
131
132impl std::error::Error for FusionError {}
133
134#[derive(Clone)]
136pub struct BuiltinFusionSpec {
137 pub name: &'static str,
138 pub shape: ShapeRequirements,
139 pub constant_strategy: ConstantStrategy,
140 pub elementwise: Option<FusionKernelTemplate>,
141 pub reduction: Option<FusionKernelTemplate>,
142 pub emits_nan: bool,
143 pub notes: &'static str,
144}
145
146pub struct GpuSpecInventory {
148 pub spec: &'static BuiltinGpuSpec,
149}
150
151pub struct FusionSpecInventory {
153 pub spec: &'static BuiltinFusionSpec,
154}
155
156inventory::collect!(GpuSpecInventory);
157inventory::collect!(FusionSpecInventory);
158
159pub fn builtin_gpu_specs() -> impl Iterator<Item = &'static BuiltinGpuSpec> {
161 inventory::iter::<GpuSpecInventory>().map(|entry| entry.spec)
162}
163
164pub fn builtin_fusion_specs() -> impl Iterator<Item = &'static BuiltinFusionSpec> {
166 inventory::iter::<FusionSpecInventory>().map(|entry| entry.spec)
167}
168
169impl fmt::Debug for BuiltinFusionSpec {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("BuiltinFusionSpec")
172 .field("name", &self.name)
173 .field("shape", &self.shape)
174 .field("emits_nan", &self.emits_nan)
175 .finish()
176 }
177}
178
179#[macro_export]
180macro_rules! register_builtin_gpu_spec {
181 ($spec:expr) => {
182 inventory::submit! {
183 $crate::builtins::common::spec::GpuSpecInventory { spec: &$spec }
184 }
185 };
186}
187
188#[macro_export]
189macro_rules! register_builtin_fusion_spec {
190 ($spec:expr) => {
191 inventory::submit! {
192 $crate::builtins::common::spec::FusionSpecInventory { spec: &$spec }
193 }
194 };
195}
196
197pub struct DocTextInventory {
199 pub name: &'static str,
200 pub text: &'static str,
201}
202
203inventory::collect!(DocTextInventory);
204
205pub fn builtin_doc_texts() -> impl Iterator<Item = &'static DocTextInventory> {
206 inventory::iter::<DocTextInventory>()
207}
208
209#[macro_export]
210macro_rules! register_builtin_doc_text {
211 ($name:expr, $text:expr) => {
212 #[cfg(feature = "doc_export")]
213 inventory::submit! {
214 $crate::builtins::common::spec::DocTextInventory { name: $name, text: $text }
215 }
216 };
217}