use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];
pub mod codes {
pub const OK: u32 = 0;
pub const ERR_OVERFLOW: u32 = 1;
pub const ERR_NOT_FOUND: u32 = 2;
pub const REPLACED: u32 = 3;
pub const ERR_RESERVED_KEY: u32 = 4;
pub const ERR_INVALID_CAPACITY: u32 = 5;
}
pub const SPEC: OpSpec = OpSpec::intrinsic(
"workgroup.hashmap",
INPUTS,
OUTPUTS,
LAWS,
wgsl_only,
IntrinsicDescriptor::new(
"workgroup_hashmap_kernel",
"workgroup-sram-linear-probe",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HashmapStatus {
Ok = codes::OK as isize,
Overflow = codes::ERR_OVERFLOW as isize,
NotFound = codes::ERR_NOT_FOUND as isize,
Replaced = codes::REPLACED as isize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HashmapError {
InvalidCapacity,
ReservedKey,
Overflow,
NotFound,
}
impl HashmapError {
#[must_use]
pub const fn code(self) -> u32 {
match self {
Self::InvalidCapacity => codes::ERR_INVALID_CAPACITY,
Self::ReservedKey => codes::ERR_RESERVED_KEY,
Self::Overflow => codes::ERR_OVERFLOW,
Self::NotFound => codes::ERR_NOT_FOUND,
}
}
}
pub const EMPTY_KEY: u32 = u32::MAX;
pub const EMPTY_VALUE: u32 = u32::MAX;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupHashmap {
slots: Vec<Slot>,
len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Slot {
key: u32,
value: u32,
}
pub fn wgsl_only(backend: &Backend) -> bool {
matches!(backend, Backend::Wgsl) || WorkgroupHashmap::backend_is_wgsl()
}
#[must_use]
pub const fn mix_hash(mut x: u32) -> u32 {
x = x.wrapping_add(0x9E37_79B1);
x = (x ^ (x >> 16)).wrapping_mul(0x7FEB_352D);
x = (x ^ (x >> 15)).wrapping_mul(0x846C_A68B);
x ^ (x >> 16)
}
impl WorkgroupHashmap {
pub fn new(capacity: usize) -> Result<Self, HashmapError> {
Self::try_new(capacity).ok_or(HashmapError::InvalidCapacity)
}
pub fn try_new(capacity: usize) -> Option<Self> {
if capacity == 0 || !capacity.is_power_of_two() {
return None;
}
Some(Self {
slots: vec![
Slot {
key: EMPTY_KEY,
value: EMPTY_VALUE,
};
capacity
],
len: 0,
})
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn capacity(&self) -> usize {
self.slots.len()
}
pub fn insert(&mut self, key: u32, value: u32) -> Result<HashmapStatus, HashmapError> {
if key == EMPTY_KEY {
return Err(HashmapError::ReservedKey);
}
let mask = self.slots.len() - 1;
let start = mix_hash(key) as usize & mask;
for probe in 0..self.slots.len() {
let idx = (start + probe) & mask;
let slot = &mut self.slots[idx];
if slot.key == EMPTY_KEY {
slot.key = key;
slot.value = value;
self.len += 1;
return Ok(HashmapStatus::Ok);
}
if slot.key == key {
slot.value = value;
return Ok(HashmapStatus::Replaced);
}
}
Err(HashmapError::Overflow)
}
pub fn get(&self, key: u32) -> Result<u32, HashmapError> {
let mask = self.slots.len() - 1;
let start = mix_hash(key) as usize & mask;
for probe in 0..self.slots.len() {
let idx = (start + probe) & mask;
let slot = &self.slots[idx];
if slot.key == EMPTY_KEY {
return Err(HashmapError::NotFound);
}
if slot.key == key {
return Ok(slot.value);
}
}
Err(HashmapError::NotFound)
}
#[must_use]
pub fn contains(&self, key: u32) -> bool {
self.get(key).is_ok()
}
pub fn remove(&mut self, key: u32) -> Result<u32, HashmapError> {
let mask = self.slots.len() - 1;
let start = mix_hash(key) as usize & mask;
let mut hit = None;
for probe in 0..self.slots.len() {
let idx = (start + probe) & mask;
let slot = self.slots[idx];
if slot.key == EMPTY_KEY {
return Err(HashmapError::NotFound);
}
if slot.key == key {
hit = Some((idx, slot.value));
break;
}
}
let (i, value) = hit.ok_or(HashmapError::NotFound)?;
self.slots[i] = Slot {
key: EMPTY_KEY,
value: EMPTY_VALUE,
};
self.len -= 1;
let mut cursor = (i + 1) & mask;
while self.slots[cursor].key != EMPTY_KEY {
let displaced = self.slots[cursor];
self.slots[cursor] = Slot {
key: EMPTY_KEY,
value: EMPTY_VALUE,
};
self.len -= 1;
self.insert(displaced.key, displaced.value)?;
cursor = (cursor + 1) & mask;
}
Ok(value)
}
pub(crate) fn backend_is_wgsl() -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::{codes, HashmapError, HashmapStatus};
#[test]
fn error_codes_are_one_to_one_with_wgsl_constants() {
let pairs = [
(
"HASHMAP_ERR_INVALID_CAPACITY",
HashmapError::InvalidCapacity.code(),
),
("HASHMAP_ERR_RESERVED_KEY", HashmapError::ReservedKey.code()),
("HASHMAP_ERR_OVERFLOW", HashmapError::Overflow.code()),
("HASHMAP_ERR_NOT_FOUND", HashmapError::NotFound.code()),
];
let mut seen = std::collections::BTreeSet::new();
for (name, code) in pairs {
assert!(seen.insert(code), "duplicate hashmap error code {code}");
assert!(
crate::ops::workgroup::primitives::hashmap::lowering::WGSL
.contains(&format!("const {name}: u32 = {code}u;")),
"WGSL hashmap error constant {name} must equal Rust code {code}"
);
}
assert_eq!(HashmapStatus::Ok as u32, codes::OK);
assert_eq!(HashmapStatus::Overflow as u32, codes::ERR_OVERFLOW);
assert_eq!(HashmapStatus::NotFound as u32, codes::ERR_NOT_FOUND);
assert_eq!(HashmapStatus::Replaced as u32, codes::REPLACED);
}
}