#![allow(dead_code)]
use crate::device_caps::DeviceCapabilities;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WorkgroupSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl WorkgroupSize {
#[must_use]
pub fn new_1d(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
#[must_use]
pub fn new_2d(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
#[must_use]
pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
#[must_use]
pub fn total_invocations(&self) -> u32 {
self.x.saturating_mul(self.y).saturating_mul(self.z)
}
#[must_use]
pub fn fits_within(&self, limits: &DeviceLimits) -> bool {
self.x <= limits.max_workgroup_size_x
&& self.y <= limits.max_workgroup_size_y
&& self.z <= limits.max_workgroup_size_z
&& self.total_invocations() <= limits.max_workgroup_invocations
}
}
impl fmt::Display for WorkgroupSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.z > 1 {
write!(f, "{}x{}x{}", self.x, self.y, self.z)
} else if self.y > 1 {
write!(f, "{}x{}", self.x, self.y)
} else {
write!(f, "{}", self.x)
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DeviceLimits {
pub max_workgroup_size_x: u32,
pub max_workgroup_size_y: u32,
pub max_workgroup_size_z: u32,
pub max_workgroup_invocations: u32,
pub max_dispatch_x: u32,
pub max_dispatch_y: u32,
pub max_dispatch_z: u32,
}
impl Default for DeviceLimits {
fn default() -> Self {
Self {
max_workgroup_size_x: 1024,
max_workgroup_size_y: 1024,
max_workgroup_size_z: 64,
max_workgroup_invocations: 1024,
max_dispatch_x: 65535,
max_dispatch_y: 65535,
max_dispatch_z: 65535,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DispatchSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl DispatchSize {
#[must_use]
pub fn total_workgroups(&self) -> u64 {
u64::from(self.x) * u64::from(self.y) * u64::from(self.z)
}
#[must_use]
pub fn total_invocations(&self, wg: &WorkgroupSize) -> u64 {
self.total_workgroups() * u64::from(wg.total_invocations())
}
#[must_use]
pub fn fits_within(&self, limits: &DeviceLimits) -> bool {
self.x <= limits.max_dispatch_x
&& self.y <= limits.max_dispatch_y
&& self.z <= limits.max_dispatch_z
}
}
impl fmt::Display for DispatchSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "dispatch({}x{}x{})", self.x, self.y, self.z)
}
}
#[must_use]
pub fn compute_dispatch_1d(problem_size: u32, workgroup_x: u32) -> DispatchSize {
let x = div_ceil(problem_size, workgroup_x.max(1));
DispatchSize { x, y: 1, z: 1 }
}
#[must_use]
pub fn compute_dispatch_2d(width: u32, height: u32, wg: &WorkgroupSize) -> DispatchSize {
let x = div_ceil(width, wg.x.max(1));
let y = div_ceil(height, wg.y.max(1));
DispatchSize { x, y, z: 1 }
}
#[must_use]
pub fn compute_dispatch_3d(
width: u32,
height: u32,
depth: u32,
wg: &WorkgroupSize,
) -> DispatchSize {
let x = div_ceil(width, wg.x.max(1));
let y = div_ceil(height, wg.y.max(1));
let z = div_ceil(depth, wg.z.max(1));
DispatchSize { x, y, z }
}
#[must_use]
pub fn optimal_2d_workgroup(_width: u32, _height: u32, limits: &DeviceLimits) -> WorkgroupSize {
let candidates: &[(u32, u32)] = &[(16, 16), (32, 8), (8, 32), (16, 8), (8, 16), (8, 8), (4, 4)];
for &(x, y) in candidates {
let wg = WorkgroupSize::new_2d(x, y);
if wg.fits_within(limits) {
return wg;
}
}
WorkgroupSize::new_2d(1, 1)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn estimate_occupancy(wg: &WorkgroupSize, limits: &DeviceLimits) -> f64 {
if limits.max_workgroup_invocations == 0 {
return 0.0;
}
let ratio = f64::from(wg.total_invocations()) / f64::from(limits.max_workgroup_invocations);
ratio.min(1.0)
}
#[must_use]
fn div_ceil(a: u32, b: u32) -> u32 {
if b == 0 {
return 0;
}
a.div_ceil(b)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OpType {
Buffer1D,
Image2D,
ColorConversion,
MotionEstimation,
Convolution,
ToneMapping,
AlphaBlend,
}
#[must_use]
pub fn compute_optimal_workgroup(caps: &DeviceCapabilities, op_type: OpType) -> WorkgroupSize {
let cl = &caps.compute_limits;
let limits = DeviceLimits {
max_workgroup_size_x: cl.max_workgroup_size_x,
max_workgroup_size_y: cl.max_workgroup_size_y,
max_workgroup_size_z: cl.max_workgroup_size_z,
max_workgroup_invocations: cl.max_workgroup_invocations,
max_dispatch_x: cl.max_dispatch_x,
max_dispatch_y: cl.max_dispatch_y,
max_dispatch_z: cl.max_dispatch_z,
};
let prefer_wave_multiple = match caps.vendor {
crate::device_caps::GpuVendor::Amd => 64,
crate::device_caps::GpuVendor::Nvidia => 32,
_ => 32,
};
match op_type {
OpType::Buffer1D | OpType::ColorConversion | OpType::AlphaBlend => {
for candidate in [512u32, 256, 128, 64, prefer_wave_multiple, 32, 16, 8, 1] {
let wg = WorkgroupSize::new_1d(candidate);
if wg.fits_within(&limits) {
return wg;
}
}
WorkgroupSize::new_1d(1)
}
OpType::Image2D | OpType::MotionEstimation => {
let candidates: &[(u32, u32)] = &[
(32, 32),
(16, 16),
(32, 8),
(8, 32),
(16, 8),
(8, 16),
(8, 8),
(4, 4),
];
for &(x, y) in candidates {
let wg = WorkgroupSize::new_2d(x, y);
if wg.fits_within(&limits) {
return wg;
}
}
WorkgroupSize::new_2d(1, 1)
}
OpType::Convolution => {
let candidates: &[(u32, u32)] =
&[(16, 16), (8, 8), (16, 4), (4, 16), (8, 4), (4, 8), (4, 4)];
for &(x, y) in candidates {
let wg = WorkgroupSize::new_2d(x, y);
if wg.fits_within(&limits) {
return wg;
}
}
WorkgroupSize::new_2d(1, 1)
}
OpType::ToneMapping => {
for candidate in [256u32, 128, 64, prefer_wave_multiple, 32, 16, 8, 1] {
let wg = WorkgroupSize::new_1d(candidate);
if wg.fits_within(&limits) {
return wg;
}
}
WorkgroupSize::new_1d(1)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workgroup_1d() {
let wg = WorkgroupSize::new_1d(256);
assert_eq!(wg.total_invocations(), 256);
assert_eq!(wg.to_string(), "256");
}
#[test]
fn test_workgroup_2d() {
let wg = WorkgroupSize::new_2d(16, 16);
assert_eq!(wg.total_invocations(), 256);
assert_eq!(wg.to_string(), "16x16");
}
#[test]
fn test_workgroup_3d() {
let wg = WorkgroupSize::new_3d(8, 8, 4);
assert_eq!(wg.total_invocations(), 256);
assert_eq!(wg.to_string(), "8x8x4");
}
#[test]
fn test_fits_within() {
let limits = DeviceLimits::default();
let wg = WorkgroupSize::new_2d(16, 16);
assert!(wg.fits_within(&limits));
let big = WorkgroupSize::new_2d(2048, 1);
assert!(!big.fits_within(&limits));
}
#[test]
fn test_dispatch_1d() {
let d = compute_dispatch_1d(1000, 256);
assert_eq!(d.x, 4); assert_eq!(d.y, 1);
}
#[test]
fn test_dispatch_2d() {
let wg = WorkgroupSize::new_2d(16, 16);
let d = compute_dispatch_2d(1920, 1080, &wg);
assert_eq!(d.x, 120); assert_eq!(d.y, 68); }
#[test]
fn test_dispatch_3d() {
let wg = WorkgroupSize::new_3d(8, 8, 4);
let d = compute_dispatch_3d(64, 64, 16, &wg);
assert_eq!(d.x, 8);
assert_eq!(d.y, 8);
assert_eq!(d.z, 4);
}
#[test]
fn test_dispatch_total_invocations() {
let wg = WorkgroupSize::new_1d(64);
let d = compute_dispatch_1d(256, 64);
assert_eq!(d.total_invocations(&wg), 256);
}
#[test]
fn test_dispatch_display() {
let d = DispatchSize { x: 3, y: 4, z: 1 };
assert_eq!(d.to_string(), "dispatch(3x4x1)");
}
#[test]
fn test_optimal_2d_workgroup() {
let limits = DeviceLimits::default();
let wg = optimal_2d_workgroup(1920, 1080, &limits);
assert!(wg.fits_within(&limits));
assert_eq!(wg.x, 16);
assert_eq!(wg.y, 16);
}
#[test]
fn test_optimal_2d_restricted_limits() {
let limits = DeviceLimits {
max_workgroup_invocations: 32,
..DeviceLimits::default()
};
let wg = optimal_2d_workgroup(1920, 1080, &limits);
assert!(wg.fits_within(&limits));
assert!(wg.total_invocations() <= 32);
}
#[test]
fn test_estimate_occupancy() {
let limits = DeviceLimits::default();
let wg = WorkgroupSize::new_1d(512);
let occ = estimate_occupancy(&wg, &limits);
assert!((occ - 0.5).abs() < 1e-9); }
#[test]
fn test_estimate_occupancy_capped() {
let limits = DeviceLimits {
max_workgroup_invocations: 128,
..DeviceLimits::default()
};
let wg = WorkgroupSize::new_1d(256);
let occ = estimate_occupancy(&wg, &limits);
assert!((occ - 1.0).abs() < 1e-9);
}
#[test]
fn test_div_ceil_basic() {
assert_eq!(div_ceil(10, 3), 4);
assert_eq!(div_ceil(9, 3), 3);
assert_eq!(div_ceil(0, 5), 0);
assert_eq!(div_ceil(5, 0), 0);
}
fn cpu_caps() -> crate::device_caps::DeviceCapabilities {
crate::device_caps::DeviceCapabilities::cpu_fallback()
}
fn wide_caps() -> crate::device_caps::DeviceCapabilities {
use crate::device_caps::{ComputeLimits, DeviceCapabilities};
let mut caps = DeviceCapabilities::cpu_fallback();
caps.compute_limits = ComputeLimits {
max_workgroup_size_x: 1024,
max_workgroup_size_y: 1024,
max_workgroup_size_z: 64,
max_workgroup_invocations: 1024,
..ComputeLimits::default()
};
caps
}
#[test]
fn test_optimal_wg_buffer1d_fits() {
let caps = wide_caps();
let wg = compute_optimal_workgroup(&caps, OpType::Buffer1D);
let cl = &caps.compute_limits;
let limits = DeviceLimits {
max_workgroup_size_x: cl.max_workgroup_size_x,
max_workgroup_size_y: cl.max_workgroup_size_y,
max_workgroup_size_z: cl.max_workgroup_size_z,
max_workgroup_invocations: cl.max_workgroup_invocations,
max_dispatch_x: cl.max_dispatch_x,
max_dispatch_y: cl.max_dispatch_y,
max_dispatch_z: cl.max_dispatch_z,
};
assert!(wg.fits_within(&limits));
assert_eq!(wg.y, 1, "Buffer1D should be 1-D");
}
#[test]
fn test_optimal_wg_image2d_is_2d() {
let caps = wide_caps();
let wg = compute_optimal_workgroup(&caps, OpType::Image2D);
assert!(wg.x > 1 && wg.y > 1, "Image2D should be 2-D; got {wg}");
}
#[test]
fn test_optimal_wg_motion_estimation_is_2d() {
let caps = wide_caps();
let wg = compute_optimal_workgroup(&caps, OpType::MotionEstimation);
assert!(
wg.x > 1 && wg.y > 1,
"MotionEstimation should be 2-D; got {wg}"
);
}
#[test]
fn test_optimal_wg_convolution_is_2d() {
let caps = wide_caps();
let wg = compute_optimal_workgroup(&caps, OpType::Convolution);
assert!(wg.y > 1, "Convolution should use 2-D workgroup; got {wg}");
}
#[test]
fn test_optimal_wg_tone_mapping_1d() {
let caps = wide_caps();
let wg = compute_optimal_workgroup(&caps, OpType::ToneMapping);
assert_eq!(wg.y, 1, "ToneMapping should be 1-D; got {wg}");
}
#[test]
fn test_optimal_wg_alpha_blend_1d() {
let caps = wide_caps();
let wg = compute_optimal_workgroup(&caps, OpType::AlphaBlend);
assert_eq!(wg.y, 1, "AlphaBlend should be 1-D; got {wg}");
}
#[test]
fn test_optimal_wg_restricted_caps_fits() {
use crate::device_caps::{ComputeLimits, DeviceCapabilities};
let mut caps = DeviceCapabilities::cpu_fallback();
caps.compute_limits = ComputeLimits {
max_workgroup_size_x: 4,
max_workgroup_size_y: 4,
max_workgroup_size_z: 1,
max_workgroup_invocations: 8,
..ComputeLimits::default()
};
for op in [
OpType::Buffer1D,
OpType::Image2D,
OpType::ColorConversion,
OpType::MotionEstimation,
OpType::Convolution,
OpType::ToneMapping,
OpType::AlphaBlend,
] {
let wg = compute_optimal_workgroup(&caps, op);
let cl = &caps.compute_limits;
let limits = DeviceLimits {
max_workgroup_size_x: cl.max_workgroup_size_x,
max_workgroup_size_y: cl.max_workgroup_size_y,
max_workgroup_size_z: cl.max_workgroup_size_z,
max_workgroup_invocations: cl.max_workgroup_invocations,
max_dispatch_x: cl.max_dispatch_x,
max_dispatch_y: cl.max_dispatch_y,
max_dispatch_z: cl.max_dispatch_z,
};
assert!(
wg.fits_within(&limits),
"op {op:?}: {wg} does not fit within limits"
);
}
}
#[test]
fn test_optimal_wg_color_conversion_is_1d() {
let caps = cpu_caps();
let wg = compute_optimal_workgroup(&caps, OpType::ColorConversion);
assert_eq!(wg.y, 1, "ColorConversion should be 1-D; got {wg}");
}
}