use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];
pub const SPEC: OpSpec = OpSpec::intrinsic(
"workgroup.queue_priority",
INPUTS,
OUTPUTS,
LAWS,
wgsl_only,
IntrinsicDescriptor::new(
"workgroup_queue_priority_push",
"workgroup-sram-binary-heap",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PriorityItem<T> {
pub value: T,
pub priority: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PriorityError {
Overflow,
Underflow,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupPriorityQueue<T> {
capacity: usize,
heap: Vec<PriorityItem<T>>,
}
impl<T: Copy + Ord> WorkgroupPriorityQueue<T> {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
capacity,
heap: Vec::with_capacity(capacity),
}
}
pub fn push(&mut self, value: T, priority: u32) -> Result<(), PriorityError> {
if self.heap.len() >= self.capacity {
return Err(PriorityError::Overflow);
}
self.heap.push(PriorityItem { value, priority });
self.sift_up(self.heap.len() - 1);
Ok(())
}
pub fn pop_max(&mut self) -> Result<PriorityItem<T>, PriorityError> {
if self.heap.is_empty() {
return Err(PriorityError::Underflow);
}
let max = self.heap[0];
let last_index = self.heap.len() - 1;
self.heap.swap(0, last_index);
let _ = self.heap.pop();
if !self.heap.is_empty() {
self.sift_down(0);
}
Ok(max)
}
pub fn peek_max(&self) -> Result<PriorityItem<T>, PriorityError> {
self.heap.first().copied().ok_or(PriorityError::Underflow)
}
#[must_use]
pub fn len(&self) -> usize {
self.heap.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub(crate) fn higher_priority(left: PriorityItem<T>, right: PriorityItem<T>) -> bool {
left.priority > right.priority
|| (left.priority == right.priority && left.value > right.value)
}
pub(crate) fn sift_up(&mut self, mut index: usize) {
while index > 0 {
let parent = (index - 1) / 2;
if !Self::higher_priority(self.heap[index], self.heap[parent]) {
break;
}
self.heap.swap(index, parent);
index = parent;
}
}
pub(crate) fn sift_down(&mut self, mut index: usize) {
loop {
let left = index * 2 + 1;
let right = left + 1;
let mut best = index;
if left < self.heap.len() && Self::higher_priority(self.heap[left], self.heap[best]) {
best = left;
}
if right < self.heap.len() && Self::higher_priority(self.heap[right], self.heap[best]) {
best = right;
}
if best == index {
break;
}
self.heap.swap(index, best);
index = best;
}
}
}
pub fn wgsl_only(backend: &Backend) -> bool {
matches!(backend, Backend::Wgsl)
}