vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Specification and CPU reference for `workgroup.union_find`.

use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};

pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];

pub const SPEC: OpSpec = OpSpec::intrinsic(
    "workgroup.union_find",
    INPUTS,
    OUTPUTS,
    LAWS,
    wgsl_only,
    IntrinsicDescriptor::new(
        "workgroup_union_find_kernel",
        "workgroup-sram-atomic-parent",
        crate::ops::cpu_op::structured_intrinsic_cpu,
    ),
);
/// Union-find command status word shared by the CPU oracle and WGSL lowering.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnionFindStatus {
    /// Operation completed.
    Ok = 0,
    /// The element index was out of range.
    OutOfRange = 1,
    /// Union was a no-op because both elements shared a representative.
    AlreadyUnified = 2,
}
/// Error returned by fallible union-find operations.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnionFindError {
    /// Element index is outside the declared universe.
    OutOfRange,
}
/// Bounded disjoint-set data structure used as the CPU reference.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupUnionFind {
    parent: Vec<u32>,
    rank: Vec<u32>,
}
impl WorkgroupUnionFind {
    /// Create a fresh DSU with `n` singleton sets.
    #[must_use]
    pub fn new(n: usize) -> Self {
        let parent: Vec<u32> = (0..n as u32).collect();
        let rank = vec![0u32; n];
        Self { parent, rank }
    }

    /// Number of elements in the universe.
    #[must_use]
    pub fn len(&self) -> usize {
        self.parent.len()
    }

    /// Whether the DSU is empty.
    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.parent.is_empty()
    }

    /// Return the representative of the set containing `x`, applying
    /// path compression so subsequent lookups are O(α(n)).
    ///
    /// # Errors
    ///
    /// Returns [`UnionFindError::OutOfRange`] when `x >= len()`.
    pub fn find(&mut self, x: u32) -> Result<u32, UnionFindError> {
        let n = self.parent.len();
        if x as usize >= n {
            return Err(UnionFindError::OutOfRange);
        }
        // Iterative two-pass path compression: first walk to the root,
        // then re-walk and rewrite parents.
        let mut root = x;
        while self.parent[root as usize] != root {
            root = self.parent[root as usize];
        }
        let mut cursor = x;
        while self.parent[cursor as usize] != root {
            let next = self.parent[cursor as usize];
            self.parent[cursor as usize] = root;
            cursor = next;
        }
        Ok(root)
    }

    /// Union the sets containing `a` and `b`. Uses union-by-rank so the
    /// resulting tree height is logarithmic.
    ///
    /// # Errors
    ///
    /// Returns [`UnionFindError::OutOfRange`] when either index is out
    /// of range.
    pub fn union(&mut self, a: u32, b: u32) -> Result<UnionFindStatus, UnionFindError> {
        let ra = self.find(a)?;
        let rb = self.find(b)?;
        if ra == rb {
            return Ok(UnionFindStatus::AlreadyUnified);
        }
        let (small, large) = if self.rank[ra as usize] < self.rank[rb as usize] {
            (ra, rb)
        } else if self.rank[ra as usize] > self.rank[rb as usize] {
            (rb, ra)
        } else {
            // Equal ranks: arbitrary choice, break tie by smaller index
            // so both CPU and GPU paths agree on the canonical parent.
            let (small, large) = if ra < rb { (rb, ra) } else { (ra, rb) };
            self.rank[large as usize] += 1;
            (small, large)
        };
        self.parent[small as usize] = large;
        Ok(UnionFindStatus::Ok)
    }

    /// Whether `a` and `b` are currently in the same set.
    ///
    /// # Errors
    ///
    /// Returns [`UnionFindError::OutOfRange`] when either index is out
    /// of range.
    pub fn connected(&mut self, a: u32, b: u32) -> Result<bool, UnionFindError> {
        Ok(self.find(a)? == self.find(b)?)
    }
}
pub fn wgsl_only(backend: &Backend) -> bool {
    matches!(backend, Backend::Wgsl)
}