vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Specification and CPU reference for `workgroup.hashmap`.

use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};

pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];
pub mod codes {
    pub const OK: u32 = 0;
    pub const ERR_OVERFLOW: u32 = 1;
    pub const ERR_NOT_FOUND: u32 = 2;
    pub const REPLACED: u32 = 3;
    pub const ERR_RESERVED_KEY: u32 = 4;
    pub const ERR_INVALID_CAPACITY: u32 = 5;
}

pub const SPEC: OpSpec = OpSpec::intrinsic(
    "workgroup.hashmap",
    INPUTS,
    OUTPUTS,
    LAWS,
    wgsl_only,
    IntrinsicDescriptor::new(
        "workgroup_hashmap_kernel",
        "workgroup-sram-linear-probe",
        crate::ops::cpu_op::structured_intrinsic_cpu,
    ),
);
/// Hashmap command status word shared by the CPU oracle and WGSL lowering.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HashmapStatus {
    /// Operation completed.
    Ok = codes::OK as isize,
    /// The map was full before insert could reserve a slot.
    Overflow = codes::ERR_OVERFLOW as isize,
    /// The key was not present for a `get` or `remove`.
    NotFound = codes::ERR_NOT_FOUND as isize,
    /// The insert replaced an existing entry.
    Replaced = codes::REPLACED as isize,
}
/// Error returned by fallible hashmap operations.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HashmapError {
    /// Capacity was zero or not a power of two.
    InvalidCapacity,
    /// The key is reserved as the empty-slot sentinel.
    ReservedKey,
    /// Capacity exhausted — no empty slot found after a full probe walk.
    Overflow,
    /// Key not present.
    NotFound,
}
impl HashmapError {
    #[must_use]
    pub const fn code(self) -> u32 {
        match self {
            Self::InvalidCapacity => codes::ERR_INVALID_CAPACITY,
            Self::ReservedKey => codes::ERR_RESERVED_KEY,
            Self::Overflow => codes::ERR_OVERFLOW,
            Self::NotFound => codes::ERR_NOT_FOUND,
        }
    }
}
/// A reserved sentinel key that may never be inserted. The GPU kernel
/// uses it to mark empty slots; the CPU oracle mirrors that choice so
/// parity is byte-exact.
pub const EMPTY_KEY: u32 = u32::MAX;
/// A reserved sentinel value returned by `get` / `remove` when a slot
/// is empty. The value is only meaningful alongside
/// [`HashmapStatus::NotFound`].
pub const EMPTY_VALUE: u32 = u32::MAX;
/// Bounded linear-probe hashmap used as the CPU reference for
/// `workgroup.hashmap`. Capacity must be a non-zero power of two —
/// the GPU kernel relies on bit-masking for the probe start index and
/// wrapping, and the CPU oracle follows the same rule so the two
/// implementations agree byte-for-byte.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupHashmap {
    slots: Vec<Slot>,
    len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Slot {
    key: u32,
    value: u32,
}
pub fn wgsl_only(backend: &Backend) -> bool {
    matches!(backend, Backend::Wgsl) || WorkgroupHashmap::backend_is_wgsl()
}
/// 32-bit integer mixer used by both the CPU oracle and WGSL kernel.
/// Matches the `hash_u32` helper declared in `hashmap.wgsl`.
#[must_use]
pub const fn mix_hash(mut x: u32) -> u32 {
    x = x.wrapping_add(0x9E37_79B1);
    x = (x ^ (x >> 16)).wrapping_mul(0x7FEB_352D);
    x = (x ^ (x >> 15)).wrapping_mul(0x846C_A68B);
    x ^ (x >> 16)
}

impl WorkgroupHashmap {
    /// Create an empty hashmap with the given capacity.
    ///
    /// # Errors
    ///
    /// Returns [`HashmapError::InvalidCapacity`] when `capacity` is zero or
    /// not a power of two.
    pub fn new(capacity: usize) -> Result<Self, HashmapError> {
        Self::try_new(capacity).ok_or(HashmapError::InvalidCapacity)
    }

    /// Fallible constructor for runtime-derived capacities.
    pub fn try_new(capacity: usize) -> Option<Self> {
        if capacity == 0 || !capacity.is_power_of_two() {
            return None;
        }
        Some(Self {
            slots: vec![
                Slot {
                    key: EMPTY_KEY,
                    value: EMPTY_VALUE,
                };
                capacity
            ],
            len: 0,
        })
    }

    /// Number of live entries.
    #[must_use]
    pub fn len(&self) -> usize {
        self.len
    }

