use crate::{
device::{Device, DeviceInner, Features},
scalar::{ScalarElem, ScalarType},
};
use anyhow::{bail, Result};
#[cfg(feature = "device")]
use dry::macro_wrap;
#[cfg(feature = "device")]
use rspirv::{binary::Assemble, dr::Operand};
use std::{borrow::Cow, sync::Arc};
#[cfg(feature = "device")]
use std::{
collections::HashMap,
hash::Hash,
sync::atomic::{AtomicBool, Ordering},
};
#[cfg_attr(not(feature = "device"), allow(dead_code))]
#[derive(Clone, Debug)]
pub(crate) struct KernelDesc {
pub(crate) name: Cow<'static, str>,
pub(crate) spirv: Vec<u32>,
features: Features,
pub(crate) threads: u32,
spec_descs: &'static [SpecDesc],
pub(crate) slice_descs: &'static [SliceDesc],
push_descs: &'static [PushDesc],
}
#[cfg(feature = "device")]
impl KernelDesc {
pub(crate) fn push_consts_range(&self) -> u32 {
let mut size = 0;
for push_desc in self.push_descs.iter() {
while size % push_desc.scalar_type.size() != 0 {
size += 1;
}
size += push_desc.scalar_type.size()
}
while size % 4 != 0 {
size += 1;
}
size += self.slice_descs.len() * 2 * 4;
size.try_into().unwrap()
}
fn specialize(
&self,
threads: u32,
spec_consts: &[ScalarElem],
debug_printf: bool,
) -> Result<Self> {
use rspirv::spirv::{Decoration, Op};
let mut module = rspirv::dr::load_words(&self.spirv).unwrap();
let mut spec_ids = HashMap::<u32, u32>::with_capacity(spec_consts.len());
let mut spec_string = format!("threads={threads}");
use std::fmt::Write;
for (desc, spec) in self.spec_descs.iter().zip(spec_consts) {
if !spec_string.is_empty() {
spec_string.push_str(", ");
}
let n = desc.name;
macro_wrap!(match spec {
macro_for!($T in [U8, I8, U16, I16, F16, BF16, U32, I32, F32, U64, I64, F64] {
ScalarElem::$T(x) => write!(&mut spec_string, "{n}={x}").unwrap(),
})
_ => unreachable!("{spec:?}"),
});
}
let name = if !spec_string.is_empty() {
format!("{}<{spec_string}>", self.name).into()
} else {
self.name.clone()
};
for inst in module.annotations.iter() {
if inst.class.opcode == Op::Decorate {
if let [Operand::IdRef(id), Operand::Decoration(Decoration::SpecId), Operand::LiteralInt32(spec_id)] =
inst.operands.as_slice()
{
spec_ids.insert(*id, *spec_id);
}
}
}
for inst in module.types_global_values.iter_mut() {
if inst.class.opcode == Op::SpecConstant {
if let Some(result_id) = inst.result_id {
if let Some(spec_id) = spec_ids.get(&result_id).copied().map(|x| x as usize) {
let value = if let Some(value) = spec_consts.get(spec_id).copied() {
value
} else if spec_id == spec_consts.len() {
ScalarElem::U32(threads)
} else {
unreachable!("{inst:?}")
};
match inst.operands.as_mut_slice() {
[Operand::LiteralInt32(a)] => {
bytemuck::bytes_of_mut(a).copy_from_slice(value.as_bytes());
}
[Operand::LiteralInt32(a), Operand::LiteralInt32(b)] => {
bytemuck::bytes_of_mut(a).copy_from_slice(&value.as_bytes()[..8]);
bytemuck::bytes_of_mut(b).copy_from_slice(&value.as_bytes()[9..]);
}
_ => unreachable!("{:?}", inst.operands),
}
}
}
}
}
if !debug_printf {
strip_debug_printf(&mut module);
}
let spirv = module.assemble();
Ok(Self {
name,
spirv,
spec_descs: &[],
threads,
..self.clone()
})
}
}
#[cfg(feature = "device")]
fn strip_debug_printf(module: &mut rspirv::dr::Module) {
use fxhash::FxHashSet;
use rspirv::spirv::Op;
module.extensions.retain(|inst| {
inst.operands.first().unwrap().unwrap_literal_string() != "SPV_KHR_non_semantic_info"
});
let mut ext_insts = FxHashSet::default();
module.ext_inst_imports.retain(|inst| {
if inst
.operands
.first()
.unwrap()
.unwrap_literal_string()
.starts_with("NonSemantic.DebugPrintf")
{
ext_insts.insert(inst.result_id.unwrap());
false
} else {
true
}
});
if ext_insts.is_empty() {
return;
}
module.debug_string_source.clear();
for func in module.functions.iter_mut() {
for block in func.blocks.iter_mut() {
block.instructions.retain(|inst| {
if inst.class.opcode == Op::ExtInst {
let id = inst.operands.first().unwrap().unwrap_id_ref();
if ext_insts.contains(&id) {
return false;
}
}
!matches!(inst.class.opcode, Op::Line | Op::NoLine)
})
}
}
}
#[cfg(feature = "device")]
#[derive(PartialEq, Eq, Hash, Debug)]
pub(crate) struct KernelKey {
id: usize,
spec_bytes: Vec<u8>,
}
#[doc(hidden)]
pub mod __private {
#[cfg(feature = "device")]
use num_traits::ToPrimitive;
use super::*;
#[cfg(feature = "device")]
use crate::device::{DeviceBuffer, RawKernel};
use crate::{
buffer::{ScalarSlice, ScalarSliceMut, Slice, SliceMut},
scalar::Scalar,
};
#[derive(Clone, Copy)]
pub struct KernelDesc {
name: &'static str,
spirv: &'static [u8],
features: Features,
safe: bool,
spec_descs: &'static [SpecDesc],
slice_descs: &'static [SliceDesc],
push_descs: &'static [PushDesc],
}
#[derive(Clone, Copy)]
pub struct KernelDescArgs {
pub name: &'static str,
pub spirv: &'static [u8],
pub features: Features,
pub safe: bool,
pub spec_descs: &'static [SpecDesc],
pub slice_descs: &'static [SliceDesc],
pub push_descs: &'static [PushDesc],
}
const fn bytes_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut i = 0;
while i < a.len() {
if a[i] != b[i] {
return false;
}
i += 1;
}
true
}
pub const fn find_kernel(name: &str, kernels: &[KernelDesc]) -> Option<KernelDesc> {
let mut i = 0;
while i < kernels.len() {
if bytes_eq(name.as_bytes(), kernels[i].name.as_bytes()) {
return Some(kernels[i]);
}
i += 1;
}
None
}
pub const fn validate_kernel(
kernel: Option<Option<KernelDesc>>,
safety: Safety,
spec_descs: &[SpecDesc],
slice_descs: &[SliceDesc],
push_descs: &[PushDesc],
) -> Option<KernelDesc> {
if let Some(kernel) = kernel {
let success = if let Some(kernel) = kernel.as_ref() {
kernel.check_declaration(safety, spec_descs, slice_descs, push_descs)
} else {
false
};
if !success {
panic!("recompile with krnlc");
}
kernel
} else {
None
}
}
impl KernelDesc {
pub const fn from_args(args: KernelDescArgs) -> Self {
let KernelDescArgs {
name,
spirv,
features,
safe,
spec_descs,
slice_descs,
push_descs,
} = args;
Self {
name,
spirv,
features,
safe,
spec_descs,
slice_descs,
push_descs,
}
}
const fn check_declaration(
&self,
safety: Safety,
spec_descs: &[SpecDesc],
slice_descs: &[SliceDesc],
push_descs: &[PushDesc],
) -> bool {
if self.safe != safety.is_safe() {
return false;
}
{
if self.spec_descs.len() != spec_descs.len() {
return false;
}
let mut index = spec_descs.len();
while index < spec_descs.len() {
if !self.spec_descs[index].const_eq(&spec_descs[index]) {
return false;
}
index += 1;
}
}
{
if self.slice_descs.len() != slice_descs.len() {
return false;
}
let mut index = slice_descs.len();
while index < slice_descs.len() {
if !self.slice_descs[index].const_eq(&slice_descs[index]) {
return false;
}
index += 1;
}
}
{
if self.push_descs.len() != push_descs.len() {
return false;
}
let mut index = push_descs.len();
while index < push_descs.len() {
if !self.push_descs[index].const_eq(&push_descs[index]) {
return false;
}
index += 1;
}
}
true
}
}
#[derive(Clone, Copy)]
pub enum Safety {
Safe,
Unsafe,
}
impl Safety {
const fn is_safe(&self) -> bool {
matches!(self, Self::Safe)
}
}
const fn scalar_type_const_eq(a: ScalarType, b: ScalarType) -> bool {
a as u32 == b as u32
}
#[derive(Clone, Copy, Debug)]
pub struct SpecDesc {
pub name: &'static str,
pub scalar_type: ScalarType,
}
impl SpecDesc {
const fn const_eq(&self, other: &Self) -> bool {
bytes_eq(self.name.as_bytes(), other.name.as_bytes())
&& scalar_type_const_eq(self.scalar_type, other.scalar_type)
}
}
#[derive(Clone, Copy, Debug)]
pub struct SliceDesc {
pub name: &'static str,
pub scalar_type: ScalarType,
pub mutable: bool,
pub item: bool,
}
impl SliceDesc {
const fn const_eq(&self, other: &Self) -> bool {
bytes_eq(self.name.as_bytes(), other.name.as_bytes())
&& scalar_type_const_eq(self.scalar_type, other.scalar_type)
&& self.mutable == other.mutable
&& self.item == other.item
}
}
#[derive(Clone, Copy, Debug)]
pub struct PushDesc {
pub name: &'static str,
pub scalar_type: ScalarType,
}
impl PushDesc {
const fn const_eq(&self, other: &Self) -> bool {
bytes_eq(self.name.as_bytes(), other.name.as_bytes())
&& scalar_type_const_eq(self.scalar_type, other.scalar_type)
}
}
fn decode_spirv(name: &str, input: &[u8]) -> Result<Vec<u32>, String> {
use flate2::read::GzDecoder;
use std::io::Read;
let mut output = Vec::new();
GzDecoder::new(bytemuck::cast_slice(input))
.read_to_end(&mut output)
.map_err(|e| format!("Kernel `{name}` failed to decode! {e}"))?;
let output = output
.chunks_exact(4)
.map(|x| u32::from_ne_bytes(x.try_into().unwrap()))
.collect();
Ok(output)
}
pub enum Specialized<const S: bool> {}
#[cfg_attr(not(feature = "device"), allow(dead_code))]
#[derive(Clone)]
pub struct KernelBuilder {
id: usize,
desc: Arc<super::KernelDesc>,
spec_consts: Vec<ScalarElem>,
threads: Option<u32>,
}
impl KernelBuilder {
pub fn from_desc(desc: KernelDesc) -> Result<Self, String> {
let KernelDesc {
name,
spirv,
features,
safe: _,
spec_descs,
slice_descs,
push_descs,
} = desc;
let spirv = decode_spirv(name, spirv)?;
let desc = super::KernelDesc {
name: name.into(),
spirv,
features,
threads: 0,
spec_descs,
slice_descs,
push_descs,
};
Ok(Self {
id: name.as_ptr() as usize,
desc: desc.into(),
spec_consts: Vec::new(),
threads: None,
})
}
pub fn with_threads(self, threads: u32) -> Self {
Self {
threads: Some(threads),
..self
}
}
pub fn specialize(self, spec_consts: &[ScalarElem]) -> Self {
debug_assert_eq!(spec_consts.len(), self.desc.spec_descs.len());
#[cfg(debug_assertions)]
for (spec_const, spec_desc) in
spec_consts.iter().copied().zip(self.desc.spec_descs.iter())
{
assert_eq!(spec_const.scalar_type(), spec_desc.scalar_type);
}
Self {
spec_consts: spec_consts.to_vec(),
..self
}
}
pub fn build(&self, device: Device) -> Result<Kernel> {
match device.inner() {
DeviceInner::Host => {
bail!("Kernel `{}` expected device, found host!", self.desc.name);
}
#[cfg(feature = "device")]
DeviceInner::Device(device) => {
let desc = &self.desc;
let name = &desc.name;
let features = desc.features;
let info = device.info();
let device_features = info.features();
if !device_features.contains(features) {
bail!("Kernel {name} requires {features:?}, {device:?} has {device_features:?}!");
}
let threads = self.threads.unwrap_or(info.default_threads());
let max_threads = info.max_threads();
if threads > max_threads {
bail!("Kernel {name} threads {threads} is greater than max_threads {max_threads}!");
}
let spec_bytes = self
.spec_consts
.iter()
.flat_map(|x| x.as_bytes())
.copied()
.chain(threads.to_ne_bytes())
.collect();
let key = KernelKey {
id: self.id,
spec_bytes,
};
let debug_printf = info.debug_printf();
let inner = RawKernel::cached(device.clone(), key, || {
desc.specialize(threads, &self.spec_consts, debug_printf)
.map(Arc::new)
})?;
Ok(Kernel {
inner,
threads,
groups: None,
})
}
}
}
pub fn features(&self) -> Features {
self.desc.features
}
}
pub enum WithGroups<const G: bool> {}
#[derive(Clone)]
pub struct Kernel {
#[cfg(feature = "device")]
inner: RawKernel,
threads: u32,
#[cfg(feature = "device")]
groups: Option<u32>,
}
impl Kernel {
pub fn threads(&self) -> u32 {
self.threads
}
pub fn with_global_threads(self, global_threads: u32) -> Self {
#[cfg(feature = "device")]
{
let desc = &self.inner.desc();
let threads = desc.threads;
let groups = global_threads / threads + u32::from(global_threads % threads != 0);
self.with_groups(groups)
}
#[cfg(not(feature = "device"))]
{
let _ = global_threads;
unreachable!()
}
}
pub fn with_groups(self, groups: u32) -> Self {
#[cfg(feature = "device")]
{
Self {
groups: Some(groups),
..self
}
}
#[cfg(not(feature = "device"))]
{
let _ = groups;
unreachable!()
}
}
pub unsafe fn dispatch(
&self,
slices: &[KernelSliceArg],
push_consts: &[ScalarElem],
) -> Result<()> {
#[cfg(feature = "device")]
{
let desc = &self.inner.desc();
let kernel_name = &desc.name;
let mut buffers = Vec::with_capacity(desc.slice_descs.len());
let mut items: Option<u32> = None;
let device = self.inner.device();
let mut push_bytes = Vec::with_capacity(desc.push_consts_range() as usize);
debug_assert_eq!(push_consts.len(), desc.push_descs.len());
for (push, push_desc) in push_consts.iter().zip(desc.push_descs.iter()) {
debug_assert_eq!(push.scalar_type(), push_desc.scalar_type);
debug_assert_eq!(push_bytes.len() % push.scalar_type().size(), 0);
push_bytes.extend_from_slice(push.as_bytes());
}
while push_bytes.len() % 4 != 0 {
push_bytes.push(0);
}
for (slice, slice_desc) in slices.iter().zip(desc.slice_descs.iter()) {
debug_assert_eq!(slice.scalar_type(), slice_desc.scalar_type);
debug_assert!(!slice_desc.mutable || slice.mutable());
let slice_name = &slice_desc.name;
if slice.len() == 0 {
bail!("Kernel `{kernel_name}`.`{slice_name}` is empty!");
}
let buffer = if let Some(buffer) = slice.device_buffer() {
buffer
} else {
bail!("Kernel `{kernel_name}`.`{slice_name}` expected device, found host!");
};
let buffer_device = buffer.device();
if device != buffer_device {
bail!(
"Kernel `{kernel_name}`.`{slice_name}`, expected `{device:?}`, found {buffer_device:?}!"
);
}
buffers.push(buffer.clone());
if slice_desc.item {
items.replace(if let Some(items) = items {
items.min(slice.len() as u32)
} else {
slice.len() as u32
});
}
let width = slice_desc.scalar_type.size();
let offset = buffer.offset() / width;
let len = buffer.len() / width;
push_bytes.extend_from_slice(&offset.to_u32().unwrap().to_ne_bytes());
push_bytes.extend_from_slice(&len.to_u32().unwrap().to_ne_bytes());
}
let info = self.inner.device().info().clone();
let max_groups = info.max_groups();
let groups = if let Some(groups) = self.groups {
if groups > max_groups {
bail!("Kernel `{kernel_name}` groups {groups} is greater than max_groups {max_groups}!");
}
groups
} else if let Some(items) = items {
let threads = self.threads;
let groups = items / threads + u32::from(items % threads != 0);
groups.min(max_groups)
} else {
unreachable!("groups not provided!")
};
let debug_printf_panic = if info.debug_printf() {
Some(Arc::new(AtomicBool::default()))
} else {
None
};
unsafe {
self.inner.dispatch(
groups,
&buffers,
push_bytes,
debug_printf_panic.clone(),
)?;
}
if let Some(debug_printf_panic) = debug_printf_panic {
device.wait()?;
while Arc::strong_count(&debug_printf_panic) > 1 {
std::thread::yield_now();
}
if debug_printf_panic.load(Ordering::SeqCst) {
bail!("Kernel `{kernel_name}` panicked!");
}
}
Ok(())
}
#[cfg(not(feature = "device"))]
{
let _ = (slices, push_consts);
unreachable!()
}
}
pub fn features(&self) -> Features {
#[cfg(feature = "device")]
{
return self.inner.desc().features;
}
#[cfg(not(feature = "device"))]
{
unreachable!()
}
}
}
#[doc(hidden)]
pub enum KernelSliceArg<'a> {
Slice(ScalarSlice<'a>),
SliceMut(ScalarSliceMut<'a>),
}
#[cfg(feature = "device")]
impl KernelSliceArg<'_> {
fn scalar_type(&self) -> ScalarType {
match self {
Self::Slice(x) => x.scalar_type(),
Self::SliceMut(x) => x.scalar_type(),
}
}
fn mutable(&self) -> bool {
match self {
Self::Slice(_) => false,
Self::SliceMut(_) => true,
}
}
fn device_buffer(&self) -> Option<&DeviceBuffer> {
match self {
Self::Slice(x) => x.device_buffer(),
Self::SliceMut(x) => x.device_buffer_mut(),
}
}
fn len(&self) -> usize {
match self {
Self::Slice(x) => x.len(),
Self::SliceMut(x) => x.len(),
}
}
}
impl<'a, T: Scalar> From<Slice<'a, T>> for KernelSliceArg<'a> {
fn from(slice: Slice<'a, T>) -> Self {
Self::Slice(slice.into())
}
}
impl<'a, T: Scalar> From<SliceMut<'a, T>> for KernelSliceArg<'a> {
fn from(slice: SliceMut<'a, T>) -> Self {
Self::SliceMut(slice.into())
}
}
}
pub(crate) use __private::{PushDesc, SliceDesc, SpecDesc};