vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Specification and CPU reference for `workgroup.queue_priority`.

use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};

pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];

pub const SPEC: OpSpec = OpSpec::intrinsic(
    "workgroup.queue_priority",
    INPUTS,
    OUTPUTS,
    LAWS,
    wgsl_only,
    IntrinsicDescriptor::new(
        "workgroup_queue_priority_push",
        "workgroup-sram-binary-heap",
        crate::ops::cpu_op::structured_intrinsic_cpu,
    ),
);
/// Value and priority returned by priority queue reads.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PriorityItem<T> {
    /// Stored payload.
    pub value: T,
    /// Max-heap priority. Larger values pop first.
    pub priority: u32,
}
/// Error returned by fallible priority queue operations.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PriorityError {
    /// The bounded heap has no remaining capacity.
    Overflow,
    /// The heap contains no value to read.
    Underflow,
}
/// Bounded binary max-heap used as the CPU reference.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupPriorityQueue<T> {
    capacity: usize,
    heap: Vec<PriorityItem<T>>,
}
impl<T: Copy + Ord> WorkgroupPriorityQueue<T> {
    /// Create an empty priority queue with a fixed capacity.
    #[must_use]
    pub fn new(capacity: usize) -> Self {
        Self {
            capacity,
            heap: Vec::with_capacity(capacity),
        }
    }

    /// Push one `(value, priority)` pair into the heap.
    ///
    /// # Errors
    ///
    /// Returns [`PriorityError::Overflow`] when `len() == capacity`.
    pub fn push(&mut self, value: T, priority: u32) -> Result<(), PriorityError> {
        if self.heap.len() >= self.capacity {
            return Err(PriorityError::Overflow);
        }
        self.heap.push(PriorityItem { value, priority });
        self.sift_up(self.heap.len() - 1);
        Ok(())
    }

    /// Pop the maximum-priority item.
    ///
    /// # Errors
    ///
    /// Returns [`PriorityError::Underflow`] when the heap is empty.
    pub fn pop_max(&mut self) -> Result<PriorityItem<T>, PriorityError> {
        if self.heap.is_empty() {
            return Err(PriorityError::Underflow);
        }
        let max = self.heap[0];
        let last_index = self.heap.len() - 1;
        self.heap.swap(0, last_index);
        let _ = self.heap.pop();
        if !self.heap.is_empty() {
            self.sift_down(0);
        }
        Ok(max)
    }

    /// Read the maximum-priority item without removing it.
    ///
    /// # Errors
    ///
    /// Returns [`PriorityError::Underflow`] when the heap is empty.
    pub fn peek_max(&self) -> Result<PriorityItem<T>, PriorityError> {
        self.heap.first().copied().ok_or(PriorityError::Underflow)
    }

    /// Return the number of live heap items.
    #[must_use]
    pub fn len(&self) -> usize {
        self.heap.len()
    }

    /// Return true when the heap contains no values.
    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.heap.is_empty()
    }

    pub(crate) fn higher_priority(left: PriorityItem<T>, right: PriorityItem<T>) -> bool {
        left.priority > right.priority
            || (left.priority == right.priority && left.value > right.value)
    }

    pub(crate) fn sift_up(&mut self, mut index: usize) {
        while index > 0 {
            let parent = (index - 1) / 2;
            if !Self::higher_priority(self.heap[index], self.heap[parent]) {
                break;
            }
            self.heap.swap(index, parent);
            index = parent;
        }
    }

    pub(crate) fn sift_down(&mut self, mut index: usize) {
        loop {
            let left = index * 2 + 1;
            let right = left + 1;
            let mut best = index;
            if left < self.heap.len() && Self::higher_priority(self.heap[left], self.heap[best]) {
                best = left;
            }
            if right < self.heap.len() && Self::higher_priority(self.heap[right], self.heap[best]) {
                best = right;
            }
            if best == index {
                break;
            }
            self.heap.swap(index, best);
            index = best;
        }
    }
}
pub fn wgsl_only(backend: &Backend) -> bool {
    matches!(backend, Backend::Wgsl)
}