use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::transform::compiler::U32X4_INPUTS;
use rustc_hash::FxHashMap;
use thiserror::Error;
use vyre_spec::AlgebraicLaw;
#[must_use]
pub fn source() -> Option<&'static str> {
crate::transform::compiler::shader_provider::source("string_interner")
}
pub const EMPTY_STRING_ID: u32 = 0;
#[must_use]
pub fn fnv1a_program(input: &str, out: &str, len: u32) -> Program {
let body = vec![
Node::let_bind("hash", Expr::u32(0x811c_9dc5)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(len),
vec![
Node::let_bind(
"word",
Expr::load(input, Expr::div(Expr::var("i"), Expr::u32(4))),
),
Node::let_bind(
"shift",
Expr::mul(Expr::u32(8), Expr::rem(Expr::var("i"), Expr::u32(4))),
),
Node::let_bind(
"byte",
Expr::bitand(
Expr::shr(Expr::var("word"), Expr::var("shift")),
Expr::u32(0xff),
),
),
Node::assign(
"hash",
Expr::mul(
Expr::bitxor(Expr::var("hash"), Expr::var("byte")),
Expr::u32(0x0100_0193),
),
),
],
),
Node::store(
out,
Expr::u32(0),
if len == 0 {
Expr::u32(EMPTY_STRING_ID)
} else {
Expr::var("hash")
},
),
];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(out, 1, BufferAccess::ReadWrite, DataType::U32).with_count(1),
],
[1, 1, 1],
body,
)
}
const _: &[DataType] = U32X4_INPUTS;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Entry {
pub(crate) hash: u32,
pub(crate) offset: usize,
pub(crate) len: usize,
pub(crate) id: u32,
}
#[must_use]
pub fn fnv1a32(bytes: &[u8]) -> u32 {
let mut hash = 0x811c_9dc5u32;
for &byte in bytes {
hash ^= u32::from(byte);
hash = hash.wrapping_mul(0x0100_0193);
}
hash
}
impl StringInterner {
#[must_use]
pub fn new(slot_capacity: usize, byte_capacity: usize) -> Self {
Self {
slots: vec![None; slot_capacity],
bytes: Vec::with_capacity(byte_capacity),
reverse: FxHashMap::default(),
byte_capacity,
next_id: 1,
}
}
#[must_use]
pub fn intern(&mut self, input: &[u8]) -> Result<u32, StringInternerError> {
if input.is_empty() {
return Ok(0);
}
if self.slots.is_empty() {
return Err(StringInternerError::TableFull);
}
let hash = fnv1a32(input);
let start = usize::try_from(hash).map_err(|_| StringInternerError::IndexOverflow)?
% self.slots.len();
for probe in 0..self.slots.len() {
let slot = (start + probe) % self.slots.len();
match &self.slots[slot] {
Some(entry) if entry.hash == hash && self.entry_bytes(entry) == input => {
return Ok(entry.id);
}
Some(_) => {}
None => {
if self
.bytes
.len()
.checked_add(input.len())
.is_none_or(|total| total > self.byte_capacity)
{
return Err(StringInternerError::BytePoolFull);
}
let offset = self.bytes.len();
self.bytes.extend_from_slice(input);
let id = self.next_id;
self.next_id = self
.next_id
.checked_add(1)
.ok_or(StringInternerError::IndexOverflow)?;
let entry = Entry {
hash,
offset,
len: input.len(),
id,
};
self.reverse.insert(id, slot);
self.slots[slot] = Some(entry);
return Ok(id);
}
}
}
Err(StringInternerError::TableFull)
}
#[must_use]
pub fn lookup(&self, id: u32) -> Option<&[u8]> {
if id == 0 {
return Some(&[]);
}
let slot = *self.reverse.get(&id)?;
let entry = self.slots.get(slot)?.as_ref()?;
Some(self.entry_bytes(entry))
}
pub(crate) fn entry_bytes(&self, entry: &Entry) -> &[u8] {
&self.bytes[entry.offset..entry.offset + entry.len]
}
}
impl StringInternerOp {}
#[must_use]
pub fn intern_all(inputs: &[&[u8]], slot_capacity: usize) -> Result<Vec<u32>, StringInternerError> {
let byte_capacity = slot_capacity.saturating_mul(64);
let mut interner = StringInterner::new(slot_capacity, byte_capacity);
inputs.iter().map(|bytes| interner.intern(bytes)).collect()
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StringInterner {
pub(crate) slots: Vec<Option<Entry>>,
pub(crate) bytes: Vec<u8>,
pub(crate) reverse: FxHashMap<u32, usize>,
pub(crate) byte_capacity: usize,
pub(crate) next_id: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum StringInternerError {
#[error(
"InternerTableFull: no SRAM slot accepted the string. Fix: increase table slots or split the lexing batch."
)]
TableFull,
#[error(
"InternerBytePoolFull: byte storage is exhausted. Fix: raise byte_capacity or split the lexing batch."
)]
BytePoolFull,
#[error(
"InternerIndexOverflow: table slot cannot fit u32 intern id. Fix: lower the workgroup interner capacity."
)]
IndexOverflow,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct StringInternerOp;
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];
#[cfg(test)]
mod ir_program_tests {
use super::*;
#[test]
fn fnv1a_program_validates() {
let prog = fnv1a_program("input", "out", 8);
let errors = crate::validate::validate::validate(&prog);
assert!(
errors.is_empty(),
"string_interner fnv1a IR must validate: {errors:?}"
);
}
#[test]
fn fnv1a_program_wire_round_trips() {
let prog = fnv1a_program("input", "out", 16);
let bytes = prog
.to_wire()
.expect("Fix: serialize; restore this invariant before continuing.");
let decoded = Program::from_wire(&bytes)
.expect("Fix: decode; restore this invariant before continuing.");
assert_eq!(decoded.buffers().len(), 2);
assert_eq!(decoded.workgroup_size(), [1, 1, 1]);
}
#[test]
fn fnv1a_program_empty_input_short_circuits() {
let prog = fnv1a_program("input", "out", 0);
let errors = crate::validate::validate::validate(&prog);
assert!(
errors.is_empty(),
"empty-input IR must validate: {errors:?}"
);
}
#[test]
fn fnv1a_program_different_lens_produce_different_wire() {
let a = fnv1a_program("input", "out", 4).to_wire().unwrap();
let b = fnv1a_program("input", "out", 16).to_wire().unwrap();
assert_ne!(a, b, "loop bound is part of the canonical wire");
}
#[test]
fn empty_string_sentinel_is_zero() {
assert_eq!(EMPTY_STRING_ID, 0);
}
}