use proc_macro::TokenStream;
use quote::{ToTokens, quote};
mod cpu;
mod cuda;
pub(super) struct ShaderBinding {
pub name: syn::Ident,
pub descriptor_set: u32,
pub binding: u32,
pub is_uniform: bool,
pub is_mutable: bool,
pub element_type: syn::Type,
pub cfg_attrs: Vec<syn::Attribute>,
}
pub(super) struct PushConstantBinding {
pub name: syn::Ident,
pub ty: syn::Type,
pub cfg_attrs: Vec<syn::Attribute>,
}
#[derive(Clone, Copy, PartialEq)]
pub(super) enum BuiltinKind {
GlobalInvocationId,
LocalInvocationId,
WorkgroupId,
NumWorkgroups,
LocalInvocationIndex,
SubgroupId,
SubgroupLocalInvocationId,
Other,
}
enum SpirvAttrKind {
Binding(bool, u32, u32),
PushConstant,
Builtin(BuiltinKind),
Workgroup,
}
pub(super) struct OriginalParam {
pub name: syn::Ident,
pub kind: OriginalParamKind,
pub ty: syn::Type,
pub cfg_attrs: Vec<syn::Attribute>,
}
pub(super) enum OriginalParamKind {
Builtin(BuiltinKind),
Binding { is_uniform: bool, is_mutable: bool },
PushConstant,
Workgroup,
}
struct ExtractedType {
element_type: syn::Type,
is_mutable: bool,
}
fn parse_spirv_attr(attr: &syn::Attribute) -> Option<SpirvAttrKind> {
if !attr.path().is_ident("spirv") {
return None;
}
let mut is_uniform = false;
let mut is_storage = false;
let mut is_push_constant = false;
let mut builtin_kind: Option<BuiltinKind> = None;
let mut is_workgroup = false;
let mut descriptor_set: Option<u32> = None;
let mut binding: Option<u32> = None;
let _ = attr.parse_nested_meta(|meta| {
let ident_str = meta.path.get_ident().map(|i| i.to_string());
match ident_str.as_deref() {
Some("uniform") => {
is_uniform = true;
}
Some("storage_buffer") => {
is_storage = true;
}
Some("push_constant") => {
is_push_constant = true;
}
Some("descriptor_set") => {
let value: syn::LitInt = meta.value()?.parse()?;
descriptor_set = Some(value.base10_parse()?);
}
Some("binding") => {
let value: syn::LitInt = meta.value()?.parse()?;
binding = Some(value.base10_parse()?);
}
Some("global_invocation_id") => {
builtin_kind = Some(BuiltinKind::GlobalInvocationId);
}
Some("local_invocation_id") => {
builtin_kind = Some(BuiltinKind::LocalInvocationId);
}
Some("workgroup_id") => {
builtin_kind = Some(BuiltinKind::WorkgroupId);
}
Some("num_workgroups") => {
builtin_kind = Some(BuiltinKind::NumWorkgroups);
}
Some("local_invocation_index") => {
builtin_kind = Some(BuiltinKind::LocalInvocationIndex);
}
Some("subgroup_id") => {
builtin_kind = Some(BuiltinKind::SubgroupId);
}
Some("subgroup_local_invocation_id") => {
builtin_kind = Some(BuiltinKind::SubgroupLocalInvocationId);
}
Some("workgroup") => {
is_workgroup = true;
}
Some("vertex_index") | Some("instance_index") | Some("position") => {
builtin_kind = Some(BuiltinKind::Other);
}
Some("compute") => {
}
_ => {
}
}
Ok(())
});
if is_push_constant {
Some(SpirvAttrKind::PushConstant)
} else if (is_uniform || is_storage) && binding.is_some() {
Some(SpirvAttrKind::Binding(
is_uniform,
descriptor_set.unwrap_or(0),
binding.unwrap(),
))
} else if is_workgroup {
Some(SpirvAttrKind::Workgroup)
} else {
builtin_kind.map(SpirvAttrKind::Builtin)
}
}
fn extract_cfg_attrs(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg") || attr.path().is_ident("cfg_attr"))
.cloned()
.collect()
}
fn extract_element_type(ty: &syn::Type) -> Option<ExtractedType> {
if let syn::Type::Reference(ref_type) = ty {
let is_mutable = ref_type.mutability.is_some();
let inner = &*ref_type.elem;
let element_type = if let syn::Type::Slice(slice) = inner {
*slice.elem.clone()
} else {
inner.clone()
};
return Some(ExtractedType {
element_type,
is_mutable,
});
}
None
}
fn parse_workgroup_size(attr: &syn::Attribute) -> Option<[u32; 3]> {
if !attr.path().is_ident("spirv") {
return None;
}
let mut workgroup_size: Option<[u32; 3]> = None;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("compute") {
meta.parse_nested_meta(|inner| {
if inner.path.is_ident("threads") {
let content;
syn::parenthesized!(content in inner.input);
let mut dims = [1u32, 1, 1];
let x: syn::LitInt = content.parse()?;
dims[0] = x.base10_parse()?;
if content.peek(syn::Token![,]) {
let _: syn::Token![,] = content.parse()?;
let y: syn::LitInt = content.parse()?;
dims[1] = y.base10_parse()?;
if content.peek(syn::Token![,]) {
let _: syn::Token![,] = content.parse()?;
let z: syn::LitInt = content.parse()?;
dims[2] = z.base10_parse()?;
}
}
workgroup_size = Some(dims);
}
Ok(())
})?;
}
Ok(())
});
workgroup_size
}
fn snake_to_pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect()
}
pub(super) fn is_slice_reference(ty: &syn::Type) -> bool {
if let syn::Type::Reference(ref_type) = ty {
matches!(&*ref_type.elem, syn::Type::Slice(_))
} else {
false
}
}
pub(crate) fn spirv_bindgen(attr: TokenStream, item: TokenStream) -> TokenStream {
let func = syn::parse_macro_input!(item as syn::ItemFn);
let mut spirv_passthrough = false;
let struct_name: syn::Ident = if attr.is_empty() {
let func_name = func.sig.ident.to_string();
let pascal_name = snake_to_pascal_case(&func_name);
syn::Ident::new(&pascal_name, func.sig.ident.span())
} else {
let args = syn::parse_macro_input!(attr with syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated);
let mut name = None;
for ident in &args {
if ident == "spirv_passthrough" {
spirv_passthrough = true;
} else {
if name.is_some() {
return syn::Error::new_spanned(
ident,
"Multiple struct names specified in #[spirv_bindgen] attribute",
)
.to_compile_error()
.into();
}
name = Some(ident.clone());
}
}
name.unwrap_or_else(|| {
let func_name = func.sig.ident.to_string();
let pascal_name = snake_to_pascal_case(&func_name);
syn::Ident::new(&pascal_name, func.sig.ident.span())
})
};
let workgroup_size = func
.attrs
.iter()
.find_map(parse_workgroup_size)
.unwrap_or([1, 1, 1]);
let mut bindings: Vec<ShaderBinding> = vec![];
let mut push_constants: Vec<PushConstantBinding> = vec![];
let mut original_params: Vec<OriginalParam> = vec![];
for param in &func.sig.inputs {
if let syn::FnArg::Typed(pat_type) = param {
let param_name = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
pat_ident.ident.clone()
} else {
continue;
};
let cfg_attrs = extract_cfg_attrs(&pat_type.attrs);
for attr in &pat_type.attrs {
if let Some(kind) = parse_spirv_attr(attr) {
match kind {
SpirvAttrKind::Binding(is_uniform, descriptor_set, binding_index) => {
if let Some(extracted) = extract_element_type(&pat_type.ty) {
bindings.push(ShaderBinding {
name: param_name.clone(),
descriptor_set,
binding: binding_index,
is_uniform,
is_mutable: extracted.is_mutable,
element_type: extracted.element_type,
cfg_attrs: cfg_attrs.clone(),
});
original_params.push(OriginalParam {
name: param_name.clone(),
kind: OriginalParamKind::Binding {
is_uniform,
is_mutable: extracted.is_mutable,
},
ty: (*pat_type.ty).clone(),
cfg_attrs: cfg_attrs.clone(),
});
}
}
SpirvAttrKind::PushConstant => {
if let Some(extracted) = extract_element_type(&pat_type.ty) {
push_constants.push(PushConstantBinding {
name: param_name.clone(),
ty: extracted.element_type,
cfg_attrs: cfg_attrs.clone(),
});
original_params.push(OriginalParam {
name: param_name.clone(),
kind: OriginalParamKind::PushConstant,
ty: (*pat_type.ty).clone(),
cfg_attrs: cfg_attrs.clone(),
});
}
}
SpirvAttrKind::Builtin(builtin_kind) => {
original_params.push(OriginalParam {
name: param_name.clone(),
kind: OriginalParamKind::Builtin(builtin_kind),
ty: (*pat_type.ty).clone(),
cfg_attrs: cfg_attrs.clone(),
});
}
SpirvAttrKind::Workgroup => {
original_params.push(OriginalParam {
name: param_name.clone(),
kind: OriginalParamKind::Workgroup,
ty: (*pat_type.ty).clone(),
cfg_attrs: cfg_attrs.clone(),
});
}
}
break;
}
}
}
}
bindings.sort_by_key(|b| (b.descriptor_set, b.binding));
let binding_fields: Vec<proc_macro2::TokenStream> = bindings
.iter()
.map(|b| {
let name = &b.name;
let set = b.descriptor_set;
let index = b.binding;
let elem_ty = &b.element_type;
let cfg_attrs = &b.cfg_attrs;
let attr = if b.is_uniform {
quote! { #[uniform(set = #set, index = #index)] }
} else {
quote! { #[storage(set = #set, index = #index)] }
};
if b.is_mutable {
quote! {
#(#cfg_attrs)*
#attr
pub #name: khal::backend::GpuBufferSliceMut<'a, #elem_ty>,
}
} else {
quote! {
#(#cfg_attrs)*
#attr
pub #name: khal::backend::GpuBufferSlice<'a, #elem_ty>,
}
}
})
.collect();
let push_constant_fields: Vec<proc_macro2::TokenStream> = push_constants
.iter()
.map(|pc| {
let name = &pc.name;
let ty = &pc.ty;
let cfg_attrs = &pc.cfg_attrs;
quote! {
#(#cfg_attrs)*
#[push_constant]
pub #name: #ty,
}
})
.collect();
let wg_x = workgroup_size[0];
let wg_y = workgroup_size[1];
let wg_z = workgroup_size[2];
let args_struct_name = syn::Ident::new(&format!("{}Args", struct_name), struct_name.span());
let args_doc = format!(
"Arguments the [`{}`] GPU kernel build and pass to its internal `GpuFunction`.",
struct_name
);
let args_struct_def = quote! {
#[doc = #args_doc]
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[derive(khal::ShaderArgs)]
#[workgroup_size(#wg_x, #wg_y, #wg_z)]
pub struct #args_struct_name<'a> {
#(#binding_fields)*
#(#push_constant_fields)*
}
};
let call_params: Vec<proc_macro2::TokenStream> = bindings
.iter()
.map(|b| {
let name = &b.name;
let elem_ty = &b.element_type;
let cfg_attrs = &b.cfg_attrs;
if b.is_mutable {
quote! {
#(#cfg_attrs)*
#name: &'a mut (impl khal::AsGpuSliceMut<#elem_ty>),
}
} else {
quote! {
#(#cfg_attrs)*
#name: &'a (impl khal::AsGpuSlice<#elem_ty>),
}
}
})
.collect();
let call_push_constant_params: Vec<proc_macro2::TokenStream> = push_constants
.iter()
.map(|pc| {
let name = &pc.name;
let ty = &pc.ty;
let cfg_attrs = &pc.cfg_attrs;
quote! {
#(#cfg_attrs)*
#name: #ty,
}
})
.collect();
let args_construction: Vec<proc_macro2::TokenStream> = bindings
.iter()
.map(|b| {
let name = &b.name;
let cfg_attrs = &b.cfg_attrs;
if b.is_mutable {
quote! {
#(#cfg_attrs)*
#name: #name.as_gpu_slice_mut(),
}
} else {
quote! {
#(#cfg_attrs)*
#name: #name.as_gpu_slice(),
}
}
})
.collect();
let push_constant_construction: Vec<proc_macro2::TokenStream> = push_constants
.iter()
.map(|pc| {
let name = &pc.name;
let cfg_attrs = &pc.cfg_attrs;
quote! {
#(#cfg_attrs)*
#name,
}
})
.collect();
let func_ident = &func.sig.ident;
let cpu_dispatch_block =
cpu::generate_cpu_dispatch_block(&original_params, workgroup_size, func_ident);
let func_name_str = func.sig.ident.to_string();
let hash = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
func.to_token_stream().to_string().hash(&mut hasher);
format!("{:016x}", hasher.finish())
};
let cuda_entry_name = format!("{}_cuda_entry_{}", func_name_str, hash);
let cuda_entry_ident = syn::Ident::new(&cuda_entry_name, func.sig.ident.span());
let cuda_entry_block = cuda::generate_cuda_entry_block(
&original_params,
&bindings,
workgroup_size,
func_ident,
&cuda_entry_ident,
);
let cuda_entry_name_str = &cuda_entry_name;
let doc_attrs: Vec<_> = func
.attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
.collect();
let wrapper_def = quote! {
#(#doc_attrs)*
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
pub struct #struct_name {
pub function: khal::backend::GpuFunction<#args_struct_name<'static>>,
}
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
impl #struct_name {
pub const ENTRY_POINT: &'static str = #func_name_str;
pub const CUDA_ENTRY_POINT: &'static str = #cuda_entry_name_str;
pub const MODULE_PATH: &'static str = module_path!();
pub const SPIRV_PASSTHROUGH: bool = #spirv_passthrough;
#[doc(hidden)]
#[cfg(feature = "cpu")]
pub const __ERROR__SHADER_CRATE_IS_MISSING_FEATURE_NAMED____CPU: () = ();
#[doc(hidden)]
#[cfg(feature = "cpu-parallel")]
pub const __ERROR__SHADER_CRATE_IS_MISSING_FEATURE_NAMED____CPU_PARALLEL: () = ();
#[doc(hidden)]
#[cfg(feature = "cuda")]
pub const __ERROR__SHADER_CRATE_IS_MISSING_FEATURE_NAMED____CUDA: () = ();
pub fn from_dir(
backend: &khal::backend::GpuBackend,
dir: &khal::re_exports::include_dir::Dir<'static>,
) -> Result<Self, khal::backend::GpuBackendError> {
match backend.target() {
khal::backend::CompileTarget::Ptx => {
Self::from_dir_ptx(backend, dir, Self::CUDA_ENTRY_POINT)
}
_ => {
Self::from_dir_with_entry_point(backend, dir, Self::ENTRY_POINT)
}
}
}
pub fn from_dir_ptx(
backend: &khal::backend::GpuBackend,
dir: &khal::re_exports::include_dir::Dir<'static>,
entry_point: &str,
) -> Result<Self, khal::backend::GpuBackendError> {
let file = dir.get_file("shaders.ptx")
.unwrap_or_else(|| panic!("PTX file 'shaders.ptx' not found in embedded dir"));
Self::from_bytes(backend, file.contents(), entry_point)
}
pub fn from_dir_with_entry_point(
backend: &khal::backend::GpuBackend,
dir: &khal::re_exports::include_dir::Dir<'static>,
entry_point: &str,
) -> Result<Self, khal::backend::GpuBackendError> {
let module = Self::MODULE_PATH.split_once("::")
.map(|(_, rest)| rest)
.unwrap_or("");
let filename = if module.is_empty() {
format!("{}.spv", entry_point)
} else {
format!("{}-{}.spv", module.replace("::", "-"), entry_point)
};
let file = dir.get_file(&filename)
.unwrap_or_else(|| panic!("SPIR-V file not found in embedded dir: {}", filename));
let full_entry = if module.is_empty() {
entry_point.to_string()
} else {
format!("{}::{}", module, entry_point)
};
#[cfg(target_arch = "wasm32")]
let full_entry = full_entry.replace("::", "_");
Self::from_bytes(backend, file.contents(), &full_entry)
}
pub fn from_bytes(
backend: &khal::backend::GpuBackend,
bytes: &[u8],
entry_point: &str,
) -> Result<Self, khal::backend::GpuBackendError> {
Ok(Self {
function: khal::backend::GpuFunction::from_bytes_with_passthrough(
backend, bytes, entry_point, Self::SPIRV_PASSTHROUGH,
)?,
})
}
pub fn call<'a>(
&self,
__pass: &mut khal::backend::GpuPass,
__dispatch_grid: impl Into<khal::backend::DispatchGrid<'a, khal::backend::GpuBackend>>,
#(#call_params)*
#(#call_push_constant_params)*
) -> Result<(), khal::backend::GpuBackendError> {
use khal::AsGpuSlice as _;
use khal::AsGpuSliceMut as _;
let __dispatch_grid = __dispatch_grid.into();
#cpu_dispatch_block
let args = #args_struct_name {
#(#args_construction)*
#(#push_constant_construction)*
};
self.function.launch_grid(__pass, &args, __dispatch_grid)
}
}
};
let output = quote! {
#func
#args_struct_def
#wrapper_def
#cuda_entry_block
};
output.into()
}