use std::fmt;
use crate::ir::Program;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub struct RequiredCapabilities {
pub subgroup_ops: bool,
pub f16: bool,
pub bf16: bool,
pub f64: bool,
pub async_dispatch: bool,
pub indirect_dispatch: bool,
pub tensor_ops: bool,
pub trap: bool,
pub max_workgroup_size: [u32; 3],
pub static_storage_bytes: u64,
}
impl RequiredCapabilities {
#[must_use]
pub fn none() -> Self {
Self::default()
}
#[must_use]
pub fn union(mut self, other: RequiredCapabilities) -> Self {
self.subgroup_ops |= other.subgroup_ops;
self.f16 |= other.f16;
self.bf16 |= other.bf16;
self.f64 |= other.f64;
self.async_dispatch |= other.async_dispatch;
self.indirect_dispatch |= other.indirect_dispatch;
self.tensor_ops |= other.tensor_ops;
self.trap |= other.trap;
for axis in 0..3 {
self.max_workgroup_size[axis] =
self.max_workgroup_size[axis].max(other.max_workgroup_size[axis]);
}
self.static_storage_bytes = self
.static_storage_bytes
.saturating_add(other.static_storage_bytes);
self
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MissingCapability {
pub backend: String,
pub missing: Vec<&'static str>,
}
impl fmt::Display for MissingCapability {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"backend `{}` is missing required capabilities: {}. \
Fix: pick a backend that advertises these capabilities, \
or run the program on the CPU reference.",
self.backend,
self.missing.join(", ")
)
}
}
impl std::error::Error for MissingCapability {}
#[must_use]
pub fn scan(program: &Program) -> RequiredCapabilities {
let stats = program.stats();
RequiredCapabilities {
subgroup_ops: stats.subgroup_ops(),
f16: stats.f16(),
bf16: stats.bf16(),
f64: stats.f64(),
async_dispatch: stats.async_dispatch(),
indirect_dispatch: stats.indirect_dispatch(),
tensor_ops: stats.tensor_ops(),
trap: stats.trap(),
max_workgroup_size: program.workgroup_size,
static_storage_bytes: stats.static_storage_bytes,
}
}
pub fn check_backend_capabilities(
backend_id: &str,
supports_subgroup_ops: bool,
supports_f16: bool,
supports_bf16: bool,
supports_indirect_dispatch: bool,
supports_trap_propagation: bool,
max_workgroup_size: [u32; 3],
required: &RequiredCapabilities,
) -> Result<(), MissingCapability> {
let mut missing = Vec::new();
if required.subgroup_ops && !supports_subgroup_ops {
missing.push("subgroup_ops");
}
if required.f16 && !supports_f16 {
missing.push("f16");
}
if required.bf16 && !supports_bf16 {
missing.push("bf16");
}
if required.indirect_dispatch && !supports_indirect_dispatch {
missing.push("indirect_dispatch");
}
if required.trap && !supports_trap_propagation {
missing.push("trap_propagation");
}
for (req_size, max_size) in required
.max_workgroup_size
.iter()
.zip(max_workgroup_size.iter())
{
if *req_size > *max_size && *max_size != 0 {
missing.push("workgroup_size");
break;
}
}
if missing.is_empty() {
Ok(())
} else {
Err(MissingCapability {
backend: backend_id.to_string(),
missing,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr as IrExpr, Node as IrNode, Program};
fn empty_program() -> Program {
Program::wrapped(
vec![BufferDecl::storage(
"out",
0,
BufferAccess::ReadWrite,
DataType::U32,
)],
[1, 1, 1],
vec![IrNode::let_bind("x", IrExpr::u32(0))],
)
}
#[test]
fn scan_scalar_program_declares_no_capabilities() {
let caps = scan(&empty_program());
assert!(!caps.subgroup_ops);
assert!(!caps.f16);
assert!(!caps.async_dispatch);
}
#[test]
fn scan_subgroup_add_requires_subgroup_ops() {
let program = Program::wrapped(
vec![BufferDecl::storage(
"out",
0,
BufferAccess::ReadWrite,
DataType::U32,
)],
[1, 1, 1],
vec![IrNode::let_bind(
"s",
IrExpr::SubgroupAdd {
value: Box::new(IrExpr::u32(1)),
},
)],
);
let caps = scan(&program);
assert!(caps.subgroup_ops);
}
#[test]
fn scan_call_to_subgroup_intrinsic_requires_subgroup_ops() {
let program = Program::wrapped(
vec![BufferDecl::storage(
"out",
0,
BufferAccess::ReadWrite,
DataType::U32,
)],
[1, 1, 1],
vec![IrNode::let_bind(
"s",
IrExpr::call(
"vyre-intrinsics::math::subgroup_inclusive_add",
vec![IrExpr::u32(1)],
),
)],
);
let caps = scan(&program);
assert!(caps.subgroup_ops);
}
#[test]
fn check_backend_reports_every_missing_bit() {
let required = RequiredCapabilities {
subgroup_ops: true,
f16: true,
trap: true,
..RequiredCapabilities::default()
};
let error = check_backend_capabilities(
"test_backend",
false,
false,
false,
false,
false,
[64, 1, 1],
&required,
)
.unwrap_err();
assert_eq!(error.backend, "test_backend");
assert!(error.missing.contains(&"subgroup_ops"));
assert!(error.missing.contains(&"f16"));
assert!(error.missing.contains(&"trap_propagation"));
}
}