use crate::ir_inner::model::types::DataType;
use rustc_hash::FxHasher;
use std::borrow::Borrow;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GeneratorRef {
pub name: String,
}
#[derive(Clone, Eq, PartialEq)]
pub struct Ident {
text: Arc<str>,
hash: u64,
}
impl Ident {
#[inline]
fn prehash(text: &str) -> u64 {
let mut hasher = FxHasher::default();
text.hash(&mut hasher);
hasher.finish()
}
#[must_use]
#[inline]
pub fn new(text: Arc<str>) -> Self {
let hash = Self::prehash(&text);
Self { text, hash }
}
#[must_use]
#[inline]
pub fn shared_text(&self) -> Arc<str> {
Arc::clone(&self.text)
}
#[must_use]
#[inline]
pub fn as_str(&self) -> &str {
&self.text
}
#[must_use]
#[inline]
pub fn cached_hash(&self) -> u64 {
self.hash
}
}
impl From<&str> for Ident {
#[inline]
fn from(value: &str) -> Self {
Self::new(Arc::from(value))
}
}
impl From<String> for Ident {
#[inline]
fn from(value: String) -> Self {
Self::new(Arc::from(value))
}
}
impl From<Arc<str>> for Ident {
#[inline]
fn from(value: Arc<str>) -> Self {
Self::new(value)
}
}
impl From<&String> for Ident {
#[inline]
fn from(value: &String) -> Self {
Self::from(value.as_str())
}
}
impl From<&Ident> for Ident {
#[inline]
fn from(value: &Ident) -> Self {
value.clone()
}
}
impl fmt::Debug for Ident {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Ident").field(&self.as_str()).finish()
}
}
impl Hash for Ident {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.text.hash(state);
}
}
impl Deref for Ident {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl AsRef<str> for Ident {
#[inline]
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Borrow<str> for Ident {
#[inline]
fn borrow(&self) -> &str {
self.as_str()
}
}
impl fmt::Display for Ident {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl PartialEq<str> for Ident {
#[inline]
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<&str> for Ident {
#[inline]
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
impl PartialOrd for Ident {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Ident {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_str().cmp(other.as_str())
}
}
pub use crate::ir_inner::model::generated::Expr;
pub trait ExprNode: fmt::Debug + Send + Sync + 'static {
fn extension_kind(&self) -> &'static str;
fn debug_identity(&self) -> &str;
fn result_type(&self) -> Option<DataType>;
fn cse_safe(&self) -> bool;
fn stable_fingerprint(&self) -> [u8; 32];
fn validate_extension(&self) -> Result<(), String>;
fn as_any(&self) -> &dyn std::any::Any;
fn wire_payload(&self) -> Vec<u8> {
Vec::new()
}
}
impl Expr {
#[must_use]
#[inline]
pub fn load(buffer: impl Into<Ident>, index: Self) -> Self {
Self::Load {
buffer: buffer.into(),
index: Box::new(index),
}
}
#[must_use]
#[inline]
pub fn buf_len(buffer: impl Into<Ident>) -> Self {
Self::BufLen {
buffer: buffer.into(),
}
}
#[must_use]
#[inline]
pub fn gid_x() -> Self {
Self::InvocationId { axis: 0 }
}
#[must_use]
#[inline]
pub fn gid_y() -> Self {
Self::InvocationId { axis: 1 }
}
#[must_use]
#[inline]
pub fn gid_z() -> Self {
Self::InvocationId { axis: 2 }
}
#[must_use]
#[inline]
pub fn workgroup_x() -> Self {
Self::WorkgroupId { axis: 0 }
}
#[must_use]
#[inline]
pub fn workgroup_y() -> Self {
Self::WorkgroupId { axis: 1 }
}
#[must_use]
#[inline]
pub fn workgroup_z() -> Self {
Self::WorkgroupId { axis: 2 }
}
#[must_use]
#[inline]
pub fn local_x() -> Self {
Self::LocalId { axis: 0 }
}
#[must_use]
#[inline]
pub fn subgroup_local_id() -> Self {
Self::SubgroupLocalId
}
#[must_use]
#[inline]
pub fn subgroup_size() -> Self {
Self::SubgroupSize
}
#[must_use]
#[inline]
pub fn local_y() -> Self {
Self::LocalId { axis: 1 }
}
#[must_use]
#[inline]
pub fn local_z() -> Self {
Self::LocalId { axis: 2 }
}
#[must_use]
#[inline]
pub fn parallel_region_x() -> Self {
Self::WorkgroupId { axis: 0 }
}
#[must_use]
#[inline]
pub fn parallel_region_y() -> Self {
Self::WorkgroupId { axis: 1 }
}
#[must_use]
#[inline]
pub fn parallel_region_z() -> Self {
Self::WorkgroupId { axis: 2 }
}
#[must_use]
#[inline]
pub fn invocation_local_x() -> Self {
Self::LocalId { axis: 0 }
}
#[must_use]
#[inline]
pub fn invocation_local_y() -> Self {
Self::LocalId { axis: 1 }
}
#[must_use]
#[inline]
pub fn invocation_local_z() -> Self {
Self::LocalId { axis: 2 }
}
#[must_use]
#[inline]
pub fn select(cond: Self, true_val: Self, false_val: Self) -> Self {
Self::Select {
cond: Box::new(cond),
true_val: Box::new(true_val),
false_val: Box::new(false_val),
}
}
#[must_use]
#[inline]
pub fn subgroup_add(value: Self) -> Self {
Self::SubgroupAdd {
value: Box::new(value),
}
}
#[must_use]
#[inline]
pub fn subgroup_shuffle(value: Self, lane: Self) -> Self {
Self::SubgroupShuffle {
value: Box::new(value),
lane: Box::new(lane),
}
}
#[must_use]
#[inline]
pub fn subgroup_ballot(cond: Self) -> Self {
Self::SubgroupBallot {
cond: Box::new(cond),
}
}
#[must_use]
#[inline]
pub fn var(name: impl Into<Ident>) -> Self {
Self::Var(name.into())
}
#[must_use]
#[inline]
pub fn u32(value: u32) -> Self {
Self::LitU32(value)
}
#[must_use]
#[inline]
pub fn i32(value: i32) -> Self {
Self::LitI32(value)
}
#[must_use]
#[inline]
pub fn f32(value: f32) -> Self {
Self::LitF32(value)
}
#[must_use]
#[inline]
pub fn bool(value: bool) -> Self {
Self::LitBool(value)
}
#[must_use]
#[inline]
pub fn call(op_id: impl Into<Ident>, args: Vec<Self>) -> Self {
Self::Call {
op_id: op_id.into(),
args,
}
}
#[must_use]
#[inline]
pub fn fma(a: Self, b: Self, c: Self) -> Self {
Self::Fma {
a: Box::new(a),
b: Box::new(b),
c: Box::new(c),
}
}
#[must_use]
#[inline]
pub fn cast(target: DataType, value: Self) -> Self {
Self::Cast {
target,
value: Box::new(value),
}
}
#[must_use]
#[inline]
pub fn opaque(node: impl ExprNode) -> Self {
Self::Opaque(Arc::new(node))
}
#[must_use]
#[inline]
pub fn opaque_arc(node: Arc<dyn ExprNode>) -> Self {
Self::Opaque(node)
}
}
mod atomics;
mod builders;
#[cfg(test)]
mod tests {
use super::Expr;
#[test]
fn expr_size_is_bounded() {
let size = std::mem::size_of::<Expr>();
eprintln!("Expr size: {size}");
assert!(
size <= 128,
"Expr grew to {size} bytes. Fix: box the largest variant before adding more fields."
);
}
}