use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];
pub const SPEC: OpSpec = OpSpec::intrinsic(
"workgroup.string_interner",
INPUTS,
OUTPUTS,
LAWS,
wgsl_only,
IntrinsicDescriptor::new(
"workgroup_string_interner_intern",
"workgroup-sram-fnv1a-table",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct InternedSymbol(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InternError {
OutOfSlots,
OutOfBytes,
UnknownSymbol,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupStringInterner {
slot_capacity: usize,
byte_capacity: usize,
slots: Vec<(u32, u32)>,
bytes: Vec<u8>,
}
impl WorkgroupStringInterner {
#[must_use]
pub fn new(slot_capacity: usize, byte_capacity: usize) -> Self {
Self {
slot_capacity,
byte_capacity,
slots: Vec::with_capacity(slot_capacity),
bytes: Vec::with_capacity(byte_capacity),
}
}
#[must_use]
pub fn slot_capacity(&self) -> usize {
self.slot_capacity
}
#[must_use]
pub fn byte_capacity(&self) -> usize {
self.byte_capacity
}
#[must_use]
pub fn len(&self) -> usize {
self.slots.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.slots.is_empty()
}
pub fn intern(&mut self, bytes: &[u8]) -> Result<InternedSymbol, InternError> {
if let Some(existing) = self.lookup_by_bytes(bytes) {
return Ok(existing);
}
if self.slots.len() >= self.slot_capacity {
return Err(InternError::OutOfSlots);
}
if self
.bytes
.len()
.checked_add(bytes.len())
.is_none_or(|total| total > self.byte_capacity)
{
return Err(InternError::OutOfBytes);
}
let offset = self.bytes.len() as u32;
let len = bytes.len() as u32;
self.bytes.extend_from_slice(bytes);
let id = self.slots.len() as u32;
self.slots.push((offset, len));
Ok(InternedSymbol(id))
}
pub fn lookup(&self, symbol: InternedSymbol) -> Result<&[u8], InternError> {
let idx = symbol.0 as usize;
if idx >= self.slots.len() {
return Err(InternError::UnknownSymbol);
}
let (offset, len) = self.slots[idx];
let start = offset as usize;
let end = start + len as usize;
Ok(&self.bytes[start..end])
}
pub(crate) fn lookup_by_bytes(&self, bytes: &[u8]) -> Option<InternedSymbol> {
for (idx, (offset, len)) in self.slots.iter().enumerate() {
let start = *offset as usize;
let end = start + *len as usize;
if &self.bytes[start..end] == bytes {
return Some(InternedSymbol(idx as u32));
}
}
None
}
}
pub fn wgsl_only(backend: &Backend) -> bool {
matches!(backend, Backend::Wgsl)
}