1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod overloads;
11mod terminator;
12mod type_methods;
13mod typifier;
14
15pub use constant_evaluator::{
16 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
17};
18pub use emitter::Emitter;
19pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
20pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
21pub use namer::{EntryPointIndex, NameKey, Namer};
22pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
23pub use terminator::ensure_block_returns;
24use thiserror::Error;
25pub use type_methods::min_max_float_representable_by;
26pub use typifier::{ResolveContext, ResolveError, TypeResolution};
27
28impl From<super::StorageFormat> for super::Scalar {
29 fn from(format: super::StorageFormat) -> Self {
30 use super::{ScalarKind as Sk, StorageFormat as Sf};
31 let kind = match format {
32 Sf::R8Unorm => Sk::Float,
33 Sf::R8Snorm => Sk::Float,
34 Sf::R8Uint => Sk::Uint,
35 Sf::R8Sint => Sk::Sint,
36 Sf::R16Uint => Sk::Uint,
37 Sf::R16Sint => Sk::Sint,
38 Sf::R16Float => Sk::Float,
39 Sf::Rg8Unorm => Sk::Float,
40 Sf::Rg8Snorm => Sk::Float,
41 Sf::Rg8Uint => Sk::Uint,
42 Sf::Rg8Sint => Sk::Sint,
43 Sf::R32Uint => Sk::Uint,
44 Sf::R32Sint => Sk::Sint,
45 Sf::R32Float => Sk::Float,
46 Sf::Rg16Uint => Sk::Uint,
47 Sf::Rg16Sint => Sk::Sint,
48 Sf::Rg16Float => Sk::Float,
49 Sf::Rgba8Unorm => Sk::Float,
50 Sf::Rgba8Snorm => Sk::Float,
51 Sf::Rgba8Uint => Sk::Uint,
52 Sf::Rgba8Sint => Sk::Sint,
53 Sf::Bgra8Unorm => Sk::Float,
54 Sf::Rgb10a2Uint => Sk::Uint,
55 Sf::Rgb10a2Unorm => Sk::Float,
56 Sf::Rg11b10Ufloat => Sk::Float,
57 Sf::R64Uint => Sk::Uint,
58 Sf::Rg32Uint => Sk::Uint,
59 Sf::Rg32Sint => Sk::Sint,
60 Sf::Rg32Float => Sk::Float,
61 Sf::Rgba16Uint => Sk::Uint,
62 Sf::Rgba16Sint => Sk::Sint,
63 Sf::Rgba16Float => Sk::Float,
64 Sf::Rgba32Uint => Sk::Uint,
65 Sf::Rgba32Sint => Sk::Sint,
66 Sf::Rgba32Float => Sk::Float,
67 Sf::R16Unorm => Sk::Float,
68 Sf::R16Snorm => Sk::Float,
69 Sf::Rg16Unorm => Sk::Float,
70 Sf::Rg16Snorm => Sk::Float,
71 Sf::Rgba16Unorm => Sk::Float,
72 Sf::Rgba16Snorm => Sk::Float,
73 };
74 let width = match format {
75 Sf::R64Uint => 8,
76 _ => 4,
77 };
78 super::Scalar { kind, width }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
83pub enum HashableLiteral {
84 F64(u64),
85 F32(u32),
86 F16(u16),
87 U32(u32),
88 I32(i32),
89 U64(u64),
90 I64(i64),
91 Bool(bool),
92 AbstractInt(i64),
93 AbstractFloat(u64),
94}
95
96impl From<crate::Literal> for HashableLiteral {
97 fn from(l: crate::Literal) -> Self {
98 match l {
99 crate::Literal::F64(v) => Self::F64(v.to_bits()),
100 crate::Literal::F32(v) => Self::F32(v.to_bits()),
101 crate::Literal::F16(v) => Self::F16(v.to_bits()),
102 crate::Literal::U32(v) => Self::U32(v),
103 crate::Literal::I32(v) => Self::I32(v),
104 crate::Literal::U64(v) => Self::U64(v),
105 crate::Literal::I64(v) => Self::I64(v),
106 crate::Literal::Bool(v) => Self::Bool(v),
107 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
108 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
109 }
110 }
111}
112
113impl crate::Literal {
114 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
115 match (value, scalar.kind, scalar.width) {
116 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
117 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
118 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
119 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
120 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
121 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
122 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
123 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
124 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
125 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
126 _ => None,
127 }
128 }
129
130 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
131 Self::new(0, scalar)
132 }
133
134 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
135 Self::new(1, scalar)
136 }
137
138 pub const fn width(&self) -> crate::Bytes {
139 match *self {
140 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
141 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
142 Self::F16(_) => 2,
143 Self::Bool(_) => crate::BOOL_WIDTH,
144 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
145 }
146 }
147 pub const fn scalar(&self) -> crate::Scalar {
148 match *self {
149 Self::F64(_) => crate::Scalar::F64,
150 Self::F32(_) => crate::Scalar::F32,
151 Self::F16(_) => crate::Scalar::F16,
152 Self::U32(_) => crate::Scalar::U32,
153 Self::I32(_) => crate::Scalar::I32,
154 Self::U64(_) => crate::Scalar::U64,
155 Self::I64(_) => crate::Scalar::I64,
156 Self::Bool(_) => crate::Scalar::BOOL,
157 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
158 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
159 }
160 }
161 pub const fn scalar_kind(&self) -> crate::ScalarKind {
162 self.scalar().kind
163 }
164 pub const fn ty_inner(&self) -> crate::TypeInner {
165 crate::TypeInner::Scalar(self.scalar())
166 }
167}
168
169impl super::AddressSpace {
170 pub fn access(self) -> crate::StorageAccess {
171 use crate::StorageAccess as Sa;
172 match self {
173 crate::AddressSpace::Function
174 | crate::AddressSpace::Private
175 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
176 crate::AddressSpace::Uniform => Sa::LOAD,
177 crate::AddressSpace::Storage { access } => access,
178 crate::AddressSpace::Handle => Sa::LOAD,
179 crate::AddressSpace::PushConstant => Sa::LOAD,
180 }
181 }
182}
183
184impl super::MathFunction {
185 pub const fn argument_count(&self) -> usize {
186 match *self {
187 Self::Abs => 1,
189 Self::Min => 2,
190 Self::Max => 2,
191 Self::Clamp => 3,
192 Self::Saturate => 1,
193 Self::Cos => 1,
195 Self::Cosh => 1,
196 Self::Sin => 1,
197 Self::Sinh => 1,
198 Self::Tan => 1,
199 Self::Tanh => 1,
200 Self::Acos => 1,
201 Self::Asin => 1,
202 Self::Atan => 1,
203 Self::Atan2 => 2,
204 Self::Asinh => 1,
205 Self::Acosh => 1,
206 Self::Atanh => 1,
207 Self::Radians => 1,
208 Self::Degrees => 1,
209 Self::Ceil => 1,
211 Self::Floor => 1,
212 Self::Round => 1,
213 Self::Fract => 1,
214 Self::Trunc => 1,
215 Self::Modf => 1,
216 Self::Frexp => 1,
217 Self::Ldexp => 2,
218 Self::Exp => 1,
220 Self::Exp2 => 1,
221 Self::Log => 1,
222 Self::Log2 => 1,
223 Self::Pow => 2,
224 Self::Dot => 2,
226 Self::Outer => 2,
227 Self::Cross => 2,
228 Self::Distance => 2,
229 Self::Length => 1,
230 Self::Normalize => 1,
231 Self::FaceForward => 3,
232 Self::Reflect => 2,
233 Self::Refract => 3,
234 Self::Sign => 1,
236 Self::Fma => 3,
237 Self::Mix => 3,
238 Self::Step => 2,
239 Self::SmoothStep => 3,
240 Self::Sqrt => 1,
241 Self::InverseSqrt => 1,
242 Self::Inverse => 1,
243 Self::Transpose => 1,
244 Self::Determinant => 1,
245 Self::QuantizeToF16 => 1,
246 Self::CountTrailingZeros => 1,
248 Self::CountLeadingZeros => 1,
249 Self::CountOneBits => 1,
250 Self::ReverseBits => 1,
251 Self::ExtractBits => 3,
252 Self::InsertBits => 4,
253 Self::FirstTrailingBit => 1,
254 Self::FirstLeadingBit => 1,
255 Self::Pack4x8snorm => 1,
257 Self::Pack4x8unorm => 1,
258 Self::Pack2x16snorm => 1,
259 Self::Pack2x16unorm => 1,
260 Self::Pack2x16float => 1,
261 Self::Pack4xI8 => 1,
262 Self::Pack4xU8 => 1,
263 Self::Unpack4x8snorm => 1,
265 Self::Unpack4x8unorm => 1,
266 Self::Unpack2x16snorm => 1,
267 Self::Unpack2x16unorm => 1,
268 Self::Unpack2x16float => 1,
269 Self::Unpack4xI8 => 1,
270 Self::Unpack4xU8 => 1,
271 }
272 }
273}
274
275impl crate::Expression {
276 pub const fn needs_pre_emit(&self) -> bool {
278 match *self {
279 Self::Literal(_)
280 | Self::Constant(_)
281 | Self::Override(_)
282 | Self::ZeroValue(_)
283 | Self::FunctionArgument(_)
284 | Self::GlobalVariable(_)
285 | Self::LocalVariable(_) => true,
286 _ => false,
287 }
288 }
289
290 pub const fn is_dynamic_index(&self) -> bool {
304 match *self {
305 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
306 _ => true,
307 }
308 }
309}
310
311impl crate::Function {
312 pub fn originating_global(
321 &self,
322 mut pointer: crate::Handle<crate::Expression>,
323 ) -> Option<crate::Handle<crate::GlobalVariable>> {
324 loop {
325 pointer = match self.expressions[pointer] {
326 crate::Expression::Access { base, .. } => base,
327 crate::Expression::AccessIndex { base, .. } => base,
328 crate::Expression::GlobalVariable(handle) => return Some(handle),
329 crate::Expression::LocalVariable(_) => return None,
330 crate::Expression::FunctionArgument(_) => return None,
331 _ => unreachable!(),
333 }
334 }
335 }
336}
337
338impl crate::SampleLevel {
339 pub const fn implicit_derivatives(&self) -> bool {
340 match *self {
341 Self::Auto | Self::Bias(_) => true,
342 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
343 }
344 }
345}
346
347impl crate::Binding {
348 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
349 match *self {
350 crate::Binding::BuiltIn(built_in) => Some(built_in),
351 Self::Location { .. } => None,
352 }
353 }
354}
355
356impl super::SwizzleComponent {
357 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
358
359 pub const fn index(&self) -> u32 {
360 match *self {
361 Self::X => 0,
362 Self::Y => 1,
363 Self::Z => 2,
364 Self::W => 3,
365 }
366 }
367 pub const fn from_index(idx: u32) -> Self {
368 match idx {
369 0 => Self::X,
370 1 => Self::Y,
371 2 => Self::Z,
372 _ => Self::W,
373 }
374 }
375}
376
377impl super::ImageClass {
378 pub const fn is_multisampled(self) -> bool {
379 match self {
380 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
381 crate::ImageClass::Storage { .. } => false,
382 }
383 }
384
385 pub const fn is_mipmapped(self) -> bool {
386 match self {
387 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
388 crate::ImageClass::Storage { .. } => false,
389 }
390 }
391
392 pub const fn is_depth(self) -> bool {
393 matches!(self, crate::ImageClass::Depth { .. })
394 }
395}
396
397impl crate::Module {
398 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
399 GlobalCtx {
400 types: &self.types,
401 constants: &self.constants,
402 overrides: &self.overrides,
403 global_expressions: &self.global_expressions,
404 }
405 }
406}
407
408#[derive(Debug)]
409pub(super) enum U32EvalError {
410 NonConst,
411 Negative,
412}
413
414#[derive(Clone, Copy)]
415pub struct GlobalCtx<'a> {
416 pub types: &'a crate::UniqueArena<crate::Type>,
417 pub constants: &'a crate::Arena<crate::Constant>,
418 pub overrides: &'a crate::Arena<crate::Override>,
419 pub global_expressions: &'a crate::Arena<crate::Expression>,
420}
421
422impl GlobalCtx<'_> {
423 #[allow(dead_code)]
425 pub(super) fn eval_expr_to_u32(
426 &self,
427 handle: crate::Handle<crate::Expression>,
428 ) -> Result<u32, U32EvalError> {
429 self.eval_expr_to_u32_from(handle, self.global_expressions)
430 }
431
432 pub(super) fn eval_expr_to_u32_from(
434 &self,
435 handle: crate::Handle<crate::Expression>,
436 arena: &crate::Arena<crate::Expression>,
437 ) -> Result<u32, U32EvalError> {
438 match self.eval_expr_to_literal_from(handle, arena) {
439 Some(crate::Literal::U32(value)) => Ok(value),
440 Some(crate::Literal::I32(value)) => {
441 value.try_into().map_err(|_| U32EvalError::Negative)
442 }
443 _ => Err(U32EvalError::NonConst),
444 }
445 }
446
447 #[allow(dead_code)]
449 pub(super) fn eval_expr_to_bool_from(
450 &self,
451 handle: crate::Handle<crate::Expression>,
452 arena: &crate::Arena<crate::Expression>,
453 ) -> Option<bool> {
454 match self.eval_expr_to_literal_from(handle, arena) {
455 Some(crate::Literal::Bool(value)) => Some(value),
456 _ => None,
457 }
458 }
459
460 #[allow(dead_code)]
461 pub(crate) fn eval_expr_to_literal(
462 &self,
463 handle: crate::Handle<crate::Expression>,
464 ) -> Option<crate::Literal> {
465 self.eval_expr_to_literal_from(handle, self.global_expressions)
466 }
467
468 pub(super) fn eval_expr_to_literal_from(
469 &self,
470 handle: crate::Handle<crate::Expression>,
471 arena: &crate::Arena<crate::Expression>,
472 ) -> Option<crate::Literal> {
473 fn get(
474 gctx: GlobalCtx,
475 handle: crate::Handle<crate::Expression>,
476 arena: &crate::Arena<crate::Expression>,
477 ) -> Option<crate::Literal> {
478 match arena[handle] {
479 crate::Expression::Literal(literal) => Some(literal),
480 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
481 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
482 _ => None,
483 },
484 _ => None,
485 }
486 }
487 match arena[handle] {
488 crate::Expression::Constant(c) => {
489 get(*self, self.constants[c].init, self.global_expressions)
490 }
491 _ => get(*self, handle, arena),
492 }
493 }
494}
495
496#[derive(Error, Debug, Clone, Copy, PartialEq)]
497pub enum ResolveArraySizeError {
498 #[error("array element count must be positive (> 0)")]
499 ExpectedPositiveArrayLength,
500 #[error("internal: array size override has not been resolved")]
501 NonConstArrayLength,
502}
503
504impl crate::ArraySize {
505 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
515 match *self {
516 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
517 crate::ArraySize::Pending(handle) => {
518 let Some(expr) = gctx.overrides[handle].init else {
519 return Err(ResolveArraySizeError::NonConstArrayLength);
520 };
521 let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
522 U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
523 U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
524 })?;
525
526 if length == 0 {
527 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
528 }
529
530 Ok(IndexableLength::Known(length))
531 }
532 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
533 }
534 }
535}
536
537pub fn flatten_compose<'arenas>(
550 ty: crate::Handle<crate::Type>,
551 components: &'arenas [crate::Handle<crate::Expression>],
552 expressions: &'arenas crate::Arena<crate::Expression>,
553 types: &'arenas crate::UniqueArena<crate::Type>,
554) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
555 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
561 (size as usize, true)
562 } else {
563 (components.len(), false)
564 };
565
566 fn flatten_compose<'c>(
568 component: &'c crate::Handle<crate::Expression>,
569 is_vector: bool,
570 expressions: &'c crate::Arena<crate::Expression>,
571 ) -> &'c [crate::Handle<crate::Expression>] {
572 if is_vector {
573 if let crate::Expression::Compose {
574 ty: _,
575 components: ref subcomponents,
576 } = expressions[*component]
577 {
578 return subcomponents;
579 }
580 }
581 core::slice::from_ref(component)
582 }
583
584 fn flatten_splat<'c>(
586 component: &'c crate::Handle<crate::Expression>,
587 is_vector: bool,
588 expressions: &'c crate::Arena<crate::Expression>,
589 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
590 let mut expr = *component;
591 let mut count = 1;
592 if is_vector {
593 if let crate::Expression::Splat { size, value } = expressions[expr] {
594 expr = value;
595 count = size as usize;
596 }
597 }
598 core::iter::repeat_n(expr, count)
599 }
600
601 components
608 .iter()
609 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
610 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
611 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
612 .take(size)
613}
614
615#[test]
616fn test_matrix_size() {
617 let module = crate::Module::default();
618 assert_eq!(
619 crate::TypeInner::Matrix {
620 columns: crate::VectorSize::Tri,
621 rows: crate::VectorSize::Tri,
622 scalar: crate::Scalar::F32,
623 }
624 .size(module.to_ctx()),
625 48,
626 );
627}