Skip to main content

vyre_driver/
subgroup.rs

1//! Backend-neutral subgroup operation taxonomy.
2
3/// Canonical subgroup intrinsic operation.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5#[non_exhaustive]
6pub enum SubgroupOp {
7    /// Broadcast a value from one subgroup lane to all lanes.
8    Broadcast,
9    /// Reduce add across the subgroup.
10    Add,
11    /// Reduce max across the subgroup.
12    Max,
13    /// Reduce min across the subgroup.
14    Min,
15    /// Inclusive scan add across the subgroup.
16    InclusiveAdd,
17    /// Exclusive scan add across the subgroup.
18    ExclusiveAdd,
19    /// Shuffle-xor butterfly swap.
20    ShuffleXor,
21}
22
23impl SubgroupOp {
24    /// Iterate every canonical operation.
25    #[must_use]
26    pub const fn all() -> &'static [SubgroupOp] {
27        &[
28            SubgroupOp::Broadcast,
29            SubgroupOp::Add,
30            SubgroupOp::Max,
31            SubgroupOp::Min,
32            SubgroupOp::InclusiveAdd,
33            SubgroupOp::ExclusiveAdd,
34            SubgroupOp::ShuffleXor,
35        ]
36    }
37}
38
39/// Subgroup capability record shared by validation and optimizers.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub struct SubgroupCaps {
42    /// Native subgroup operations are available for compute.
43    pub supports_subgroup: bool,
44    /// Subgroup operations are available in vertex-stage contexts.
45    pub supports_subgroup_vertex: bool,
46    /// Subgroup size in lanes; `0` means unknown.
47    pub subgroup_size: u32,
48}
49
50impl SubgroupCaps {
51    /// Capability record for native subgroup intrinsics.
52    #[must_use]
53    pub const fn native(subgroup_size: u32) -> Self {
54        Self {
55            supports_subgroup: true,
56            supports_subgroup_vertex: false,
57            subgroup_size,
58        }
59    }
60
61    /// Capability record from a feature bit and reported lane-size range.
62    #[must_use]
63    pub const fn from_feature_range(
64        supports_feature: bool,
65        supports_vertex_stage: bool,
66        min_size: u32,
67        max_size: u32,
68    ) -> Self {
69        let supports_subgroup = supports_feature && min_size > 0 && max_size >= min_size;
70        Self {
71            supports_subgroup,
72            supports_subgroup_vertex: supports_vertex_stage && supports_subgroup,
73            subgroup_size: if supports_subgroup { min_size } else { 0 },
74        }
75    }
76
77    /// Return true when native subgroup operations are usable.
78    #[must_use]
79    pub const fn is_usable(self) -> bool {
80        self.supports_subgroup && self.subgroup_size > 0
81    }
82}
83
84/// Canonical lane offsets for a power-of-two full-subgroup tree reduction.
85#[must_use]
86pub fn reduction_offsets(subgroup_size: u32) -> Vec<u32> {
87    let mut offsets = Vec::new();
88    reduction_offsets_into(subgroup_size, &mut offsets);
89    offsets
90}
91
92/// Fallible canonical lane offsets for a full-subgroup tree reduction.
93///
94/// # Errors
95///
96/// Returns an error when the requested subgroup width cannot be rounded to a
97/// power-of-two reduction width or the output vector cannot reserve storage.
98pub fn try_reduction_offsets(subgroup_size: u32) -> Result<Vec<u32>, String> {
99    let mut offsets = Vec::new();
100    try_reduction_offsets_into(subgroup_size, &mut offsets)?;
101    Ok(offsets)
102}
103
104/// Write canonical reduction offsets into caller-owned storage.
105pub fn reduction_offsets_into(subgroup_size: u32, offsets: &mut Vec<u32>) {
106    if try_reduction_offsets_into(subgroup_size, offsets).is_err() {
107        offsets.clear();
108    }
109}
110
111/// Fallibly write canonical reduction offsets into caller-owned storage.
112///
113/// # Errors
114///
115/// Returns an error when the subgroup width overflows power-of-two rounding or
116/// the output vector cannot reserve the required offsets.
117pub fn try_reduction_offsets_into(
118    subgroup_size: u32,
119    offsets: &mut Vec<u32>,
120) -> Result<(), String> {
121    offsets.clear();
122    let Some(rounded_width) = subgroup_size.checked_next_power_of_two() else {
123        return Err(format!(
124            "subgroup reduction width {subgroup_size} cannot be rounded to a power of two. Fix: clamp subgroup size to a valid backend-reported hardware width."
125        ));
126    };
127    let offset_count = if subgroup_size == 0 {
128        0
129    } else {
130        rounded_width.ilog2() as usize
131    };
132    crate::allocation::try_reserve_vec_to_capacity(offsets, offset_count).map_err(|error| {
133        format!(
134            "subgroup reduction offsets could not reserve {offset_count} slot(s): {error}. Fix: reuse caller-owned offset storage or clamp subgroup size."
135        )
136    })?;
137    let mut width = rounded_width / 2;
138    while width > 0 {
139        offsets.push(width);
140        width /= 2;
141    }
142    Ok(())
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn all_enumerates_seven_ops() {
151        assert_eq!(SubgroupOp::all().len(), 7);
152    }
153
154    #[test]
155    fn try_reduction_offsets_reuses_storage() {
156        let mut offsets = Vec::with_capacity(8);
157        let ptr = offsets.as_ptr();
158
159        try_reduction_offsets_into(32, &mut offsets).unwrap();
160
161        assert_eq!(offsets, [16, 8, 4, 2, 1]);
162        assert_eq!(offsets.as_ptr(), ptr);
163    }
164
165    #[test]
166    fn try_reduction_offsets_rejects_overflowing_rounding() {
167        let error = try_reduction_offsets(u32::MAX).unwrap_err();
168        assert!(error.contains("cannot be rounded to a power of two"));
169    }
170
171    #[test]
172    fn legacy_reduction_offset_wrapper_clears_invalid_width() {
173        let mut offsets = vec![16, 8, 4];
174
175        reduction_offsets_into(u32::MAX, &mut offsets);
176
177        assert!(offsets.is_empty());
178        assert!(reduction_offsets(u32::MAX).is_empty());
179    }
180}