use core::iter::{self, zip};
use crate::arena::{Handle, HandleSet, UniqueArena};
use crate::{valid, FastHashSet};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum BoundsCheckPolicy {
Restrict,
ReadZeroSkipWrite,
Unchecked,
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct BoundsCheckPolicies {
pub index: BoundsCheckPolicy,
pub buffer: BoundsCheckPolicy,
pub image_load: BoundsCheckPolicy,
pub binding_array: BoundsCheckPolicy,
}
impl Default for BoundsCheckPolicy {
fn default() -> Self {
BoundsCheckPolicy::Unchecked
}
}
impl BoundsCheckPolicies {
pub fn choose_policy(
&self,
base: Handle<crate::Expression>,
types: &UniqueArena<crate::Type>,
info: &valid::FunctionInfo,
) -> BoundsCheckPolicy {
let ty = info[base].ty.inner_with(types);
if let crate::TypeInner::BindingArray { .. } = *ty {
return self.binding_array;
}
match ty.pointer_space() {
Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
self.buffer
}
_ => self.index,
}
}
pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
self.index == policy || self.buffer == policy || self.image_load == policy
}
}
#[derive(Clone, Copy, Debug)]
pub enum GuardedIndex {
Known(u32),
Expression(Handle<crate::Expression>),
}
pub fn find_checked_indexes(
module: &crate::Module,
function: &crate::Function,
info: &valid::FunctionInfo,
policies: BoundsCheckPolicies,
) -> HandleSet<crate::Expression> {
use crate::Expression as Ex;
let mut guarded_indices = HandleSet::for_arena(&function.expressions);
if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
for (_handle, expr) in function.expressions.iter() {
match *expr {
Ex::Access { base, index } => {
if policies.choose_policy(base, &module.types, info)
== BoundsCheckPolicy::ReadZeroSkipWrite
&& access_needs_check(
base,
GuardedIndex::Expression(index),
module,
&function.expressions,
info,
)
.is_some()
{
guarded_indices.insert(index);
}
}
Ex::ImageLoad {
coordinate,
array_index,
sample,
level,
..
} => {
if policies.image_load == BoundsCheckPolicy::ReadZeroSkipWrite {
guarded_indices.insert(coordinate);
if let Some(array_index) = array_index {
guarded_indices.insert(array_index);
}
if let Some(sample) = sample {
guarded_indices.insert(sample);
}
if let Some(level) = level {
guarded_indices.insert(level);
}
}
}
_ => {}
}
}
}
guarded_indices
}
pub fn access_needs_check(
base: Handle<crate::Expression>,
mut index: GuardedIndex,
module: &crate::Module,
expressions: &crate::Arena<crate::Expression>,
info: &valid::FunctionInfo,
) -> Option<IndexableLength> {
let base_inner = info[base].ty.inner_with(&module.types);
let length = base_inner.indexable_length_resolved(module).unwrap();
index.try_resolve_to_constant(expressions, module);
if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
if index < length {
return None;
}
};
Some(length)
}
#[cfg_attr(not(feature = "msl-out"), allow(dead_code))]
pub(crate) struct BoundsCheck {
pub base: Handle<crate::Expression>,
pub index: GuardedIndex,
pub length: IndexableLength,
}
pub(crate) fn bounds_check_iter<'a>(
mut chain: Handle<crate::Expression>,
module: &'a crate::Module,
function: &'a crate::Function,
info: &'a valid::FunctionInfo,
) -> impl Iterator<Item = BoundsCheck> + 'a {
iter::from_fn(move || {
let (next_expr, result) = match function.expressions[chain] {
crate::Expression::Access { base, index } => {
(base, Some((base, GuardedIndex::Expression(index))))
}
crate::Expression::AccessIndex { base, index } => {
let mut base_inner = info[base].ty.inner_with(&module.types);
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
base_inner = &module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
_ => (base, Some((base, GuardedIndex::Known(index)))),
}
}
_ => return None,
};
chain = next_expr;
Some(result)
})
.flatten()
.filter_map(|(base, index)| {
access_needs_check(base, index, module, &function.expressions, info).map(|length| {
BoundsCheck {
base,
index,
length,
}
})
})
}
pub fn oob_local_types(
module: &crate::Module,
function: &crate::Function,
info: &valid::FunctionInfo,
policies: BoundsCheckPolicies,
) -> FastHashSet<Handle<crate::Type>> {
let mut result = FastHashSet::default();
if policies.index != BoundsCheckPolicy::ReadZeroSkipWrite {
return result;
}
for statement in &function.body {
if let crate::Statement::Call {
function: callee,
ref arguments,
..
} = *statement
{
for (arg_info, &arg) in zip(&module.functions[callee].arguments, arguments) {
match module.types[arg_info.ty].inner {
crate::TypeInner::ValuePointer { .. } => {
unreachable!("`ValuePointer` found in arena")
}
crate::TypeInner::Pointer { base, .. } => {
if bounds_check_iter(arg, module, function, info)
.next()
.is_some()
{
result.insert(base);
}
}
_ => continue,
};
}
}
}
result
}
impl GuardedIndex {
pub(crate) fn try_resolve_to_constant(
&mut self,
expressions: &crate::Arena<crate::Expression>,
module: &crate::Module,
) {
if let GuardedIndex::Expression(expr) = *self {
*self = GuardedIndex::from_expression(expr, expressions, module);
}
}
pub(crate) fn from_expression(
expr: Handle<crate::Expression>,
expressions: &crate::Arena<crate::Expression>,
module: &crate::Module,
) -> Self {
match module.to_ctx().eval_expr_to_u32_from(expr, expressions) {
Ok(value) => Self::Known(value),
Err(_) => Self::Expression(expr),
}
}
}
#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
pub enum IndexableLengthError {
#[error("Type is not indexable, and has no length (validation error)")]
TypeNotIndexable,
#[error(transparent)]
ResolveArraySizeError(#[from] super::ResolveArraySizeError),
#[error("Array size is still pending")]
Pending(crate::ArraySize),
}
impl crate::TypeInner {
pub fn indexable_length(
&self,
module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
use crate::TypeInner as Ti;
let known_length = match *self {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module);
}
Ti::ValuePointer {
size: Some(size), ..
} => size as _,
Ti::Pointer { base, .. } => {
let base_inner = &module.types[base].inner;
match *base_inner {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module)
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
}
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
};
Ok(IndexableLength::Known(known_length))
}
pub fn indexable_length_pending(
&self,
module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
let length = self.indexable_length(module);
if let Err(IndexableLengthError::Pending(_)) = length {
return Ok(IndexableLength::Dynamic);
}
length
}
pub fn indexable_length_resolved(
&self,
module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
let length = self.indexable_length(module);
if let Err(IndexableLengthError::Pending(size)) = length {
if let IndexableLength::Known(computed) = size.resolve(module.to_ctx())? {
return Ok(IndexableLength::Known(computed));
}
}
length
}
}
#[derive(Debug)]
pub enum IndexableLength {
Known(u32),
Dynamic,
}
impl crate::ArraySize {
pub const fn to_indexable_length(
self,
_module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
match self {
Self::Constant(length) => Ok(IndexableLength::Known(length.get())),
Self::Pending(_) => Err(IndexableLengthError::Pending(self)),
Self::Dynamic => Ok(IndexableLength::Dynamic),
}
}
}