use core::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ExtensionDataTypeId(pub u32);
impl ExtensionDataTypeId {
pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
#[must_use]
pub const fn from_name(name: &str) -> Self {
Self(fnv1a_with_high_bit(name))
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
#[must_use]
pub const fn is_extension(self) -> bool {
(self.0 & Self::EXTENSION_RANGE_MASK) != 0
}
}
pub trait ExtensionDataType: Send + Sync + Debug + 'static {
fn id(&self) -> ExtensionDataTypeId;
fn display_name(&self) -> &'static str;
fn min_bytes(&self) -> usize;
fn max_bytes(&self) -> Option<usize>;
fn size_bytes(&self) -> Option<usize>;
fn is_float_family(&self) -> bool {
false
}
fn is_host_shareable(&self) -> bool {
true
}
}
pub trait ExtensionBinOp: Send + Sync + Debug + 'static {
fn id(&self) -> ExtensionBinOpId;
fn display_name(&self) -> &'static str;
fn eval_u32(&self, _a: u32, _b: u32) -> Option<u32> {
None
}
}
pub trait ExtensionUnOp: Send + Sync + Debug + 'static {
fn id(&self) -> ExtensionUnOpId;
fn display_name(&self) -> &'static str;
fn eval_u32(&self, _a: u32) -> Option<u32> {
None
}
}
pub trait ExtensionAtomicOp: Send + Sync + Debug + 'static {
fn id(&self) -> ExtensionAtomicOpId;
fn display_name(&self) -> &'static str;
}
pub trait ExtensionTernaryOp: Send + Sync + Debug + 'static {
fn id(&self) -> ExtensionTernaryOpId;
fn display_name(&self) -> &'static str;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ExtensionBinOpId(pub u32);
impl ExtensionBinOpId {
pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
#[must_use]
pub const fn from_name(name: &str) -> Self {
Self(fnv1a_with_high_bit(name))
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ExtensionUnOpId(pub u32);
impl ExtensionUnOpId {
pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
#[must_use]
pub const fn from_name(name: &str) -> Self {
Self(fnv1a_with_high_bit(name))
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ExtensionAtomicOpId(pub u32);
impl ExtensionAtomicOpId {
pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
#[must_use]
pub const fn from_name(name: &str) -> Self {
Self(fnv1a_with_high_bit(name))
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ExtensionTernaryOpId(pub u32);
impl ExtensionTernaryOpId {
pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
#[must_use]
pub const fn from_name(name: &str) -> Self {
Self(fnv1a_with_high_bit(name))
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ExtensionRuleConditionId(pub u32);
impl ExtensionRuleConditionId {
pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
#[must_use]
pub const fn from_name(name: &str) -> Self {
Self(fnv1a_with_high_bit(name))
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
}
#[must_use]
const fn fnv1a_with_high_bit(name: &str) -> u32 {
let mut hash: u32 = 0x811c_9dc5;
let bytes = name.as_bytes();
let mut i = 0;
while i < bytes.len() {
hash ^= bytes[i] as u32;
hash = hash.wrapping_mul(0x0100_0193);
i += 1;
}
hash | 0x8000_0000
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn id_from_name_is_deterministic() {
assert_eq!(
ExtensionDataTypeId::from_name("tensor.gather"),
ExtensionDataTypeId::from_name("tensor.gather"),
);
}
#[test]
fn id_from_different_names_differ() {
let a = ExtensionDataTypeId::from_name("tensor.gather");
let b = ExtensionDataTypeId::from_name("tensor.scatter");
assert_ne!(a, b);
}
#[test]
fn every_id_is_in_extension_range() {
let id = ExtensionDataTypeId::from_name("anything");
assert!(id.is_extension(), "{:#010x} missing high bit", id.as_u32());
assert!(id.as_u32() & ExtensionDataTypeId::EXTENSION_RANGE_MASK != 0);
}
}