naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod overloads;
11mod terminator;
12mod type_methods;
13mod typifier;
14
15pub use constant_evaluator::{
16    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
17};
18pub use emitter::Emitter;
19pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
20pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
21pub use namer::{EntryPointIndex, NameKey, Namer};
22pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
23pub use terminator::ensure_block_returns;
24use thiserror::Error;
25pub use type_methods::min_max_float_representable_by;
26pub use typifier::{ResolveContext, ResolveError, TypeResolution};
27
28impl From<super::StorageFormat> for super::Scalar {
29    fn from(format: super::StorageFormat) -> Self {
30        use super::{ScalarKind as Sk, StorageFormat as Sf};
31        let kind = match format {
32            Sf::R8Unorm => Sk::Float,
33            Sf::R8Snorm => Sk::Float,
34            Sf::R8Uint => Sk::Uint,
35            Sf::R8Sint => Sk::Sint,
36            Sf::R16Uint => Sk::Uint,
37            Sf::R16Sint => Sk::Sint,
38            Sf::R16Float => Sk::Float,
39            Sf::Rg8Unorm => Sk::Float,
40            Sf::Rg8Snorm => Sk::Float,
41            Sf::Rg8Uint => Sk::Uint,
42            Sf::Rg8Sint => Sk::Sint,
43            Sf::R32Uint => Sk::Uint,
44            Sf::R32Sint => Sk::Sint,
45            Sf::R32Float => Sk::Float,
46            Sf::Rg16Uint => Sk::Uint,
47            Sf::Rg16Sint => Sk::Sint,
48            Sf::Rg16Float => Sk::Float,
49            Sf::Rgba8Unorm => Sk::Float,
50            Sf::Rgba8Snorm => Sk::Float,
51            Sf::Rgba8Uint => Sk::Uint,
52            Sf::Rgba8Sint => Sk::Sint,
53            Sf::Bgra8Unorm => Sk::Float,
54            Sf::Rgb10a2Uint => Sk::Uint,
55            Sf::Rgb10a2Unorm => Sk::Float,
56            Sf::Rg11b10Ufloat => Sk::Float,
57            Sf::R64Uint => Sk::Uint,
58            Sf::Rg32Uint => Sk::Uint,
59            Sf::Rg32Sint => Sk::Sint,
60            Sf::Rg32Float => Sk::Float,
61            Sf::Rgba16Uint => Sk::Uint,
62            Sf::Rgba16Sint => Sk::Sint,
63            Sf::Rgba16Float => Sk::Float,
64            Sf::Rgba32Uint => Sk::Uint,
65            Sf::Rgba32Sint => Sk::Sint,
66            Sf::Rgba32Float => Sk::Float,
67            Sf::R16Unorm => Sk::Float,
68            Sf::R16Snorm => Sk::Float,
69            Sf::Rg16Unorm => Sk::Float,
70            Sf::Rg16Snorm => Sk::Float,
71            Sf::Rgba16Unorm => Sk::Float,
72            Sf::Rgba16Snorm => Sk::Float,
73        };
74        let width = match format {
75            Sf::R64Uint => 8,
76            _ => 4,
77        };
78        super::Scalar { kind, width }
79    }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
83pub enum HashableLiteral {
84    F64(u64),
85    F32(u32),
86    F16(u16),
87    U32(u32),
88    I32(i32),
89    U64(u64),
90    I64(i64),
91    Bool(bool),
92    AbstractInt(i64),
93    AbstractFloat(u64),
94}
95
96impl From<crate::Literal> for HashableLiteral {
97    fn from(l: crate::Literal) -> Self {
98        match l {
99            crate::Literal::F64(v) => Self::F64(v.to_bits()),
100            crate::Literal::F32(v) => Self::F32(v.to_bits()),
101            crate::Literal::F16(v) => Self::F16(v.to_bits()),
102            crate::Literal::U32(v) => Self::U32(v),
103            crate::Literal::I32(v) => Self::I32(v),
104            crate::Literal::U64(v) => Self::U64(v),
105            crate::Literal::I64(v) => Self::I64(v),
106            crate::Literal::Bool(v) => Self::Bool(v),
107            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
108            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
109        }
110    }
111}
112
113impl crate::Literal {
114    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
115        match (value, scalar.kind, scalar.width) {
116            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
117            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
118            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
119            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
120            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
121            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
122            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
123            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
124            (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
125            (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
126            _ => None,
127        }
128    }
129
130    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
131        Self::new(0, scalar)
132    }
133
134    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
135        Self::new(1, scalar)
136    }
137
138    pub const fn width(&self) -> crate::Bytes {
139        match *self {
140            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
141            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
142            Self::F16(_) => 2,
143            Self::Bool(_) => crate::BOOL_WIDTH,
144            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
145        }
146    }
147    pub const fn scalar(&self) -> crate::Scalar {
148        match *self {
149            Self::F64(_) => crate::Scalar::F64,
150            Self::F32(_) => crate::Scalar::F32,
151            Self::F16(_) => crate::Scalar::F16,
152            Self::U32(_) => crate::Scalar::U32,
153            Self::I32(_) => crate::Scalar::I32,
154            Self::U64(_) => crate::Scalar::U64,
155            Self::I64(_) => crate::Scalar::I64,
156            Self::Bool(_) => crate::Scalar::BOOL,
157            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
158            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
159        }
160    }
161    pub const fn scalar_kind(&self) -> crate::ScalarKind {
162        self.scalar().kind
163    }
164    pub const fn ty_inner(&self) -> crate::TypeInner {
165        crate::TypeInner::Scalar(self.scalar())
166    }
167}
168
169impl super::AddressSpace {
170    pub fn access(self) -> crate::StorageAccess {
171        use crate::StorageAccess as Sa;
172        match self {
173            crate::AddressSpace::Function
174            | crate::AddressSpace::Private
175            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
176            crate::AddressSpace::Uniform => Sa::LOAD,
177            crate::AddressSpace::Storage { access } => access,
178            crate::AddressSpace::Handle => Sa::LOAD,
179            crate::AddressSpace::PushConstant => Sa::LOAD,
180        }
181    }
182}
183
184impl super::MathFunction {
185    pub const fn argument_count(&self) -> usize {
186        match *self {
187            // comparison
188            Self::Abs => 1,
189            Self::Min => 2,
190            Self::Max => 2,
191            Self::Clamp => 3,
192            Self::Saturate => 1,
193            // trigonometry
194            Self::Cos => 1,
195            Self::Cosh => 1,
196            Self::Sin => 1,
197            Self::Sinh => 1,
198            Self::Tan => 1,
199            Self::Tanh => 1,
200            Self::Acos => 1,
201            Self::Asin => 1,
202            Self::Atan => 1,
203            Self::Atan2 => 2,
204            Self::Asinh => 1,
205            Self::Acosh => 1,
206            Self::Atanh => 1,
207            Self::Radians => 1,
208            Self::Degrees => 1,
209            // decomposition
210            Self::Ceil => 1,
211            Self::Floor => 1,
212            Self::Round => 1,
213            Self::Fract => 1,
214            Self::Trunc => 1,
215            Self::Modf => 1,
216            Self::Frexp => 1,
217            Self::Ldexp => 2,
218            // exponent
219            Self::Exp => 1,
220            Self::Exp2 => 1,
221            Self::Log => 1,
222            Self::Log2 => 1,
223            Self::Pow => 2,
224            // geometry
225            Self::Dot => 2,
226            Self::Outer => 2,
227            Self::Cross => 2,
228            Self::Distance => 2,
229            Self::Length => 1,
230            Self::Normalize => 1,
231            Self::FaceForward => 3,
232            Self::Reflect => 2,
233            Self::Refract => 3,
234            // computational
235            Self::Sign => 1,
236            Self::Fma => 3,
237            Self::Mix => 3,
238            Self::Step => 2,
239            Self::SmoothStep => 3,
240            Self::Sqrt => 1,
241            Self::InverseSqrt => 1,
242            Self::Inverse => 1,
243            Self::Transpose => 1,
244            Self::Determinant => 1,
245            Self::QuantizeToF16 => 1,
246            // bits
247            Self::CountTrailingZeros => 1,
248            Self::CountLeadingZeros => 1,
249            Self::CountOneBits => 1,
250            Self::ReverseBits => 1,
251            Self::ExtractBits => 3,
252            Self::InsertBits => 4,
253            Self::FirstTrailingBit => 1,
254            Self::FirstLeadingBit => 1,
255            // data packing
256            Self::Pack4x8snorm => 1,
257            Self::Pack4x8unorm => 1,
258            Self::Pack2x16snorm => 1,
259            Self::Pack2x16unorm => 1,
260            Self::Pack2x16float => 1,
261            Self::Pack4xI8 => 1,
262            Self::Pack4xU8 => 1,
263            // data unpacking
264            Self::Unpack4x8snorm => 1,
265            Self::Unpack4x8unorm => 1,
266            Self::Unpack2x16snorm => 1,
267            Self::Unpack2x16unorm => 1,
268            Self::Unpack2x16float => 1,
269            Self::Unpack4xI8 => 1,
270            Self::Unpack4xU8 => 1,
271        }
272    }
273}
274
275impl crate::Expression {
276    /// Returns true if the expression is considered emitted at the start of a function.
277    pub const fn needs_pre_emit(&self) -> bool {
278        match *self {
279            Self::Literal(_)
280            | Self::Constant(_)
281            | Self::Override(_)
282            | Self::ZeroValue(_)
283            | Self::FunctionArgument(_)
284            | Self::GlobalVariable(_)
285            | Self::LocalVariable(_) => true,
286            _ => false,
287        }
288    }
289
290    /// Return true if this expression is a dynamic array/vector/matrix index,
291    /// for [`Access`].
292    ///
293    /// This method returns true if this expression is a dynamically computed
294    /// index, and as such can only be used to index matrices when they appear
295    /// behind a pointer. See the documentation for [`Access`] for details.
296    ///
297    /// Note, this does not check the _type_ of the given expression. It's up to
298    /// the caller to establish that the `Access` expression is well-typed
299    /// through other means, like [`ResolveContext`].
300    ///
301    /// [`Access`]: crate::Expression::Access
302    /// [`ResolveContext`]: crate::proc::ResolveContext
303    pub const fn is_dynamic_index(&self) -> bool {
304        match *self {
305            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
306            _ => true,
307        }
308    }
309}
310
311impl crate::Function {
312    /// Return the global variable being accessed by the expression `pointer`.
313    ///
314    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
315    /// expressions that ultimately access some part of a `GlobalVariable`,
316    /// return a handle for that global.
317    ///
318    /// If the expression does not ultimately access a global variable, return
319    /// `None`.
320    pub fn originating_global(
321        &self,
322        mut pointer: crate::Handle<crate::Expression>,
323    ) -> Option<crate::Handle<crate::GlobalVariable>> {
324        loop {
325            pointer = match self.expressions[pointer] {
326                crate::Expression::Access { base, .. } => base,
327                crate::Expression::AccessIndex { base, .. } => base,
328                crate::Expression::GlobalVariable(handle) => return Some(handle),
329                crate::Expression::LocalVariable(_) => return None,
330                crate::Expression::FunctionArgument(_) => return None,
331                // There are no other expressions that produce pointer values.
332                _ => unreachable!(),
333            }
334        }
335    }
336}
337
338impl crate::SampleLevel {
339    pub const fn implicit_derivatives(&self) -> bool {
340        match *self {
341            Self::Auto | Self::Bias(_) => true,
342            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
343        }
344    }
345}
346
347impl crate::Binding {
348    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
349        match *self {
350            crate::Binding::BuiltIn(built_in) => Some(built_in),
351            Self::Location { .. } => None,
352        }
353    }
354}
355
356impl super::SwizzleComponent {
357    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
358
359    pub const fn index(&self) -> u32 {
360        match *self {
361            Self::X => 0,
362            Self::Y => 1,
363            Self::Z => 2,
364            Self::W => 3,
365        }
366    }
367    pub const fn from_index(idx: u32) -> Self {
368        match idx {
369            0 => Self::X,
370            1 => Self::Y,
371            2 => Self::Z,
372            _ => Self::W,
373        }
374    }
375}
376
377impl super::ImageClass {
378    pub const fn is_multisampled(self) -> bool {
379        match self {
380            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
381            crate::ImageClass::Storage { .. } => false,
382        }
383    }
384
385    pub const fn is_mipmapped(self) -> bool {
386        match self {
387            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
388            crate::ImageClass::Storage { .. } => false,
389        }
390    }
391
392    pub const fn is_depth(self) -> bool {
393        matches!(self, crate::ImageClass::Depth { .. })
394    }
395}
396
397impl crate::Module {
398    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
399        GlobalCtx {
400            types: &self.types,
401            constants: &self.constants,
402            overrides: &self.overrides,
403            global_expressions: &self.global_expressions,
404        }
405    }
406}
407
408#[derive(Debug)]
409pub(super) enum U32EvalError {
410    NonConst,
411    Negative,
412}
413
414#[derive(Clone, Copy)]
415pub struct GlobalCtx<'a> {
416    pub types: &'a crate::UniqueArena<crate::Type>,
417    pub constants: &'a crate::Arena<crate::Constant>,
418    pub overrides: &'a crate::Arena<crate::Override>,
419    pub global_expressions: &'a crate::Arena<crate::Expression>,
420}
421
422impl GlobalCtx<'_> {
423    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
424    #[allow(dead_code)]
425    pub(super) fn eval_expr_to_u32(
426        &self,
427        handle: crate::Handle<crate::Expression>,
428    ) -> Result<u32, U32EvalError> {
429        self.eval_expr_to_u32_from(handle, self.global_expressions)
430    }
431
432    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
433    pub(super) fn eval_expr_to_u32_from(
434        &self,
435        handle: crate::Handle<crate::Expression>,
436        arena: &crate::Arena<crate::Expression>,
437    ) -> Result<u32, U32EvalError> {
438        match self.eval_expr_to_literal_from(handle, arena) {
439            Some(crate::Literal::U32(value)) => Ok(value),
440            Some(crate::Literal::I32(value)) => {
441                value.try_into().map_err(|_| U32EvalError::Negative)
442            }
443            _ => Err(U32EvalError::NonConst),
444        }
445    }
446
447    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
448    #[allow(dead_code)]
449    pub(super) fn eval_expr_to_bool_from(
450        &self,
451        handle: crate::Handle<crate::Expression>,
452        arena: &crate::Arena<crate::Expression>,
453    ) -> Option<bool> {
454        match self.eval_expr_to_literal_from(handle, arena) {
455            Some(crate::Literal::Bool(value)) => Some(value),
456            _ => None,
457        }
458    }
459
460    #[allow(dead_code)]
461    pub(crate) fn eval_expr_to_literal(
462        &self,
463        handle: crate::Handle<crate::Expression>,
464    ) -> Option<crate::Literal> {
465        self.eval_expr_to_literal_from(handle, self.global_expressions)
466    }
467
468    pub(super) fn eval_expr_to_literal_from(
469        &self,
470        handle: crate::Handle<crate::Expression>,
471        arena: &crate::Arena<crate::Expression>,
472    ) -> Option<crate::Literal> {
473        fn get(
474            gctx: GlobalCtx,
475            handle: crate::Handle<crate::Expression>,
476            arena: &crate::Arena<crate::Expression>,
477        ) -> Option<crate::Literal> {
478            match arena[handle] {
479                crate::Expression::Literal(literal) => Some(literal),
480                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
481                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
482                    _ => None,
483                },
484                _ => None,
485            }
486        }
487        match arena[handle] {
488            crate::Expression::Constant(c) => {
489                get(*self, self.constants[c].init, self.global_expressions)
490            }
491            _ => get(*self, handle, arena),
492        }
493    }
494}
495
496#[derive(Error, Debug, Clone, Copy, PartialEq)]
497pub enum ResolveArraySizeError {
498    #[error("array element count must be positive (> 0)")]
499    ExpectedPositiveArrayLength,
500    #[error("internal: array size override has not been resolved")]
501    NonConstArrayLength,
502}
503
504impl crate::ArraySize {
505    /// Return the number of elements that `size` represents, if known at code generation time.
506    ///
507    /// If `size` is override-based, return an error unless the override's
508    /// initializer is a fully evaluated constant expression. You can call
509    /// [`pipeline_constants::process_overrides`] to supply values for a
510    /// module's overrides and ensure their initializers are fully evaluated, as
511    /// this function expects.
512    ///
513    /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides
514    pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
515        match *self {
516            crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
517            crate::ArraySize::Pending(handle) => {
518                let Some(expr) = gctx.overrides[handle].init else {
519                    return Err(ResolveArraySizeError::NonConstArrayLength);
520                };
521                let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err {
522                    U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength,
523                    U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength,
524                })?;
525
526                if length == 0 {
527                    return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
528                }
529
530                Ok(IndexableLength::Known(length))
531            }
532            crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
533        }
534    }
535}
536
537/// Return an iterator over the individual components assembled by a
538/// `Compose` expression.
539///
540/// Given `ty` and `components` from an `Expression::Compose`, return an
541/// iterator over the components of the resulting value.
542///
543/// Normally, this would just be an iterator over `components`. However,
544/// `Compose` expressions can concatenate vectors, in which case the i'th
545/// value being composed is not generally the i'th element of `components`.
546/// This function consults `ty` to decide if this concatenation is occurring,
547/// and returns an iterator that produces the components of the result of
548/// the `Compose` expression in either case.
549pub fn flatten_compose<'arenas>(
550    ty: crate::Handle<crate::Type>,
551    components: &'arenas [crate::Handle<crate::Expression>],
552    expressions: &'arenas crate::Arena<crate::Expression>,
553    types: &'arenas crate::UniqueArena<crate::Type>,
554) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
555    // Returning `impl Iterator` is a bit tricky. We may or may not
556    // want to flatten the components, but we have to settle on a
557    // single concrete type to return. This function returns a single
558    // iterator chain that handles both the flattening and
559    // non-flattening cases.
560    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
561        (size as usize, true)
562    } else {
563        (components.len(), false)
564    };
565
566    /// Flatten `Compose` expressions if `is_vector` is true.
567    fn flatten_compose<'c>(
568        component: &'c crate::Handle<crate::Expression>,
569        is_vector: bool,
570        expressions: &'c crate::Arena<crate::Expression>,
571    ) -> &'c [crate::Handle<crate::Expression>] {
572        if is_vector {
573            if let crate::Expression::Compose {
574                ty: _,
575                components: ref subcomponents,
576            } = expressions[*component]
577            {
578                return subcomponents;
579            }
580        }
581        core::slice::from_ref(component)
582    }
583
584    /// Flatten `Splat` expressions if `is_vector` is true.
585    fn flatten_splat<'c>(
586        component: &'c crate::Handle<crate::Expression>,
587        is_vector: bool,
588        expressions: &'c crate::Arena<crate::Expression>,
589    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
590        let mut expr = *component;
591        let mut count = 1;
592        if is_vector {
593            if let crate::Expression::Splat { size, value } = expressions[expr] {
594                expr = value;
595                count = size as usize;
596            }
597        }
598        core::iter::repeat_n(expr, count)
599    }
600
601    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
602    // flatten up to two levels of `Compose` expressions.
603    //
604    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
605    // `Splat` expressions. Fortunately, the operand of a `Splat` must
606    // be a scalar, so we can stop there.
607    components
608        .iter()
609        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
610        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
611        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
612        .take(size)
613}
614
615#[test]
616fn test_matrix_size() {
617    let module = crate::Module::default();
618    assert_eq!(
619        crate::TypeInner::Matrix {
620            columns: crate::VectorSize::Tri,
621            rows: crate::VectorSize::Tri,
622            scalar: crate::Scalar::F32,
623        }
624        .size(module.to_ctx()),
625        48,
626    );
627}