use std::ops::Range;
use std::sync::Arc;
use crate::ir_inner::model::types::{BufferAccess, DataType};
use super::{MemoryHints, MemoryKind};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[non_exhaustive]
pub enum LinearType {
Linear,
Affine,
Relevant,
#[default]
Unrestricted,
}
impl LinearType {
#[must_use]
#[inline]
pub const fn forbids_drop(self) -> bool {
matches!(self, Self::Linear | Self::Relevant)
}
#[must_use]
#[inline]
pub const fn forbids_reuse(self) -> bool {
matches!(self, Self::Linear | Self::Affine)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum ShapePredicate {
AtLeast(u32),
AtMost(u32),
Exactly(u32),
MultipleOf(u32),
ModEquals {
modulus: u32,
remainder: u32,
},
AffineRange {
scale: i64,
offset: i64,
min: i64,
max: i64,
},
And(Box<ShapePredicate>, Box<ShapePredicate>),
Or(Box<ShapePredicate>, Box<ShapePredicate>),
Not(Box<ShapePredicate>),
}
impl ShapePredicate {
#[must_use]
pub fn holds(&self, count: u32) -> bool {
match self {
Self::AtLeast(n) => count >= *n,
Self::AtMost(n) => count <= *n,
Self::Exactly(n) => count == *n,
Self::MultipleOf(n) => *n != 0 && count % *n == 0,
Self::ModEquals { modulus, remainder } => {
*modulus != 0 && *remainder < *modulus && count % *modulus == *remainder
}
Self::AffineRange {
scale,
offset,
min,
max,
} => {
let value = i128::from(count) * i128::from(*scale) + i128::from(*offset);
value >= i128::from(*min) && value <= i128::from(*max)
}
Self::And(a, b) => a.holds(count) && b.holds(count),
Self::Or(a, b) => a.holds(count) || b.holds(count),
Self::Not(inner) => !inner.holds(count),
}
}
#[must_use]
pub fn describe(&self) -> String {
match self {
Self::AtLeast(n) => format!("count >= {n}"),
Self::AtMost(n) => format!("count <= {n}"),
Self::Exactly(n) => format!("count == {n}"),
Self::MultipleOf(n) => format!("count % {n} == 0"),
Self::ModEquals { modulus, remainder } => format!("count % {modulus} == {remainder}"),
Self::AffineRange {
scale,
offset,
min,
max,
} => {
format!("{min} <= count * {scale} + {offset} <= {max}")
}
Self::And(a, b) => format!("({}) && ({})", a.describe(), b.describe()),
Self::Or(a, b) => format!("({}) || ({})", a.describe(), b.describe()),
Self::Not(inner) => format!("!({})", inner.describe()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BufferDecl {
pub name: Arc<str>,
pub binding: u32,
pub access: BufferAccess,
pub kind: MemoryKind,
pub element: DataType,
pub count: u32,
pub is_output: bool,
pub pipeline_live_out: bool,
pub output_byte_range: Option<Range<usize>>,
pub hints: MemoryHints,
pub bytes_extraction: bool,
pub linear_type: LinearType,
pub shape_predicate: Option<ShapePredicate>,
}
impl BufferDecl {
#[must_use]
#[inline]
pub fn storage(name: &str, binding: u32, access: BufferAccess, element: DataType) -> Self {
let kind = match &access {
BufferAccess::ReadOnly => MemoryKind::Readonly,
BufferAccess::ReadWrite => MemoryKind::Global,
BufferAccess::Uniform => MemoryKind::Uniform,
BufferAccess::Workgroup => MemoryKind::Shared,
_ => MemoryKind::Global,
};
Self {
name: Arc::from(name),
binding,
access,
kind,
element,
count: 0,
is_output: false,
pipeline_live_out: false,
output_byte_range: None,
hints: MemoryHints::default(),
bytes_extraction: false,
linear_type: LinearType::default(),
shape_predicate: None,
}
}
#[must_use]
#[inline]
pub fn read(name: &str, binding: u32, element: DataType) -> Self {
Self::storage(name, binding, BufferAccess::ReadOnly, element)
}
#[must_use]
#[inline]
pub fn read_write(name: &str, binding: u32, element: DataType) -> Self {
Self::storage(name, binding, BufferAccess::ReadWrite, element)
}
#[must_use]
#[inline]
pub fn output(name: &str, binding: u32, element: DataType) -> Self {
Self {
is_output: true,
pipeline_live_out: true,
..Self::read_write(name, binding, element)
}
}
#[must_use]
#[inline]
pub fn with_pipeline_live_out(mut self, flag: bool) -> Self {
self.pipeline_live_out = flag;
self
}
#[must_use]
#[inline]
pub fn with_output_byte_range(mut self, range: Range<usize>) -> Self {
self.output_byte_range = Some(range);
self
}
#[must_use]
#[inline]
pub fn with_count(mut self, count: u32) -> Self {
self.count = count;
self
}
#[must_use]
#[inline]
pub fn uniform(name: &str, binding: u32, element: DataType) -> Self {
Self::storage(name, binding, BufferAccess::Uniform, element)
}
#[must_use]
#[inline]
pub fn workgroup(name: &str, count: u32, element: DataType) -> Self {
Self {
name: Arc::from(name),
binding: 0,
access: BufferAccess::Workgroup,
kind: MemoryKind::Shared,
element,
count,
is_output: false,
pipeline_live_out: false,
output_byte_range: None,
hints: MemoryHints::default(),
bytes_extraction: false,
linear_type: LinearType::default(),
shape_predicate: None,
}
}
#[must_use]
#[inline]
pub fn with_bytes_extraction(mut self, flag: bool) -> Self {
self.bytes_extraction = flag;
self
}
#[must_use]
#[inline]
pub fn with_linear_type(mut self, linear_type: LinearType) -> Self {
self.linear_type = linear_type;
self
}
#[must_use]
#[inline]
pub fn with_shape_predicate(mut self, predicate: ShapePredicate) -> Self {
self.shape_predicate = Some(predicate);
self
}
#[must_use]
#[inline]
pub fn with_kind(mut self, kind: MemoryKind) -> Self {
self.kind = kind;
self
}
#[must_use]
#[inline]
pub fn with_hints(mut self, hints: MemoryHints) -> Self {
self.hints = hints;
self
}
#[must_use]
#[inline]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
#[inline]
pub fn binding(&self) -> u32 {
self.binding
}
#[must_use]
#[inline]
pub fn access(&self) -> BufferAccess {
self.access.clone()
}
#[must_use]
#[inline]
pub fn kind(&self) -> MemoryKind {
self.kind
}
#[must_use]
#[inline]
pub fn hints(&self) -> MemoryHints {
self.hints
}
#[must_use]
#[inline]
pub fn element(&self) -> DataType {
self.element.clone()
}
#[must_use]
#[inline]
pub fn count(&self) -> u32 {
self.count
}
#[must_use]
#[inline]
pub fn is_output(&self) -> bool {
self.is_output
}
#[must_use]
#[inline]
pub fn is_pipeline_live_out(&self) -> bool {
self.pipeline_live_out
}
#[must_use]
#[inline]
pub fn output_byte_range(&self) -> Option<Range<usize>> {
self.output_byte_range.clone()
}
#[must_use]
#[inline]
pub fn linear_type(&self) -> LinearType {
self.linear_type
}
#[must_use]
#[inline]
pub fn shape_predicate(&self) -> Option<&ShapePredicate> {
self.shape_predicate.as_ref()
}
}
#[cfg(test)]
mod linear_type_tests {
use super::*;
#[test]
fn default_is_unrestricted() {
let buf = BufferDecl::read("a", 0, DataType::U32);
assert_eq!(buf.linear_type(), LinearType::Unrestricted);
assert!(!LinearType::Unrestricted.forbids_drop());
assert!(!LinearType::Unrestricted.forbids_reuse());
}
#[test]
fn linear_forbids_both() {
assert!(LinearType::Linear.forbids_drop());
assert!(LinearType::Linear.forbids_reuse());
}
#[test]
fn affine_forbids_only_reuse() {
assert!(!LinearType::Affine.forbids_drop());
assert!(LinearType::Affine.forbids_reuse());
}
#[test]
fn relevant_forbids_only_drop() {
assert!(LinearType::Relevant.forbids_drop());
assert!(!LinearType::Relevant.forbids_reuse());
}
#[test]
fn with_linear_type_is_round_trip() {
for lt in [
LinearType::Linear,
LinearType::Affine,
LinearType::Relevant,
LinearType::Unrestricted,
] {
let buf = BufferDecl::read("a", 0, DataType::U32).with_linear_type(lt);
assert_eq!(buf.linear_type(), lt);
}
}
#[test]
fn workgroup_constructor_defaults_to_unrestricted() {
let buf = BufferDecl::workgroup("scratch", 64, DataType::U32);
assert_eq!(buf.linear_type(), LinearType::Unrestricted);
}
}
#[cfg(test)]
mod shape_predicate_tests {
use super::*;
#[test]
fn at_least_holds_when_count_meets_minimum() {
let p = ShapePredicate::AtLeast(64);
assert!(p.holds(64));
assert!(p.holds(128));
assert!(!p.holds(32));
}
#[test]
fn at_most_holds_when_count_within_bound() {
let p = ShapePredicate::AtMost(64);
assert!(p.holds(0));
assert!(p.holds(64));
assert!(!p.holds(65));
}
#[test]
fn exactly_holds_only_for_match() {
let p = ShapePredicate::Exactly(7);
assert!(p.holds(7));
assert!(!p.holds(6));
assert!(!p.holds(8));
}
#[test]
fn multiple_of_holds_for_aligned_count() {
let p = ShapePredicate::MultipleOf(64);
assert!(p.holds(0));
assert!(p.holds(64));
assert!(p.holds(128));
assert!(!p.holds(63));
assert!(!p.holds(65));
}
#[test]
fn multiple_of_zero_never_holds() {
let p = ShapePredicate::MultipleOf(0);
assert!(!p.holds(0));
assert!(!p.holds(64));
}
#[test]
fn and_combines_two_predicates() {
let p = ShapePredicate::And(
Box::new(ShapePredicate::AtLeast(64)),
Box::new(ShapePredicate::MultipleOf(32)),
);
assert!(p.holds(64));
assert!(p.holds(96));
assert!(!p.holds(32)); assert!(!p.holds(80)); }
#[test]
fn or_accepts_either_predicate() {
let p = ShapePredicate::Or(
Box::new(ShapePredicate::Exactly(8)),
Box::new(ShapePredicate::Exactly(16)),
);
assert!(p.holds(8));
assert!(p.holds(16));
assert!(!p.holds(12));
}
#[test]
fn not_inverts_predicate() {
let p = ShapePredicate::Not(Box::new(ShapePredicate::AtMost(64)));
assert!(!p.holds(64));
assert!(p.holds(65));
}
#[test]
fn mod_equals_requires_valid_modular_form() {
assert!(ShapePredicate::ModEquals {
modulus: 16,
remainder: 4,
}
.holds(20));
assert!(!ShapePredicate::ModEquals {
modulus: 16,
remainder: 4,
}
.holds(21));
assert!(!ShapePredicate::ModEquals {
modulus: 0,
remainder: 0,
}
.holds(0));
assert!(!ShapePredicate::ModEquals {
modulus: 4,
remainder: 4,
}
.holds(4));
}
#[test]
fn affine_range_uses_wide_arithmetic() {
let p = ShapePredicate::AffineRange {
scale: 4,
offset: -8,
min: 24,
max: 40,
};
assert!(!p.holds(7));
assert!(p.holds(8));
assert!(p.holds(12));
assert!(!p.holds(13));
assert!(!ShapePredicate::AffineRange {
scale: i64::MAX,
offset: i64::MAX,
min: i64::MIN,
max: i64::MAX,
}
.holds(u32::MAX));
}
#[test]
fn buffer_decl_default_shape_predicate_is_none() {
let buf = BufferDecl::read("a", 0, DataType::U32);
assert_eq!(buf.shape_predicate(), None);
}
#[test]
fn with_shape_predicate_round_trip() {
let buf = BufferDecl::read("a", 0, DataType::U32)
.with_shape_predicate(ShapePredicate::MultipleOf(32));
assert_eq!(buf.shape_predicate(), Some(&ShapePredicate::MultipleOf(32)));
}
#[test]
fn describe_renders_human_readable() {
assert_eq!(
ShapePredicate::And(
Box::new(ShapePredicate::AtLeast(64)),
Box::new(ShapePredicate::MultipleOf(32)),
)
.describe(),
"(count >= 64) && (count % 32 == 0)"
);
}
}