runmat_runtime/builtins/common/
spec.rs

1use std::fmt;
2
3/// Supported scalar precisions that GPU kernels may target.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum ScalarType {
6    F32,
7    F64,
8    I32,
9    Bool,
10}
11
12/// High-level GPU operation kind for builtin categorisation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum GpuOpKind {
15    Elementwise,
16    Reduction,
17    MatMul,
18    Transpose,
19    Custom(&'static str),
20}
21
22/// Broadcast semantics supported by the builtin.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BroadcastSemantics {
25    Matlab,
26    ScalarOnly,
27    None,
28}
29
30/// Hook names that providers may implement for specialised kernels.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ProviderHook {
33    Unary {
34        name: &'static str,
35    },
36    Binary {
37        name: &'static str,
38        commutative: bool,
39    },
40    Reduction {
41        name: &'static str,
42    },
43    Custom(&'static str),
44}
45
46/// Strategy used when embedding constants in fused kernels.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ConstantStrategy {
49    InlineLiteral,
50    UniformBuffer,
51    WorkgroupMemory,
52}
53
54/// Residency policy for builtin outputs.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum ResidencyPolicy {
57    InheritInputs,
58    NewHandle,
59    GatherImmediately,
60}
61
62/// How reductions should treat NaN values.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum ReductionNaN {
65    Include,
66    Omit,
67}
68
69/// Shape requirements for fused kernels.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ShapeRequirements {
72    BroadcastCompatible,
73    Exact(&'static [usize]),
74    Any,
75}
76
77/// Context provided to fusion expression builders.
78pub struct FusionExprContext<'a> {
79    pub scalar_ty: ScalarType,
80    pub inputs: &'a [&'a str],
81    pub constants: &'a [&'a str],
82}
83
84/// Builder used to generate WGSL expressions.
85pub type FusionExprBuilder = fn(&FusionExprContext) -> Result<String, FusionError>;
86
87/// Description of a fusion kernel template.
88#[derive(Clone)]
89pub struct FusionKernelTemplate {
90    pub scalar_precisions: &'static [ScalarType],
91    pub wgsl_body: FusionExprBuilder,
92}
93
94/// Possible errors emitted by a fusion builder.
95#[derive(Debug)]
96pub enum FusionError {
97    MissingInput(usize),
98    UnsupportedPrecision(ScalarType),
99    Message(&'static str),
100}
101
102/// GPU metadata registered alongside builtin functions.
103#[derive(Debug, Clone, Copy)]
104pub struct BuiltinGpuSpec {
105    pub name: &'static str,
106    pub op_kind: GpuOpKind,
107    pub supported_precisions: &'static [ScalarType],
108    pub broadcast: BroadcastSemantics,
109    pub provider_hooks: &'static [ProviderHook],
110    pub constant_strategy: ConstantStrategy,
111    pub residency: ResidencyPolicy,
112    pub nan_mode: ReductionNaN,
113    /// If set, reductions with reduce_len greater than this should prefer a two-pass kernel.
114    pub two_pass_threshold: Option<usize>,
115    /// Optional workgroup size hint for generated kernels.
116    pub workgroup_size: Option<u32>,
117    /// Whether the provider hook (if used) supports device-side omitnan handling.
118    pub accepts_nan_mode: bool,
119    pub notes: &'static str,
120}
121
122impl fmt::Display for FusionError {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        match self {
125            FusionError::MissingInput(idx) => write!(f, "missing input {}", idx),
126            FusionError::UnsupportedPrecision(ty) => write!(f, "unsupported precision {:?}", ty),
127            FusionError::Message(msg) => write!(f, "{msg}"),
128        }
129    }
130}
131
132impl std::error::Error for FusionError {}
133
134/// Fusion metadata registered alongside builtin functions.
135#[derive(Clone)]
136pub struct BuiltinFusionSpec {
137    pub name: &'static str,
138    pub shape: ShapeRequirements,
139    pub constant_strategy: ConstantStrategy,
140    pub elementwise: Option<FusionKernelTemplate>,
141    pub reduction: Option<FusionKernelTemplate>,
142    pub emits_nan: bool,
143    pub notes: &'static str,
144}
145
146/// Inventory wrapper for GPU specs.
147pub struct GpuSpecInventory {
148    pub spec: &'static BuiltinGpuSpec,
149}
150
151/// Inventory wrapper for fusion specs.
152pub struct FusionSpecInventory {
153    pub spec: &'static BuiltinFusionSpec,
154}
155
156inventory::collect!(GpuSpecInventory);
157inventory::collect!(FusionSpecInventory);
158
159/// Iterate all registered GPU specs.
160pub fn builtin_gpu_specs() -> impl Iterator<Item = &'static BuiltinGpuSpec> {
161    inventory::iter::<GpuSpecInventory>().map(|entry| entry.spec)
162}
163
164/// Iterate all registered fusion specs.
165pub fn builtin_fusion_specs() -> impl Iterator<Item = &'static BuiltinFusionSpec> {
166    inventory::iter::<FusionSpecInventory>().map(|entry| entry.spec)
167}
168
169impl fmt::Debug for BuiltinFusionSpec {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("BuiltinFusionSpec")
172            .field("name", &self.name)
173            .field("shape", &self.shape)
174            .field("emits_nan", &self.emits_nan)
175            .finish()
176    }
177}
178
179#[macro_export]
180macro_rules! register_builtin_gpu_spec {
181    ($spec:expr) => {
182        inventory::submit! {
183            $crate::builtins::common::spec::GpuSpecInventory { spec: &$spec }
184        }
185    };
186}
187
188#[macro_export]
189macro_rules! register_builtin_fusion_spec {
190    ($spec:expr) => {
191        inventory::submit! {
192            $crate::builtins::common::spec::FusionSpecInventory { spec: &$spec }
193        }
194    };
195}
196
197// Documentation text inventory (only populated when doc_export feature is enabled)
198pub struct DocTextInventory {
199    pub name: &'static str,
200    pub text: &'static str,
201}
202
203inventory::collect!(DocTextInventory);
204
205pub fn builtin_doc_texts() -> impl Iterator<Item = &'static DocTextInventory> {
206    inventory::iter::<DocTextInventory>()
207}
208
209#[macro_export]
210macro_rules! register_builtin_doc_text {
211    ($name:expr, $text:expr) => {
212        #[cfg(feature = "doc_export")]
213        inventory::submit! {
214            $crate::builtins::common::spec::DocTextInventory { name: $name, text: $text }
215        }
216    };
217}