    /// Whether the map contains zero live entries.
    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.len == 0
    }

    /// Total slot capacity.
    #[must_use]
    pub fn capacity(&self) -> usize {
        self.slots.len()
    }

    /// Insert `(key, value)`. Returns [`HashmapStatus::Ok`] for a fresh
    /// insert, [`HashmapStatus::Replaced`] when an existing entry was
    /// overwritten, and [`HashmapError::Overflow`] when capacity is
    /// exhausted.
    ///
    /// # Errors
    ///
    /// Returns [`HashmapError::Overflow`] when every probe slot is
    /// occupied by a different key.
    pub fn insert(&mut self, key: u32, value: u32) -> Result<HashmapStatus, HashmapError> {
        if key == EMPTY_KEY {
            return Err(HashmapError::ReservedKey);
        }
        let mask = self.slots.len() - 1;
        let start = mix_hash(key) as usize & mask;
        for probe in 0..self.slots.len() {
            let idx = (start + probe) & mask;
            let slot = &mut self.slots[idx];
            if slot.key == EMPTY_KEY {
                slot.key = key;
                slot.value = value;
                self.len += 1;
                return Ok(HashmapStatus::Ok);
            }
            if slot.key == key {
                slot.value = value;
                return Ok(HashmapStatus::Replaced);
            }
        }
        Err(HashmapError::Overflow)
    }

    /// Look up `key`. Returns the stored value or [`HashmapError::NotFound`].
    ///
    /// # Errors
    ///
    /// Returns [`HashmapError::NotFound`] when the key is not present.
    pub fn get(&self, key: u32) -> Result<u32, HashmapError> {
        let mask = self.slots.len() - 1;
        let start = mix_hash(key) as usize & mask;
        for probe in 0..self.slots.len() {
            let idx = (start + probe) & mask;
            let slot = &self.slots[idx];
            if slot.key == EMPTY_KEY {
                return Err(HashmapError::NotFound);
            }
            if slot.key == key {
                return Ok(slot.value);
            }
        }
        Err(HashmapError::NotFound)
    }

    /// Whether `key` is currently a live entry.
    #[must_use]
    pub fn contains(&self, key: u32) -> bool {
        self.get(key).is_ok()
    }

    /// Remove `key` and return its value, preserving the probe chain
    /// via backward-shift deletion.
    ///
    /// # Errors
    ///
    /// Returns [`HashmapError::NotFound`] when the key is not present.
    pub fn remove(&mut self, key: u32) -> Result<u32, HashmapError> {
        let mask = self.slots.len() - 1;
        let start = mix_hash(key) as usize & mask;
        let mut hit = None;
        for probe in 0..self.slots.len() {
            let idx = (start + probe) & mask;
            let slot = self.slots[idx];
            if slot.key == EMPTY_KEY {
                return Err(HashmapError::NotFound);
            }
            if slot.key == key {
                hit = Some((idx, slot.value));
                break;
            }
        }
        let (i, value) = hit.ok_or(HashmapError::NotFound)?;
        self.slots[i] = Slot {
            key: EMPTY_KEY,
            value: EMPTY_VALUE,
        };
        self.len -= 1;
        // Backward-shift: walk forward and rehash any displaced entries
        // so later lookups still find them.
        let mut cursor = (i + 1) & mask;
        while self.slots[cursor].key != EMPTY_KEY {
            let displaced = self.slots[cursor];
            self.slots[cursor] = Slot {
                key: EMPTY_KEY,
                value: EMPTY_VALUE,
            };
            self.len -= 1;
            self.insert(displaced.key, displaced.value)?;
            cursor = (cursor + 1) & mask;
        }
        Ok(value)
    }

    pub(crate) fn backend_is_wgsl() -> bool {
        true
    }
}

#[cfg(test)]
mod tests {
    use super::{codes, HashmapError, HashmapStatus};

    #[test]
    fn error_codes_are_one_to_one_with_wgsl_constants() {
        let pairs = [
            (
                "HASHMAP_ERR_INVALID_CAPACITY",
                HashmapError::InvalidCapacity.code(),
            ),
            ("HASHMAP_ERR_RESERVED_KEY", HashmapError::ReservedKey.code()),
            ("HASHMAP_ERR_OVERFLOW", HashmapError::Overflow.code()),
            ("HASHMAP_ERR_NOT_FOUND", HashmapError::NotFound.code()),
        ];
        let mut seen = std::collections::BTreeSet::new();
        for (name, code) in pairs {
            assert!(seen.insert(code), "duplicate hashmap error code {code}");
            assert!(
                crate::ops::workgroup::primitives::hashmap::lowering::WGSL
                    .contains(&format!("const {name}: u32 = {code}u;")),
                "WGSL hashmap error constant {name} must equal Rust code {code}"
            );
        }
        assert_eq!(HashmapStatus::Ok as u32, codes::OK);
        assert_eq!(HashmapStatus::Overflow as u32, codes::ERR_OVERFLOW);
        assert_eq!(HashmapStatus::NotFound as u32, codes::ERR_NOT_FOUND);
        assert_eq!(HashmapStatus::Replaced as u32, codes::REPLACED);
    }
}