use super::ProcError;
use crate::valid;
use crate::{Handle, UniqueArena};
use bit_set::BitSet;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum BoundsCheckPolicy {
Restrict,
ReadZeroSkipWrite,
Unchecked,
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct BoundsCheckPolicies {
#[cfg_attr(feature = "deserialize", serde(default))]
pub index: BoundsCheckPolicy,
#[cfg_attr(feature = "deserialize", serde(default))]
pub buffer: BoundsCheckPolicy,
#[cfg_attr(feature = "deserialize", serde(default))]
pub image: BoundsCheckPolicy,
}
impl Default for BoundsCheckPolicy {
fn default() -> Self {
BoundsCheckPolicy::Unchecked
}
}
impl BoundsCheckPolicies {
pub fn choose_policy(
&self,
access: Handle<crate::Expression>,
types: &UniqueArena<crate::Type>,
info: &valid::FunctionInfo,
) -> BoundsCheckPolicy {
match info[access].ty.inner_with(types).pointer_class() {
Some(crate::StorageClass::Storage { access: _ })
| Some(crate::StorageClass::Uniform) => self.buffer,
_ => self.index,
}
}
pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
self.index == policy || self.buffer == policy || self.image == policy
}
}
#[derive(Clone, Copy, Debug)]
pub enum GuardedIndex {
Known(u32),
Expression(Handle<crate::Expression>),
}
pub fn find_checked_indexes(
module: &crate::Module,
function: &crate::Function,
info: &crate::valid::FunctionInfo,
policies: BoundsCheckPolicies,
) -> BitSet {
use crate::Expression as Ex;
let mut guarded_indices = BitSet::new();
if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
for (_handle, expr) in function.expressions.iter() {
if let Ex::Access { base, index } = *expr {
if policies.choose_policy(base, &module.types, info)
== BoundsCheckPolicy::ReadZeroSkipWrite
&& access_needs_check(
base,
GuardedIndex::Expression(index),
module,
function,
info,
)
.is_some()
{
guarded_indices.insert(index.index());
}
}
}
}
guarded_indices
}
pub fn access_needs_check(
base: Handle<crate::Expression>,
mut index: GuardedIndex,
module: &crate::Module,
function: &crate::Function,
info: &crate::valid::FunctionInfo,
) -> Option<IndexableLength> {
let base_inner = info[base].ty.inner_with(&module.types);
let length = base_inner.indexable_length(module).unwrap();
index.try_resolve_to_constant(function, module);
if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
if index < length {
return None;
}
};
Some(length)
}
impl GuardedIndex {
fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
if let GuardedIndex::Expression(expr) = *self {
if let crate::Expression::Constant(handle) = function.expressions[expr] {
if let Some(value) = module.constants[handle].to_array_length() {
*self = GuardedIndex::Known(value);
}
}
}
}
}
impl crate::TypeInner {
pub fn indexable_length(&self, module: &crate::Module) -> Result<IndexableLength, ProcError> {
use crate::TypeInner as Ti;
let known_length = match *self {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } => {
return size.to_indexable_length(module);
}
Ti::ValuePointer {
size: Some(size), ..
} => size as _,
Ti::Pointer { base, .. } => {
let base_inner = &module.types[base].inner;
match *base_inner {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } => return size.to_indexable_length(module),
_ => return Err(ProcError::TypeNotIndexable),
}
}
_ => return Err(ProcError::TypeNotIndexable),
};
Ok(IndexableLength::Known(known_length))
}
}
#[derive(Debug)]
pub enum IndexableLength {
Known(u32),
Dynamic,
}
impl crate::ArraySize {
pub fn to_indexable_length(self, module: &crate::Module) -> Result<IndexableLength, ProcError> {
use crate::Constant as K;
Ok(match self {
Self::Constant(k) => match module.constants[k] {
K {
specialization: Some(_),
..
} => {
return Err(ProcError::InvalidArraySizeConstant(k));
}
ref unspecialized => {
let length = unspecialized
.to_array_length()
.ok_or(ProcError::InvalidArraySizeConstant(k))?;
IndexableLength::Known(length)
}
},
Self::Dynamic => IndexableLength::Dynamic,
})
}
}