use crate::{valid, Handle, UniqueArena};
use bit_set::BitSet;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum BoundsCheckPolicy {
Restrict,
ReadZeroSkipWrite,
Unchecked,
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct BoundsCheckPolicies {
#[cfg_attr(feature = "deserialize", serde(default))]
pub index: BoundsCheckPolicy,
#[cfg_attr(feature = "deserialize", serde(default))]
pub buffer: BoundsCheckPolicy,
#[cfg_attr(feature = "deserialize", serde(default))]
pub image_load: BoundsCheckPolicy,
#[cfg_attr(feature = "deserialize", serde(default))]
pub image_store: BoundsCheckPolicy,
#[cfg_attr(feature = "deserialize", serde(default))]
pub binding_array: BoundsCheckPolicy,
}
impl Default for BoundsCheckPolicy {
fn default() -> Self {
BoundsCheckPolicy::Unchecked
}
}
impl BoundsCheckPolicies {
pub fn choose_policy(
&self,
base: Handle<crate::Expression>,
types: &UniqueArena<crate::Type>,
info: &valid::FunctionInfo,
) -> BoundsCheckPolicy {
let ty = info[base].ty.inner_with(types);
if let crate::TypeInner::BindingArray { .. } = *ty {
return self.binding_array;
}
match ty.pointer_space() {
Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
self.buffer
}
_ => self.index,
}
}
pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
self.index == policy
|| self.buffer == policy
|| self.image_load == policy
|| self.image_store == policy
}
}
#[derive(Clone, Copy, Debug)]
pub enum GuardedIndex {
Known(u32),
Expression(Handle<crate::Expression>),
}
pub fn find_checked_indexes(
module: &crate::Module,
function: &crate::Function,
info: &crate::valid::FunctionInfo,
policies: BoundsCheckPolicies,
) -> BitSet {
use crate::Expression as Ex;
let mut guarded_indices = BitSet::new();
if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
for (_handle, expr) in function.expressions.iter() {
match *expr {
Ex::Access { base, index } => {
if policies.choose_policy(base, &module.types, info)
== BoundsCheckPolicy::ReadZeroSkipWrite
&& access_needs_check(
base,
GuardedIndex::Expression(index),
module,
function,
info,
)
.is_some()
{
guarded_indices.insert(index.index());
}
}
Ex::ImageLoad {
coordinate,
array_index,
sample,
level,
..
} => {
if policies.image_load == BoundsCheckPolicy::ReadZeroSkipWrite {
guarded_indices.insert(coordinate.index());
if let Some(array_index) = array_index {
guarded_indices.insert(array_index.index());
}
if let Some(sample) = sample {
guarded_indices.insert(sample.index());
}
if let Some(level) = level {
guarded_indices.insert(level.index());
}
}
}
_ => {}
}
}
}
guarded_indices
}
pub fn access_needs_check(
base: Handle<crate::Expression>,
mut index: GuardedIndex,
module: &crate::Module,
function: &crate::Function,
info: &crate::valid::FunctionInfo,
) -> Option<IndexableLength> {
let base_inner = info[base].ty.inner_with(&module.types);
let length = base_inner.indexable_length(module).unwrap();
index.try_resolve_to_constant(function, module);
if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
if index < length {
return None;
}
};
Some(length)
}
impl GuardedIndex {
fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
if let GuardedIndex::Expression(expr) = *self {
if let Ok(value) = module
.to_ctx()
.eval_expr_to_u32_from(expr, &function.expressions)
{
*self = GuardedIndex::Known(value);
}
}
}
}
#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
pub enum IndexableLengthError {
#[error("Type is not indexable, and has no length (validation error)")]
TypeNotIndexable,
#[error("Array length constant {0:?} is invalid")]
InvalidArrayLength(Handle<crate::Expression>),
}
impl crate::TypeInner {
pub fn indexable_length(
&self,
module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
use crate::TypeInner as Ti;
let known_length = match *self {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module);
}
Ti::ValuePointer {
size: Some(size), ..
} => size as _,
Ti::Pointer { base, .. } => {
let base_inner = &module.types[base].inner;
match *base_inner {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module)
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
}
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
};
Ok(IndexableLength::Known(known_length))
}
}
#[derive(Debug)]
pub enum IndexableLength {
Known(u32),
Dynamic,
}
impl crate::ArraySize {
pub const fn to_indexable_length(
self,
_module: &crate::Module,
) -> Result<IndexableLength, IndexableLengthError> {
Ok(match self {
Self::Constant(length) => IndexableLength::Known(length.get()),
Self::Dynamic => IndexableLength::Dynamic,
})
}
}