use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];
pub const SPEC: OpSpec = OpSpec::intrinsic(
"workgroup.union_find",
INPUTS,
OUTPUTS,
LAWS,
wgsl_only,
IntrinsicDescriptor::new(
"workgroup_union_find_kernel",
"workgroup-sram-atomic-parent",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnionFindStatus {
Ok = 0,
OutOfRange = 1,
AlreadyUnified = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnionFindError {
OutOfRange,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupUnionFind {
parent: Vec<u32>,
rank: Vec<u32>,
}
impl WorkgroupUnionFind {
#[must_use]
pub fn new(n: usize) -> Self {
let parent: Vec<u32> = (0..n as u32).collect();
let rank = vec![0u32; n];
Self { parent, rank }
}
#[must_use]
pub fn len(&self) -> usize {
self.parent.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.parent.is_empty()
}
pub fn find(&mut self, x: u32) -> Result<u32, UnionFindError> {
let n = self.parent.len();
if x as usize >= n {
return Err(UnionFindError::OutOfRange);
}
let mut root = x;
while self.parent[root as usize] != root {
root = self.parent[root as usize];
}
let mut cursor = x;
while self.parent[cursor as usize] != root {
let next = self.parent[cursor as usize];
self.parent[cursor as usize] = root;
cursor = next;
}
Ok(root)
}
pub fn union(&mut self, a: u32, b: u32) -> Result<UnionFindStatus, UnionFindError> {
let ra = self.find(a)?;
let rb = self.find(b)?;
if ra == rb {
return Ok(UnionFindStatus::AlreadyUnified);
}
let (small, large) = if self.rank[ra as usize] < self.rank[rb as usize] {
(ra, rb)
} else if self.rank[ra as usize] > self.rank[rb as usize] {
(rb, ra)
} else {
let (small, large) = if ra < rb { (rb, ra) } else { (ra, rb) };
self.rank[large as usize] += 1;
(small, large)
};
self.parent[small as usize] = large;
Ok(UnionFindStatus::Ok)
}
pub fn connected(&mut self, a: u32, b: u32) -> Result<bool, UnionFindError> {
Ok(self.find(a)? == self.find(b)?)
}
}
pub fn wgsl_only(backend: &Backend) -> bool {
matches!(backend, Backend::Wgsl)
}