use crate::ir::DataType;
use crate::ir::transform::compiler::{U32X4_INPUTS, U32_OUTPUTS};
use crate::lower::wgsl::compiler::wgsl_backend;
use crate::ops::{AlgebraicLaw, IntrinsicDescriptor, OpSpec};
use thiserror::Error;
#[must_use]
pub const fn source() -> &'static str {
include_str!("../../../lower/wgsl/compiler/string_interner.wgsl")
}
const _: &[DataType] = U32X4_INPUTS;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Entry {
pub(crate) hash: u32,
pub(crate) offset: usize,
pub(crate) len: usize,
pub(crate) id: u32,
}
#[must_use]
pub fn fnv1a32(bytes: &[u8]) -> u32 {
let mut hash = 0x811c_9dc5u32;
for &byte in bytes {
hash ^= u32::from(byte);
hash = hash.wrapping_mul(0x0100_0193);
}
hash
}
impl StringInterner {
#[must_use]
pub fn new(slot_capacity: usize, byte_capacity: usize) -> Self {
Self {
slots: vec![None; slot_capacity],
bytes: Vec::with_capacity(byte_capacity),
byte_capacity,
next_id: 1,
}
}
pub fn intern(&mut self, input: &[u8]) -> Result<u32, StringInternerError> {
if input.is_empty() {
return Ok(0);
}
if self.slots.is_empty() {
return Err(StringInternerError::TableFull);
}
let hash = fnv1a32(input);
let start = usize::try_from(hash).map_err(|_| StringInternerError::IndexOverflow)?
% self.slots.len();
for probe in 0..self.slots.len() {
let slot = (start + probe) % self.slots.len();
match &self.slots[slot] {
Some(entry) if entry.hash == hash && self.entry_bytes(entry) == input => {
return Ok(entry.id)
}
Some(_) => {}
None => {
if self
.bytes
.len()
.checked_add(input.len())
.is_none_or(|total| total > self.byte_capacity)
{
return Err(StringInternerError::BytePoolFull);
}
let offset = self.bytes.len();
self.bytes.extend_from_slice(input);
let id = self.next_id;
self.next_id = self
.next_id
.checked_add(1)
.ok_or(StringInternerError::IndexOverflow)?;
self.slots[slot] = Some(Entry {
hash,
offset,
len: input.len(),
id,
});
return Ok(id);
}
}
}
Err(StringInternerError::TableFull)
}
#[must_use]
pub fn lookup(&self, id: u32) -> Option<&[u8]> {
if id == 0 {
return Some(&[]);
}
for entry in self.slots.iter().flatten() {
if entry.id == id {
return Some(self.entry_bytes(entry));
}
}
None
}
pub(crate) fn entry_bytes(&self, entry: &Entry) -> &[u8] {
&self.bytes[entry.offset..entry.offset + entry.len]
}
}
impl StringInternerOp {
pub const SPEC: OpSpec = OpSpec::intrinsic(
"compiler_primitives.string_interner",
U32X4_INPUTS,
U32_OUTPUTS,
LAWS,
wgsl_backend,
IntrinsicDescriptor::new(
"compiler_primitives_string_interner",
"workgroup_atomic_sram",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
}
pub fn intern_all(inputs: &[&[u8]], slot_capacity: usize) -> Result<Vec<u32>, StringInternerError> {
let byte_capacity = slot_capacity.saturating_mul(64);
let mut interner = StringInterner::new(slot_capacity, byte_capacity);
inputs.iter().map(|bytes| interner.intern(bytes)).collect()
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StringInterner {
pub(crate) slots: Vec<Option<Entry>>,
pub(crate) bytes: Vec<u8>,
pub(crate) byte_capacity: usize,
pub(crate) next_id: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum StringInternerError {
#[error("InternerTableFull: no SRAM slot accepted the string. Fix: increase table slots or split the lexing batch.")]
TableFull,
#[error("InternerBytePoolFull: byte storage is exhausted. Fix: raise byte_capacity or split the lexing batch.")]
BytePoolFull,
#[error("InternerIndexOverflow: table slot cannot fit u32 intern id. Fix: lower the workgroup interner capacity.")]
IndexOverflow,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct StringInternerOp;
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];