use crate::ir::model::types::{AtomicOp, BinOp, DataType, UnOp};
use std::borrow::Borrow;
use std::fmt;
use std::ops::Deref;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Ident(Arc<str>);
impl Ident {
#[must_use]
#[inline]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&str> for Ident {
#[inline]
fn from(value: &str) -> Self {
Self(Arc::from(value))
}
}
impl From<String> for Ident {
#[inline]
fn from(value: String) -> Self {
Self(Arc::from(value))
}
}
impl Deref for Ident {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl AsRef<str> for Ident {
#[inline]
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Borrow<str> for Ident {
#[inline]
fn borrow(&self) -> &str {
self.as_str()
}
}
impl fmt::Display for Ident {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl PartialEq<str> for Ident {
#[inline]
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<&str> for Ident {
#[inline]
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
LitU32(u32),
LitI32(i32),
LitF32(f32),
LitBool(bool),
Var(Ident),
Load {
buffer: Ident,
index: Box<Expr>,
},
BufLen {
buffer: Ident,
},
InvocationId {
axis: u8,
},
WorkgroupId {
axis: u8,
},
LocalId {
axis: u8,
},
BinOp {
op: BinOp,
left: Box<Expr>,
right: Box<Expr>,
},
UnOp {
op: UnOp,
operand: Box<Expr>,
},
Call {
op_id: String,
args: Vec<Expr>,
},
Select {
cond: Box<Expr>,
true_val: Box<Expr>,
false_val: Box<Expr>,
},
Cast {
target: DataType,
value: Box<Expr>,
},
Fma {
a: Box<Expr>,
b: Box<Expr>,
c: Box<Expr>,
},
Atomic {
op: AtomicOp,
buffer: Ident,
index: Box<Expr>,
expected: Option<Box<Expr>>,
value: Box<Expr>,
},
}
impl Expr {
#[must_use]
#[inline]
pub fn load(buffer: &str, index: Self) -> Self {
Self::Load {
buffer: Ident::from(buffer),
index: Box::new(index),
}
}
#[must_use]
#[inline]
pub fn buf_len(buffer: &str) -> Self {
Self::BufLen {
buffer: Ident::from(buffer),
}
}
#[must_use]
#[inline]
pub fn gid_x() -> Self {
Self::InvocationId { axis: 0 }
}
#[must_use]
#[inline]
pub fn gid_y() -> Self {
Self::InvocationId { axis: 1 }
}
#[must_use]
#[inline]
pub fn gid_z() -> Self {
Self::InvocationId { axis: 2 }
}
#[must_use]
#[inline]
pub fn workgroup_x() -> Self {
Self::WorkgroupId { axis: 0 }
}
#[must_use]
#[inline]
pub fn workgroup_y() -> Self {
Self::WorkgroupId { axis: 1 }
}
#[must_use]
#[inline]
pub fn workgroup_z() -> Self {
Self::WorkgroupId { axis: 2 }
}
#[must_use]
#[inline]
pub fn local_x() -> Self {
Self::LocalId { axis: 0 }
}
#[must_use]
#[inline]
pub fn local_y() -> Self {
Self::LocalId { axis: 1 }
}
#[must_use]
#[inline]
pub fn local_z() -> Self {
Self::LocalId { axis: 2 }
}
#[must_use]
#[inline]
pub fn select(cond: Self, true_val: Self, false_val: Self) -> Self {
Self::Select {
cond: Box::new(cond),
true_val: Box::new(true_val),
false_val: Box::new(false_val),
}
}
#[must_use]
#[inline]
pub fn var(name: &str) -> Self {
Self::Var(Ident::from(name))
}
#[must_use]
#[inline]
pub fn u32(value: u32) -> Self {
Self::LitU32(value)
}
#[must_use]
#[inline]
pub fn i32(value: i32) -> Self {
Self::LitI32(value)
}
#[must_use]
#[inline]
pub fn f32(value: f32) -> Self {
Self::LitF32(value)
}
#[must_use]
#[inline]
pub fn bool(value: bool) -> Self {
Self::LitBool(value)
}
#[must_use]
#[inline]
pub fn call(op_id: &str, args: Vec<Self>) -> Self {
Self::Call {
op_id: op_id.to_string(),
args,
}
}
#[must_use]
#[inline]
pub fn fma(a: Self, b: Self, c: Self) -> Self {
Self::Fma {
a: Box::new(a),
b: Box::new(b),
c: Box::new(c),
}
}
#[must_use]
#[inline]
pub fn cast(target: DataType, value: Self) -> Self {
Self::Cast {
target,
value: Box::new(value),
}
}
}
mod atomics;
mod builders;
#[cfg(test)]
mod tests {
use super::Expr;
#[test]
fn expr_size_is_bounded() {
let size = std::mem::size_of::<Expr>();
eprintln!("Expr size: {size}");
assert!(
size <= 128,
"Expr grew to {size} bytes. Fix: box the largest variant before adding more fields."
);
}
}